Skip to main content

Train transfomer-based models

Project description

Zelda Rose

Latest PyPI version Build Status Code style: black

A trainer for transformer-based models.

Installation

Simply install with pip (preferably in a virtual env, you know the drill)

pip install zeldarose

Train a model

Here is a short example:

TOKENIZERS_PARALLELISM=true zeldarose-tokenizer --vocab-size 4096 --out-path local/tokenizer  --model-name "my-muppet" tests/fixtures/raw.txt
zeldarose-transformer --tokenizer local/tokenizer --pretrained-model flaubert/flaubert_small_cased --out-dir local/muppet --val-text tests/fixtures/raw.txt tests/fixtures/raw.txt

There are other parameters (see zeldarose-transformer --help for a comprehensive list), the one you are probably mostly interested in is --config (for which there is an example target in examples/).

The parameters --pretrained-models, --tokenizer and --model-config are all fed directly to Huggingface's transformers and can be pretrained models names or local path.

Distributed training

This is somewhat tricky, you have several options

  • If you are running in a SLURM cluster use --accelerator ddp and invoke via srun

  • Otherwise you have two options

    • Run with --accelerator ddp_spawn, which uses multiprocessing.spawn to start the process swarm (tested, but possibly slower and more limited, see pytorch-lightning doc)
    • Run with --accelerator ddp and start with torch.distributed.launch with --use_env and --no_python (untested)

Other hints

  • Data management relies on 🤗 datasets and use their cache management system. To run in a clear environment, you might have to check the cache directory pointed to by theHF_DATASETS_CACHE environment variable.

Inspirations

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zeldarose-0.3.4.tar.gz (15.0 kB view hashes)

Uploaded Source

Built Distribution

zeldarose-0.3.4-py3-none-any.whl (15.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page