Skip to main content

A pythonic implementation of the REINFORCE algorithm that is actually fun to use

Project description

torchreinforce

A pythonic implementation of the REINFORCE algorithm that is actually fun to use

Installation

You can install it with pip as you would for any other python package

pip install torchreinforce

Quickstart

In order to use the REINFORCE algorithm with your model you only need to do two things:

  • Use the ReinforceModule class as your base class
  • Decorate your forward function with @ReinforceModule.forward

That's it!

class Model(ReinforceModule):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        self.net = torch.nn.Sequential(
            torch.nn.Linear(20, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 2),
            torch.nn.Softmax(dim=-1),
        )

    @ReinforceModule.forward
    def forward(self, x):
        return self.net(x)

Your model will now output ReinforceOutput objects.

This objects have two important functions

  • get()
  • reward(value)

You can use output.get() to get an actual sample of the overlaying distribution and output.reward(value) to set a reward for the specific output.

Being net your model you have to do something like that

action = net(observation)
observation, reward, done, info = env.step(action.get())
action.reward(reward)

Wait, did you just said distribution?

Yes! As the REINFORCE algorithm states the outputs of your model will be used as parameters for a probability distribution function.

Actually you can use whatever probability distribution you want, the ReinforceModule constructor accepts indeed the following parameters:

  • gamma the gamma parameter of the REINFORCE algorithm (default: Categorical)
  • distribution every ReinforceDistribution or pytorch.distributions distribution (default: 0.99)

like that

net = Model(distribution=torch.distributions.Beta, gamma=0.99)

Keep in mind that the outputs of your decorated forward(x) outputs will be used as the parameters for the distribution. If your distribution needs more than one parameters just return a list.

I've added the possibility to distribution to have a deterministic behavior in testing and I've implemented it only for the Categorical distribution, if you want to implement your own deterministic logic check the file distributions/categorical.py it is pretty straightforward

If you want to use the torch.distributions.Beta distribution for example you will need to do something like

class Model(ReinforceModule):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        ...

    @ReinforceModule.forward
    def forward(self, x):
        return [self.net1(x), self.net2(x)] # the Beta distribution accepts two parameters

net = Model(distribution=torch.distributions.Beta, gamma=0.99)

action = net(inp)
env.step(action.get())

Nice! What about training?

You can compute the REINFORCE loss by calling the loss() function of ReinforceModule and than treat it as you would do with any other pytorch loss function

net = ...
optmizer = ...

while training:
    net.reset()
    for steps:
        ....

    loss = net.loss(normalize=True)

    optimizer.zero_grad()
    loss.backward()
    optmizer.step()

You have to call the reset() function of ReinforceModule before the beginning of each episode. You can also pass the argument normalize to loss() if you want to normalize the rewards

Putting all together

A complete example looks like this:

class Model(ReinforceModule):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        self.net = torch.nn.Sequential(
            torch.nn.Linear(4, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 2),
            torch.nn.Softmax(dim=-1),
        )

    @ReinforceModule.forward
    def forward(self, x):
        return self.net(x)


env = gym.make('CartPole-v0')
net = Model()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

for i in range(EPISODES):
    done = False
    net.reset()
    observation = env.reset()
    while not done:
        action = net(torch.tensor(observation, dtype=torch.float32))

        observation, reward, done, info = env.step(action.get())
        action.reward(reward)

    loss = net.loss(normalize=False)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

You can find a running example in the examples/ folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

torchreinforce-0.1.0-py3.6.egg (12.8 kB view details)

Uploaded Source

torchreinforce-0.1.0-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file torchreinforce-0.1.0-py3.6.egg.

File metadata

  • Download URL: torchreinforce-0.1.0-py3.6.egg
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.29.1 CPython/3.6.7

File hashes

Hashes for torchreinforce-0.1.0-py3.6.egg
Algorithm Hash digest
SHA256 d7023facdea8f79409c5e58526bec9e182540af3f392b83257f258b8766f31f5
MD5 b18372214156e1546ba1485301275401
BLAKE2b-256 fea604fb485d82a7ba41711190a0da8fb246fd0f7812d09ce58e5f6c15daa86b

See more details on using hashes here.

File details

Details for the file torchreinforce-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchreinforce-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.29.1 CPython/3.6.7

File hashes

Hashes for torchreinforce-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 69b205c21b0044f82991d550f950ed1e99ee79aef7f981c18753815823fe4c83
MD5 d021cfa11e21ff3ea825f44717ef9d92
BLAKE2b-256 47825d3c2eb2dad9f2a5cb65b74ba29b23ae21322832c8c8c167dab05c8c7128

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page