Skip to main content

System to ease incremental training of a Huggingface transformer model from a large S3-based dataset

Project description

continuing_education

The continuing_education module provides for easily training large language models (LLMs) with Quantized Low Rank Adapters (QLoRA), focusing on network and storage efficiency and the ability to continue training from checkpoints. This is particularly useful in environments prone to interruptions or where network bandwidth and storage space is limited. e.g. vast.ai or salad.com instances.

Features

  • Efficient Training with QLoRA: Enhance your Hugging Face Transformer models with Low Rank Adapters, optimizing for both performance and memory efficiency.
  • Checkpointing for Continuation: Seamlessly resume training from the last checkpoint, minimizing data reprocessing and model initialization times.
  • AWS Integration: Leverage AWS S3 buckets for dataset storage, ensuring scalable and accessible data management.
  • Flexible Training Schemes: Configure training sessions according to your specific needs, including setting steps per round, maximum steps, and segment-based training.
  • Automatic Tokenization and Dataset Preparation: Utilize integrated tokenization and dataset management for a streamlined setup process.

Installation

pip install continuing-education

Usage

Setting Up Environment Variables

Before using continuing_education, ensure the following environment variables are set:

  • AWS_PROFILE or AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY: For AWS authentication.

Example Usage

from continuing_education import QLoRAContinuingTrainer, logger
import os

# Ensure AWS credentials and bucket names are set
if 'AWS_PROFILE' not in os.environ and not ('AWS_SECRET_ACCESS_KEY' in os.environ and 'AWS_ACCESS_KEY_ID' in os.environ):
    raise EnvironmentError("AWS credentials required.")

dataset_bucket_name = os.environ['DATASET_BUCKET']
output_dir = os.environ['OUTPUT_DIR']

base_model_id = 'Mistralai/Mistral-7Bv0.1'  # Default model

# Initialize trainer with environment configurations
continuing_trainer = QLoRAContinuingTrainer(
    base_model_id=base_model_id,
    dataset_bucket_name=dataset_bucket_name,
    output_dir=output_dir,
    steps_per_round=10000,  # Example configuration
    max_steps=250000  # Example configuration
)

# Start training
continuing_trainer.train()

Note that it is up to you to download the resulting checkpoints. An example script to do this with rsync to a vast.ai instance is included in the examples directory in the source repo: vast_sync.bash

API Reference

QLoRAContinuingTrainer

QLoRAContinuingTrainer is the core class for training models with QLoRA. It extends BaseContinuingTrainer, providing mechanisms to load models, prepare datasets, and manage training sessions effectively.

Sure, let's provide a more detailed overview of the initialization parameters for the QLoRAContinuingTrainer class, directly leveraging the comprehensive information from the class docstrings:

Initialization Parameters

The QLoRAContinuingTrainer class initializes with the following parameters:

  • base_model_id (str): Identifier for the base model to which LoRA will be attached. This is the ID of a Hugging Face model that will serve as the starting point for training.

  • dataset_bucket_name (str): Name of the S3 bucket containing the training text documents.

  • output_dir (str, optional): Directory where checkpoints will be saved. Defaults to "/root/outputs". It is crucial for managing training interruptions and resumptions.

  • dataset_id (Union[str, None], optional): Key to an S3 object containing a JSON array of keys for training text objects. This is used if your dataset is a single, JSON document listing all of the keys of all of the training text documents.

  • dataset_series (Union[str, None], optional): Pattern for S3 keys of JSON objects specifying keys of objects for training. Must include "{segment_number}" if specified. This is useful for large datasets where the key list is split across multiple files. Defaults to None. Exactly one of dataset_id or dataset_series must be specified.

  • test_dataset_id (Union[str, None], optional): Key to an S3 object with a JSON array of keys for evaluation text objects.

  • steps_per_round (Union[int, None], optional): Maximum number of training steps per call to train(). This parameter can be used to limit the training duration per execution, while debugging/experimenting for quick turnaround.

  • max_seq_length (int, optional): Maximum token count per training example. This parameter is critical for memory management and ensuring the model can handle the inputs without exceeding GPU memory limits. Defaults to 2048.

  • max_steps (Union[int, None]): Explicit maximum training steps. This parameter sets a hard limit on the number of training steps, providing a way to precisely control the training duration. It's necessary to specify this if not using estimated_max_steps(). Defaults to None.

  • save_steps (int, optional): Interval of training steps after which a checkpoint is automatically saved. Regular checkpoints are crucial for resuming training without losing progress. Defaults to 1000.

Methods

  • train(): Starts the training process, utilizing checkpoints and dataset management to efficiently continue training sessions.

Contributing

Contributions to continuing_education are welcome! Feel free to open an issue if you encounter problems or want to suggest features.


This README provides a solid foundation for understanding and using your continuing_education module. Adjustments can be made to fit more specific details or to expand certain sections as needed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

continuing_education-0.0.1a4.tar.gz (19.8 kB view details)

Uploaded Source

Built Distribution

continuing_education-0.0.1a4-py3-none-any.whl (18.7 kB view details)

Uploaded Python 3

File details

Details for the file continuing_education-0.0.1a4.tar.gz.

File metadata

File hashes

Hashes for continuing_education-0.0.1a4.tar.gz
Algorithm Hash digest
SHA256 7715e41383b5e23ea097d301f2b7bfdebb64ed89310771e0fbcd08bafc7d67d4
MD5 24f3018b60788c89ca4a99e97efdcfef
BLAKE2b-256 700058bc0801df607956628c1af5a74baa956e176b503417c3c50757ac96c93c

See more details on using hashes here.

File details

Details for the file continuing_education-0.0.1a4-py3-none-any.whl.

File metadata

File hashes

Hashes for continuing_education-0.0.1a4-py3-none-any.whl
Algorithm Hash digest
SHA256 e9175d1559fa0fa59f17ce09fa85c075d2615800932c5b890b198ed5abe2097c
MD5 7c48543472701bfd9a9d2be7c9dbcd42
BLAKE2b-256 61a720d67c8815bec86a7be430f337a8e8edd67ec7b105650bda37662a03786c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page