Skip to main content

pytorch-lightning tutorial

Project description

⚡ lightning-tutorial

PyPI pyversions PyPI version Code style: black

Installation of the partner package

pip install lightning-tutorial

Table of contents

PyTorch Datasets and DataLoaders

Key module: torch.utils.data.Dataset

The Dataset module is an overwritable python module. You can modify it at will as long as you maintain the following three class methods:

  1. __init__
  2. __len__
  3. __getitem__

These are name-specific handles used by torch under the hood when passing data through a model.

from torch.utils.data import Dataset

class TurtleData(Dataset):
    def __init__(self):
        """
        here we should pass requisite arguments
        that enable __len__() and __getitem__()
        """
        
    def __len__(self):
        """
        Returns the length/size/# of samples in the dataset.
        e.g., a 20,000 cell dataset would return `20_000`.
        """
        return # len
    
    def __getitem__(self, idx):
        """
        Subset and return a batch of the data.
        
        `idx` is the batch index (# of idx values = batch size). 
        Maximum `idx` passed is <= `self.__len__()`
        """
        return # sampled data

Key module: torch.utils.data.DataLoader

Similar to the usefulness of AnnData, the Dataset module creates a base unit for distributing and handling data. We can then take advantage of several torch built-ins to enable not only more organized, but faster data processing.

from torch.utils.data import DataLoader

dataset = TurtleData()
data_size = dataset.__len__()
print(data_size)
20_000

Other essential functions

from torch.utils.data import random_split

train_dataset, val_dataset = random_split(dataset, [18_000, 2_000])

# this can then be fed to a DataLoader, as above
train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)

Useful tutorials and documentation

☝️ back to table of contents

Single-cell data structures meet pytorch: torch-adata

torch-adata-logo

Create pytorch Datasets from AnnData

Installation

  • Note: This is already done for you, if you've installed this tutorials associated package
pip install torch-adata

torch-adata-concept-overview

Example use of the base class

The base class, AnnDataset is a subclass of the widely-used torch.utils.data.Dataset.

import anndata as a
import torch_adata

adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.AnnDataset(adata)

Returns sampled data X_batch as a torch.Tensor.

# create a dummy index
idx = np.random.choice(range(dataset.__len__()), 5)
X_batch = dataset.__getitem__(idx)

TimeResolvedAnnDataset

Specialized class for time-resolved datasets. A subclass of the class, AnnDataset.

import anndata as a
import torch_adata as ta

adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.TimeResolvedAnnDataset(adata, time_key="Time point")

☝️ back to table of contents

Lightning basics and the LightningModule

from pytorch_lightning imoport LightningModule

class YourSOTAModel(LightningModule):
    def __init__(self,
                 net,
                 optimizer_kwargs={"lr":1e-3},
                 scheduler_kwargs={},
                ):
        super().__init__()
        
        self.net = net
        self.optimizer_kwargs = optimizer_kwargs
        self.scheduler_kwargs = scheduler_kwargs
        
        
    def forward(self, batch):
        
        x, y = batch
        
        y_hat = self.net(x)
        loss  = LossFunc(y_hat, y)
        
        return y_hat, loss
        
    def training_step(self, batch, batch_idx):
        
        y_hat, loss = self.forward(batch)
        
        return loss.sum()
    
    def validation_step(self, batch, batch_idx):
        
        y_hat, loss = self.forward(batch)
        
        return loss.sum()
    
    def test_step(self, batch, batch_idx):
        
        y_hat, loss = self.forward(batch)
        
        return loss.sum()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), **self._optim_kwargs)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer(), **self._scheduler_kwargs)
        
        return [optimizer, ...], [scheduler, ...]

Additional useful documentation and standalone tutorials

☝️ back to table of contents

LightningDataModule

Purpose: Make your model independent of a given dataset, while at the same time making your dataset reproducible and perhaps just as important: easily shareable.

from pytorch_lightning import LightningDataModule
from torch.data.utils import DataLoader

class YourDataModule(LightningDataModule):
    
    def __init__(self):
        # define any setup computations
        
    def prepare_data(self):        
        # download data if applicable
        
    def setup(self, stage):
        # assign data to `Dataset`(s)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
        
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
        
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
        

When it comes to actually using one of these, it looks something like the following:

# Init the LightningDataModule as well as the LightningModel
data = YourDataModule()
model = YourLightningModel()

# Define trainer
trainer = Trainer(accelerator="auto", devices=1)

# Ultimately, both  model and data are passed as an arg to trainer.fit
trainer.fit(model, data)

Here's an example of a LightningDataModule implemented in practice, using the LARRY single-cell dataset: link. Initial downloading and formatting occurs only once but takes several minutes so we will leave it outside the scope of this tutorial.

☝️ back to table of contents

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning-tutorial-0.0.2.tar.gz (6.5 kB view hashes)

Uploaded Source

Built Distribution

lightning_tutorial-0.0.2-py3-none-any.whl (6.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page