pytorch-lightning tutorial
Project description
⚡ lightning-tutorial
Installation of the partner package
pip install lightning-tutorial
Table of contents
-
PyTorch Datasets and DataLoaders
- Key module:
torch.utils.data.Dataset
- Key module:
torch.utils.data.DataLoader
- Other essential functions
- Key module:
PyTorch Datasets and DataLoaders
Key module: torch.utils.data.Dataset
The Dataset
module is an overwritable python module. You can modify it at will as long as you maintain the following three class methods:
__init__
__len__
__getitem__
These are name-specific handles used by torch
under the hood when passing data through a model.
from torch.utils.data import Dataset
class TurtleData(Dataset):
def __init__(self):
"""
here we should pass requisite arguments
that enable __len__() and __getitem__()
"""
def __len__(self):
"""
Returns the length/size/# of samples in the dataset.
e.g., a 20,000 cell dataset would return `20_000`.
"""
return # len
def __getitem__(self, idx):
"""
Subset and return a batch of the data.
`idx` is the batch index (# of idx values = batch size).
Maximum `idx` passed is <= `self.__len__()`
"""
return # sampled data
-
Try it for yourself! Colab
Dataset
tutorial notebook
Key module: torch.utils.data.DataLoader
Similar to the usefulness of AnnData
, the Dataset
module creates a base unit for distributing and handling data. We can then take advantage of several torch built-ins to enable not only more organized, but faster data processing.
from torch.utils.data import DataLoader
dataset = TurtleData()
data_size = dataset.__len__()
print(data_size)
20_000
Other essential functions
from torch.utils.data import random_split
train_dataset, val_dataset = random_split(dataset, [18_000, 2_000])
# this can then be fed to a DataLoader, as above
train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)
Useful tutorials and documentation
- Parent module:
torch.utils.data
- Datasets and DataLoaders tutorial
Single-cell data structures meet pytorch: torch-adata
Create pytorch Datasets from AnnData
Installation
- Note: This is already done for you, if you've installed this tutorials associated package
pip install torch-adata
Example use of the base class
The base class, AnnDataset
is a subclass of the widely-used torch.utils.data.Dataset
.
import anndata as a
import torch_adata
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.AnnDataset(adata)
Returns sampled data X_batch
as a torch.Tensor
.
# create a dummy index
idx = np.random.choice(range(dataset.__len__()), 5)
X_batch = dataset.__getitem__(idx)
TimeResolvedAnnDataset
Specialized class for time-resolved datasets. A subclass of the class, AnnDataset
.
import anndata as a
import torch_adata as ta
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.TimeResolvedAnnDataset(adata, time_key="Time point")
Lightning basics and the LightningModule
from pytorch_lightning imoport LightningModule
class YourSOTAModel(LightningModule):
def __init__(self,
net,
optimizer_kwargs={"lr":1e-3},
scheduler_kwargs={},
):
super().__init__()
self.net = net
self.optimizer_kwargs = optimizer_kwargs
self.scheduler_kwargs = scheduler_kwargs
def forward(self, batch):
x, y = batch
y_hat = self.net(x)
loss = LossFunc(y_hat, y)
return y_hat, loss
def training_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def validation_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def test_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), **self._optim_kwargs)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer(), **self._scheduler_kwargs)
return [optimizer, ...], [scheduler, ...]
Additional useful documentation and standalone tutorials
LightningDataModule
Purpose: Make your model independent of a given dataset, while at the same time making your dataset reproducible and perhaps just as important: easily shareable.
from pytorch_lightning import LightningDataModule
from torch.data.utils import DataLoader
class YourDataModule(LightningDataModule):
def __init__(self):
# define any setup computations
def prepare_data(self):
# download data if applicable
def setup(self, stage):
# assign data to `Dataset`(s)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
When it comes to actually using one of these, it looks something like the following:
# Init the LightningDataModule as well as the LightningModel
data = YourDataModule()
model = YourLightningModel()
# Define trainer
trainer = Trainer(accelerator="auto", devices=1)
# Ultimately, both model and data are passed as an arg to trainer.fit
trainer.fit(model, data)
Here's an example of a LightningDataModule
implemented in practice, using the LARRY single-cell dataset: link. Initial downloading and formatting occurs only once but takes several minutes so we will leave it outside the scope of this tutorial.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file lightning-tutorial-0.0.2.tar.gz
.
File metadata
- Download URL: lightning-tutorial-0.0.2.tar.gz
- Upload date:
- Size: 6.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8b2bdfc5124431893dca0dc51807dabc858399ec65041b7ab34a17ecb35bf524 |
|
MD5 | 0fa1d77e28fad645e749188fb5dcbf33 |
|
BLAKE2b-256 | 420f67ba40262b7af0143741a1c72cc52af302a585c11de7ce552965600da777 |
File details
Details for the file lightning_tutorial-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: lightning_tutorial-0.0.2-py3-none-any.whl
- Upload date:
- Size: 6.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f487cdfe9b850d7be83bb5605cbdab3619647f28e0a3d6a3e41e2c8a52c6d245 |
|
MD5 | 385e05e49d6a348e9bf646810f38815a |
|
BLAKE2b-256 | c9eb7656ba3abc165f7b47cd713abde3c5204903f86b70238c7639473403b968 |