Skip to main content

A toolbox for using complex valued standard network modules in PyTorch.

Project description

ComplexNN: Complex Neural Network Modules

Static Badge Static Badge Static Badge Actions Status Scc Count Badge PyPI version Downloads

This repository provides the plural form of standard modules under the PyTorch framework without any extra trainable parameters. The parameters and calling methods of the modules in this framework are consistent with those of the PyTorch framework, incurring no additional learning cost. This repository is completed due to PyTorch's support for complex gradients. Please refer to the documentation for details.

Install

To install complexNN for the first time:

pip install complexNN

To upgrade a previous installation of complexNN to the most recent version:

pip install --upgrade complexNN

Versions

v0.0.1 Provides the plural form of the base standard PyTorch network module.

v0.1.1 Adds support for the Linear Recurrent Unit (LRU).

v0.1.2 Bug fixes, and new support.

v0.2.1 Bug fixes, and new support.

v0.3.1 Code structure optimization, bug fixes, and new support.

Modules

The plural form modules include

complexLayer complexRNNcell complexActivation complexFunction complexRNN
Linear RNN Cell Relu BatchNorm 1d/ 2d/ 3d RNN
MLP GRU Cell Gelu LayerNorm GRU
Conv 1d/ 2d LSTM Cell Tanh dropout 1d/ 2d LSTM
LRU Cell [1] Sigmoid avg/ max pool

Other modules will be considered for updates in the future.

Examples

Convolutional neural network

import torch
from complexNN.complexLayer import complexConv1d, complexConv2d


if __name__ == '__main__':
    batch_size, in_channels, out_channels, seq_len = 10, 3, 16, 10
    conv_tensor = torch.rand((batch_size, in_channels, seq_len))
    con1d = complexConv1d(in_channels, out_channels, padding='same')
    print(con1d(conv_tensor).shape)

    H, W = 256, 256
    conv2d_tensor = torch.rand((batch_size, in_channels, H, W))
    conv2d = complexConv2d(in_channels, out_channels, padding=1)
    print(conv2d(conv2d_tensor).shape)

Multilayer perceptron

import torch
form complexNN.complexLayer import complexMLP


if __name__ == '__main__':
    batch_size, input_size, hidden_size, output_size = 10, 10, 20, 15
    input_tensor = torch.rand((batch_size, input_size), dtype=torch.cfloat)
    mlp = complexMLP(input_size, hidden_size, output_size, num_layers=3)
    out = mlp(input_tensor)
    print(out.shape)

Recurrent neural networks

import torch
from complexNN.complexRNN import complexRNN, complexGRU, complexLSTM


if __name__ == '__main__':
    batch_size, input_size, hidden_size, seq_len, num_layers = 10, 10, 20, 15, 3
    input_tensor = torch.rand((seq_len, batch_size, input_size), dtype=torch.cfloat)
    h0, c0 = torch.zeros((num_layers, batch_size, hidden_size)), torch.zeros((num_layers, batch_size, hidden_size))

    rnn = complexRNN(input_size, hidden_size, num_layers)
    gru = complexGRU(input_size, hidden_size, num_layers)
    lstm = complexLSTM(input_size, hidden_size, num_layers)

    rnn_out, _ = rnn(input_tensor, h0)
    gru_out, _ = gru(input_tensor, h0)
    lstm_out, _ = lstm(input_tensor, (h0, c0))

    print(rnn_out.shape, gru_out.shape, lstm_out.shape)

Cite as

@misc{ComplexNN,
      title={ComplexNN: Complex Neural Network Modules},
      author={Xinyuan Liao},
      Url= {https://github.com/XinyuanLiao/ComplexNN}, 
      year={2023}
}

Reference

[1] Orvieto, Antonio, et al. "Resurrecting recurrent neural networks for long sequences." arXiv preprint arXiv:2303.06349 (2023).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

complexNN-0.3.1.tar.gz (7.2 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page