Skip to main content

Lightning AI Model hub.

Project description

Save, share and host AI model checkpoints Lightning fast ⚡

Lightning

Save, load, host, and share models without slowing down training. LitModels minimizes training slowdowns from checkpoint saving. Share public links on Lightning AI or your own cloud with enterprise-grade access controls.

✅ Checkpoint without slowing training.  ✅ Granular access controls.           
✅ Load models anywhere.                 ✅ Host on Lightning or your own cloud.

Discord CI testing Cloud integration codecov license

Quick start

Install LitModels via pip:

pip install litmodels

Toy example (see real examples):

import litmodels as lm
import torch

# save a model
model = torch.nn.Module()
lm.save_model(model=model, name="model-name")

# load a model
model = lm.load_model(name="model-name")

Examples

PyTorch

Save model:

import torch
from litmodels import save_model

model = torch.nn.Module()
save_model(model=model, name="your_org/your_team/torch-model")

Load model:

from litmodels import load_model

model_ = load_model(name="your_org/your_team/torch-model")
PyTorch Lightning

Save model:

from lightning import Trainer
from litmodels import upload_model
from litmodels.demos import BoringModel

# Configure Lightning Trainer
trainer = Trainer(max_epochs=2)
# Define the model and train it
trainer.fit(BoringModel())

# Upload the best model to cloud storage
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
# Define the model name - this should be unique to your model
upload_model(model=checkpoint_path, name="<organization>/<teamspace>/<model-name>")

Load model:

from lightning import Trainer
from litmodels import download_model
from litmodels.demos import BoringModel

# Load the model from cloud storage
checkpoint_path = download_model(
    # Define the model name and version - this needs to be unique to your model
    name="<organization>/<teamspace>/<model-name>:<model-version>",
    download_dir="my_models",
)
print(f"model: {checkpoint_path}")

# Train the model with extended training period
trainer = Trainer(max_epochs=4)
trainer.fit(BoringModel(), ckpt_path=checkpoint_path)
TensorFlow / Keras

Save model:

from tensorflow import keras

from litmodels import save_model

# Define the model
model = keras.Sequential(
    [
        keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
        keras.layers.Dense(10, name="dense_2"),
    ]
)

# Compile the model
model.compile(optimizer="adam", loss="categorical_crossentropy")

# Save the model
save_model("lightning-ai/jirka/sample-tf-keras-model", model=model)

Load model:

from litmodels import load_model

model_ = load_model(
    "lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model"
)
SKLearn

Save model:

from sklearn import datasets, model_selection, svm
from litmodels import save_model

# Load example dataset
iris = datasets.load_iris()
X, y = iris.data, iris.target

# Split dataset into training and test sets
X_train, X_test, y_train, y_test = model_selection.train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train a simple SVC model
model = svm.SVC()
model.fit(X_train, y_train)

# Upload the saved model using litmodels
save_model(model=model, name="your_org/your_team/sklearn-svm-model")

Use model:

from litmodels import load_model

# Download and load the model file from cloud storage
model = load_model(
    name="your_org/your_team/sklearn-svm-model", download_dir="my_models"
)

# Example: run inference with the loaded model
sample_input = [[5.1, 3.5, 1.4, 0.2]]
prediction = model.predict(sample_input)
print(f"Prediction: {prediction}")

Features

PyTorch Lightning Callback

Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch.

import torch.utils.data as data
import torchvision as tv
from lightning import Trainer
from litmodels.integrations import LightningModelCheckpoint
from litmodels.demos import BoringModel

dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

trainer = Trainer(
    max_epochs=2,
    callbacks=[
        LightningModelCheckpoint(
            # Define the model name - this should be unique to your model
            model_registry="<organization>/<teamspace>/<model-name>",
        )
    ],
)
trainer.fit(
    BoringModel(),
    data.DataLoader(train, batch_size=256),
    data.DataLoader(val, batch_size=256),
)
Save any Python class as a checkpoint

Mixin classes streamline model management in Python by modularizing reusable functionalities like saving/loading, enabling consistent, conflict-free, and maintainable code across multiple models.

Save model:

from litmodels.integrations.mixins import PickleRegistryMixin


class MyModel(PickleRegistryMixin):
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2
        # Your model initialization code
        ...


# Create and push a model instance
model = MyModel(param1=42, param2="hello")
model.upload_model(name="my-org/my-team/my-model")

Load model:

loaded_model = MyModel.download_model(name="my-org/my-team/my-model")
Save custom PyTorch models

Mixin classes centralize serialization logic, eliminating redundant code and ensuring consistent, error-free model persistence across projects. The download_model method bypasses constructor arguments entirely, reconstructing the model directly from the registry with pre-configured architecture and weights, eliminating initialization mismatches.

Save model:

import torch
from litmodels.integrations.mixins import PyTorchRegistryMixin


# Important: PyTorchRegistryMixin must be first in the inheritance order
class MyTorchModel(PyTorchRegistryMixin, torch.nn.Module):
    def __init__(self, input_size, hidden_size=128):
        super().__init__()
        self.linear = torch.nn.Linear(input_size, hidden_size)
        self.activation = torch.nn.ReLU()

    def forward(self, x):
        return self.activation(self.linear(x))


# Create and push the model
model = MyTorchModel(input_size=784)
model.upload_model(name="my-org/my-team/torch-model")

Use the model:

# Pull the model with the same architecture
loaded_model = MyTorchModel.download_model(name="my-org/my-team/torch-model")

Performance

Community

💬 Get help on Discord
📋 License: Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

litmodels-0.1.8.tar.gz (24.1 kB view details)

Uploaded Source

Built Distribution

litmodels-0.1.8-py3-none-any.whl (22.6 kB view details)

Uploaded Python 3

File details

Details for the file litmodels-0.1.8.tar.gz.

File metadata

  • Download URL: litmodels-0.1.8.tar.gz
  • Upload date:
  • Size: 24.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for litmodels-0.1.8.tar.gz
Algorithm Hash digest
SHA256 f8f892966142089163ec01fde9f25a214be9b5abefe65c41141fd16b6ad5247b
MD5 25fdf5cf260e85496efcc18b3ba1a968
BLAKE2b-256 c2b82d365efd0b8db5db0b4626ddb702955e402f83fa03c904d1bf263ef3d06e

See more details on using hashes here.

File details

Details for the file litmodels-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: litmodels-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 22.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for litmodels-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 58cf7ab25cdcd63d191d69dfb1b8f53881a131b2169be1a027f589443e6b7061
MD5 a543cf45da959737c0aebc6c2dbefa32
BLAKE2b-256 560605eca676cb77d1de55dc131f2f63789f986299a570ab80409e3f7f29dc78

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page