Skip to main content

No project description provided

Project description

Torch Lure

Chandelure

Installations

pip install torchlure

Usage

import torchlure as lure

# Optimizers
lure.SophiaG(lr=1e-3, weight_decay=0.2)

# Functions
lure.tanh_exp(x)
lure.TanhExp()

lure.quantile_loss(y_pred, y_target, quantile=0.5)
lure.QuantileLoss(quantile=0.5)

lure.RMSNrom(dim=256, eps=1e-6)

# Noise Scheduler
lure.LinearNoiseScheduler(beta=1e-4, beta_end=0.02, num_timesteps=1000)
lure.CosineNoiseScheduler(max_beta=0.999, s=0.008, num_timesteps=1000):


lure.ReLUKAN(width=[11, 16, 16, 2], grid=5, k=3)

lure.create_relukan_network(
    input_dim=11,
    output_dim=2,
    hidden_dim=32,
    num_layers=3,
    grid=5,
    k=3,
)
import torchlure as lure

# Optimizers
lure.SophiaG(lr=1e-3, weight_decay=0.2)

# Functions
lure.tanh_exp(x)
lure.TanhExp()

lure.quantile_loss(y_pred, y_target, quantile=0.5)
lure.QuantileLoss(quantile=0.5)

lure.RMSNrom(dim=256, eps=1e-6)

# Noise Scheduler
lure.LinearNoiseScheduler(beta=1e-4, beta_end=0.02, num_timesteps=1000)
lure.CosineNoiseScheduler(max_beta=0.999, s=0.008, num_timesteps=1000):

Dataset

import gymnasium as gym
import numpy as np
import torch
from torchlure.datasets import MinariEpisodeDataset, MinariTrajectoryDataset
from torchtyping import TensorType

def return_to_go(rewards: TensorType[..., "T"], gamma: float) -> TensorType[..., "T"]:
    if gamma == 1.0:
        return rewards.flip(-1).cumsum(-1).flip(-1)

    seq_len = rewards.shape[-1]
    rtgs = torch.zeros_like(rewards)
    rtg = torch.zeros_like(rewards[..., 0])

    for i in range(seq_len - 1, -1, -1):
        rtg = rewards[..., i] + gamma * rtg
        rtgs[..., i] = rtg

    return rtgs


env = gym.make("Hopper-v4")
minari_dataset = MinariEpisodeDataset("Hopper-random-v0")
minari_dataset.create(env, n_episodes=100, exist_ok=True)
minari_dataset.info()
# Observation space: Box(-inf, inf, (11,), float64)
# Action space: Box(-1.0, 1.0, (3,), float32)
# Total episodes: 100
# Total steps: 2,182

traj_dataset = MinariTrajectoryDataset(minari_dataset, traj_len=20, {
    "returns": lambda ep: return_to_go(torch.tensor(ep.rewards), 0.99),
})

traj = traj_dataset[2]
traj = traj_dataset[[3, 8, 15]]
traj = traj_dataset[np.arange(16)]
traj = traj_dataset[torch.arange(16)]
traj = traj_dataset[-16:]
traj["observations"].shape, traj["actions"].shape, traj["rewards"].shape, traj[
    "terminated"
].shape, traj["truncated"].shape, traj["timesteps"].shape
# (torch.Size([16, 20, 4, 4, 16]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]),
#  torch.Size([16, 20]))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlure-0.2408.2.tar.gz (67.5 kB view details)

Uploaded Source

Built Distribution

torchlure-0.2408.2-py3-none-any.whl (21.0 kB view details)

Uploaded Python 3

File details

Details for the file torchlure-0.2408.2.tar.gz.

File metadata

  • Download URL: torchlure-0.2408.2.tar.gz
  • Upload date:
  • Size: 67.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.2

File hashes

Hashes for torchlure-0.2408.2.tar.gz
Algorithm Hash digest
SHA256 16ce0c0bd3075fbe6c7e821137d9e29ec57c015b56aa6d4de89d3af3913a76c6
MD5 f346e9fe13579362d16b0e6731876143
BLAKE2b-256 c23ec9e6d12ea91d706f110138d15a9bd3a9f5f186c1fc3a5f10997f5f6c3e06

See more details on using hashes here.

File details

Details for the file torchlure-0.2408.2-py3-none-any.whl.

File metadata

  • Download URL: torchlure-0.2408.2-py3-none-any.whl
  • Upload date:
  • Size: 21.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.2

File hashes

Hashes for torchlure-0.2408.2-py3-none-any.whl
Algorithm Hash digest
SHA256 93f60de71a53bf740d086242ba36601be383e76b7ee1ce91c3b585cb21b8dba6
MD5 0d28e345947cbda651eb9f444653aa53
BLAKE2b-256 7c8e1bf24422894251fc2adea013ed5483b0e8a01a9b4bf765b732b98faede80

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page