Skip to main content

Toolbox to learn biclustering and triclustering task using Ray's rllib and torch

Project description

NclustRL

NclustRL is a toolkit that implements some functionalities to help train agents for n-clustering tasks. It works with Ray's RLlib to train DRL agents.

Ray is a general-purpose framework for distributed computing that implements a known library for hyperparameter tunning, Tune. Furthermore, it implements RLlib, a DRL framework that supports distributed computing and great customization.

NclustRL implements a trainer API for n-clustering that handles all training tasks for the user; a set of default models and metrics; and other helpful functions. Likewise, it provides a set of default configurations for $n$-clustering tasks available in "configs".

Diagram exemplifying NclustEnv's architecture

The trainer API aims to provide a simple way of training and testing DRL agents for n-clustering tasks. This class handles all of RLlib's logic and expose only user-friendly methods.

After initialized, the trainer exposes four primary methods:

  • Train: Exposes the primary training function. It receives the training parameters that should be passed on to Tune, initiates the training process, manages multiple samples of the same trial, and parses results returning the best performance obtained;
  • Load: Imports an agent from a checkpoint for testing;
  • Test: Evaluates the accuracy and mean reward and returns the mean and standard deviation for each of these metrics across n episodes.
  • Test Dataset: Evaluates the performance in the same way as Test but receives as input a specific dataset from where episodes should be sampled.

Installation

This tool can be installed from PyPI:

pip install nclustRL

Getting started

Here are the basics, for more information check the Experiments available on "Exp".

## Train basic agent

from nclustRL.trainer import Trainer
from nclustRL.configs.default_configs import PPO_PBT, DEFAULT_CONFIG
from ray.rllib.agents.ppo import PPOTrainer
from nclustenv.configs import biclustering

# Inicialize Trainer

config = DEFAULT_CONFIG.copy()
config['env_config'] = biclustering.binary.basic_v2

trainer = Trainer(
    trainer=PPOTrainer,
    env='BiclusterEnv-v0',
    save_dir='nclustRL/Exp/test',
    name='test',
    config=config
)

## Tune agent

best_checkpoint = trainer.train(
    num_samples=8, 
    scheduler=PPO_PBT,
    stop_iters=500,
)

Model

By default this tool implements a model for hybrid proximal policy optimization algorithm, available in "models". This model can be customized, or other models might be implemented and passed in the configs.

License

GPLv3

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nclustRL-1.0.0.tar.gz (19.1 kB view details)

Uploaded Source

Built Distribution

nclustRL-1.0.0-py3-none-any.whl (33.2 kB view details)

Uploaded Python 3

File details

Details for the file nclustRL-1.0.0.tar.gz.

File metadata

  • Download URL: nclustRL-1.0.0.tar.gz
  • Upload date:
  • Size: 19.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.10

File hashes

Hashes for nclustRL-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3ed70bb1ca4101926790058a51f3cbcc845c2101e0f40244233e822f1b8d4572
MD5 aff4d3be3e8a3e451b13ac58210bcaac
BLAKE2b-256 8af13b85960b05faddff58584ebfc03ea0e0cce3197f07fd68fc9ec393fe6e80

See more details on using hashes here.

File details

Details for the file nclustRL-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: nclustRL-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 33.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.10

File hashes

Hashes for nclustRL-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1aeaabb8355b99ef696686229f0a9d5fcdcb10d352fa2ce4a75ee5bcdc419f82
MD5 d579aadaacf34b1e97d81ef91e6ae035
BLAKE2b-256 7a9e331dc22947292d63f44b47e7775babca68db3e7abf20dd3669b2f5ac621c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page