Skip to main content

Pytorch implementation of the Stigmergic Memory

Project description

torchsm

pytorch implementation of the Stigmergic Memory as presented in the paper Using stigmergy as a computational memory in the design of recurrent neural networks.

You can use this package to easly integrate our model into existing ones

You can safely mix native pytorch Modules with ours.

But do not forget to reset() them before starting every new time sequence

Implementing our proposed architecture to solve MNIST becomes as easy as:

import torch
import torchsm

net = torchsm.Sequential(
    torchsm.RecurrentStigmergicMemoryLayer(28, 15, hidden_layers=1, hidden_dim=20),
    torch.nn.Linear(15, 10),
    torch.nn.PReLU(),
    torch.nn.Linear(10, 10),
    torch.nn.PReLU()
)

You can train the time-unfolded model by computing the loss function on the desired temporal output

optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)
loss_fn = torch.nn.MSELoss()

for i in range(0,N):
    for X, Y in zip(dataset_X, dataset_Y):
        net.reset()
        out = None
        for i in range(0, X.shape[1]):
            out = net(torch.tensor(X[:,i], dtype=torch.float32))
        
        loss = loss_fn(out, Y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Does it support batch inputs?

Yes! The inputs have to be batched

for t in range(0, num_ticks):
    batch_out[0], batch_out[1], ... = net(torch.tensor([batch_in[0][t], batch_in[1][t], ...]))

Can it run on CUDA?

Yes and as you will expect from a pytorch Module!
You just need to call the to(device) method on a model to move it in the GPU memory

device = torch.device("cuda")

net = net.to(device)

net(torch.tensor(..., device=device))

Documentation

torchsm.Sequential

Wrapper of torch.nn.Sequential that adds the reset() method and forward the call to each torchsm.BaseLayer child.

If you want to use a SequentialContaier to build your models with one or more torchsm's layers you have to use torchsm.Sequential instead of torch.nn.Sequential in order to be able to reset() them.

torchsm.StigmergicMemoryLayer

This layer has two hidden ANNs with the layer's inputs as inputs and which outputs respectively determine the marks and ticks of a multi-monodimensional stigmergic space.

Imgur

torchsm.RecurrentStigmergicMemoryLayer

This layer is a StigmergicMemoryLayer which output is normalized by a linear layer and recurrently forwarded as input to the two hidden ANNs

Imgur

Citing

We can't wait to see what you will build with torchsm!
When you will publish your work you can use this BibTex to cite us :)

@article{galatolo_snn
,	author	= {Galatolo, Federico A and Cimino, Mario GCA and Vaglini, Gigliola}
,	title	= {Using stigmergy as a computational memory in the design of recurrent neural networks}
,	journal	= {ICPRAM 2019}
,	year	= {2019}
,	pages	= {}
}

Contributing

This code is released under GNU/GPLv3 so feel free to fork it and submit your changes, every PR helps.
If you need help using it or for any question please reach me at federico.galatolo@ing.unipi.it or on Telegram @galatolo

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsm-0.1.0.tar.gz (4.4 kB view details)

Uploaded Source

File details

Details for the file torchsm-0.1.0.tar.gz.

File metadata

  • Download URL: torchsm-0.1.0.tar.gz
  • Upload date:
  • Size: 4.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.29.1 CPython/3.6.6

File hashes

Hashes for torchsm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e83698a788ab6c2145be331636f0a1822d87cb0e4054a78b767ee3ae15c669d3
MD5 88957dba8b3e68739b617de94e58b46b
BLAKE2b-256 2fff707cdcb31a93a47c88b680e44bc87273674c2018bb178d39a96dd5f6d9dc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page