A ML toolkit with code bits useful for our day to day research
Project description
ASTRA
"AI for Sustainability" Toolkit for Research and Analysis. ASTRA (अस्त्र) means "a weapon" in Sanskrit, Hindi and a few other Bharatiya languages.
Install
pip install astra-lib
Contributing
Please go through the contributing guidelines before making a contribution.
Useful Code Snippets
Data
Load Data
from astra.torch.data import load_mnist, load_cifar_10
ds, ds_name = load_cifar_10()
Models
MLPs
from astra.torch.models import MLP
mlp = MLP(input_dim=100, hidden_dims=[128, 64], output_dim=10, activation="relu", dropout=0.1)
CNNs
from astra.torch.models import CNN
cnn = CNN(image_dim=32,
kernel_size=5,
n_channels=3,
conv_hidden_dims=[32, 64],
dense_hidden_dims=[128, 64],
output_dim=10)
EfficientNets
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.models import EfficientNet
model = EfficientNet(efficientnet_b0, EfficientNet_B0_Weights.DEFAULT, output_dim=10)
ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.models import ViT
model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
Training
Quick train a model
from astra.torch.utils import train_fn
result = train_fn(model, inputs, outputs, loss_fn, lr, n_epochs, batch_size, enable_tqdm=True)
print(result.keys()) # dict_keys(['epoch_losses', 'iter_losses'])
Adhoc
Count number of parameters in a model
from astra.torch.utils import count_params
n_params = count_params(mlp)
Flatten/Unflatten the weights of a model
import torch
from astra.torch.models import ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.utils import ravel_pytree
import optree
model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
params = dict(model.named_parameters())
flat_params, unravel_fn = ravel_pytree(params)
unraveled_params = unravel_fn(flat_params) # returns the original params
# check if the tree structure is preserved
assert optree.tree_structure(params) == optree.tree_structure(unraveled_params)
# check if the values are preserved
for before_leaf, after_leaf in zip(optree.tree_leaves(params), optree.tree_leaves(unraveled_params)):
assert torch.all(before_leaf == after_leaf)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
astra-lib-0.0.1.tar.gz
(15.9 kB
view details)
File details
Details for the file astra-lib-0.0.1.tar.gz
.
File metadata
- Download URL: astra-lib-0.0.1.tar.gz
- Upload date:
- Size: 15.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.12.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f4d0ce7551bfa3f5c73c2a5efa286e40e1b295936f91752c570bcb937015f62a |
|
MD5 | 54821bd4c1dd745b1933864291550d82 |
|
BLAKE2b-256 | 98063d1335385ecb423aaa72b1c419747e53c6bce717830f5b7f30c8139464dd |