Train transfomer-based models
Project description
Zelda Rose
A trainer for transformer-based models.
Installation
Simply install with pip (preferably in a virtual env, you know the drill)
pip install git+https://github.com/LoicGrobol/zeldarose.git
Train a model
Here is a short example:
zeldarose-tokenizer --vocab-size 4096 --out-path local/tokenizer --model-name "my-muppet" tests/fixtures/raw.txt
zeldarose-transformer --tokenizer local/tokenizer --pretrained-model flaubert/flaubert_small_cased --out-dir local/muppet --val-text tests/fixtures/raw.txt tests/fixtures/raw.txt
There are other parameters (see zeldarose-transformer --help
for a comprehensive list), the one you are probably mostly interested in is --config
(for which there is an example target in examples/
).
The parameters --pretrained-models
, --tokenizer
and --model-config
are all fed directly to Huggingface's transformers
and can be pretrained models names or local path.
Distributed training
This is somewhat tricky, you have several options
-
If you are running in a SLURM cluster use
--distributed-backend ddp
and invoke viasrun
-
Otherwise you have two options
- Run with
--distributed-backend ddp_spawn
, which usesmultiprocessing.spawn
to start the process swarm (tested, but possibly slower and more limited, seepytorch-lightning
doc) - Run with
--distributed-backend ddp
and start withtorch.distributed.launch
with--use_env
and--no_python
(untested)
- Run with
Whatever you do, for now it's safer to run once without distributed training in order to preprocess the raw texts in a predictable environment.
Inspirations
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for zeldarose-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4ddda274c334b2acbb797901d8219ec90d4c87c4f4dd6fd9dba756ad80b7125c |
|
MD5 | 1a3109fae0dc61ed8bbed077db971b4c |
|
BLAKE2b-256 | e2889df11f5afaa736c9cad1abadffdcb9bf93e1034b92bfecd62fb854bc0f58 |