Skip to main content

An RL agent training assistant

Project description

Crete

PyPI version

               _       
  ___ _ __ ___| |_ ___ 
 / __| '__/ _ \ __/ _ \
| (__| | |  __/ ||  __/
 \___|_|  \___|\__\___|
Reinforcement Learning CLI assistant.

Installation

Crete is now officially on PyPI!

pip install crete

Introduction

Crete is a CLI for RL training. Named after the island that Talos, the first AI in human mythology, walked upon.

It can be used as easily as:

python -m crete --help

Registering Agents, Environments and Wrappers

To start using Crete, you must register (at least) one agent. Crete already comes preloaded with all the OpenAI Gym(nasium) environments. You can register your own environments too, as long as they use the OpenAI Gym(nasium) API. More information on making enivironment can be found on the Gym(nasium) website.

Agents must inherit the crete.Agent class. For example;

 cool_rl_project
 +- crete_profile.yaml  <- config file, more on this later.
 +- module_a 
 |  +- agents
 |  |  +- __init__.py
 |  |  +- agent_1.py
 |  +- etc...
# module_a/agent_1.py

from crete import Agent, ExtraState, ProfileConfig


class DQNAgent(Agent):
    def __init__(self, obs, n_actions, device='cpu') -> None:
        super().__init__("DQN")
        ...

    def get_action(self, obs, extra_state: ExtraState = None) -> Tuple[int, ExtraState]:
        """Request an action."""
        ...

    def save(self) -> bytes:
        """Extract all policy data."""
        ...

    def load(self, agent_data: bytes):
        """Load policy data."""
        ...


def dqn_training_wrapper(
        env_factory: Callable[[int], gym.Env],  # Function that creates the environment.
        agent: DQNAgent,  # The agent itself.
        config: ProfileConfig,  # Object with config values.
        artifacts: Dict,  # Empty map for storing training information.
        save_callback  # Function for saving agents, for instance at peak rewards.
):
    """The DQN training procedure."""
    # Note that this function is outside the agent definition.
    ...

And then register that agent in the module's initialisation file. There is also register_env and register_wrapper which work the same way. Environments and wrappers should use the OpenAI Gym(nasium) API.

# module_a/__init__.py

from crete import register_agent
from .dqn_agent import DQNAgent, dqn_training_wrapper

register_agent(
    agent_id="DQN",
    agent_factory=lambda obs, n_actions, device: DQNAgent(obs, n_actions, device=device),
    training_wrapper=dqn_training_wrapper
)

Then, register that module with crete (this will create a .crete.yaml file);

python -m crete module add module_a.agents

Verify the module has been loaded correctly by running the following, which should list DQN.

python -m crete list agents

Training Agents

Previously we saw the training wrapper for the example DQN agent. It contained 5 parameters;

  • env_factory: Callable[[int], gym.Env]. This is the function that creates the environment.
    • A factory is given instead of the environment itself to allow fresh environments to be created for evaluation.
  • agent: Agent. The agent itself.
  • config: ProfileConfig. The configuration profile containing agent and training parameters.
  • artifacts: Dict. An empty dictionary for storing training artifacts.
    • Training artifacts are things like loss over time. These artifacts are saved in the concrete file.
  • save_callback: SaveCallback. Utility function for saving snapshots of the agent.
    • Used like save_callback(agent, artifacts, step_iteration, "autosave-name")

Agents can be trained using the train or batch command, which trains one or several 'profiles' respectively. A 'profile' is contained within a .yaml file, and may look like the following;

---
# crete_profile.yaml

defaults: # Contains default `config` values that can be overridden. 
  batch_size: 32
  init_epsilon: 1
  final_epsilon: 0.1
  eval_freq: 1000
  gather_freq: 50

final_dqn_map_0: # Arbitrary profile ID, can be anything.
  agent_id: DQN # The ID of the registered agent to use.
  env_id: CartPole-v1 # The ID of the environment to use.
  env_args: # Arbitrary environment arguments that will be passed to the constructor
    render_mode: human
  config: # Config values that can be accessed through the ProfileConfig object via the training wrapper.
    total_steps: 40000
    decay_steps: 38000
    replay_buffer_size: 10000
    batch_size: 128
    refresh_target_network_freq: 1000
    hidden_layers:
      - 64
      - 64

Training is engaged by;

python -m crete train crete_profile.yaml final_dqn_map_0 ; to train a specific profile or,
python -m crete batch crete_profile.yaml                 ; to train all profiles within a file.

Saving Agents and their Performances

Crete allows the entire training process and outcomes to be saved, for replaying, evaluating and comparing. All that is required is for the agent to implement the save and load methods, which serialise the agents weight data to bytes. Everything else is handled by Crete. One would usually use pickle to achieve this.

For example, with a PyTorch-based agent:

def save(self) -> bytes:
    data = {
        "data": self.net.state_dict(),
        "layers": self.net.hidden_layers
    }

    return pickle.dumps(data)


def load(self, agent_data: bytes):
    agent_dict = pickle.loads(agent_data)
    data, layers = itemgetter("data", "layers")(agent_dict)

    self.net.set_hidden_layers(layers)
    self.net.load_state_dict(data)

Crete automatically saves the agent on completion of the training wrapper, or on keyboard interrupt.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crete-0.2.0.tar.gz (19.0 kB view details)

Uploaded Source

Built Distribution

crete-0.2.0-py3-none-any.whl (21.3 kB view details)

Uploaded Python 3

File details

Details for the file crete-0.2.0.tar.gz.

File metadata

  • Download URL: crete-0.2.0.tar.gz
  • Upload date:
  • Size: 19.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.6rc1

File hashes

Hashes for crete-0.2.0.tar.gz
Algorithm Hash digest
SHA256 08808e81ee7f72fc2bed209edaa98bb5b34b789a5e6e164dddf6cce23e275aa9
MD5 d10ca360902b8e2f10b87f5728cbcb30
BLAKE2b-256 3d9b39989e6cc06af3391dd3f82c309b3a3ce0a754308ed5ca1b45ab628904d9

See more details on using hashes here.

File details

Details for the file crete-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: crete-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 21.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.6rc1

File hashes

Hashes for crete-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f1d69011ada80e503c6d52653af58e63fa34f922f2aaf2b5a10138350f29ac06
MD5 5398903f30dc7cc04a389a186ed136dd
BLAKE2b-256 6373173e346616df0353f15318cb08708305fc93b04b2e13240d6deedf0fa61e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page