A flexible and interactive grid world environment for reinforcement learning
Project description
GridWorldPy
一个灵活且交互式的网格世界环境,专为强化学习实验和教育目的设计。
✨ 特性
- 🎯 灵活的环境配置:支持自定义网格大小、奖励函数和终止条件
- 🎨 实时可视化:基于tkinter的直观图形界面,支持实时渲染
- ⚡ 交互控制:支持键盘控制和自动执行模式
- 🎮 策略可视化:直观显示动作概率分布和状态价值
- 🚫 状态禁用:支持禁用特定状态创建障碍物
- 📊 颜色映射:根据奖励值和状态价值自动调整颜色显示
🚀 安装
使用pip从PyPI安装:
pip install gridworldpy
或者从源码安装:
git clone https://github.com/hitlic/gridworldpy.git
cd gridworldpy
pip install -e .
📖 快速开始
基本使用
from gridworldpy import GridWorldEnv
# 创建一个5x5的网格世界
env = GridWorldEnv(grid_size=(5, 5), keyboard_control=True)
# 设置随机奖励
env.set_rewards('random')
# 设置随机策略
env.set_policy('random')
# 渲染环境
env.render()
# 执行一个动作
obs, reward, done, is_effective, info = env.step(1) # 向上移动
自定义配置
from gridworldpy import GridWorldEnv
# 创建带自定义配置的环境
env = GridWorldEnv(
grid_size=(4, 4),
keyboard_control=True, # 启用键盘控制
terminal_condition=(3, 3), # 目标位置为(3,3)
cell_size=150, # 每个格子150像素
circle_radius=40, # 奖励圆圈半径40像素
)
# 设置特定位置的奖励
rewards = [
((0, 0), -0.1), # 起始位置小负奖励
((1, 1), -1.0), # 陷阱:大负奖励
((3, 3), 1.0), # 目标:大正奖励
]
env.set_rewards(rewards)
# 设置特定的策略
policy = [
((0, 0), [0.1, 0.2, 0.2, 0.2, 0.3]), # 位置(0,0)的动作概率
((1, 0), [0.0, 0.0, 0.5, 0.0, 0.5]), # 位置(1,0)的动作概率
]
env.set_policy(policy)
# 禁用某些状态(创建障碍物)
env.disable_states([(1, 2), (2, 1)])
# 渲染环境
env.render()
带状态价值的可视化
# 创建状态价值列表
state_values = [
((0, 0), 0.1), ((0, 1), 0.2), ((0, 2), 0.3),
((1, 0), 0.4), ((1, 1), -0.5), ((1, 2), 0.6),
((2, 0), 0.7), ((2, 1), 0.8), ((2, 2), 0.9),
]
# 渲染时包含状态价值
env.render(state_values=state_values)
🎮 控制方式
键盘控制模式
当keyboard_control=True时,使用空格键控制执行:
- 按空格键执行下一步
- 窗口会显示当前步数和环境状态
- 程序会等待用户输入后继续
自动模式
当keyboard_control=False时,可以连续调用render()方法:
for step in range(10):
env.render()
action = np.random.randint(0, 5) # 随机选择动作
obs, reward, done, is_effective,info = env.step(action)
if done:
break
🎯 API 参考
GridWorldEnv
主要的环境类,提供网格世界的完整功能。
初始化参数
GridWorldEnv(
grid_size=(5, 5), # 网格大小 (行, 列)
render_mode="human", # 渲染模式
keyboard_control=True, # 是否启用键盘控制
terminal_condition=None, # 终止条件
cell_size=130, # 每个格子的像素大小
circle_radius=35, # 奖励圆圈半径
font_size=16, # 字体大小
max_arrow_length=50, # 策略箭头最大长度
show_cell_pos=True, # 在每个单元上显示位置坐标
color_alpha=0.0 # 控制色彩敏感度
)
主要方法
step(action): 执行动作,返回(observation, reward, done, info)render(state_values=None, policy_config=None, reward_config=None): 渲染环境set_rewards(reward_config): 设置奖励配置set_policy(policy_config): 设置策略配置disable_states(disabled_poses): 禁用指定状态close(): 关闭环境
动作空间
- 0: 停留在当前位置
- 1: 向上移动
- 2: 向下移动
- 3: 向左移动
- 4: 向右移动
🎨 可视化说明
颜色编码
- 蓝色: 负值(奖励或状态价值)
- 橙色: 正值(奖励或状态价值)
- 浅蓝色: 零值或中性值
- 黄色: 目标状态
- 灰色: 禁用状态
- 红色边框: 智能体当前位置
显示元素
- 圆圈: 显示奖励值和状态价值
- 箭头长短: 当前策略中转移至相邻状态的动作概率
- 圆心大小:停留在当前状态的概率
- 文本: 显示具体的数值
📝 示例
查看examples/目录获取更多使用示例:
basic_usage.py:基本使用示例policy_evaluation.py:蒙特卡罗法策略评估value_iteration.py:价值迭代寻找最优策略
📄 许可
本项目使用MIT许可证 - 查看 LICENSE 文件了解详情。
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gridworldpy-0.1.2.tar.gz
(30.6 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gridworldpy-0.1.2.tar.gz.
File metadata
- Download URL: gridworldpy-0.1.2.tar.gz
- Upload date:
- Size: 30.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fd617e96a5e17a202c319f99e753ef8fb813f6372c20f626c031d913e149f5ff
|
|
| MD5 |
31c4e5e1216f5a5719584064549cabe8
|
|
| BLAKE2b-256 |
d4af153df2bbc0b7198050086bbf94866b45147ee971c0eca4a94718fc37a731
|
File details
Details for the file gridworldpy-0.1.2-py3-none-any.whl.
File metadata
- Download URL: gridworldpy-0.1.2-py3-none-any.whl
- Upload date:
- Size: 16.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
07e17c9ca8f515b4b479706af7e8ad5062a229149d08ac62315310cbd44828af
|
|
| MD5 |
634bb440b412da953f507a0cf52be186
|
|
| BLAKE2b-256 |
17568febe7bb720e6b0c06d40a59431f1e335138fecdafdf20cbad8d07e8d428
|