Skip to main content

Unofficial JAX implementations of deep learning research papers

Project description

JAX Models

license-shield

release-shield

python-shield

code-style

Table of Contents
    <li>
    
      <a href="#about-the-project">About The Project</a>
    
    </li>
    
    <li>
    
      <a href="#getting-started">Getting Started</a>
    
      <ul>
    
        <li><a href="#prerequisites">Prerequisites</a></li>
    
        <li><a href="#installation">Installation</a></li>
    
        <li><a href="#usage">Usage</a></li>
    
      </ul>
    
    </li>
    
    <li><a href="#contributing">Contributing</a></li>
    
    <li><a href="#license">License</a></li>
    
    <li><a href="#contact">Contact</a></li>
    

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:

  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)

  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)

  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)

  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)

  5. Patches Are All You Need (Anonymous et al., 2021)

  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)

  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)

  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)

  9. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows (Ze Liu et al., 2021)


Available layers for out-of-the-box integration:

  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)

  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)

  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models.models.model_registry import list_models

from pprint import pprint



pprint(list_models())

To load your desired model:

from jax_models.models.model_registry import load_model

load_model('mpvit-base', attach_head=True, num_classes=1000, dropout=0.1)

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.


License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_models-0.0.3.tar.gz (25.6 kB view details)

Uploaded Source

Built Distribution

jax_models-0.0.3-py3-none-any.whl (31.7 kB view details)

Uploaded Python 3

File details

Details for the file jax_models-0.0.3.tar.gz.

File metadata

  • Download URL: jax_models-0.0.3.tar.gz
  • Upload date:
  • Size: 25.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for jax_models-0.0.3.tar.gz
Algorithm Hash digest
SHA256 43dc86a39d69b9c6f7189e3d9973a5a48d1e0a416ac9e2f1cd7eef9c15064fa3
MD5 ab64629f636050d8c065ca35b0a9461e
BLAKE2b-256 e76a3fbf5cdd45e56d61de523d42955342c9bdd8a0bc674917660098223af091

See more details on using hashes here.

File details

Details for the file jax_models-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: jax_models-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 31.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for jax_models-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 bc9f7e04189efbd4f38c44ddf7059027bd126818f85a54e678c056b7f8dc1bbf
MD5 a293522e7f520afafd471cea2af11a07
BLAKE2b-256 40c944d3da7a8531067afa0635a03caa147969213395307671a7b9650aff492e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page