Skip to main content

Vectorized Reinforcement Learning Library

Project description

Samsara RL

Samsara RL

CI

A vectorized NumPy implementation of foundational reinforcement learning algorithms, following David Silver's RL lecture series, Sutton and Barto book and other papers cited when referenced. Built for clarity, learning as well as speed.

Applications of RL include robotic manipulation, LLM fine-tuning, financial portfolio management, and control systems.


Table of Contents

  1. Installation
  2. Quick Start
  3. Planning
  4. Model-Free Prediction
  5. Model-Free Control
  6. Function Approximation

Installation

pip install samsara-rl

Quick Start

from samsara_rl.mdp.grid_world.grid_world_mdp import GridWorldMDP
from samsara_rl.planning.policy_iteration import PolicyIteration

mdp = GridWorldMDP()
pi = PolicyIteration(mdp)
policy = pi.find_optimal_policy()

Planning

Planning algorithms assume full knowledge of environment dynamics (transition probabilities and reward function). While not "true RL" — agents never have access to dynamics in practice — planning provides the theoretical foundation all RL algorithms build on.

MDP Structure

MDPs are represented as NumPy arrays. The included GridWorldMDP implements the 4x4 grid world from David Silver's Lecture 3.

Attribute Shape Description
state_action_transition_matrix (S, A, S') T(s, a, s') — transition probabilities
reward_matrix (S, A, S') R(s, a, s') — reward for each transition

Policy Iteration

Alternates between evaluating the current policy using the Bellman expectation equation and improving it greedily until the policy stops changing.

PolicyIteration(mdp, bellman_tolerance)

Argument Type Default Description
mdp MDP MDP instance to solve
bellman_tolerance float 0.01 Convergence threshold for policy evaluation

find_optimal_policy(max_iter)

Argument Type Default Description
max_iter int 99 Maximum number of policy iteration steps

Value Iteration

Applies the Bellman optimality equation directly each iteration. Equivalent to policy iteration with k=1 evaluation steps per improvement. Policy is extracted once at convergence.

ValueIteration(mdp, bellman_tolerance)

Argument Type Default Description
mdp MDP MDP instance to solve
bellman_tolerance float 0.01 Convergence threshold for value iteration

Examples

from samsara_rl.mdp.grid_world.grid_world_mdp import GridWorldMDP
from samsara_rl.planning.policy_iteration import PolicyIteration
from samsara_rl.planning.value_iteration import ValueIteration

mdp = GridWorldMDP()

policy = PolicyIteration(mdp, bellman_tolerance=0.001).find_optimal_policy(max_iter=50)
policy = ValueIteration(mdp, bellman_tolerance=0.001).find_optimal_policy()

Model-Free Prediction

Model-free methods learn value functions directly from experience (sampled episodes) without access to the MDP's transition or reward dynamics.

Monte Carlo

Every-visit Monte Carlo prediction estimates Q(s, a) from sampled returns. After each episode, the return G_t (discounted cumulative reward from time step t onward) is computed for every visited state-action pair, and the Q-table is updated using constant-alpha learning:

Q(s, a) ← Q(s, a) + α (G_t − Q(s, a))

If the same (s, a) pair appears multiple times in an episode, each occurrence triggers an update. Returns are computed in a fully vectorized pass using a cumulative-sum trick that avoids the standard reverse loop over time steps.

MonteCarloPrediction(mdp, policy, alpha, gamma)

Argument Type Default Description
mdp MDP MDP instance to sample episodes from
policy array Stochastic policy of shape (S, A)
alpha float 0.01 Learning rate for incremental Q updates
gamma float 1 Discount factor

evaluate(max_iter)

Argument Type Default Description
max_iter int 10000 Number of episodes to sample from

TD(λ)

TD(λ) learns Q(s, a) online using one-step bootstrapping with eligibility traces. After each step, the TD error is computed against the expected Q-value of the next state under the current policy, and all previously visited state-action pairs are updated proportionally to their eligibility:

δ = R + γ E_π[Q(S', ·)] − Q(S, A)

Q(s, a) ← Q(s, a) + α δ e(s, a)

Eligibility traces use the replacing variant — on each visit to (s, a), the trace is set to 1 rather than incremented. All traces decay by γλ at each time step. Traces are reset to zero between episodes.

TemporalDifference(mdp, policy, alpha, gamma, _lambda)

Argument Type Default Description
mdp MDP MDP instance to sample episodes from
policy array Stochastic policy of shape (S, A)
alpha float 0.01 Learning rate for incremental Q updates
gamma float 0.9 Discount factor
_lambda float 0.4 Trace decay parameter (0 = TD(0), 1 = TD(1))

evaluate(max_iter)

Argument Type Default Description
max_iter int 1000 Number of episodes to sample from

Examples

from samsara_rl.mdp.grid_world.grid_world_mdp import GridWorldMDP
from samsara_rl.prediction.monte_carlo import MonteCarloPrediction
from samsara_rl.prediction.td import TemporalDifference
from samsara_rl.utils.policy.policy_utils import init_uniform_random

mdp = GridWorldMDP()
policy = init_uniform_random(mdp)

mc = MonteCarloPrediction(mdp, policy, alpha=0.01, gamma=0.9)
mc.evaluate(max_iter=10000)

td = TemporalDifference(mdp, policy, alpha=0.01, gamma=0.9, _lambda=0.4)
td.evaluate(max_iter=10000)

# V(s) for the random policy (expected value over actions)
v_mc = mc.q_table.mean(axis=1).reshape(4, 4)
v_td = td.q_table.mean(axis=1).reshape(4, 4)

Model-Free Control

Control algorithms learn an optimal policy by interleaving evaluation and improvement on every step. Both SARSA and Q-Learning build on the TD(λ) engine, using ε-greedy exploration to balance exploitation with discovery of new state-action pairs.

SARSA

On-policy TD control. Bootstraps from a sampled next action A' drawn from the current policy — the name comes from the quintuple (S, A, R, S', A'). Because the bootstrap target reflects the exploratory policy, SARSA's Q values account for the cost of occasional random actions.

δ = R + γ Q(S', A') − Q(S, A)

Sarsa(mdp, policy, alpha, gamma)

Argument Type Default Description
mdp MDP MDP instance to sample episodes from
policy array Initial stochastic policy of shape (S, A)
alpha float 0.01 Learning rate for incremental Q updates
gamma float 0.9 Discount factor

evaluate(max_iter)

Argument Type Default Description
max_iter int 5000 Number of episodes to run

Q-Learning

Off-policy TD control. Bootstraps from the greedy action max_a Q(S', a) regardless of the action actually taken. This means Q-Learning converges to the optimal Q* even while following an exploratory ε-greedy policy.

δ = R + γ max_a Q(S', a) − Q(S, A)

QLearning(mdp, policy, alpha, gamma)

Argument Type Default Description
mdp MDP MDP instance to sample episodes from
policy array Initial stochastic policy of shape (S, A)
alpha float 0.01 Learning rate for incremental Q updates
gamma float 0.9 Discount factor

evaluate(max_iter)

Argument Type Default Description
max_iter int 5000 Number of episodes to run

Examples

from samsara_rl.mdp.grid_world.grid_world_mdp import GridWorldMDP
from samsara_rl.control.tabular.sarsa import Sarsa
from samsara_rl.control.tabular.q_learning import QLearning
from samsara_rl.utils.policy.policy_utils import init_uniform_random

mdp = GridWorldMDP()
policy = init_uniform_random(mdp)

sarsa = Sarsa(mdp, policy, alpha=0.01, gamma=0.9)
sarsa.evaluate(max_iter=5000)

ql = QLearning(mdp, policy, alpha=0.01, gamma=0.9)
ql.evaluate(max_iter=5000)

# Optimal value per state (best action)
v_sarsa = sarsa.agent.q_table.max(axis=1).reshape(4, 4)
v_ql = ql.agent.q_table.max(axis=1).reshape(4, 4)

Function Approximation

TODO

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

samsara_rl-0.0.5.tar.gz (843.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

samsara_rl-0.0.5-py3-none-any.whl (18.9 kB view details)

Uploaded Python 3

File details

Details for the file samsara_rl-0.0.5.tar.gz.

File metadata

  • Download URL: samsara_rl-0.0.5.tar.gz
  • Upload date:
  • Size: 843.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.12 {"installer":{"name":"uv","version":"0.10.12","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for samsara_rl-0.0.5.tar.gz
Algorithm Hash digest
SHA256 da253688c5441c0e8fa0c2f114a7be166047cffc0112ac6eabfc2b97bdcfbaf4
MD5 f0bc718ea29ede3d913e223f2b4249d8
BLAKE2b-256 b8df2976fd9c53682e1bb0e686860f45257609736d43f3f0feba4bdd158cf278

See more details on using hashes here.

File details

Details for the file samsara_rl-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: samsara_rl-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 18.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.12 {"installer":{"name":"uv","version":"0.10.12","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for samsara_rl-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 6e4dae63096109548789671e3b4958d8a47c23185fdbcc621c99feeb9779c874
MD5 d8460c9296933b773d0fffe66e66cda9
BLAKE2b-256 0a267dcfff614a522f9bc385d963068b1cb669bfe5a02286969f2c8de6a7aef2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page