Skip to main content

The Deep Learning framework to train, deploy, and ship AI products Lightning fast.

Project description

Lightning

Blazing fast, distributed streaming of training data from cloud storage

⚡ Welcome to Lightning Data

We developed StreamingDataset to optimize training of large datasets stored on the cloud while prioritizing speed, affordability, and scalability.

Specifically crafted for multi-node, distributed training with large models, it enhances accuracy, performance, and user-friendliness. Now, training efficiently is possible regardless of the data's location. Simply stream in the required data when needed.

The StreamingDataset is compatible with any data type, including images, text, video, and multimodal data and it is a drop-in replacement for your PyTorch IterableDataset class. For example, it is used by Lit-GPT to pretrain LLMs.

Finally, the StreamingDataset is fast! Check out our benchmark.

Here is an illustration showing how the StreamingDataset works.

An illustration showing how the Streaming Dataset works.

🎬 Getting Started

💾 Installation

Lightning Data can be installed with pip:

pip install --no-cache-dir git+https://github.com/Lightning-AI/lit-data.git@master

🏁 Quick Start

1. Prepare Your Data

Convert your raw dataset into Lightning Streaming format using the optimize operator. More formats are coming...

import numpy as np
from lightning_data import optimize
from PIL import Image


# Store random images into the chunks
def random_images(index):
    data = {
        "index": index,
        "image": Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)),
        "class": np.random.randint(10),
    }
    return data # The data is serialized into bytes and stored into chunks by the optimize operator.

if __name__ == "__main__":
    optimize(
        fn=random_images,  # The function applied over each input.
        inputs=list(range(1000)),  # Provide any inputs. The fn is applied on each item.
        output_dir="my_dataset",  # The directory where the optimized data are stored.
        num_workers=4,  # The number of workers. The inputs are distributed among them.
        chunk_bytes="64MB"  # The maximum number of bytes to write into a chunk.
    )

The optimize operator supports any data structures and types. Serialize whatever you want.

2. Upload Your Data to Cloud Storage

Cloud providers such as AWS, Google Cloud, Azure, etc.. provide command line client to upload your data to their storage.

Here is an example with AWS S3.

 aws s3 cp --recursive my_dataset s3://my-bucket/my_dataset

3. Use StreamingDataset and DataLoader

from lightning_data import StreamingDataset
from torch.utils.data import DataLoader

# Remote path where full dataset is persistently stored
input_dir = 's3://pl-flash-data/my_dataset'

# Create streaming dataset
dataset = StreamingDataset(input_dir, shuffle=True)

# Check any elements
sample = dataset[50]
img = sample['image']
cls = sample['class']

# Create PyTorch DataLoader
dataloader = DataLoader(dataset)

Transform data

Similar to optimize, the map operator can be used to transform data by applying a function over a list of item and persist all the files written inside the output directory.

1. Put some images on a cloud storage

We generates 1000 images and upload them to AWS S3.

import os
from PIL import Image
import numpy as np

data_dir = "my_images"
os.makedirs(data_dir, exist_ok=True)

for i in range(1000):
    width = np.random.randint(224, 320) 
    height = np.random.randint(224, 320) 
    image_path = os.path.join(data_dir, f"{i}.JPEG")
    Image.fromarray(
        np.random.randint(0, 256, (width, height, 3), np.uint8)
    ).save(image_path, format="JPEG", quality=90)
 aws s3 cp --recursive my_images s3://my-bucket/my_images

2. Resize the images

import os
from lightning_data import map
from PIL import Image

input_dir = "s3://my-bucket/my_images"
inputs = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]

def resize_image(image_path, output_dir):
  output_image_path = os.path.join(output_dir, os.path.basename(image_path))
  Image.open(image_path).resize((224, 224)).save(output_image_path)
  
if __name__ == "__main__":
    map(
        fn=resize_image,
        inputs=inputs, 
        output_dir="s3://my-bucket/my_resized_images",
        num_workers=4,
    )

📚 End-to-end Lightning Studio Templates

We have end-to-end free Studios showing all the steps to prepare the following datasets:

Dataset Data type Studio
LAION-400M Image & description Use or explore LAION-400MILLION dataset
Chesapeake Roads Spatial Context Image & Mask Convert GeoSpatial data to Lightning Streaming
Imagenet 1M Image & Label Benchmark cloud data-loading libraries
SlimPajama & StartCoder Text Prepare the TinyLlama 1T token dataset
English Wikepedia Text Embed English Wikipedia under 5 dollars
Generated Parquet Files Convert parquets to Lightning Streaming

Lightning Studios are fully reproducible cloud IDE with data, code, dependencies, etc... Finally reproducible science.

📈 Easily scale data processing

To scale data processing, create a free account on lightning.ai platform. With the platform, the optimize and map can start multiple machines to make data processing drastically faster as follows:

from lightning_data import optimize, Machine

optimize(
  ...
  num_nodes=32,
  machine=Machine.DATA_PREP, # You can select between dozens of optimized machines
)

OR

from lightning_data import map, Machine

map(
  ...
  num_nodes=32,
  machine=Machine.DATA_PREP, # You can select between dozens of optimized machines
)
Lightning

The Data Prep Job UI from the LAION 400M Studio where we used 32 machines with 32 CPU each to download 400 million images in only 2 hours.

🔑 Key Features

🚀 Multi-GPU / Multi-Node

The StreamingDataset and StreamingDataLoader takes care of everything for you. They automatically make sure each rank receives different batch of data. There is nothing for you to do if you use them.

🎨 Easy data mixing

You can easily experiment with dataset mixtures using the CombinedStreamingDataset.

from lightning_data import StreamingDataset, CombinedStreamingDataset
from lightning_data.streaming.item_loader import TokensLoader
from tqdm import tqdm
import os
from torch.utils.data import DataLoader

train_datasets = [
    StreamingDataset(
        input_dir="s3://tinyllama-template/slimpajama/train/",
        item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs 
        shuffle=True,
        drop_last=True,
    ),
    StreamingDataset(
        input_dir="s3://tinyllama-template/starcoder/",
        item_loader=TokensLoader(block_size=2048 + 1), # Optimized loader for tokens used by LLMs 
        shuffle=True,
        drop_last=True,
    ),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=42, weights=weights)

train_dataloader = DataLoader(combined_dataset, batch_size=8, pin_memory=True, num_workers=os.cpu_count())

# Iterate over the combined datasets
for batch in tqdm(train_dataloader):
    pass

🔘 Stateful StreamingDataLoader

Lightning Data provides a stateful StreamingDataLoader. This simplifies resuming training over large datasets.

Note: The StreamingDataLoader is used by Lit-GPT to pretrain LLMs. The statefulness still works when using a mixture of datasets with the CombinedStreamingDataset.

import os
import torch
from lightning_data import StreamingDataset, StreamingDataLoader

dataset = StreamingDataset("s3://my-bucket/my-data", shuffle=True)
dataloader = StreamingDataLoader(dataset, num_workers=os.cpu_count(), batch_size=64)

# Restore the dataLoader state if it exists
if os.path.isfile("dataloader_state.pt"):
    state_dict = torch.load("dataloader_state.pt")
    dataloader.load_state_dict(state_dict)

# Iterate over the data
for batch_idx, batch in enumerate(dataloader):
  
    # Store the state every 1000 batches
    if batch_idx % 1000 == 0:
        torch.save(dataloader.state_dict(), "dataloader_state.pt")

🎥 Profiling

The StreamingDataLoader supports profiling your data loading. Simply use the profile_batches argument as follows:

from lightning_data import StreamingDataset, StreamingDataLoader

StreamingDataLoader(..., profile_batches=5)

This generates a Chrome trace called result.json. You can visualize this trace by opening Chrome browser at the chrome://tracing URL and load the trace inside.

🪇 Random access

Access the data you need when you need it.

from lightning_data import StreamingDataset

dataset = StreamingDataset(...)

print(len(dataset)) # display the length of your data

print(dataset[42]) # show the 42th element of the dataset

✢ Use data transforms

from lightning_data import StreamingDataset, StreamingDataLoader
import torchvision.transforms.v2.functional as F

class ImagenetStreamingDataset(StreamingDataset):

    def __getitem__(self, index):
        image = super().__getitem__(index)
        return F.resize(image, (224, 224))

dataset = ImagenetStreamingDataset(...)
dataloader = StreamingDataLoader(dataset, batch_size=4)

for batch in dataloader:
    print(batch.shape)
    # Out: (4, 3, 224, 224)

⚙️ Disk usage limits

Limit the size of the cache holding the chunks.

from lightning_data import StreamingDataset

dataset = StreamingDataset(..., max_cache_size="10GB")

💾 Support yield

When processing large files like compressed parquet files, you can use python yield to process and store one item at the time.

from pathlib import Path
import pyarrow.parquet as pq
from lightning_data import optimize
from tokenizer import Tokenizer
from functools import partial

# 1. Define a function to convert the text within the parquet files into tokens
def tokenize_fn(filepath, tokenizer=None):
    parquet_file = pq.ParquetFile(filepath)
    # Process per batch to reduce RAM usage
    for batch in parquet_file.iter_batches(batch_size=8192, columns=["content"]):
        for text in batch.to_pandas()["content"]:
            yield tokenizer.encode(text, bos=False, eos=True)

# 2. Generate the inputs
input_dir = "/teamspace/s3_connections/tinyllama-template"
inputs = [str(file) for file in Path(f"{input_dir}/starcoderdata").rglob("*.parquet")]

# 3. Store the optimized data wherever you want under "/teamspace/datasets" or "/teamspace/s3_connections"
outputs = optimize(
    fn=partial(tokenize_fn, tokenizer=Tokenizer(f"{input_dir}/checkpoints/Llama-2-7b-hf")), # Note: You can use HF tokenizer or any others
    inputs=inputs,
    output_dir="/teamspace/datasets/starcoderdata",
    chunk_size=(2049 * 8012),
)

⚡ Contributors

We welcome any contributions, pull requests, or issues. If you use the Streaming Dataset for your own project, please reach out to us on Slack or Discord.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning-data-0.2.0.dev0.tar.gz (67.3 kB view details)

Uploaded Source

Built Distribution

lightning_data-0.2.0.dev0-py3-none-any.whl (78.8 kB view details)

Uploaded Python 3

File details

Details for the file lightning-data-0.2.0.dev0.tar.gz.

File metadata

  • Download URL: lightning-data-0.2.0.dev0.tar.gz
  • Upload date:
  • Size: 67.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for lightning-data-0.2.0.dev0.tar.gz
Algorithm Hash digest
SHA256 28367a0eb7311ade25bff3674b8252b85c95dfc22605057846d5a0b2cd44297d
MD5 d8e4daeeffc103959faeb3209e917504
BLAKE2b-256 00e0719c16110f48d71a858c1fda5abaf5f56efce618a134fd7de02f18a75bd6

See more details on using hashes here.

File details

Details for the file lightning_data-0.2.0.dev0-py3-none-any.whl.

File metadata

File hashes

Hashes for lightning_data-0.2.0.dev0-py3-none-any.whl
Algorithm Hash digest
SHA256 139a877382d666990941df66a3a432d84a6822f3ad9f757b5c25bf231d25c527
MD5 3a375c9b42d5fac0529d31fc64f7910f
BLAKE2b-256 90e7c0af0668f7dac4e3fdc12a5b51e1e49d6a85557253526bc5f3af9f6cc49c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page