A toolbox for using complex valued standard network modules in PyTorch.
Project description
Since the latest versions of PyTorch support matrix operations and gradient descent on plural parameters, this repository provides the latest plural form of some standard PyTorch network modules. Compared with utilizing two sets of parameters to represent the real and imaginary parts of the network plural parameters respectively, directly utilizing complex numbers as network parameters will halve the number of trainable parameters, which speeds up the training process.
Install
To install complexNN for the first time:
pip install complexNN
To upgrade a previous installation of complexNN to the most recent version:
pip install --upgrade complexNN
Versions
v0.0.1
Provides the plural form of the base standard PyTorch network module.
v0.1.1
Adds support for the Linear Recurrent Unit (LRU).
v0.1.2
Bug fixed. Adds support for BatchNorm2d, and BatchNorm3d.
Modules
The complex form modules include
complexLayer | complexRNNcell | complexActivation | complexFunction |
---|---|---|---|
Linear | RNN Cell | Relu | BatchNorm |
GRU Cell | Gelu | LayerNorm | |
LSTM Cell | Tanh | Dropout | |
LRU Cell [1] | Sigmoid |
Note that the native version of torch.nn.Dropout
is supported:exclamation::exclamation: Other modules will be considered for updates in the future.
Examples
Multilayer perceptron
import torch.nn as nn
from complexNN.complexActivation import complexTanh
from complexNN.complexLinear import complexLinear
class complexMLP(nn.Module):
"""
Complex Multilayer Perceptron
"""
def __init__(self, input_size, hidden_size, output_size, num_layers):
super(complexMLP, self).__init__()
self.num_layers = num_layers
self.input_layer = complexLinear(input_size, hidden_size)
self.hidden_layers = nn.ModuleList([complexLinear(hidden_size, hidden_size) for _ in range(num_layers - 1)])
self.output_layer = complexLinear(hidden_size, output_size)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = complexTanh(self.input_layer(x))
x = self.dropout(x)
for i in range(self.layer_num - 1):
x = complexTanh(self.hidden_layers[i](x))
x = self.dropout(x)
output = self.output_layer(x)
return output
Recurrent neural network
import torch.nn as nn
from complexNN.complexRNNcell import complexRNNCell
class complexRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(complexRNN, self).__init__()
self.num_layers = num_layers
self.rnn_layers = nn.ModuleList()
for _ in range(num_layers):
self.rnn_layers.append(complexRNNCell(input_size, hidden_size))
input_size = hidden_size
def forward(self, x, h_0):
h_prev = h_0
for i in range(self.laryer_num):
h_prev = self.rnn_layers[i](x, h_prev)
return h_prev
Cite as
@misc{ComplexNN,
title={ComplexNN: Complex Neural Network Modules},
author={Xinyuan Liao},
Url= {https://github.com/XinyuanLiao/ComplexNN},
year={2023}
}
Reference
[1] Orvieto, Antonio, et al. "Resurrecting recurrent neural networks for long sequences." arXiv preprint arXiv:2303.06349 (2023).
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.