Skip to main content

Minimal Flax implementation of FNet: Mixing Tokens with Fourier Transforms

Project description

In this paper, the authors introduce FNet, a new architecture inspired by Transformers where the learnable Self-Attention layer has been replaced by a unlearnable Fourier Transform. This work highlights the potential of linear units as a drop-in replacement for the attention mechanism. They particularly found Fourier Transforms to be an effective mixing mechanism, in part due to the highly efficient FFT. Remarkably, this unparameterized mixing mechanism can yield relatively competitive models. This work among many is particularly important for deploying transformer like models to low-compute scenarios.

Installation

You can install this package from PyPI:

pip install fnet-flax

Or directly from GitHub:

pip install --upgrade git+hhttps://github.com/SauravMaheshkar/FNet-Flax.git

Usage

import numpy as np
from jax import random
from fnet_flax import FNet

x = np.random.randn(2, 8, 32)
init_rngs = {"params": random.PRNGKey(0), "dropout": random.PRNGKey(1)}
model = FNet(depth=2, dim=32).init(init_rngs, x)

Development

1. Conda Approach

conda env create --name <env-name> sauravmaheshkar/fnet
conda activate <env-name>

2. Docker Approach

docker pull ghcr.io/sauravmaheshkar/fnet-dev:latest
docker run -it -d --name <container_name> ghcr.io/sauravmaheshkar/fnet-dev

Use the Remote Containers Extension in VSCode and attach to the running container. The code resides in the code/ dir.

Alternatively you can also download the image from Docker Hub.

docker pull sauravmaheshkar/fnet-dev

Citations

@misc{leethorp2021fnet,
      title={FNet: Mixing Tokens with Fourier Transforms},
      author={James Lee-Thorp and Joshua Ainslie and Ilya Eckstein and Santiago Ontanon},
      year={2021},
      eprint={2105.03824},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fnet-flax-0.0.1.tar.gz (4.3 kB view hashes)

Uploaded source

Built Distribution

fnet_flax-0.0.1-py2.py3-none-any.whl (4.1 kB view hashes)

Uploaded py2 py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page