Skip to main content

Snowflake LLM training library

Project description

License Apache 2.0 PyPI version


| Documentation | Blog |

Latest News

ArcticTraining: Simplifying and Accelerating Post-Training for LLMs

ArcticTraining is a framework designed to simplify and accelerate the post-training process for large language models (LLMs). It addresses challenges in current frameworks, such as limited support for rapid prototyping and the lack of native data generation tools, by offering modular trainer designs, simplified code structures, and integrated pipelines for creating and cleaning synthetic data. These features enable users to enhance LLM capabilities, like code generation and complex reasoning, with greater efficiency and flexibility. Read more about ArcticTraining in our blog.

Projects

The projects folder contains various special projects we have released that build on-top of ArcticTraining. Each project includes it's own README and associated assets to get started:

Papers

Quickstart

To get started training a model with ArcticTraining, follow the steps below:

  1. Install the ArcticTraining package and its dependencies:
pip install arctic-training
  1. Create a training recipe YAML that uses the built-in Supervised Fine-Tuning (SFT) trainer:
type: sft
micro_batch_size: 2
model:
  name_or_path: meta-llama/Llama-3.1-8B-Instruct
data:
  sources:
    - HuggingFaceH4/ultrachat_200k
checkpoint:
  - type: huggingface
    save_end_of_training: true
    output_dir: ./fine-tuned-model
  1. Run the training recipe with the ArcticTraining CLI (see below). This will use the DeepSpeed launcher behind the scenes, you can pass any compatible DeepSpeed launcher arguments to the ArcticTraining CLI (e.g., --num_nodes, --num_gpus).
arctic_training path/to/sft-recipe.yaml

Customize Training

To customize the training workflow, you can modify the training recipe YAML we created in step 3 above. For example, you can change the model, dataset, checkpoint, or other settings to meet your specific requirements. A full list of configuration options can be found on the configuration documentation page.

Creating a New Trainer

If you want to create a new trainer, you can do so by subclassing the Trainer or SFTTrainer classes and implementing the necessary modifications. For example, you could create a new trainer from SFTTrainer that uses a different loss function:

from arctic_training import SFTTrainer

class CustomTrainer(SFTTrainer):
   name = "my_custom_trainer"

   def loss(self, batch):
       # Custom loss function implementation
       return loss

This new trainer will be automatically registered with ArcticTraining when the script containing the declaration of CustomTrainer is imported. By default, ArcticTraining looks for a train.py in the current working directory to find custom trainers. You can also specify a custom path to the trainers with the code field in your training recipe:

type: my_custom_trainer
code: path/to/custom_trainers.py
model:
 name_or_path: meta-llama/Llama-3.1-8B-Instruct
data:
 sources:
   - HuggingFaceH4/ultrachat_200k

CI funding

Modal is kindly supporting our GPU CI runs by funding the hardware for us. Modal is an AI infrastructure platform for inference, fine-tuning, batch jobs and more. Get started with $30/mo in free credits today at https://modal.com. We have been getting an amazing support from Modal's team and will surely recommend them to your business.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

arctic_training-0.7.1-py3-none-any.whl (160.4 kB view details)

Uploaded Python 3

File details

Details for the file arctic_training-0.7.1-py3-none-any.whl.

File metadata

File hashes

Hashes for arctic_training-0.7.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3d6dd2703d7cbcc4ea2ea23da7f8c7556b96a38465bb7006c777bb3633d91a53
MD5 765d380545f8148f4ac2d382981c4743
BLAKE2b-256 e8c3ba5b91798a7d98a975ead1c6c50403cb0370a1e34b7ef0a50510ecc1cbd0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page