Skip to main content

A toolbox for using complex valued standard network modules in PyTorch.

Project description

ComplexNN: Complex Neural Network Modules

Static Badge Static Badge Static Badge GitHub all releases

Since subsequent versions of PyTorch support matrix operations and gradient descent on complex parameters, this repository provides the latest complex form of some standard Pytorch network modules. Compared with utilizing two sets of parameters to represent the real and imaginary parts of the network parameters respectively, directly utilizing complex numbers as network parameters will halve the number of trainable parameters, which results in faster training speed.

Install

pip install complexNN

Module

The complex form modules include

complexLinear complexRNNcell complexActivation complexFunction
Linear RNN Cell Relu BatchNorm1d
GRU Cell Gelu LayerNorm1d
LSTM Cell Tanh
Sigmoid

Other modules will be considered for updates in the future.

Examples

Multilayer perceptron

import torch.nn as nn
from complexActivation import complexTanh
from complexLinear import complexLinear


class complexMLP(nn.Module):
    """
    Complex Multilayer Perceptron
    """

    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(complexMLP, self).__init__()
        self.num_layers = num_layers
        self.input_layer = complexLinear(input_size, hidden_size)
        self.hidden_layers = nn.ModuleList([complexLinear(hidden_size, hidden_size) for _ in range(num_layers - 1)])
        self.output_layer = complexLinear(hidden_size, output_size)

    def forward(self, x):
        x = complexTanh(self.input_layer(x))
        for i in range(self.layer_num - 1):
            x = complexTanh(self.hidden_layers[i](x))
        output = self.output_layer(x)
        return output

Recurrent neural network

import torch.nn as nn
from compledRNNcell import complexRNNCell


class complexRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(complexRNN, self).__init__()
        self.num_layers = num_layers
        self.rnn_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.rnn_layers.append(complexRNNCell(input_size, hidden_size))
            input_size = hidden_size

    def forward(self, x, h_0):
        h_prev = h_0
        for i in range(self.laryer_num):
            h_prev = self.rnn_layers[i](x, h_prev)
        return h_prev

Cite as

@misc{ComplexNN,
      title={ComplexNN: Complex Neural Network Modules},
      author={Xinyuan Liao},
      Url= {https://github.com/XinyuanLiao/ComplexNN}, 
      year={2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

complexNN-0.0.1.tar.gz (6.5 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page