Skip to main content

No project description provided

Project description

Torch Lure

Chandelure

Installations

pip install torchlure

Usage

import torchlure as lure

# Optimizers
lure.SophiaG(lr=1e-3, weight_decay=0.2)

# Functions
lure.tanh_exp(x)
lure.TanhExp()

lure.quantile_loss(y_pred, y_target, quantile=0.5)
lure.QuantileLoss(quantile=0.5)

lure.RMSNrom(dim=256, eps=1e-6)

# Noise Scheduler
lure.LinearNoiseScheduler(beta=1e-4, beta_end=0.02, num_timesteps=1000)
lure.CosineNoiseScheduler(max_beta=0.999, s=0.008, num_timesteps=1000):

Dataset

import gymnasium as gym
import numpy as np
import torch
from torchlure.datasets import MinariEpisodeDataset, MinariTrajectoryDataset
from torchtyping import TensorType

def return_to_go(rewards: TensorType[..., "T"], gamma: float) -> TensorType[..., "T"]:
    if gamma == 1.0:
        return rewards.flip(-1).cumsum(-1).flip(-1)

    seq_len = rewards.shape[-1]
    rtgs = torch.zeros_like(rewards)
    rtg = torch.zeros_like(rewards[..., 0])

    for i in range(seq_len - 1, -1, -1):
        rtg = rewards[..., i] + gamma * rtg
        rtgs[..., i] = rtg

    return rtgs


env = gym.make("Hopper-v4")
minari_dataset = MinariEpisodeDataset("Hopper-random-v0")
minari_dataset.create(env, n_episodes=100, exist_ok=True)
minari_dataset.info()
# Observation space: Box(-inf, inf, (11,), float64)
# Action space: Box(-1.0, 1.0, (3,), float32)
# Total episodes: 100
# Total steps: 2,182

traj_dataset = MinariTrajectoryDataset(minari_dataset, traj_len=20, {
    "returns": lambda ep: return_to_go(torch.tensor(ep.rewards), 0.99),
})

traj = traj_dataset[2]
traj = traj_dataset[[3, 8, 15]]
traj = traj_dataset[np.arange(16)]
traj = traj_dataset[torch.arange(16)]
traj = traj_dataset[-16:]
traj["observations"].shape, traj["actions"].shape, traj["rewards"].shape, traj[
    "terminated"
].shape, traj["truncated"].shape, traj["timesteps"].shape
# (torch.Size([16, 20, 4, 4, 16]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlure-0.2407.12.tar.gz (19.1 kB view details)

Uploaded Source

Built Distribution

torchlure-0.2407.12-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file torchlure-0.2407.12.tar.gz.

File metadata

  • Download URL: torchlure-0.2407.12.tar.gz
  • Upload date:
  • Size: 19.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.2

File hashes

Hashes for torchlure-0.2407.12.tar.gz
Algorithm Hash digest
SHA256 76cec3a08dbb95759f32560190fa719ffe0778f4815c85ea2bb73109b839aaba
MD5 5d5bcafbc69d567df8b176c6485e045a
BLAKE2b-256 4c0ab5b098f895057d229e33cf043e4c6e12215ed219098f98798722b449aabb

See more details on using hashes here.

File details

Details for the file torchlure-0.2407.12-py3-none-any.whl.

File metadata

File hashes

Hashes for torchlure-0.2407.12-py3-none-any.whl
Algorithm Hash digest
SHA256 106dd2e75370eed38a62995cf73664d39708cc07b65d184ce4d9edbd0780de6a
MD5 c446ef58c1963f1b888e3df35f534876
BLAKE2b-256 47757bb0fcfe7b080a3b0bc62a09927d0aa087983846be75b15565ffed1ae369

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page