Skip to main content

No project description provided

Project description

Torch Lure

Chandelure

Installations

pip install torchlure

Usage

import torchlure as lure

# Optimizers
lure.SophiaG(lr=1e-3, weight_decay=0.2)

# Functions
lure.tanh_exp(x)
lure.TanhExp()

lure.quantile_loss(y_pred, y_target, quantile=0.5)
lure.QuantileLoss(quantile=0.5)

lure.RMSNrom(dim=256, eps=1e-6)

# Noise Scheduler
lure.LinearNoiseScheduler(beta=1e-4, beta_end=0.02, num_timesteps=1000)
lure.CosineNoiseScheduler(max_beta=0.999, s=0.008, num_timesteps=1000):


lure.ReLUKAN(width=[11, 16, 16, 2], grid=5, k=3)

lure.create_relukan_network(
    input_dim=11,
    output_dim=2,
    hidden_dim=32,
    num_layers=3,
    grid=5,
    k=3,
)
import torchlure as lure

# Optimizers
lure.SophiaG(lr=1e-3, weight_decay=0.2)

# Functions
lure.tanh_exp(x)
lure.TanhExp()

lure.quantile_loss(y_pred, y_target, quantile=0.5)
lure.QuantileLoss(quantile=0.5)

lure.RMSNrom(dim=256, eps=1e-6)

# Noise Scheduler
lure.LinearNoiseScheduler(beta=1e-4, beta_end=0.02, num_timesteps=1000)
lure.CosineNoiseScheduler(max_beta=0.999, s=0.008, num_timesteps=1000):

Dataset

import gymnasium as gym
import numpy as np
import torch
from torchlure.datasets import MinariEpisodeDataset, MinariTrajectoryDataset
from torchtyping import TensorType

def return_to_go(rewards: TensorType[..., "T"], gamma: float) -> TensorType[..., "T"]:
    if gamma == 1.0:
        return rewards.flip(-1).cumsum(-1).flip(-1)

    seq_len = rewards.shape[-1]
    rtgs = torch.zeros_like(rewards)
    rtg = torch.zeros_like(rewards[..., 0])

    for i in range(seq_len - 1, -1, -1):
        rtg = rewards[..., i] + gamma * rtg
        rtgs[..., i] = rtg

    return rtgs


env = gym.make("Hopper-v4")
minari_dataset = MinariEpisodeDataset("Hopper-random-v0")
minari_dataset.create(env, n_episodes=100, exist_ok=True)
minari_dataset.info()
# Observation space: Box(-inf, inf, (11,), float64)
# Action space: Box(-1.0, 1.0, (3,), float32)
# Total episodes: 100
# Total steps: 2,182

traj_dataset = MinariTrajectoryDataset(minari_dataset, traj_len=20, {
    "returns": lambda ep: return_to_go(torch.tensor(ep.rewards), 0.99),
})

traj = traj_dataset[2]
traj = traj_dataset[[3, 8, 15]]
traj = traj_dataset[np.arange(16)]
traj = traj_dataset[torch.arange(16)]
traj = traj_dataset[-16:]
traj["observations"].shape, traj["actions"].shape, traj["rewards"].shape, traj[
    "terminated"
].shape, traj["truncated"].shape, traj["timesteps"].shape
# (torch.Size([16, 20, 4, 4, 16]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlure-0.2408.0.tar.gz (20.8 kB view details)

Uploaded Source

Built Distribution

torchlure-0.2408.0-py3-none-any.whl (20.9 kB view details)

Uploaded Python 3

File details

Details for the file torchlure-0.2408.0.tar.gz.

File metadata

  • Download URL: torchlure-0.2408.0.tar.gz
  • Upload date:
  • Size: 20.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.2

File hashes

Hashes for torchlure-0.2408.0.tar.gz
Algorithm Hash digest
SHA256 2aa8ee883d78a440a6bd5cf2042c8c83aa4a5b7611ae98b0c3fc59fcfcb8610a
MD5 2fa5579b7bc7b867b44336bfd38d0c16
BLAKE2b-256 0c919984fa04eb8856a256bf78bf68fac06295bd410b6f5628a84e6c2716e6c3

See more details on using hashes here.

File details

Details for the file torchlure-0.2408.0-py3-none-any.whl.

File metadata

  • Download URL: torchlure-0.2408.0-py3-none-any.whl
  • Upload date:
  • Size: 20.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.2

File hashes

Hashes for torchlure-0.2408.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4d3afb06fdad2473a3ea08f01655cd21dc4085d990eaca2ce4ef4e666a5d2ebf
MD5 76680486b7dd9d98554b03b5b37f3b22
BLAKE2b-256 8727efea9901f8d13aa71b8a21de574d6d02b7ef1253ec46a1e1e364db61a8d9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page