Skip to main content

A toolkit for scaling law research

Project description

chinchilla

chinchilla is a research toolkit designed to estimate scaling laws and train compute-optimal models for various deep learning tasks.

Expected Use Cases:

  • Researching the neural scaling law itself
  • Scaling compute for
    • Large Language Models (LLM)
    • Vision Transformers (ViT)
    • Reinforcement Learning (RL)
    • Embedding Models
    • Knowledge distillation
  • Evaluating compute efficiencies of new algorithms & architectures
  • Researching the neural scaling law itself

Probably Not For:

  • Fine-tuning tasks

  • Data-scarce domains

[!IMPORTANT] This work builds upon the scaling law formulation proposed in the original Chinchilla paper by DeepMind (2022), with some modifications detailed in ./docs/changes.md.

Features

  • Scaling Law Estimation: Fit a loss predictor based on multiple training runs.
  • Compute-Optimal Allocation: Train the best possible model within a given compute budget.
  • Progressive Scaling: Iteratively update the scaling law estimation and scale up the compute.
  • Simulation Mode: Test scaling law estimations in hypothetical scenarios.

Basics

Definitions

  • $N$: The number of parameters
  • $D$: The number of data samples
  • $C$: Total compute in FLOPs ($C\approx 6\ ND$)
  • $L(N,\ D) = E + A/N ^ \alpha + B / D ^ \beta$: A loss predictor parameterized by ${E, A, B, \alpha, \beta}$ and $C$

Compute-Optimal Allocation

  1. Optimize the parameters ${E, A, B, \alpha, \beta}$ to better predict losses $L_i$ from $(N_i, D_i)$
  2. Solve $\underset{N,\ D}{argmin}\ L(N,\ D\ |\ C)$, which can be derived from ${A, B, \alpha, \beta}$

chinchilla Procedure

  • seed: Sample X training runs $(N_i, D_i, L_i)$, referred to as seeds
  • For i = 0 to K:
    • fit: Optimize the scaling law parameters to fit $L(N,\ D)$ on the training runs
    • scale: Configure a new model with a scaled compute
    • Evaluate the allocation by training a model
    • append: Add the result to the database of training runs

Installation

[!WARNING]

chinchilla requires Python >= 3.8

From Source (Recommended for Customization)

git clone https://github.com/kyo-takano/chinchilla.git
cd chinchilla
pip install -e .

From PyPI

pip install -U chinchilla

Usage

Below is an example to get started with chinchilla.

import numpy as np
from chinchilla import Chinchilla

cc = Chinchilla(
    "your_project__dir",
    param_grid=dict(
        E=np.linspace(1.1, 1.5, 5),
        A=np.linspace(200, 1000, 5),
        B=np.linspace(200, 1000, 5),
        alpha=np.linspace(0.1, 0.5, 5),
        beta=np.linspace(0.1, 0.5, 5),
    ),
    seed_ranges=dict(C=(1e15, 1e16), N_to_D=(10, 100)),
    # To search for the model configuration with N closest to suggested:
    model_search_config=dict(
        hyperparam_grid=dict(
            hidden_size=list(range(64, 16384 + 1, 64)),
            num_hidden_layers=list(range(1, 50 + 1)),
            num_heads=list(range(1, 40 + 1)),
        ),
        size_estimator=estimate_model_size,  # You gotta define a function to estimate & return model size also
    ),
    # Parameters you may pre-set
    num_seeding_steps=100,
    scaling_factor=2.0,
)


# Run the scaling law estimation and training process
for i in range(100 + 5):
    # Sample a new model
    (N, D), model_config = cc.step(num_seeding_steps=100)

    # Define a model
    model = YourModelClass(**model_config)

    # Train & evaluate the allocation C => (N, D)
    loss = train_and_evaluate(model, D)

    # Finally, append the training run into the database
    cc.append(N=N, D=D, loss=loss)

Ensure you define functionally equivalent versions of:

  • estimate_model_size: Estimates and returns the model size.
  • YourModelClass: Your model class definition.
  • train_and_evaluate: Function to train and evaluate your model.

Simulation

You can also visualize how chinchilla would perform under the given setup and a hypothetical scaling law, optionally with a noise term:

import random

cc.simulate(
    num_seeding_steps=401,
    num_scaling_steps=1,
    scaling_factor=10.0,
    target_params=dict(
        E=1.69337368,
        A=406.401018,
        B=410.722827,
        alpha=0.33917084,
        beta=0.2849083
    ),
    # Add exponentially distributed loss averaging at 0.1
    noise_generator=(random.expovariate, (10,))
)

Please see API Reference for more.

Examples

Find a practical application of chinchilla in the examples directory (more to come):

Documentation

For a detailed API Reference, tips, differences from the original Chinchilla paper, etc., please browse to ./docs.

Contributing

We welcome your contributions. Please report bugs and suggest improvements through new issues and pull requests.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

chinchilla-0.1.4-py3-none-any.whl (35.7 kB view details)

Uploaded Python 3

File details

Details for the file chinchilla-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: chinchilla-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 35.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for chinchilla-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 3a1cacd69045892db8f02001da2ecf59d07e0ec78ff554c53ad955e5ee266627
MD5 0c7768550ea1935a864855dc3510bb2b
BLAKE2b-256 64ab85f53c3f4ea14ff7c87878278394c55133405858d6c86a539cc79637b0bc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page