A toolbox for using complex valued standard network modules in PyTorch.
Project description
ComplexNN: Complex Neural Network Modules
This repository provides the plural form of standard modules under the PyTorch framework without any extra trainable parameters. The parameters and calling methods of the modules in this framework are consistent with those of the PyTorch framework, incurring no additional learning cost. This repository is completed due to PyTorch's support for complex gradients. Please refer to the documentation for details.
Install
To install complexNN for the first time:
pip install complexNN
To upgrade a previous installation of complexNN to the most recent version:
pip install --upgrade complexNN
Versions
v0.0.1
Provides the plural form of the base standard PyTorch network module.
v0.1.1
Adds support for the Linear Recurrent Unit (LRU).
v0.1.2
Bug fixes, and new support.
v0.2.1
Bug fixes, and new support.
v0.3.1
Code structure optimization, bug fixes, and new support.
Modules
The plural form modules include
complexLayer | complexRNNcell | complexActivation | complexFunction | complexRNN |
---|---|---|---|---|
Linear | RNN Cell | Relu | BatchNorm 1d/ 2d/ 3d | RNN |
MLP | GRU Cell | Gelu | LayerNorm | GRU |
Conv 1d/ 2d | LSTM Cell | Tanh | dropout 1d/ 2d | LSTM |
LRU Cell [1] | Sigmoid | avg/ max pool |
Other modules will be considered for updates in the future.
Examples
Convolutional neural network
import torch
from complexNN.complexLayer import complexConv1d, complexConv2d
if __name__ == '__main__':
batch_size, in_channels, out_channels, seq_len = 10, 3, 16, 10
conv_tensor = torch.rand((batch_size, in_channels, seq_len))
con1d = complexConv1d(in_channels, out_channels, padding='same')
print(con1d(conv_tensor).shape)
H, W = 256, 256
conv2d_tensor = torch.rand((batch_size, in_channels, H, W))
conv2d = complexConv2d(in_channels, out_channels, padding=1)
print(conv2d(conv2d_tensor).shape)
Multilayer perceptron
import torch
form complexNN.complexLayer import complexMLP
if __name__ == '__main__':
batch_size, input_size, hidden_size, output_size = 10, 10, 20, 15
input_tensor = torch.rand((batch_size, input_size), dtype=torch.cfloat)
mlp = complexMLP(input_size, hidden_size, output_size, num_layers=3)
out = mlp(input_tensor)
print(out.shape)
Recurrent neural networks
import torch
from complexNN.complexRNN import complexRNN, complexGRU, complexLSTM
if __name__ == '__main__':
batch_size, input_size, hidden_size, seq_len, num_layers = 10, 10, 20, 15, 3
input_tensor = torch.rand((seq_len, batch_size, input_size), dtype=torch.cfloat)
h0, c0 = torch.zeros((num_layers, batch_size, hidden_size)), torch.zeros((num_layers, batch_size, hidden_size))
rnn = complexRNN(input_size, hidden_size, num_layers)
gru = complexGRU(input_size, hidden_size, num_layers)
lstm = complexLSTM(input_size, hidden_size, num_layers)
rnn_out, _ = rnn(input_tensor, h0)
gru_out, _ = gru(input_tensor, h0)
lstm_out, _ = lstm(input_tensor, (h0, c0))
print(rnn_out.shape, gru_out.shape, lstm_out.shape)
Cite as
@misc{ComplexNN,
title={ComplexNN: Complex Neural Network Modules},
author={Xinyuan Liao},
Url= {https://github.com/XinyuanLiao/ComplexNN},
year={2023}
}
Reference
[1] Orvieto, Antonio, et al. "Resurrecting recurrent neural networks for long sequences." arXiv preprint arXiv:2303.06349 (2023).
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.