Skip to main content

A ML toolkit with code bits useful for our day to day research

Project description

ASTRA

"AI for Sustainability" Toolkit for Research and Analysis. ASTRA (अस्त्र) means a "tool" or "a weapon" in Sanskrit.

Python version CI Coverage Status

Install

Stable version:

pip install astra-lib

Latest version:

pip install git+https://github.com/sustainability-lab/ASTRA

Contributing

Please go through the contributing guidelines before making a contribution.

Useful Code Snippets

Data

Load Data

from astra.torch.data import load_mnist, load_cifar_10
ds, ds_name = load_cifar_10()

Models

MLPs

from astra.torch.models import MLP

mlp = MLP(input_dim=100, hidden_dims=[128, 64], output_dim=10, activation="relu", dropout=0.1)

CNNs

from astra.torch.models import CNN
cnn = CNN(image_dim=32, 
          kernel_size=5, 
          n_channels=3, 
          conv_hidden_dims=[32, 64], 
          dense_hidden_dims=[128, 64], 
          output_dim=10)

EfficientNets

from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.models import EfficientNet

model = EfficientNet(efficientnet_b0, EfficientNet_B0_Weights.DEFAULT, output_dim=10)

ViT

from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.models import ViT

model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)

Training

Quick train a model

from astra.torch.utils import train_fn
result = train_fn(model, inputs, outputs, loss_fn, lr, n_epochs, batch_size, enable_tqdm=True)
print(result.keys()) # dict_keys(['epoch_losses', 'iter_losses'])

Adhoc

Count number of parameters in a model

from astra.torch.utils import count_params
n_params = count_params(mlp)

Flatten/Unflatten the weights of a model

import torch
from astra.torch.models import ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.utils import ravel_pytree
import optree

model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
params = dict(model.named_parameters())

flat_params, unravel_fn = ravel_pytree(params)
unraveled_params = unravel_fn(flat_params) # returns the original params

# check if the tree structure is preserved
assert optree.tree_structure(params) == optree.tree_structure(unraveled_params)

# check if the values are preserved
for before_leaf, after_leaf in zip(optree.tree_leaves(params), optree.tree_leaves(unraveled_params)):
    assert torch.all(before_leaf == after_leaf)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

astra-lib-0.0.2.tar.gz (136.6 kB view details)

Uploaded Source

File details

Details for the file astra-lib-0.0.2.tar.gz.

File metadata

  • Download URL: astra-lib-0.0.2.tar.gz
  • Upload date:
  • Size: 136.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.0

File hashes

Hashes for astra-lib-0.0.2.tar.gz
Algorithm Hash digest
SHA256 b9cbf6af2b30b32c420a955820368adda8406ab8aaa54edd31aa469a48ef0494
MD5 c06e2cf264b4446672d2cc1627b323b4
BLAKE2b-256 a4aa5b95b9500842a66faafc47d4726cfdc45fe1857f73d19e8e5efa9f5cb3f9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page