Skip to main content

System to ease incremental training of a Huggingface transformer model from a large S3-based dataset

Project description

continuing_education

The continuing_education module provides for easily training large language models (LLMs) with Quantized Low Rank Adapters (QLoRA), focusing on network and storage efficiency and the ability to continue training from checkpoints. This is particularly useful in environments prone to interruptions or where network bandwidth and storage space is limited. e.g. vast.ai or salad.com instances.

Features

  • Efficient Training with QLoRA: Enhance your Hugging Face Transformer models with Low Rank Adapters, optimizing for both performance and memory efficiency.
  • Checkpointing for Continuation: Seamlessly resume training from the last checkpoint, minimizing data reprocessing and model initialization times.
  • AWS Integration: Leverage AWS S3 buckets for dataset storage, ensuring scalable and accessible data management.
  • Flexible Training Schemes: Configure training sessions according to your specific needs, including setting steps per round, maximum steps, and segment-based training.
  • Automatic Tokenization and Dataset Preparation: Utilize integrated tokenization and dataset management for a streamlined setup process.

Installation

pip install continuing-education

Usage

Setting Up Environment Variables

Before using continuing_education, ensure the following environment variables are set:

  • AWS_PROFILE or AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY: For AWS authentication.

Example Usage

from continuing_education import QLoRAContinuingTrainer, logger
import os

# Ensure AWS credentials and bucket names are set
if 'AWS_PROFILE' not in os.environ and not ('AWS_SECRET_ACCESS_KEY' in os.environ and 'AWS_ACCESS_KEY_ID' in os.environ):
    raise EnvironmentError("AWS credentials required.")

# Where to get training data, locally store, and remotely store checkpoints
dataset_bucket_name = os.environ['DATASET_BUCKET']
output_dir = os.environ['OUTPUT_DIR']
checkpoint_sync_uri = os.environ['CHECKPOINT_SYNC_URI']

base_model_id = 'Mistralai/Mistral-7Bv0.1'  # Model which will have a trainable LoRA attached to it

# Initialize trainer with environment configurations
continuing_trainer = QLoRAContinuingTrainer(
    base_model_id=base_model_id,
    dataset_bucket_name=dataset_bucket_name,
    checkpoint_sync_uri=checkpoint_sync_uri,
    output_dir=output_dir,
    steps_per_round=10000,  # Example configuration
    max_steps=250000  # Example configuration
)

# Start training
continuing_trainer.train()

Note that it is up to you to download the resulting checkpoints unless you supply checkpoint_sync_uri. An example script to do this with rsync to a vast.ai instance is included in the examples directory in the source repo: vast_sync.bash

API Reference

QLoRAContinuingTrainer

QLoRAContinuingTrainer is the core class for training models with QLoRA. It extends BaseContinuingTrainer, providing mechanisms to load models, prepare datasets, and manage training sessions effectively.

Initialization Parameters

The QLoRAContinuingTrainer class initializes with the following parameters:

  • base_model_id (str): Identifier for the base model to which LoRA will be attached. This is the ID of a Hugging Face model that will serve as the starting point for training.

  • dataset_bucket_name (str): Name of the S3 bucket containing the training text documents.

  • output_dir (str, optional): Directory where checkpoints will be saved. Defaults to "/root/outputs". It is crucial for managing training interruptions and resumptions.

  • checkpoint_sync_uri (str,optional): URI describing a place to store and retrieve checkpoints and the checkpoint registry json (typically something like "s3://my-bucket/checkpoints/seriesname")

  • dataset_id (Union[str, None], optional): Key to an S3 object containing a JSON array of keys for training text objects. This is used if your dataset is a single, JSON document listing all of the keys of all of the training text documents.

  • dataset_series (Union[str, None], optional): Pattern for S3 keys of JSON objects specifying keys of objects for training. Must include "{segment_number}" if specified. This is useful for large datasets where the key list is split across multiple files. Defaults to None. Exactly one of dataset_id or dataset_series must be specified.

  • test_dataset_id (Union[str, None], optional): Key to an S3 object with a JSON array of keys for evaluation text objects.

  • steps_per_round (Union[int, None], optional): Maximum number of training steps per call to train(). This parameter can be used to limit the training duration per execution, while debugging/experimenting for quick turnaround.

  • max_seq_length (int, optional): Maximum token count per training example. This parameter is critical for memory management and ensuring the model can handle the inputs without exceeding GPU memory limits. Defaults to 2048.

  • max_steps (Union[int, None]): Explicit maximum training steps. This parameter sets a hard limit on the number of training steps, providing a way to precisely control the training duration. It's necessary to specify this if not using estimated_max_steps(). Defaults to None.

  • save_steps (int, optional): Interval of training steps after which a checkpoint is automatically saved. Regular checkpoints are crucial for resuming training without losing progress. Defaults to 1000.

Methods

  • train(): Starts the training process, utilizing checkpoints and dataset management to efficiently continue training sessions.

Turnkey solution for training a QLoRA on vast.ai

  1. Place a large number of text documents in an S3 bucket
  2. Create a DATASET_SERIES in your bucket (detailed instructions in the next section)
  3. Make an IAM that can read your dataset S3 bucket and read/write your checkpoints bucket (can be the same if you like) and generate "Access Keys" for that IAM
  4. Login to your vast.ai account
  5. Create a new vast template here
    1. Choose the Image Path stevemadere/vast_train_llm:qlora
    2. Set your Docker Options to this (editing as necessary)
    -e NOTEBOOK_DIR=/root/notebooks
    -e HUGGINGFACE_TOKEN=REPLACE_WITH_YOUR_HUGGINGFACE_API_TOKEN
    -e HF_MODEL_NAME=Mistral-7B-v0.1
    -e HF_CONTRIBUTOR=mistralai
    -e HF_MODEL_REVISION=main
    -e DATASET_BUCKET=THE_NAME_OF_YOUR_S3_BUCKET_CONTAINING_THE_TEXT_DOCUMENTS
    -e DATASET_SERIES=THE_KEY_PATTERN_FOR_YOUR_DATASET_SERIES
    -e CHECKPOINT_SYNC_URI=THE_S3_URI_WHERE_YOU_WANT_TO_SAVE_MODEL_CHECKPOINTS
    -e DATA_DIR=/root/huggingface
    -e AWS_ACCESS_KEY_ID=YOUR_IAM_CREDENTIALS
    -e AWS_SECRET_ACCESS_KEY=YOUR_IAM_CREDENTIALS
    -e OUTPUT_DIR=/root/outputs
    -e STEPS_PER_ROUND=10000
    -e SHOULD_DOWNLOAD_MODEL=YES
    -e SHOULD_START_TRAINING=YES
    
    1. select Run interactive shell server, SSH and Use direct SSH connection
    2. Empty the "On-start Script" field as there is already an onstart.sh script in the docker image
    3. Fill in a template name and description
    4. press [SELECT AND SAVE]
  6. Go to the Search tab and find a host with a RTX-4090 and high bandwidth (>500Mbps is best). (Watch out for excessive bandwidth rates. Some unscrupulous hosts try to pull a fast one with bandwidth rates exceeding $20/TB whereas most charge less than $3/TB)
  7. RENT
  8. Check to see that everything is working by switching to the Instances tab and pressing the [ >_ CONNECT ] button which will simply give you an ssh command to copy and paste to your local shell.
  • It may take several minutes for the base model to finish downloading to your instance (about 5-7 minutes if the available bandwidth is around 500 Mbps). If you suspect that the base model download failed, you can examine the system logs from the instance card.
  • Once the base model finishes downloading, a pair of log files should be created by the training process in /root/continuing_trainer.info.log and continuing_trainer.debug.log. You can examine either of those to see what kind of progress the trainer is making.

Creating a DATASET_SERIES

While the ContinuingTrainer can flexibly handle multiple methods of specifying the set of documents in your bucket to be used as a training corpus, the example script used in the pre-built docker image only allows for the DATASET_SERIES pattern method.

This is how you create a DATASET_SERIES:

  1. Decide on a name pattern such as my_datasets/pretraining_test_{segment_number}.json.gz
  2. Decide how many segments you want to have in your series. One is fine.
  3. Make a list of all of the S3 keys of all of the text documents in your dataset bucket that you want to use for your training dataset.
  4. Split that list into as many segments as you chose in step 2.
  5. Create compressed JSON files, each containing a list of document keys from one of those segments of the document ids list and each JSON file with a name matching the pattern you chose in step 1. e.g.: my_datasets/pretraining_test_1.json.gz
  6. Upload those files to your S3 bucket with keys matching the filenames (aws s3 sync works well for this)

If you encounter problems, feel free to reach out to me on linked-in

Contributing

Contributions to continuing_education are welcome! Feel free to open an issue if you encounter problems or want to suggest features.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

continuing_education-0.0.2a3.tar.gz (28.3 kB view details)

Uploaded Source

Built Distribution

continuing_education-0.0.2a3-py3-none-any.whl (23.2 kB view details)

Uploaded Python 3

File details

Details for the file continuing_education-0.0.2a3.tar.gz.

File metadata

File hashes

Hashes for continuing_education-0.0.2a3.tar.gz
Algorithm Hash digest
SHA256 2eb750288d1fcf0ef2dcc8f97a98cbc553a3db4e18b4be9d257b21c0ad83c20b
MD5 84e5eeb15f850635b656fc827eb48170
BLAKE2b-256 01bc3ddc1d6592deaa87dfdca0584b32477aa125522186ea9e7015fb5ef14b20

See more details on using hashes here.

File details

Details for the file continuing_education-0.0.2a3-py3-none-any.whl.

File metadata

File hashes

Hashes for continuing_education-0.0.2a3-py3-none-any.whl
Algorithm Hash digest
SHA256 a2821a062d2e784ffe95d065ebdbe77f71067f725356651ee5e1c239de8b22e5
MD5 2e39e49cf02aad29d29678762ec407da
BLAKE2b-256 1e8ea6749735cff96282cd8d5860f9a031c2e8b320a7cfcb8c3a196adaf9e050

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page