Skip to main content

No project description provided

Project description

Torch Lure

Chandelure

Installations

pip install torchlure

Usage

import torchlure as lure

# Optimizers
lure.SophiaG(lr=1e-3, weight_decay=0.2)

# Functions
lure.tanh_exp(x)
lure.TanhExp()

lure.quantile_loss(y_pred, y_target, quantile=0.5)
lure.QuantileLoss(quantile=0.5)

lure.RMSNrom(dim=256, eps=1e-6)

# Noise Scheduler
lure.LinearNoiseScheduler(beta=1e-4, beta_end=0.02, num_timesteps=1000)
lure.CosineNoiseScheduler(max_beta=0.999, s=0.008, num_timesteps=1000):

Dataset

import gymnasium as gym
import numpy as np
import torch
from torchlure.datasets import MinariEpisodeDataset, MinariTrajectoryDataset
from torchtyping import TensorType

def return_to_go(rewards: TensorType[..., "T"], gamma: float) -> TensorType[..., "T"]:
    if gamma == 1.0:
        return rewards.flip(-1).cumsum(-1).flip(-1)

    seq_len = rewards.shape[-1]
    rtgs = torch.zeros_like(rewards)
    rtg = torch.zeros_like(rewards[..., 0])

    for i in range(seq_len - 1, -1, -1):
        rtg = rewards[..., i] + gamma * rtg
        rtgs[..., i] = rtg

    return rtgs


env = gym.make("Hopper-v4")
minari_dataset = MinariEpisodeDataset("Hopper-random-v0")
minari_dataset.create(env, n_episodes=100, exist_ok=True)
minari_dataset.info()
# Observation space: Box(-inf, inf, (11,), float64)
# Action space: Box(-1.0, 1.0, (3,), float32)
# Total episodes: 100
# Total steps: 2,182

traj_dataset = MinariTrajectoryDataset(minari_dataset, traj_len=20, {
    "returns": lambda ep: return_to_go(torch.tensor(ep.rewards), 0.99),
})

traj = traj_dataset[2]
traj = traj_dataset[[3, 8, 15]]
traj = traj_dataset[np.arange(16)]
traj = traj_dataset[torch.arange(16)]
traj = traj_dataset[-16:]
traj["observations"].shape, traj["actions"].shape, traj["rewards"].shape, traj[
    "terminated"
].shape, traj["truncated"].shape, traj["timesteps"].shape
# (torch.Size([16, 20, 4, 4, 16]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlure-0.2407.13.tar.gz (19.8 kB view details)

Uploaded Source

Built Distribution

torchlure-0.2407.13-py3-none-any.whl (20.0 kB view details)

Uploaded Python 3

File details

Details for the file torchlure-0.2407.13.tar.gz.

File metadata

  • Download URL: torchlure-0.2407.13.tar.gz
  • Upload date:
  • Size: 19.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.2

File hashes

Hashes for torchlure-0.2407.13.tar.gz
Algorithm Hash digest
SHA256 55ca7e6e1a67a332d964aae9a174e54a595a0ad6ce00d6ce4279d3ef9c54fc78
MD5 0bb4ae0e8e91378c40f780b681d68f10
BLAKE2b-256 0e9b87da6340693eb5fe5e2f97622b00ad67ba68ea9a62488690aa73c78e4386

See more details on using hashes here.

File details

Details for the file torchlure-0.2407.13-py3-none-any.whl.

File metadata

File hashes

Hashes for torchlure-0.2407.13-py3-none-any.whl
Algorithm Hash digest
SHA256 3c39c80e886f6b3a9c27cb786c6e99a8c1d5b0b22a68edffb953a4bf63b43a8a
MD5 51cb71e1572568b0305d34bb95c3ad10
BLAKE2b-256 fa2206e4d0d114bc09ead8c651a903578f6dc8b1a3b5e5962dd858ef4c180856

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page