Skip to main content

No project description provided

Project description

Torch Lure

Chandelure

Installations

pip install torchlure

Usage

import torchlure as lure

# Optimizers
lure.SophiaG(lr=1e-3, weight_decay=0.2)

# Functions
lure.tanh_exp(x)
lure.TanhExp()

lure.quantile_loss(y_pred, y_target, quantile=0.5)
lure.QuantileLoss(quantile=0.5)

lure.RMSNrom(dim=256, eps=1e-6)

# Noise Scheduler
lure.LinearNoiseScheduler(beta=1e-4, beta_end=0.02, num_timesteps=1000)
lure.CosineNoiseScheduler(max_beta=0.999, s=0.008, num_timesteps=1000):

Dataset

import gymnasium as gym
import numpy as np
import torch
from torchlure.datasets import MinariEpisodeDataset, MinariTrajectoryDataset
from torchtyping import TensorType

def return_to_go(rewards: TensorType[..., "T"], gamma: float) -> TensorType[..., "T"]:
    if gamma == 1.0:
        return rewards.flip(-1).cumsum(-1).flip(-1)

    seq_len = rewards.shape[-1]
    rtgs = torch.zeros_like(rewards)
    rtg = torch.zeros_like(rewards[..., 0])

    for i in range(seq_len - 1, -1, -1):
        rtg = rewards[..., i] + gamma * rtg
        rtgs[..., i] = rtg

    return rtgs


env = gym.make("Hopper-v4")
minari_dataset = MinariEpisodeDataset("Hopper-random-v0")
minari_dataset.create(env, n_episodes=100, exist_ok=True)
minari_dataset.info()
# Observation space: Box(-inf, inf, (11,), float64)
# Action space: Box(-1.0, 1.0, (3,), float32)
# Total episodes: 100
# Total steps: 2,182

traj_dataset = MinariTrajectoryDataset(minari_dataset, traj_len=20, {
    "returns": lambda ep: return_to_go(torch.tensor(ep.rewards), 0.99),
})

traj = traj_dataset[2]
traj = traj_dataset[[3, 8, 15]]
traj = traj_dataset[np.arange(16)]
traj = traj_dataset[torch.arange(16)]
traj = traj_dataset[-16:]
traj["observations"].shape, traj["actions"].shape, traj["rewards"].shape, traj[
    "terminated"
].shape, traj["truncated"].shape, traj["timesteps"].shape
# (torch.Size([16, 20, 4, 4, 16]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlure-0.2407.11.tar.gz (19.1 kB view details)

Uploaded Source

Built Distribution

torchlure-0.2407.11-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file torchlure-0.2407.11.tar.gz.

File metadata

  • Download URL: torchlure-0.2407.11.tar.gz
  • Upload date:
  • Size: 19.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.2

File hashes

Hashes for torchlure-0.2407.11.tar.gz
Algorithm Hash digest
SHA256 b77e7c02a6737d8f7dbaa14e7f2ab4bd7c3dfd870dd688da2155e48c3baf5546
MD5 57de66da78ae9c327c3149ba676ed098
BLAKE2b-256 d01785c77a6b8940393cb91feb5fe5179b7a4bc8796ac1214ea2651dd76c9c23

See more details on using hashes here.

File details

Details for the file torchlure-0.2407.11-py3-none-any.whl.

File metadata

File hashes

Hashes for torchlure-0.2407.11-py3-none-any.whl
Algorithm Hash digest
SHA256 d178c4b2d31e3acc822924afcf4aa60c4ecda1a06f9441bb4dbeff5cff35bdc5
MD5 03888019bcf260a267558d3b52a1c3c9
BLAKE2b-256 801e3b9b1160b8e0511876987565277fbe91f0b59e435fd01678b5f731815e4a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page