Skip to main content

Improved Exploring Starts by Kernel Density Estimation-Based State-Space Coverage Acceleration in Reinforcement Learning.

Project description

DESSCA

Density Estimation-based State-Space Coverage Acceleration

Read the Paper

The provided DESSCA algorithm was designed to aid the state-space exploration in reinforcement learning applications. In many cases where standard exploring starts may be used, the degree of freedom that is provided by the initial state can be utilized to a better extent when using DESSCA instead. While regular, unsupervised exploring starts often lead to an unfavorable distribution of sample points in the state space, since the underlying system dynamics typically has particularly attractive regions, DESSCA analyzes the previous sample distribution and explores at the beginning of an episode targeted regions which are underrepresented compared to a target probability density distribution.

Suggestions or experiences concerning applications of DESSCA outside reinforcement learning are welcome!

Citing

An in-depth explanation of the principle, realization and improvement capabilities of DESSCA can be found in the article "Improved Exploring Starts by Kernel Density Estimation-Based State-Space Coverage Acceleration in Reinforcement Learning". Please cite it when using the provided code:

@misc{schenke2021improved,
      title={Improved Exploring Starts by Kernel Density Estimation-Based State-Space Coverage Acceleration in Reinforcement Learning}, 
      author={Maximilian Schenke and Oliver Wallscheid},
      year={2021},
      eprint={2105.08990},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Usage

This code snippet serves as a minimal usage example to DESSCA. Firstly, import the dessca_model from DESSCA.py and create a corresponding object. Make Sure to have DESSCA.py in the same folder as its application file

import numpy as np
from DESSCA import dessca_model
my_dessca_instance0 = dessca_model(box_constraints=[[-1, 1],
                                                    [-1, 1]],
                                   state_names=["x1", "x2"],
                                   bandwidth=0.5)

This model instance can be used on a two-dimensional state space. Now let's make use of its functionality by viewing the state-space coverage of a dataset. Here are some samples:

samples_2d = np.array([[-0.8, -0.8],
                       [0.8, -0.8],
                       [-0.8, 0.8],
                       [0, 0]])

my_dessca_instance0.update_coverage_pdf(data=np.transpose(samples_2d))
my_dessca_instance0.plot_scatter()

Output:

And a corresponding coverage heatmap

my_dessca_instance0.plot_heatmap()

Output:

The coverage probability density function (PDF) is updated with the given distribution. DESSCA can now suggest where to place the next sample.

next_sample_suggest = my_dessca_instance0.sample_optimally()
print(next_sample_suggest)

Output: [0.85517754 0.94340648] (Note: results are a little random in scenarios with very few samples)

As was to be expected, the suggestion is in the upper right corner of the state space. Update the coverage density and view the new distribution:

my_dessca_instance0.update_coverage_pdf(data=np.transpose([next_sample_suggest]))
my_dessca_instance0.plot_scatter()

Output:

Let's have a look at the density:

my_dessca_instance0.plot_heatmap()

Output:

More Features

The scatter plots can also be rendered in an online fashion (100 samples):

my_dessca_instance1 = dessca_model(box_constraints=[[-1, 1],
                                                    [-1, 1]],
                                  state_names=["x1", "x2"],
                                  bandwidth=0.1,
                                  render_online=True)

next_sample_suggest = my_dessca_instance1.update_and_sample()
for _ in range(100):
    next_sample_suggest = my_dessca_instance1.update_and_sample(np.transpose([next_sample_suggest]))

Output:

Further, we can parameterize a memory buffer to only memorize a limited number of past samples:

my_dessca_instance2 = dessca_model(box_constraints=[[-1, 1],
                                                    [-1, 1]],
                                  state_names=["x1", "x2"],
                                  bandwidth=0.1,
                                  render_online=True,
                                  buffer_size=25)

next_sample_suggest = my_dessca_instance2.update_and_sample()
for _ in range(100):
    next_sample_suggest = my_dessca_instance2.update_and_sample(np.transpose([next_sample_suggest]))

Output:

See how forgetting past samples leads to a group of samples in a similar area? Lastly, we can also choose to use a specific reference coverage density:

def reference_coverage(X):
    # for uniform distribution on a given shape the value range of the reference coverage is not important
    x0 = X[0]
    x1 = X[1]
    return np.less(x0**2 + x1**2, 1).astype(float)

my_dessca_instance3 = dessca_model(box_constraints=[[-1, 1],
                                                    [-1, 1]],
                                  state_names=["x1", "x2"],
                                  bandwidth=0.1,
                                  render_online=True,
                                  reference_pdf=reference_coverage)

next_sample_suggest = my_dessca_instance3.update_and_sample()
for _ in range(100):
    next_sample_suggest = my_dessca_instance3.update_and_sample(np.transpose([next_sample_suggest]))

Output:

DESSCA can also be used for downsampling. Let's firstly find a large dataset that we would like to reduce in size:

x1_samples = np.random.triangular(-1, 0.75, 1, size=(1000, 1))
x2_samples = np.random.triangular(-1, -0.75, 1, size=(1000, 1))
samples_2d = np.append(x1_samples, x2_samples, axis=1)

my_dessca_instance4 = dessca_model(box_constraints=[[-1, 1],
                                                    [-1, 1]],
                                   state_names=["x1", "x2"],
                                   bandwidth=0.5)
my_dessca_instance4.update_coverage_pdf(data=np.transpose(samples_2d))
my_dessca_instance4.plot_scatter()

Output:

Now, DESSCA can be used to reduce the set down to a specified number of remaining samples while trying to preserve the original distribution:

my_dessca_instance5 = dessca_model(box_constraints=[[-1, 1],
                                                    [-1, 1]],
                                   state_names=["x1", "x2"],
                                   bandwidth=0.5)

my_dessca_instance5.downsample(data=np.transpose(samples_2d), target_size=100)
my_dessca_instance5.plot_scatter()

Output:

Comparing the edge distributions, it can be seen that they are still almost the same despite removing 90 % of the dataset's content.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dessca-0.1.0.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dessca-0.1.0-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file dessca-0.1.0.tar.gz.

File metadata

  • Download URL: dessca-0.1.0.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for dessca-0.1.0.tar.gz
Algorithm Hash digest
SHA256 91c7d3f1c23ed80d96f312ef45a12d06777a6638f494e2241e1ef8d83cfdb038
MD5 49a7f9da320ebede90347fa95fc29392
BLAKE2b-256 1c6023f8624d172b14e95a3ed435d4710dcdf52248c4c2c7acb65bb33b36afc7

See more details on using hashes here.

File details

Details for the file dessca-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: dessca-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for dessca-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ff7b18d68658c6d8b308ee0fd9b9d2899aa01f2a6804e912eb1adf7ec133a87f
MD5 14e748f3c25185d3f8083be132840fae
BLAKE2b-256 3e23fe48f2291089aebefbd928ecb44795e2f3e183631dfecc74bf2583417da8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page