Skip to main content

A ML toolkit with code bits useful for our day to day research

Project description

ASTRA

"AI for Sustainability" Toolkit for Research and Analysis. ASTRA (अस्त्र) means "a weapon" in Sanskrit, Hindi and a few other Bharatiya languages.

Python version CI Coverage Status

Install

pip install astra-lib

Contributing

Please go through the contributing guidelines before making a contribution.

Useful Code Snippets

Data

Load Data

from astra.torch.data import load_mnist, load_cifar_10
ds, ds_name = load_cifar_10()

Models

MLPs

from astra.torch.models import MLP

mlp = MLP(input_dim=100, hidden_dims=[128, 64], output_dim=10, activation="relu", dropout=0.1)

CNNs

from astra.torch.models import CNN
cnn = CNN(image_dim=32, 
          kernel_size=5, 
          n_channels=3, 
          conv_hidden_dims=[32, 64], 
          dense_hidden_dims=[128, 64], 
          output_dim=10)

EfficientNets

from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.models import EfficientNet

model = EfficientNet(efficientnet_b0, EfficientNet_B0_Weights.DEFAULT, output_dim=10)

ViT

from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.models import ViT

model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)

Training

Quick train a model

from astra.torch.utils import train_fn
result = train_fn(model, inputs, outputs, loss_fn, lr, n_epochs, batch_size, enable_tqdm=True)
print(result.keys()) # dict_keys(['epoch_losses', 'iter_losses'])

Adhoc

Count number of parameters in a model

from astra.torch.utils import count_params
n_params = count_params(mlp)

Flatten/Unflatten the weights of a model

import torch
from astra.torch.models import ViT
from torchvision.models import vit_b_16, ViT_B_16_Weights
from astra.torch.utils import ravel_pytree
import optree

model = ViT(vit_b_16, ViT_B_16_Weights.DEFAULT, output_dim=10)
params = dict(model.named_parameters())

flat_params, unravel_fn = ravel_pytree(params)
unraveled_params = unravel_fn(flat_params) # returns the original params

# check if the tree structure is preserved
assert optree.tree_structure(params) == optree.tree_structure(unraveled_params)

# check if the values are preserved
for before_leaf, after_leaf in zip(optree.tree_leaves(params), optree.tree_leaves(unraveled_params)):
    assert torch.all(before_leaf == after_leaf)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

astra-lib-0.0.1.tar.gz (15.9 kB view details)

Uploaded Source

File details

Details for the file astra-lib-0.0.1.tar.gz.

File metadata

  • Download URL: astra-lib-0.0.1.tar.gz
  • Upload date:
  • Size: 15.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.0

File hashes

Hashes for astra-lib-0.0.1.tar.gz
Algorithm Hash digest
SHA256 f4d0ce7551bfa3f5c73c2a5efa286e40e1b295936f91752c570bcb937015f62a
MD5 54821bd4c1dd745b1933864291550d82
BLAKE2b-256 98063d1335385ecb423aaa72b1c419747e53c6bce717830f5b7f30c8139464dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page