Skip to main content

Client SDK for BastionAI Confidential AI Training.

Project description

BastionAI Client

BastionAI Client is a python library to create client applications for BastionAI Server (Mithril Security's confidential training server).

If you wish to know more about BastionAI, please have a look to the project Github repository.

Installation

Using pip

$ pip install bastionai

Local installation

Note: It is preferrable to install BastionAI package in a virtual environment.

Execute the following command to install BastionAI locally.

pip install -e .

Usage

Uploading a model and datasets to BastionAI

The snippet below sets up a very simple linear regression model and dataset to train the model with.

import torch
from bastionai.utils import TensorDataset  
from torch.nn import Module
from bastionai.psg.nn import Linear  
from torch.utils.data import DataLoader

class LReg(Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = Linear(1, 1, 2)

    def forward(self, x: Tensor) -> Tensor:
        return self.fc1(x)

lreg_model = LReg()

X = torch.tensor([[0.0], [1.0], [0.5], [0.2]])
Y = torch.tensor([[0.0], [2.0], [1.0], [0.4]])
train_dataset = TensorDataset([X], Y)
train_dataloader = DataLoader(train_dataset, batch_size=2)

X = torch.tensor([[0.1], [-1.0]])
Y = torch.tensor([[0.2], [-2.0]])
test_dataset = TensorDataset([X], Y)
test_dataloader = DataLoader(test_dataset, batch_size=2)

Training a model on BastionAI

With this snippet below, BastionAI is used to securely and remotely train the model.

The model, along with the training and testing datasets are uploaded to BastionAI through the client API.

from bastionai.client import Connection, SGD  


with Connection("localhost", 50051, default_secret=b"") as client:
    remote_dataloader = client.RemoteDataLoader(
        train_dataloader,
        test_dataloader,
        "Dummy 1D Linear Regression Dataset (param is 2)",
    )
    remote_learner = client.RemoteLearner(
        lreg_model,
        remote_dataloader,
        metric="l2",
        optimizer=SGD(lr=0.1),
        model_description="1D Linear Regression Model",
        expand=False,
    )

    remote_learner.fit(nb_epochs=100, eps=100.0)

    lreg_model = remote_learner.get_model() # Gets trained model from BastionAI server.

Contributing

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

License

This project is licensed under Apache 2.0 License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

bastionai-0.2.1-py3-none-any.whl (25.4 kB view details)

Uploaded Python 3

File details

Details for the file bastionai-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: bastionai-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 25.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for bastionai-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ad4ce9126e0b70d8ca8dddb71971d49fd26df7c142adb059a613707b79a22879
MD5 f6e886d56d8daf1ac1d7c2f31196ce5b
BLAKE2b-256 2ad92e094a42b62b784ae305b15d57e0a3945019131703baf672390a8956452b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page