Skip to main content

Objax is a machine learning framework that provides an Object Oriented layer for JAX.

Project description

Objax

This is not an officially supported Google product.

Objax is an object-oriented (OO) for Jax. Objax design strives for simplicity and flexibility with the goal of facilitating experimentation and research in machine learning. Objax's intent is that its code should be easily understandable and forkable, as such it is particularly aimed at students and researchers.

This the developer repository of Objax, there is very little user documentation here, for the full documentation go to coming soon.

You can find more information on:

User installation guide

Installing is done using pip with the following command:

pip install --upgrade objax

For GPU support, we assume you have already some version of CUDA installed. Here are the extra steps:

# Specify your installed CUDA version.
CUDA_VERSION=11.0
pip install --upgrade https://storage.googleapis.com/jax-releases/cuda`echo $CUDA_VERSION | sed s:\\\.::g`/jaxlib-`python3 -c 'import jaxlib; print(jaxlib.__version__)'`-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl

Useful shell configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Testing your installation

You can run a few basic operations:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal((100, 4))
m = objax.nn.Linear(4, 5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal((100, 3, 32, 32))
m = objax.nn.Conv2D(3, 4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Installing examples

Clone the code repository:

git clone https://github.com/google/objax.git
cd objax/examples

Developer installation guide

For developing purpose we recommend using virtualenv. Using Ubuntu or a similar Linux distribution, the setup is as follows:

# Install virtualenv if you haven't done so already
sudo apt install python3-dev python3-virtualenv python3-tk imagemagick virtualenv pandoc
# Create a virtual environment (for example ~/jax3, you can use your name here)
virtualenv -p python3 --system-site-packages ~/jax3
# Start the virtual environment
. ~/jax3/bin/activate

# Clone objax git repository.
git clone https://github.com/google/objax.git
cd objax

# Install python dependencies.
pip install --upgrade -r requirements.txt
pip install --upgrade -r docs/requirements.txt
pip install --upgrade -r examples/requirements.txt

# If you have CUDA installed, specify your installed CUDA version.
CUDA_VERSION=11.0
pip install --upgrade https://storage.googleapis.com/jax-releases/cuda`echo $CUDA_VERSION | sed s:\\\.::g`/jaxlib-`python3 -c 'import jaxlib; print(jaxlib.__version__)'`-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl

It is required for the current folder to be in PYTHONPATH. This can be done with the following command:

export PYTHONPATH=$PYTHONPATH:.

Useful shell configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Testing

./tests/run.sh

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

objax-1.0.1.tar.gz (33.3 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page