A flexible and interactive grid world environment for reinforcement learning
Project description
GridWorldPy
一个灵活且交互式的网格世界环境,专为强化学习实验和教育目的设计。
✨ 特性
- 🎯 灵活的环境配置:支持自定义网格大小、奖励函数和终止条件
- 🎨 实时可视化:基于tkinter的直观图形界面,支持实时渲染
- ⚡ 交互控制:支持键盘控制和自动执行模式
- 🎮 策略可视化:直观显示动作概率分布和状态价值
- 🚫 状态禁用:支持禁用特定状态创建障碍物
- 📊 颜色映射:根据奖励值和状态价值自动调整颜色显示
🚀 安装
使用pip从PyPI安装:
pip install gridworldpy
或者从源码安装:
git clone https://github.com/hitlic/gridworldpy.git
cd gridworldpy
pip install -e .
📖 快速开始
基本使用
from gridworldpy import GridWorldEnv
# 创建一个5x5的网格世界
env = GridWorldEnv(grid_size=(5, 5), keyboard_control=True)
# 设置随机奖励
env.set_rewards('random')
# 设置随机策略
env.set_policy('random')
# 渲染环境
env.render()
# 执行一个动作
obs, reward, done, is_effective, info = env.step(1) # 向上移动
自定义配置
from gridworldpy import GridWorldEnv
# 创建带自定义配置的环境
env = GridWorldEnv(
grid_size=(4, 4),
keyboard_control=True, # 启用键盘控制
enable_keep= True, # 启用停留动作
target_state=(3, 3), # 目标位置为(3,3)
cell_size=150, # 每个格子150像素
circle_radius=40, # 奖励圆圈半径40像素
)
# 设置特定位置的奖励
rewards = [
((0, 0), -0.1), # 起始位置小负奖励
((1, 1), -1.0), # 陷阱:大负奖励
((3, 3), 1.0), # 目标:大正奖励
]
env.set_rewards(rewards)
# 设置特定的策略
policy = [
((0, 0), [0.1, 0.2, 0.2, 0.2, 0.3]), # 位置(0,0)的动作概率
((1, 0), [0.0, 0.0, 0.5, 0.0, 0.5]), # 位置(1,0)的动作概率
]
env.set_policy(policy)
# 禁用某些状态(创建障碍物)
env.set_disable_states([(1, 2), (2, 1)])
# 渲染环境
env.render()
带状态价值的可视化
# 创建状态价值列表
state_values = [
((0, 0), 0.1), ((0, 1), 0.2), ((0, 2), 0.3),
((1, 0), 0.4), ((1, 1), -0.5), ((1, 2), 0.6),
((2, 0), 0.7), ((2, 1), 0.8), ((2, 2), 0.9),
]
# 渲染时包含状态价值
env.render(state_values=state_values)
🎮 控制方式
键盘控制模式
当keyboard_control=True时,使用空格键控制执行:
- 按空格键执行下一步
- 窗口会显示当前步数和环境状态
- 程序会等待用户输入后继续
自动模式
当keyboard_control=False时,可以连续调用render()方法:
for step in range(10):
env.render()
action = np.random.randint(0, 5) # 随机选择动作
obs, reward, done, is_effective,info = env.step(action)
if done:
break
🎯 API 参考
GridWorldEnv
主要的环境类,提供网格世界的完整功能。
初始化参数
GridWorldEnv(
grid_size=(5, 5), # 网格大小 (行, 列)
enable_keep=False, # 是否允许停留当前状态
start_state=(0, 0), # 初始状态
render_mode="human", # 渲染模式
keyboard_control=True, # 是否启用键盘控制
max_steps=100000, # 最大步数
cell_size=130, # 每个格子的像素大小
circle_radius=35, # 奖励圆圈半径
font_size=16, # 字体大小
max_arrow_length=50, # 策略箭头最大长度
color_alpha=0.0, # 取值[0, 1),显示颜色对数值的灵敏度
show_cell_pos=True, # 在每个单元上显示位置坐标
show_step_num=False # 是否显示步数
)
主要方法
step(action): 执行动作,返回(next_state, reward, done, valid_action, info)reset():重置环境,返回初始状态render(state_values=None, policy_config=None, reward_config=None): 渲染环境set_rewards(reward_config): 设置奖励配置set_policy(policy_config): 设置策略配置set_disable_states(disabled_poses): 禁用指定状态close(): 关闭环境
动作空间
- 当
enable_keep=True时- 0: 停留在当前位置
- 1: 向上移动
- 2: 向下移动
- 3: 向左移动
- 4: 向右移动
- 当
enable_keep=False时- 0: 向上移动
- 1: 向下移动
- 2: 向左移动
- 3: 向右移动
🎨 可视化说明
颜色编码
- 蓝色: 负值(奖励或状态价值)
- 橙色: 正值(奖励或状态价值)
- 浅蓝色: 零值或中性值
- 黄色: 目标状态
- 灰色: 禁用状态
- 红色边框: 智能体当前位置
显示元素
- 圆圈: 显示奖励值和状态价值
- 箭头长短: 当前策略中转移至相邻状态的动作概率
- 圆心大小:停留在当前状态的概率
- 文本: 显示具体的数值
📝 示例
查看examples/目录获取更多使用示例:
basic_usage.py:基本使用示例policy_evaluation.py:蒙特卡罗法策略评估value_iteration.py:价值迭代寻找最优策略sarsa.py:SARSA算法实现q_learning.py:Q-Learning算法实现
📄 许可
本项目使用MIT许可证 - 查看 LICENSE 文件了解详情。
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gridworldpy-0.2.0.tar.gz
(31.8 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gridworldpy-0.2.0.tar.gz.
File metadata
- Download URL: gridworldpy-0.2.0.tar.gz
- Upload date:
- Size: 31.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ad3b80e994f191497a7deeae6df33a31114f1be3e9e84e8218983e7a8571ea27
|
|
| MD5 |
19908524d98493f954b08da495131a03
|
|
| BLAKE2b-256 |
88c39d4cb59b465b0a0e6ce5154e418152e8cb7802c6e0eafeb999aaba19bc8d
|
File details
Details for the file gridworldpy-0.2.0-py3-none-any.whl.
File metadata
- Download URL: gridworldpy-0.2.0-py3-none-any.whl
- Upload date:
- Size: 16.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
287c1dba1e09a9c30804b65110b0e933ca6382c53b59f92eed0e379af571718c
|
|
| MD5 |
b3b27d5c0eadf3498586c8958bc97961
|
|
| BLAKE2b-256 |
8dbeeb8d3350f786655bb997cfa5e7ec66e4339b47b9903768c9982bd3d0a86e
|