A ML toolkit with code bits useful for our day to day research
Project description
ASTRA
"AI for Sustainability" Toolkit for Research and Analysis. ASTRA (अस्त्र) means a "tool" or "a weapon" in Sanskrit.
Install
Stable version:
pip install astra-lib
Latest version:
pip install git+https://github.com/sustainability-lab/ASTRA
Contributing
Please go through the contributing guidelines before making a contribution.
Useful Code Snippets
Data
Load Data
from astra.torch.data import load_mnist, load_cifar_10
ds, ds_name = load_cifar_10()
Models
MLPs
from astra.torch.models import MLP
mlp = MLP(input_dim=100, hidden_dims=[128, 64], output_dim=10, activation="relu", dropout=0.1)
CNNs
from astra.torch.models import CNN
cnn = CNN(image_dim=32,
kernel_size=5,
n_channels=3,
conv_hidden_dims=[32, 64],
dense_hidden_dims=[128, 64],
output_dim=10)
EfficientNets
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.models import EfficientNet
model = EfficientNet(efficientnet_b0, EfficientNet_B0_Weights.DEFAULT, output_dim=10)
ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.models import ViT
model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
Training
Quick train a model
from astra.torch.utils import train_fn
result = train_fn(model, inputs, outputs, loss_fn, lr, n_epochs, batch_size, enable_tqdm=True)
print(result.keys()) # dict_keys(['epoch_losses', 'iter_losses'])
Adhoc
Count number of parameters in a model
from astra.torch.utils import count_params
n_params = count_params(mlp)
Flatten/Unflatten the weights of a model
import torch
from astra.torch.models import ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.utils import ravel_pytree
import optree
model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
params = dict(model.named_parameters())
flat_params, unravel_fn = ravel_pytree(params)
unraveled_params = unravel_fn(flat_params) # returns the original params
# check if the tree structure is preserved
assert optree.tree_structure(params) == optree.tree_structure(unraveled_params)
# check if the values are preserved
for before_leaf, after_leaf in zip(optree.tree_leaves(params), optree.tree_leaves(unraveled_params)):
assert torch.all(before_leaf == after_leaf)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
astra-lib-0.0.2.tar.gz
(136.6 kB
view details)
File details
Details for the file astra-lib-0.0.2.tar.gz
.
File metadata
- Download URL: astra-lib-0.0.2.tar.gz
- Upload date:
- Size: 136.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.12.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b9cbf6af2b30b32c420a955820368adda8406ab8aaa54edd31aa469a48ef0494 |
|
MD5 | c06e2cf264b4446672d2cc1627b323b4 |
|
BLAKE2b-256 | a4aa5b95b9500842a66faafc47d4726cfdc45fe1857f73d19e8e5efa9f5cb3f9 |