Skip to main content

SwiftTD: Fast and Robust TD Learning

Project description

SwiftTD: A Fast and Robust Algorithm for Temporal Difference Learning

SwiftTD is an algorithm for learning value functions. It combines the ideas of step-size adaptation with the idea of a bound on the rate of learning. The implementations in this repository use linear function approximation.

Installation

pip install SwiftTD

Usage

After installation, you can use the three implementations of SwiftTD in Python as:

import swifttd

# Version of SwiftTD that expects the full feature vector as input. This should only be used if the feature representation is not sparse. Otherwise, the sparse versions are more efficient.
td_dense = swifttd.SwiftTDNonSparse(
    num_of_features=5,     # Number of input features
    lambda_=0.95,        # Lambda parameter for eligibility traces
    alpha=1e-2,  # Initial learning rate
    gamma=0.99,        # Discount factor
    epsilon=1e-5,          # Small constant for numerical stability
    eta=0.1, # Maximum allowed step size (bound on rate of learning)
    decay=0.999, # Step size decay rate
    meta_step_size=1e-3,  # Meta learning rate
    eta_min=1e-10 # Minimum value of the step-size parameter
)

# Feature vector
features = [1.0, 0.0, 0.5, 0.2, 0.0] 
reward = 1.0
prediction = td_dense.step(features, reward)
print("Dense prediction:", prediction)

# Version of SwiftTD that expects the feature indices as input. This version assumes that the features are binary---0 or 1. For learning, the indices of the features that are 1 are provided. 
td_sparse = swifttd.SwiftTDBinaryFeatures(
    num_of_features=1000,     # Number of input features
    lambda_=0.95,        # Lambda parameter for eligibility traces
    alpha=1e-2,  # Initial learning rate
    gamma=0.99,        # Discount factor
    epsilon=1e-5,          # Small constant for numerical stability
    eta=0.1, # Maximum allowed step size (bound on rate of learning)
    decay=0.999, # Step size decay rate
    meta_step_size=1e-3,  # Meta learning rate
    eta_min=1e-10 # Minimum value of the step-size parameter
)

# Specify the indices of the features that are 1.
active_features = [1, 42, 999]  # Indices of active features
reward = 1.0
prediction = td_sparse.step(active_features, reward)
print("Sparse binary prediction:", prediction)

# Version of SwiftTD that expects the feature indices and values as input. This version does not assume that the features are binary. For learning, it expects a list of (index, value) pairs. Only the indices of the features that are non-zero need to be provided. 

td_sparse_nonbinary = swifttd.SwiftTD(
    num_of_features=1000,     # Number of input features
    lambda_=0.95,        # Lambda parameter for eligibility traces
    alpha=1e-2,  # Initial learning rate
    gamma=0.99,        # Discount factor
    epsilon=1e-5,          # Small constant for numerical stability
    eta=0.1, # Maximum allowed step size (bound on rate of learning)
    decay=0.999, # Step size decay rate
    meta_step_size=1e-3,  # Meta learning rate
    eta_min=1e-10 # Minimum value of the step-size parameter
)

# Specify the indices and values of the features that are non-zero.
feature_values = [(1, 0.8), (42, 0.3), (999, 1.2)]  # (index, value) pairs
reward = 1.0
prediction = td_sparse_nonbinary.step(feature_values, reward)
print("Sparse non-binary prediction:", prediction)

Resources

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

swifttd-0.2.2.tar.gz (5.6 kB view details)

Uploaded Source

File details

Details for the file swifttd-0.2.2.tar.gz.

File metadata

  • Download URL: swifttd-0.2.2.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for swifttd-0.2.2.tar.gz
Algorithm Hash digest
SHA256 61f4f342c9829c4c18b038e70a774c139d87d6229933635a336d0e6affaa810d
MD5 257fe4084abab54b11ec4bccb0f84d22
BLAKE2b-256 b3c0f6a0d885e29713e90fd3b7adde429dc26dd591c5b1d265e2ddab49c5b541

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page