Skip to main content

The Deep Learning framework to train, deploy, and ship AI products Lightning fast.

Project description

Lightning

Blazingly fast, distributed streaming of training data from any cloud storage

⚡ Welcome to LitData

With LitData, users can transform and optimize their data in cloud storage environments efficiently and intuitively, at any scale.

Once optimized, efficient distributed training becomes practical regardless of where the data is located, enabling users to seamlessly stream data of any size to one or multiple machines.

LitData supports images, text, video, audio, geo-spatial, and multimodal data types, is already adopted by frameworks such as LitGPT to pretrain LLMs and integrates smoothly with PyTorch Lightning, Lightning Fabric, and PyTorch.

Runnable templates published on the Lightning.AI Platform are available at the end, reproducible in 1-click.

Table of Contents

Getting Started

Installation

Install LitData with pip

pip install litdata

Install LitData with the extras

pip install 'litdata[extras]'

Quick Start

1. Prepare Your Data

Convert your raw dataset into LitData Optimized Streaming Format using the optimize operator.

Here is an example with some random images.

import numpy as np
from litdata import optimize
from PIL import Image


# Store random images into the data chunks
def random_images(index):
    data = {
        "index": index, # int data type
        "image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)), # PIL image data type
        "class": np.random.randint(10), # numpy array data type
    }
    # The data is serialized into bytes and stored into data chunks by the optimize operator.
    return data # The data is serialized into bytes and stored into data chunks by the optimize operator.

if __name__ == "__main__":
    optimize(
        fn=random_images,  # The function applied over each input.
        inputs=list(range(1000)),  # Provide any inputs. The fn is applied on each item.
        output_dir="my_optimized_dataset",  # The directory where the optimized data are stored.
        num_workers=4,  # The number of workers. The inputs are distributed among them.
        chunk_bytes="64MB"  # The maximum number of bytes to write into a data chunk.
    )

The optimize operator supports any data structures and types. Serialize whatever you want. The optimized data is stored under the output directory my_optimized_dataset.

2. Upload your Data to Cloud Storage

Cloud providers such as AWS, Google Cloud, Azure provide command line clients to upload your data to their storage solutions.

Here is how to upload the optimized dataset using the AWS CLI to AWS S3.

 aws s3 cp --recursive my_optimized_dataset s3://my-bucket/my_optimized_dataset

3. Use StreamingDataset

Then, the Streaming Dataset can read the data directly from AWS S3.

from litdata import StreamingDataset, StreamingDataLoader

# Remote path where full dataset is stored
input_dir = 's3://my-bucket/my_optimized_dataset'

# Create the Streaming Dataset
dataset = StreamingDataset(input_dir, shuffle=True)

# Access any elements of the dataset
sample = dataset[50]
img = sample['image']
cls = sample['class']

# Create dataLoader and iterate over it to train your AI models.
dataloader = StreamingDataLoader(dataset)

Key Features

Multi-GPU / Multi-Node Support

The StreamingDataset and StreamingDataLoader automatically make sure each rank receives the same quantity of varied batches of data, so it works out of the box with your favorite frameworks (PyTorch Lightning, Lightning Fabric, or PyTorch) to do distributed training.

Here you can see an illustration showing how the Streaming Dataset works with multi node / multi gpu under the hood.

An illustration showing how the Streaming Dataset works with multi node.

Access any item

Access the data you need, whenever you need it, regardless of where it is stored.

from litdata import StreamingDataset

dataset = StreamingDataset("s3://my-bucket/my-data") # data are stored in the cloud

print(len(dataset)) # display the length of your data

print(dataset[42]) # show the 42th element of the dataset

Use any data transforms

Subclass the StreamingDataset and override its __getitem__ method to add any extra data transformations.

from litdata import StreamingDataset, StreamingDataLoader
import torchvision.transforms.v2.functional as F

class ImagenetStreamingDataset(StreamingDataset):

    def __getitem__(self, index):
        image = super().__getitem__(index)
        return F.resize(image, (224, 224))

dataset = ImagenetStreamingDataset(...)
dataloader = StreamingDataLoader(dataset, batch_size=4)

for batch in dataloader:
    print(batch.shape)
    # Out: (4, 3, 224, 224)

The Map Operator

The map operator can be used to apply a function over a list of inputs.

Here is an example where the map operator is used to apply a resize_image function over a folder of large images.

from litdata import map
from PIL import Image

# Note: Inputs could also refer to files on s3 directly.
input_dir = "my_large_images"
inputs = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]

# The resize image takes one of the input (image_path) and the output directory. 
# Files written to output_dir are persisted.
def resize_image(image_path, output_dir):
  output_image_path = os.path.join(output_dir, os.path.basename(image_path))
  Image.open(image_path).resize((224, 224)).save(output_image_path)
  
map(
    fn=resize_image,
    inputs=inputs, 
    output_dir="s3://my-bucket/my_resized_images",
)

Easy Data Mixing with the Combined Streaming Dataset

Easily experiment with dataset mixtures using the CombinedStreamingDataset class.

As an example, this mixture of Slimpajama & StarCoder was used in the TinyLLAMA project to pretrain a 1.1B Llama model on 3 trillion tokens.

from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader
from litdata.streaming.item_loader import TokensLoader
from tqdm import tqdm
import os

train_datasets = [
    StreamingDataset(
        input_dir="s3://tinyllama-template/slimpajama/train/",
        item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs 
        shuffle=True,
        drop_last=True,
    ),
    StreamingDataset(
        input_dir="s3://tinyllama-template/starcoder/",
        item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs 
        shuffle=True,
        drop_last=True,
    ),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=42, weights=weights)

train_dataloader = StreamingDataLoader(combined_dataset, batch_size=8, pin_memory=True, num_workers=os.cpu_count())

# Iterate over the combined datasets
for batch in tqdm(train_dataloader):
    pass

Pause & Resume Made Simple

LitData provides a stateful Streaming DataLoader e.g. you can pause and resume your training whenever you want.

Info: The Streaming DataLoader was used by Lit-GPT to pretrain LLMs. Restarting from an older checkpoint was critical to get to pretrain the full model due to several failures (network, CUDA Errors, etc..).

import os
import torch
from litdata import StreamingDataset, StreamingDataLoader

dataset = StreamingDataset("s3://my-bucket/my-data", shuffle=True)
dataloader = StreamingDataLoader(dataset, num_workers=os.cpu_count(), batch_size=64)

# Restore the dataLoader state if it exists
if os.path.isfile("dataloader_state.pt"):
    state_dict = torch.load("dataloader_state.pt")
    dataloader.load_state_dict(state_dict)

# Iterate over the data
for batch_idx, batch in enumerate(dataloader):
  
    # Store the state every 1000 batches
    if batch_idx % 1000 == 0:
        torch.save(dataloader.state_dict(), "dataloader_state.pt")

Support Profiling

The StreamingDataLoader supports profiling of your data loading process. Simply use the profile_batches argument to specify the number of batches you want to profile:

from litdata import StreamingDataset, StreamingDataLoader

StreamingDataLoader(..., profile_batches=5)

This generates a Chrome trace called result.json. Then, visualize this trace by opening Chrome browser at the chrome://tracing URL and load the trace inside.

Reduce your memory footprint

When processing large files like compressed parquet files, use the Python yield keyword to process and store one item at the time, reducing the memory footprint of the entire program.

from pathlib import Path
import pyarrow.parquet as pq
from litdata import optimize
from tokenizer import Tokenizer
from functools import partial

# 1. Define a function to convert the text within the parquet files into tokens
def tokenize_fn(filepath, tokenizer=None):
    parquet_file = pq.ParquetFile(filepath)
    # Process per batch to reduce RAM usage
    for batch in parquet_file.iter_batches(batch_size=8192, columns=["content"]):
        for text in batch.to_pandas()["content"]:
            yield tokenizer.encode(text, bos=False, eos=True)

# 2. Generate the inputs
input_dir = "/teamspace/s3_connections/tinyllama-template"
inputs = [str(file) for file in Path(f"{input_dir}/starcoderdata").rglob("*.parquet")]

# 3. Store the optimized data wherever you want under "/teamspace/datasets" or "/teamspace/s3_connections"
outputs = optimize(
    fn=partial(tokenize_fn, tokenizer=Tokenizer(f"{input_dir}/checkpoints/Llama-2-7b-hf")), # Note: Use HF tokenizer or any others
    inputs=inputs,
    output_dir="/teamspace/datasets/starcoderdata",
    chunk_size=(2049 * 8012), # Number of tokens to store by chunks. This is roughly 64MB of tokens per chunk.
)

Configure Cache Size Limit

Adapt the local caching limit of the StreamingDataset. This is useful to make sure the downloaded data chunks are deleted when used and the disk usage stays low.

from litdata import StreamingDataset

dataset = StreamingDataset(..., max_cache_size="10GB")

On-Prem Optimizations

On-prem compute nodes can mount and use a network drive. A network drive is a shared storage device on a local area network. In order to reduce their network overload, the StreamingDataset supports caching the data chunks.

from litdata import StreamingDataset

dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data")

Benchmarks

In order to measure the effectiveness of LitData, we used a commonly used dataset for benchmarks: Imagenet-1.2M where the training set contains 1,281,167 images.

To align with other benchmarks, we measured the streaming speed (images per second) loaded from AWS S3 for several frameworks.

Reproduce our benchmark by running this Studio.

Imagenet-1.2M Streaming from AWS S3

We can observe LitData is up to 85 % faster than the second best. Higher is better in the table below.

Framework Images / sec 1st Epoch (float32) Images / sec 2nd Epoch (float32) Images / sec 1st Epoch (torch16) Images / sec 2nd Epoch (torch16)
PL Data 5800.34 6589.98 6282.17 7221.88
Web Dataset 3134.42 3924.95 3343.40 4424.62
Mosaic ML 2898.61 5099.93 2809.69 5158.98

Imagenet-1.2M Conversion

We measured how fast the 1.2 million images can converted into a streamable format. Faster is better in the table below.

Framework Train Conversion Time Val Conversion Time Dataset Size # Files
PL Data 10:05 min 00:30 min 143.1 GB 2.339
Web Dataset 32:36 min 01:22 min 147.8 GB 1.144
Mosaic ML 49:49 min 01:04 min 143.1 GB 2.298

Runnable Templates

Fastest way to learn is with Studios.

Studios are reproducible cloud IDE with data, code, dependencies, e.g. so redo everything yourself with ease!

We've published public templates that demonstrates how best to use the LitData framework at scale and with several data types.

Sign up here and run your first Studio for free.

Studio Data type Dataset
Use or explore LAION-400MILLION dataset Image & Text LAION-400M
Convert GeoSpatial data to Lightning Streaming Image & Mask Chesapeake Roads Spatial Context
Benchmark cloud data-loading libraries Image & Label Imagenet 1M
Prepare the TinyLlama 1T token dataset Text SlimPajama & StarCoder
Tokenize 2M Swedish Wikipedia Articles Text Swedish Wikipedia
Embed English Wikipedia under 5 dollars Text English Wikipedia
Convert parquets to Lightning Streaming Parquet Files Randomly Generated data

Infinite cloud data processing

If you want to scale data processing, you typically need more machines and if you do this yourself, this becomes very tedious and can take a long time to get there.

Instead, create a free account on the Lightning.ai platform and use as many machines as you need from code.

On the platform, simply specify the number of nodes and the machine type you need as follows:

from litdata import map, Machine

map(
  ...
  num_nodes=32,
  machine=Machine.DATA_PREP, # Select between dozens of optimized machines
)

Also, the optimize operator can do the same to make immense datasets streamable as follows:

from litdata import optimize, Machine

optimize(
  ...
  num_nodes=32,
  machine=Machine.DATA_PREP, # Select between dozens of optimized machines
)

Within the LAION 400 MILLION Studio, we utilized 32 machines, each equipped with 32 CPUs, to execute the optimize operator, enabling the download of 400 million images in just 2 hours. Below is a screenshot of that job within the Lightning.ai platform. You can execute it yourself here.

Lightning

⚡ Contributors

We welcome any contributions, pull requests, or issues. If you use the Streaming Dataset for your own project, please reach out to us on Discord.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

litdata-0.2.9.tar.gz (91.1 kB view details)

Uploaded Source

Built Distribution

litdata-0.2.9-py3-none-any.whl (104.3 kB view details)

Uploaded Python 3

File details

Details for the file litdata-0.2.9.tar.gz.

File metadata

  • Download URL: litdata-0.2.9.tar.gz
  • Upload date:
  • Size: 91.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for litdata-0.2.9.tar.gz
Algorithm Hash digest
SHA256 689b8faac7f67e2de37bf6c87fab11521cbe8b129461e2fa587899318ff7dfbb
MD5 29367dff51d6598b5985f8130c94d328
BLAKE2b-256 3b9ad0e802c86e590fa083f2f47d9c00a45a2c0f036845de1667056f5e36542d

See more details on using hashes here.

File details

Details for the file litdata-0.2.9-py3-none-any.whl.

File metadata

  • Download URL: litdata-0.2.9-py3-none-any.whl
  • Upload date:
  • Size: 104.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for litdata-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 80e60f6be00a4e7399e7f06183da60522b14fa47e0541cf37798954d9436b795
MD5 1d535201f0ca2a6e481d77d6a6180dfc
BLAKE2b-256 6d3eab9af0f327b0c8b5bc6c396a9c1f426ce69199e3dda2f274292124330c04

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page