Skip to main content

DNI, for Pytorch

Project description

# Decoupled Neural Interfaces Using Synthetic Gradients

[![Build Status](https://travis-ci.org/ixaxaar/pytorch-dni.svg?branch=master)](https://travis-ci.org/ixaxaar/pytorch-dni) [![PyPI version](https://badge.fury.io/py/dni.svg)](https://badge.fury.io/py/pytorch-dni)

<!-- START doctoc generated TOC please keep comment here to allow auto update -->
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->


- [Install](#install)
- [From source](#from-source)
- [Architecure](#architecure)
- [Usage](#usage)
- [Tasks](#tasks)

<!-- END doctoc generated TOC please keep comment here to allow auto update -->

This is an implementation of [Decoupled Neural Interfaces using Synthetic Gradients, Jaderberg et al.](https://arxiv.org/abs/1608.05343).

## Install

```bash
pip install pytorch-dni
```

### From source

```
git clone https://github.com/ixaxaar/pytorch-dni
cd pytorch-dni
pip install -r ./requirements.txt
pip install -e .
```

## Architecure

<img src="./docs/3-6.gif" />

## Usage

```python
from dni import DNI

# Custom network, can be anything extending nn.Module
net = WhateverNetwork(**kwargs)
opt = optim.Adam(net.parameters(), lr=0.001)

# use DNI to optimize this network
net = DNI(net, optim=opt)

# after that we go about our business as usual
for e in range(epoch):
opt.zero_grad()
output = net(input, *args)
loss = criterion(output, target_output)
loss.backward()
opt.step()

...
```

## DNI Networks

This package ships with 3 types of DNI networks:

- RNN_DNI: stacked `LSTM`s, `GRU`s or `RNN`s
- Linear_DNI: 2-layer `Linear` modules
- Linear_Sigmoid_DNI: 2-layer `Linear` followed by `Sigmoid`

## Custom DNI Networks

Custom DNI nets can be created using the `DNI_Network` interface:

```python
class MyDNI(DNI_Network):
def __init__(self, input_size, hidden_size, output_size, **kwargs):
super(MyDNI, self).__init__(input_size, hidden_size, output_size)

self.net = { ... your custom net }
...

def forward(self, input, hidden):
return self.net(input), None # return (output, hidden), hidden can be None

```

## Tasks

The tasks included in this project are the same as those in [pytorch-dnc](https://github.com/ixaxaar/pytorch-dnc#tasks), except that they're trained here using DNI.

## Notable stuff

- Using a linear SG module makes the implicit assumption that loss is a quadratic function of the activations
- For best performance one should adapt the SG module architecture to the loss function used. For MSE linear SG is a reasonable choice, however for log loss one should use architectures including a sigmoid applied pointwise to a linear SG


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-dni-0.0.1.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

pytorch_dni-0.0.1-py3-none-any.whl (10.2 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-dni-0.0.1.tar.gz.

File metadata

  • Download URL: pytorch-dni-0.0.1.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for pytorch-dni-0.0.1.tar.gz
Algorithm Hash digest
SHA256 b19a952cfa036e18bd5f5551811af3f9ed1b9e721c2093f1f0ae325ab23fdf42
MD5 27226a76a3330b7fd72c229e5ceee049
BLAKE2b-256 58fa02761dd4b490724223ba2d5d0b9ee585703108444eeec973f8fc7be3f424

See more details on using hashes here.

File details

Details for the file pytorch_dni-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_dni-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 03105e8d1929c3765bbce61e0da872a4930b29781324ce5042730c9f1cd8cb55
MD5 7643cfd56faa266da5348fe0e2070d23
BLAKE2b-256 3fef0e764a8ffeb82930fab400e3a5ad86a0c6d1ade66a51d41a3a9b3471b516

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page