Skip to main content

A library providing the tools to solve complex environments in Minigrid using LgTS

Project description

master-minigrid-agent

A python module for training an RL agent on any Minigrid environment using LgTS.

Installation

pip install master-agent

Description

A python library providing tools for an all-in-one solution to the GTRI Research Paper LgTS: Dynamic Task Sampling using LLM-generated sub-goals for Reinforcement Learning Agents.

Includes

  • Prebuilt Minigrid Environments
  • LLM-based providers for Subtask generation + evaluation
  • Teacher Student Algorithm implementation using PPO policies
  • Automatic Minigrid Tileset Identification

Methodology

Methodology is based off the GTRI Research Paper.

Brief Overview

llm.gen_2d_array() -> create DAG -> use DAG to train set of policies using Teacher Student algorithm.

Prebuilt Minigrid Environments

master_agent.llm provides 7 customized environments based on the research paper and designed for evaluation of RL success on specific obstacles.

  • Complex Env (Copy of the example environment via GTRI Research Paper)
  • KeyOne Env
  • KeyTwo Env
  • LavaIsWall Env
  • + No Lava Variants

Subtask Generation

Generate 2D Array of Paths using the SubtasksGenerator class

Example of 2D Array

[
    ['At(OutsideRoom)', 'Holding(Key1)', 'Unlocked(Door1)', 'At(Green_Goal)'], 
    ['At(OutsideRoom)', 'Holding(Key2)', 'Unlocked(Door2)', 'At(Green_Goal)'], 
    ['At(OutsideRoom)', 'Holding(Key3)', 'At(Green_Goal)'], 
    ['At(OutsideRoom)', 'At(Wall)', 'At(Green_Goal)'],
]

Generation + Validation

from llm.client import LlmClient
from llm.subtasks import SubtasksGenerator, validate_subtask_paths

# Create llm_client
llm_client = LlmClient(llm_api_key, llm_model, llm_base_url)
# Create subtasks generator
subtasks_gen = SubtasksGenerator(llm_client)
objects = ["Key1", "Key2", "Key3", "Door1", "Door2"]
# Genereate paths (2D Array Output)
subtask_paths = subtasks_gen.gen_subtask_paths(objects)
# Validate paths
try:
    validate_subtask_paths(subtask_paths, objects)
except Exception as e:
    print(f"Validation failed: {e}")

Teacher Student Training

Use the generated 2D Array of Paths to train an RL Agent to master the environment with the TeacherStudent class.

Create Teacher Student Algorithm

from master_agent.rl.teacher_student import TeacherStudent

ts = TeacherStudent(subtask_paths)
print("Training the model...")
ts.train()
print("Training complete.")

print("Demonstrating learned path...")
ts.demo_learned_path()

VLM Identification

This project also automates the process of Object Detection within the Minigrid environment. Currently the master-agent package has the TilesetIdentifier class to aid in this process. We recommend using a gpt based model such as openai/gpt-4o-mini.

Unidentified Tileset Identification

import os
from dotenv import load_dotenv
from .identify import TileIdentifier
from .client import LlmClient
from envs.complexEnv import ComplexEnv

llm_client = LlmClient(llm_api_key, llm_model, llm_base_url)
# Create tileset identifier
identifier = TileIdentifier(llm_client)
env = ComplexEnv(render_mode='rgb_array', highlight=False) # Removing highlight for accurate tileset representation
env.reset()
# Generate unidentified tileset
unidentified_tileset = identifier.parse_tileset(env.render())
# Validate tileset
identifier.validate_unidentified_tileset(unidentified_tileset, env)

Display Tileset

import matplotlib.pyplot as plt

unique_tiles = np.unique(unidentified_tileset.reshape(-1, 32, 32, 3), axis=0)
print(f"Number of unique tiles: {len(unique_tiles)}")

# Create a mapping of tile IDs to their positions in the grid
tile_positions = {}
for tile_id, tile in enumerate(unique_tiles):
    tile_positions[tile_id] = []
    for row_idx, row in enumerate(unidentified_tileset):
        for col_idx, grid_tile in enumerate(row):
            if np.array_equal(grid_tile, tile):
                tile_positions[tile_id].append((row_idx, col_idx))

# Create a figure with subplots for each unique tile
num_tiles = len(unique_tiles)
num_cols = 5
num_rows = (num_tiles + num_cols - 1) // num_cols

fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 3 * num_rows))

# Flatten the axs array for easier indexing
axs = axs.flatten()

# Plot each unique tile in a separate subplot with its ID
for i, tile in enumerate(unique_tiles):
    axs[i].imshow(tile)
    axs[i].set_title(f"Tile ID: {i}")
    axs[i].set_xticks([])
    axs[i].set_yticks([])

# Adjust spacing between subplots
plt.subplots_adjust(wspace=0.1, hspace=0.1)

# Show the figure
plt.show()

# Print the mapping of tile IDs to their positions in the grid
for tile_id, positions in tile_positions.items():
    print(f"Tile ID: {tile_id}")
    print(f"  Tile ID: {tile_id}, Coordinate Positions: {[f'({col+1}, {(unidentified_tileset.shape[0]-row)})' for row, col in positions]}")

Identify Tileset

import os
from dotenv import load_dotenv
from .identify import TileIdentifier
from .client import LlmClient
from envs.complexEnv import ComplexEnv

llm_client = LlmClient(llm_api_key, llm_model, llm_base_url)
# Create tileset identifier
identifier = TileIdentifier(llm_client)
env = ComplexEnv(render_mode='rgb_array', highlight=False) # Removing highlight for accurate tileset representation
env.reset()
# Generate unidentified tileset
unidentified_tileset = identifier.parse_tileset(env.render())
# Validate tileset
identifier.validate_unidentified_tileset(unidentified_tileset, env)
# Identify tileset
tileset = identifier.identify_tiles(unidentified_tileset)

for tile in tileset.tiles:
    print(tile.name, tile.world_obj, tile.positions)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

master_agent-0.0.67.tar.gz (28.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

master_agent-0.0.67-py3-none-any.whl (41.9 kB view details)

Uploaded Python 3

File details

Details for the file master_agent-0.0.67.tar.gz.

File metadata

  • Download URL: master_agent-0.0.67.tar.gz
  • Upload date:
  • Size: 28.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for master_agent-0.0.67.tar.gz
Algorithm Hash digest
SHA256 e9eceaf1b7bd84de6932bf13d0e9d4ec0e5dc4b4c30c51d4b2d5fd0da0607288
MD5 53fc798837deb8ab6c9948b5db2be247
BLAKE2b-256 e2f0c744670d88abf7781f3ce8b9f818f8d747b58f7f6c0dc4fa7f32468b0635

See more details on using hashes here.

File details

Details for the file master_agent-0.0.67-py3-none-any.whl.

File metadata

  • Download URL: master_agent-0.0.67-py3-none-any.whl
  • Upload date:
  • Size: 41.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for master_agent-0.0.67-py3-none-any.whl
Algorithm Hash digest
SHA256 21e01facdfe6e5378303ae7dce40f7990fb74ecb8d57cf26d75b0ded679d276b
MD5 184e50829e3d66feb51c27bf2840936a
BLAKE2b-256 2f296cbe584365573d8e09aa38da09db7782566aa1eb7f432bb3a5a514f9df65

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page