Skip to main content

Accelerated replay buffers in JAX.

Project description

dejax

An implementation of replay buffer data structure in JAX. Operations involving dejax replay buffers can be jitted and run on both CPU and GPU.

Package contents

  • dejax.circular_buffer — an implementation of a circular buffer data structure that can store pytrees of arbitrary structure (with the restriction that the corresponding tensor shapes in different pytrees match).
  • dejax.uniform — a FIFO replay buffer of fixed size that samples replayed items uniformly.
  • dejax.clustered — a replay buffer that assigns every trajectory to a cluster and maintains a separate replay buffer per cluster. Sampling is performed uniformly over all clusters. This kind of replay buffer is helpful when, for instance, one needs to replay low and high reward trajectories at the same rate.

How to use dejax replay buffers

import dejax

First, instantiate a buffer object. Buffer objects don't have state but rather provide methods to initialize and manipulate state.

buffer = uniform_replay(max_size=10000)

Having a buffer object, we can initialize the state of the replay buffer. For that we would need a prototype item that will be used to determine the structure of the storage. The prototype item must have the same structure and tensor shapes as the items that will be stored in the buffer.

buffer_state = buffer.init_fn(item_prototype)

Now we can fill the buffer:

for item in items:
    buffer_state = buffer.add_fn(buffer_state, make_item(item))

And sample from it:

batch = buffer.sample_fn(buffer_state, rng, batch_size)

Or apply an update op to the items in the buffer:

def item_update_fn(item):
    # Possibly update an item
    return item
buffer_state = buffer.update_fn(buffer_state, item_update_fn)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dejax-0.1.3.tar.gz (9.8 kB view details)

Uploaded Source

File details

Details for the file dejax-0.1.3.tar.gz.

File metadata

  • Download URL: dejax-0.1.3.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for dejax-0.1.3.tar.gz
Algorithm Hash digest
SHA256 0511ef769bf9603415cd6f7bc24a42c11048658b826fe75c1ae0656e202f2aac
MD5 e5ee9d66d3cbccaa842f09aa7dd50e01
BLAKE2b-256 7ae99175a2d6ae90640b3e816fbf78098fa7031d0d03ea0704dc726e2f3fed00

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page