Root package info.
Project description
Eqxvision
Eqxvision is a package of popular computer vision model architectures built using Equinox.
Installation
Use the package manager pip to install eqxvision.
pip install eqxvision
requires: python>=3.7
Usage
???+ Example Importing and doing a forward pass is as simple as ```python import jax import jax.random as jr import equinox as eqx from eqxvision.models import alexnet
@eqx.filter_jit
def forward(net, images, key):
keys = jax.random.split(key, images.shape[0])
output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
...
net = alexnet(num_classes=1000)
images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
```
What's New?
-
[Experimental]
Now supports loading PyTorch weights fromtorchvision
for models without BatchNorm!!! note Due to slight differences in the implementation of underlying operations, the output can differ for pretrained versions of the network.
Tips
- Better to use
@equinox.jit_filter
instead of@jax.jit
- Advisable to use
jax.{v,p}map
withaxis_name='batch'
for all models - Don't forget to switch to
inference
mode for evaluations - Wrap with
eqx.filter(net, eqx.is_array)
forOptax
initialisation.
Contributing
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Please make sure to update tests as appropriate.
Acknowledgements
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for eqxvision-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 94952d8be744a22495fbb9bb3b67fb580cdd7bb9e9c96c6aa91fe77b632f7264 |
|
MD5 | 5d618a06fe86fbdf284d4d5973d4b0e6 |
|
BLAKE2b-256 | 7d29db7647c0637d0ea906a91ef08bf8d7f2cf296d4237e095b6ce4bc5a62302 |