Skip to main content

high level interface for tinygrad

Project description

Tinygrad Lightning - WIP

Pytorch Lightning clone for tinygrad. Easy data loading, training, logging and checkpointing.

Example

import tinygrad_lightning as pl

### model ###

class TinyBobNet(pl.LightningModule):
    def __init__(self, filters=64):
        self.model = ResNet18(num_classes=10)

    def forward(self, input: Tensor):
        return self.model(input)

    def configure_optimizers(self):
        return optim.SGD(optim.get_parameters(self), lr=5e-3, momentum=0.9)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch

        for image in x:
            self.log_image("inputs", image)

        out = self.forward(x)

        cat = np.argmax(out.cpu().numpy(), axis=-1)
        accuracy = (cat == y).mean()

        loss = sparse_categorical_crossentropy(out, y)
        loss_value = loss.detach().cpu().numpy()

        # automatically logs to train/loss, ...
        self.log("loss", loss_value.mean())
        self.log("accuracy", accuracy)

        return loss

    def validation_step(self, val_batch, val_idx):
        x, y = val_batch
        out = self.forward(x)

        cat = np.argmax(out.cpu().numpy(), axis=-1)
        accuracy = (cat == y).mean()

        loss = sparse_categorical_crossentropy(out, y)
        loss_value = loss.detach().cpu().numpy()

        # automatically logs to val/loss, ...
        self.log("loss", loss_value.mean())
        self.log("accuracy", accuracy)

        return loss

batch_size = 4

test_ds = MnistDataset(variant='test') # same as torch dataset
train_loader = pl.DataLoader(train_ds, batch_size, workers=1, shuffle=True)

# define your model
model = TinyBobNet()
callbacks=[pl.TQDMProgressBar(refresh_rate=10), pl.TensorboardLogger("./logdir")]

trainer = pl.Trainer(model, train_loader=train_loader, callbacks=callbacks)
trainer.fit(epochs=1) # train_batches=2, val_batches=4

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tinygrad_lightning-0.0.1.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

tinygrad_lightning-0.0.1-py3-none-any.whl (9.4 kB view details)

Uploaded Python 3

File details

Details for the file tinygrad_lightning-0.0.1.tar.gz.

File metadata

  • Download URL: tinygrad_lightning-0.0.1.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for tinygrad_lightning-0.0.1.tar.gz
Algorithm Hash digest
SHA256 765e177141ae3d6fe71a826b601e19cb69d49b15c999a5761c9897c8cb19fff9
MD5 ef7f621c1647a64acc955258f80c69f8
BLAKE2b-256 bd15d0ff0ce512ad9f52ffc6d5a9b8abc1c00d5c7eedcf529272e80bd029a9a9

See more details on using hashes here.

File details

Details for the file tinygrad_lightning-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for tinygrad_lightning-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4f952659e0830df3c561724dd4b75ee86496821b668ad658a6eddfa4b238be10
MD5 1fbf55a3b606e96942e989e27dd03713
BLAKE2b-256 2543c5046a678db116c31f977c78d6a1b374113e05955e20416fd3038e1f75e3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page