Skip to main content

A toolbox for using complex valued standard network modules in PyTorch.

Project description

ComplexNN: Complex Neural Network Modules

Static Badge Static Badge Static Badge Actions Status Scc Count Badge PyPI version Downloads

Since the latest versions of PyTorch support matrix operations and gradient descent on plural parameters, this repository provides the latest plural form of some standard PyTorch network modules. Compared with utilizing two sets of parameters to represent the real and imaginary parts of the network plural parameters respectively, directly utilizing complex numbers as network parameters will halve the number of trainable parameters, which speeds up the training process.

Install

To install complexNN for the first time:

pip install complexNN

To upgrade a previous installation of complexNN to the most recent version:

pip install --upgrade complexNN

Versions

v0.0.1 Provides the plural form of the base standard PyTorch network module.

v0.1.1 Adds support for the Linear Recurrent Unit (LRU).

v0.1.2 Bug fixed. Adds support for BatchNorm2d, and BatchNorm3d.

Modules

The complex form modules include

complexLayer complexRNNcell complexActivation complexFunction
Linear RNN Cell Relu BatchNorm
GRU Cell Gelu LayerNorm
LSTM Cell Tanh Dropout
LRU Cell [1] Sigmoid

Note that the native version of torch.nn.Dropout is supported:exclamation::exclamation: Other modules will be considered for updates in the future.

Examples

Multilayer perceptron

import torch.nn as nn
from complexNN.complexActivation import complexTanh
from complexNN.complexLinear import complexLinear


class complexMLP(nn.Module):
    """
    Complex Multilayer Perceptron
    """

    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(complexMLP, self).__init__()
        self.num_layers = num_layers
        self.input_layer = complexLinear(input_size, hidden_size)
        self.hidden_layers = nn.ModuleList([complexLinear(hidden_size, hidden_size) for _ in range(num_layers - 1)])
        self.output_layer = complexLinear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = complexTanh(self.input_layer(x))
        x = self.dropout(x)
        for i in range(self.layer_num - 1):
            x = complexTanh(self.hidden_layers[i](x))
            x = self.dropout(x)
        output = self.output_layer(x)
        return output

Recurrent neural network

import torch.nn as nn
from complexNN.complexRNNcell import complexRNNCell


class complexRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(complexRNN, self).__init__()
        self.num_layers = num_layers
        self.rnn_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.rnn_layers.append(complexRNNCell(input_size, hidden_size))
            input_size = hidden_size

    def forward(self, x, h_0):
        h_prev = h_0
        for i in range(self.laryer_num):
            h_prev = self.rnn_layers[i](x, h_prev)
        return h_prev

Cite as

@misc{ComplexNN,
      title={ComplexNN: Complex Neural Network Modules},
      author={Xinyuan Liao},
      Url= {https://github.com/XinyuanLiao/ComplexNN}, 
      year={2023}
}

Reference

[1] Orvieto, Antonio, et al. "Resurrecting recurrent neural networks for long sequences." arXiv preprint arXiv:2303.06349 (2023).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

complexNN-0.1.2.tar.gz (7.1 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page