System to ease incremental training of a Huggingface transformer model from a large S3-based dataset
Project description
continuing_education
The continuing_education
module provides for easily training large language models (LLMs) with Quantized Low Rank Adapters (QLoRA), focusing on network and storage efficiency and the ability to continue training from checkpoints. This is particularly useful in environments prone to interruptions or where network bandwidth and storage space is limited. e.g. vast.ai or salad.com instances.
Features
- Efficient Training with QLoRA: Enhance your Hugging Face Transformer models with Low Rank Adapters, optimizing for both performance and memory efficiency.
- Checkpointing for Continuation: Seamlessly resume training from the last checkpoint, minimizing data reprocessing and model initialization times.
- AWS Integration: Leverage AWS S3 buckets for dataset storage, ensuring scalable and accessible data management.
- Flexible Training Schemes: Configure training sessions according to your specific needs, including setting steps per round, maximum steps, and segment-based training.
- Automatic Tokenization and Dataset Preparation: Utilize integrated tokenization and dataset management for a streamlined setup process.
Installation
pip install continuing-education
Usage
Setting Up Environment Variables
Before using continuing_education
, ensure the following environment variables are set:
AWS_PROFILE
orAWS_ACCESS_KEY_ID
andAWS_SECRET_ACCESS_KEY
: For AWS authentication.
Example Usage
from continuing_education import QLoRAContinuingTrainer, logger
import os
# Ensure AWS credentials and bucket names are set
if 'AWS_PROFILE' not in os.environ and not ('AWS_SECRET_ACCESS_KEY' in os.environ and 'AWS_ACCESS_KEY_ID' in os.environ):
raise EnvironmentError("AWS credentials required.")
# Where to get training data, locally store, and remotely store checkpoints
dataset_bucket_name = os.environ['DATASET_BUCKET']
output_dir = os.environ['OUTPUT_DIR']
checkpoint_sync_uri = os.environ['CHECKPOINT_SYNC_URI']
base_model_id = 'Mistralai/Mistral-7Bv0.1' # Model which will have a trainable LoRA attached to it
# Initialize trainer with environment configurations
continuing_trainer = QLoRAContinuingTrainer(
base_model_id=base_model_id,
dataset_bucket_name=dataset_bucket_name,
checkpoint_sync_uri=checkpoint_sync_uri,
output_dir=output_dir,
steps_per_round=10000, # Example configuration
max_steps=250000 # Example configuration
)
# Start training
continuing_trainer.train()
Note that it is up to you to download the resulting checkpoints unless you supply checkpoint_sync_uri. An example script to do this with rsync to a vast.ai instance is included in the examples directory in the source repo: vast_sync.bash
API Reference
QLoRAContinuingTrainer
QLoRAContinuingTrainer
is the core class for training models with QLoRA. It extends BaseContinuingTrainer
, providing mechanisms to load models, prepare datasets, and manage training sessions effectively.
Initialization Parameters
The QLoRAContinuingTrainer
class initializes with the following parameters:
-
base_model_id (str)
: Identifier for the base model to which LoRA will be attached. This is the ID of a Hugging Face model that will serve as the starting point for training. -
dataset_bucket_name (str)
: Name of the S3 bucket containing the training text documents. -
output_dir (str, optional)
: Directory where checkpoints will be saved. Defaults to "/root/outputs". It is crucial for managing training interruptions and resumptions. -
checkpoint_sync_uri (str,optional)
: URI describing a place to store and retrieve checkpoints and the checkpoint registry json (typically something like "s3://my-bucket/checkpoints/seriesname") -
dataset_id (Union[str, None], optional)
: Key to an S3 object containing a JSON array of keys for training text objects. This is used if your dataset is a single, JSON document listing all of the keys of all of the training text documents. -
dataset_series (Union[str, None], optional)
: Pattern for S3 keys of JSON objects specifying keys of objects for training. Must include "{segment_number}" if specified. This is useful for large datasets where the key list is split across multiple files. Defaults to None. Exactly one ofdataset_id
ordataset_series
must be specified. -
test_dataset_id (Union[str, None], optional)
: Key to an S3 object with a JSON array of keys for evaluation text objects. -
steps_per_round (Union[int, None], optional)
: Maximum number of training steps per call totrain()
. This parameter can be used to limit the training duration per execution, while debugging/experimenting for quick turnaround. -
max_seq_length (int, optional)
: Maximum token count per training example. This parameter is critical for memory management and ensuring the model can handle the inputs without exceeding GPU memory limits. Defaults to 2048. -
max_steps (Union[int, None])
: Explicit maximum training steps. This parameter sets a hard limit on the number of training steps, providing a way to precisely control the training duration. It's necessary to specify this if not usingestimated_max_steps()
. Defaults to None. -
save_steps (int, optional)
: Interval of training steps after which a checkpoint is automatically saved. Regular checkpoints are crucial for resuming training without losing progress. Defaults to 1000.
Methods
train()
: Starts the training process, utilizing checkpoints and dataset management to efficiently continue training sessions.
Turnkey solution for training a QLoRA on vast.ai
- Place a large number of text documents in an S3 bucket
- Create a DATASET_SERIES in your bucket (detailed instructions in the next section)
- Make an IAM that can read your dataset S3 bucket and read/write your checkpoints bucket (can be the same if you like) and generate "Access Keys" for that IAM
- Login to your vast.ai account
- Create a new vast template here
- Choose the Image Path stevemadere/vast_train_llm:qlora
- Set your Docker Options to this (editing as necessary)
-e NOTEBOOK_DIR=/root/notebooks -e HUGGINGFACE_TOKEN=REPLACE_WITH_YOUR_HUGGINGFACE_API_TOKEN -e HF_MODEL_NAME=Mistral-7B-v0.1 -e HF_CONTRIBUTOR=mistralai -e HF_MODEL_REVISION=main -e DATASET_BUCKET=THE_NAME_OF_YOUR_S3_BUCKET_CONTAINING_THE_TEXT_DOCUMENTS -e DATASET_SERIES=THE_KEY_PATTERN_FOR_YOUR_DATASET_SERIES -e CHECKPOINT_SYNC_URI=THE_S3_URI_WHERE_YOU_WANT_TO_SAVE_MODEL_CHECKPOINTS -e DATA_DIR=/root/huggingface -e AWS_ACCESS_KEY_ID=YOUR_IAM_CREDENTIALS -e AWS_SECRET_ACCESS_KEY=YOUR_IAM_CREDENTIALS -e OUTPUT_DIR=/root/outputs -e STEPS_PER_ROUND=10000 -e SHOULD_DOWNLOAD_MODEL=YES -e SHOULD_START_TRAINING=YES
- select Run interactive shell server, SSH and Use direct SSH connection
- Empty the "On-start Script" field as there is already an onstart.sh script in the docker image
- Fill in a template name and description
- press [SELECT AND SAVE]
- Go to the Search tab and find a host with a RTX-4090 and high bandwidth (>500Mbps is best). (Watch out for excessive bandwidth rates. Some unscrupulous hosts try to pull a fast one with bandwidth rates exceeding $20/TB whereas most charge less than $3/TB)
- RENT
- Check to see that everything is working by switching to the Instances tab and pressing the [ >_ CONNECT ] button which will simply give you an ssh command to copy and paste to your local shell.
- It may take several minutes for the base model to finish downloading to your instance (about 5-7 minutes if the available bandwidth is around 500 Mbps). If you suspect that the base model download failed, you can examine the system logs from the instance card.
- Once the base model finishes downloading, a pair of log files should be created by the training process in /root/continuing_trainer.info.log and continuing_trainer.debug.log. You can examine either of those to see what kind of progress the trainer is making.
Creating a DATASET_SERIES
While the ContinuingTrainer can flexibly handle multiple methods of specifying the set of documents in your bucket to be used as a training corpus, the example script used in the pre-built docker image only allows for the DATASET_SERIES pattern method.
This is how you create a DATASET_SERIES:
- Decide on a name pattern such as my_datasets/pretraining_test_{segment_number}.json.gz
- Decide how many segments you want to have in your series. One is fine.
- Make a list of all of the S3 keys of all of the text documents in your dataset bucket that you want to use for your training dataset.
- Split that list into as many segments as you chose in step 2.
- Create compressed JSON files, each containing a list of document keys from one of those segments of the document ids list and each JSON file with a name matching the pattern you chose in step 1. e.g.: my_datasets/pretraining_test_1.json.gz
- Upload those files to your S3 bucket with keys matching the filenames (aws s3 sync works well for this)
If you encounter problems, feel free to reach out to me on linked-in
Contributing
Contributions to continuing_education
are welcome! Feel free to open an issue if you encounter problems or want to suggest features.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file continuing_education-0.0.3a1.tar.gz
.
File metadata
- Download URL: continuing_education-0.0.3a1.tar.gz
- Upload date:
- Size: 28.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 108476fca03578d0bcb2085ea97794057b049226410a5ced4320b39f45fc0b1b |
|
MD5 | d30c32ff557d3cefcd62a561ccc2a9db |
|
BLAKE2b-256 | 8a552e97b12d4d7f740ee7b70446d5ffc86d59446017b720f69fd2833172dbb5 |
File details
Details for the file continuing_education-0.0.3a1-py3-none-any.whl
.
File metadata
- Download URL: continuing_education-0.0.3a1-py3-none-any.whl
- Upload date:
- Size: 23.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3cbd6dc0b933fe52a1e0c632d2fe3a5c23122c098ee48688be8da2659cee3c4d |
|
MD5 | c41ace6a214ff1403d8ddeaf1c0f2be9 |
|
BLAKE2b-256 | 4940ad472035ae2e6591fbec239d31aeb9f83388da2c9aafb2dea3023638944d |