Skip to main content

A flexible and interactive grid world environment for reinforcement learning

Project description

GridWorldPy

PyPI version Python 3.8+ License: MIT

一个灵活且交互式的网格世界环境,专为强化学习实验和教育目的设计。

✨ 特性

  • 🎯 灵活的环境配置:支持自定义网格大小、奖励函数和终止条件
  • 🎨 实时可视化:基于tkinter的直观图形界面,支持实时渲染
  • 交互控制:支持键盘控制和自动执行模式
  • 🎮 策略可视化:直观显示动作概率分布和状态价值
  • 🚫 状态禁用:支持禁用特定状态创建障碍物
  • 📊 颜色映射:根据奖励值和状态价值自动调整颜色显示

🚀 安装

使用pip从PyPI安装:

pip install gridworldpy

或者从源码安装:

git clone https://github.com/hitlic/gridworldpy.git
cd gridworldpy
pip install -e .

📖 快速开始

基本使用

import numpy as np
from gridworldpy import GridWorldEnv

# 创建一个5x5的网格世界
env = GridWorldEnv(grid_size=(5, 5))

# 设置随机奖励
env.set_rewards('random')

# 设置随机策略
env.set_policy('random')

# 渲染环境
env.render()

# 执行一个动作
obs, reward, done, info = env.step(1)  # 向上移动

自定义配置

# 创建带自定义配置的环境
env = GridWorldEnv(
    grid_size=(4, 4),
    keyboard_control=True,  # 启用键盘控制
    terminal_condition=(3, 3),  # 目标位置为(3,3)
    cell_size=150,  # 每个格子150像素
    circle_radius=40,  # 奖励圆圈半径40像素
)

# 设置特定位置的奖励
rewards = [
    (0, 0, -0.1),  # 起始位置小负奖励
    (1, 1, -1.0),  # 陷阱:大负奖励
    (3, 3, 1.0),   # 目标:大正奖励
]
env.set_rewards(rewards)

# 设置特定的策略
policy = [
    (0, 0, [0.1, 0.2, 0.2, 0.2, 0.3]),  # 位置(0,0)的动作概率
    (1, 0, [0.0, 0.0, 0.5, 0.0, 0.5]),  # 位置(1,0)的动作概率
]
env.set_policy(policy)

# 禁用某些状态(创建障碍物)
env.disable_states([(1, 2), (2, 1)])

# 渲染环境
env.render()

带状态价值的可视化

# 创建状态价值列表
state_values = [
    (0, 0, 0.1), (0, 1, 0.2), (0, 2, 0.3),
    (1, 0, 0.4), (1, 1, -0.5), (1, 2, 0.6),
    (2, 0, 0.7), (2, 1, 0.8), (2, 2, 0.9),
]

# 渲染时包含状态价值
env.render(state_values=state_values)

🎮 控制方式

键盘控制模式

keyboard_control=True时,使用空格键控制执行:

  • 空格键执行下一步
  • 窗口会显示当前步数和环境状态
  • 程序会等待用户输入后继续

自动模式

keyboard_control=False时,可以连续调用render()方法:

for step in range(10):
    action = np.random.randint(0, 5)  # 随机选择动作
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        break

🎯 API 参考

GridWorldEnv

主要的环境类,提供网格世界的完整功能。

构造函数

GridWorldEnv(
    grid_size=(5, 5),           # 网格大小 (行, 列)
    render_mode="human",        # 渲染模式
    keyboard_control=True,      # 是否启用键盘控制
    terminal_condition=None,    # 终止条件
    cell_size=130,             # 每个格子的像素大小
    circle_radius=35,          # 奖励圆圈半径
    font_size=16,              # 字体大小
    max_arrow_length=50        # 策略箭头最大长度
)

主要方法

  • step(action): 执行动作,返回(observation, reward, done, info)
  • render(state_values=None, policy_config=None, reward_config=None): 渲染环境
  • set_rewards(reward_config): 设置奖励配置
  • set_policy(policy_config): 设置策略配置
  • disable_states(disabled_poses): 禁用指定状态
  • close(): 关闭环境

动作空间

  • 0: 停留在当前位置
  • 1: 向上移动
  • 2: 向下移动
  • 3: 向左移动
  • 4: 向右移动

🎨 可视化说明

颜色编码

  • 冰蓝色: 负值(奖励或状态价值)
  • 火橙色: 正值(奖励或状态价值)
  • 浅蓝色: 零值或中性值
  • 黄色: 目标状态
  • 灰色: 禁用状态
  • 红色边框: 智能体当前位置

显示元素

  • 圆圈: 显示奖励值和状态价值
  • 箭头: 显示策略的动作概率分布
  • 文本: 显示具体的数值

📝 示例

查看examples/目录获取更多使用示例:

  • basic_usage.py: 基本使用示例
  • custom_environment.py: 自定义环境配置
  • reinforcement_learning.py: 强化学习训练示例

📄 许可证

本项目使用MIT许可证 - 查看 LICENSE 文件了解详情。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gridworldpy-0.1.0.tar.gz (18.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gridworldpy-0.1.0-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file gridworldpy-0.1.0.tar.gz.

File metadata

  • Download URL: gridworldpy-0.1.0.tar.gz
  • Upload date:
  • Size: 18.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for gridworldpy-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8aecf59c92c2b233d7165457ca69ea6cee57d8870e78c6d9a6db88d73e67fd9f
MD5 1b8c90542c6dbe86c7629a7f92bfd7be
BLAKE2b-256 7440f4b4b363d6c6b7168852e001eaccac01249b1f154d7a6430eff7ca546c75

See more details on using hashes here.

File details

Details for the file gridworldpy-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: gridworldpy-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 11.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for gridworldpy-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 49ef732ac497d2d3f677e1071990d519448eaf93f612f71a571aa7420f62ee34
MD5 19f5b7375df305725834e37928112eed
BLAKE2b-256 d1d4fde1b1ae003dac5f8084c7f85e367674eae9a76e4dec6c664e4a03029bc3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page