ELECTRA-style RTD pretraining with Gradient-Disentangled Embedding Sharing (GDES)
Project description
GDES-DeBERTaV3
ELECTRA-style training with Gradient-Disentangled Embedding Sharing (GDES) in native PyTorch using HuggingFace Transformers.
Overview
GDES-DeBERTaV3 implements the replaced token detection (RTD) pretraining objective from ELECTRA with the gradient-disentangled embedding sharing technique introduced in DeBERTaV3. Rather than relying on a separate generator and discriminator network, GDES shares embeddings between the two while disentangling their gradient flows — enabling more parameter-efficient training without the instability of naive weight tying.
The training loop performs two forward passes per step:
- Generator pass — predict masked tokens via MLM
- Discriminator pass — classify each token as original or replaced, with embedding gradients frozen to disentangle the two objectives
Installation
git clone https://github.com/rsyue/rtd-gdes
cd rtd-gdes
pip install .
If you intend to contribute or modify the source, an editable install is recommended so changes are reflected immediately without reinstalling:
pip install -e .
Requirements
- Python ≥ 3.10
- PyTorch ≥ 2.0 (CUDA recommended, ROCm builds supported)
- Transformers
- Datasets
- Safetensors
- scikit-learn
- tqdm
Quick Start
python -m rtd_gdes.train \
--model microsoft/deberta-v3-base \
--lambda_disc 0.5 \
--batch_size 8 \
--epochs 5 \
--learning_rate 2e-5 \
--weight_decay 0.01 \
--gamma 0.9 \
--output_dir ./checkpoints \
--bf16
Usage
CLI Arguments
| Argument | Flag | Type | Default | Description |
|---|---|---|---|---|
--model |
-m |
str |
microsoft/deberta-v3-base |
Pretrained model to train with RTD + GDES |
--lambda_disc |
-ld |
float |
0.5 |
Lambda coefficient scaling the discriminator loss |
--batch_size |
-bs |
int |
8 |
Batch size for training and evaluation |
--epochs |
-ep |
int |
5 |
Number of training epochs |
--learning_rate |
-lr |
float |
2e-5 |
Learning rate for AdamW |
--weight_decay |
-wd |
float |
0.01 |
Weight decay for AdamW |
--gamma |
-g |
float |
0.9 |
Gamma for exponential LR scheduler |
--dataset |
str |
imdb |
HuggingFace dataset name | |
--output_dir |
-o |
str |
None |
Directory to save the model and tokenizer. The model is saved under <output_dir>/<save_name>/. Defaults to ./<save_name>/ |
--fp16 |
flag |
False |
Enable FP16 mixed precision | |
--bf16 |
flag |
False |
Enable BF16 mixed precision | |
--compile |
-c |
flag |
False |
Run torch.compile with max-autotune mode |
Training Details
The script trains on the IMDB unsupervised split by default, with a configurable 90/10 train/eval split. The dataset can be changed via --dataset. The combined loss is computed as:
$$\mathcal{L} = \mathcal{L}{\text{gen}} + \lambda \cdot \mathcal{L}{\text{disc}}$$
where $\mathcal{L}{\text{gen}}$ is the standard MLM cross-entropy loss and $\mathcal{L}{\text{disc}}$ is binary cross-entropy over token-level replaced/original predictions.
Evaluation reports discriminator loss, accuracy, and F1 score on the held-out set.
Saved Outputs
After training, the model and tokenizer are saved under <output_dir>/<save_name>/ where save_name is derived from the model id (e.g. deberta_v3_base_gdes). If --output_dir is not specified, the model is saved to ./<save_name>/ in the current working directory.
Development
Install with dev dependencies:
pip install -e ".[dev]"
Run the test suite:
pytest tests/ -v --cov=rtd_gdes
Project Structure
rtd-gdes/
├── src/
│ └── rtd_gdes/
│ ├── config.py # TrainConfig dataclass — all hyperparameter defaults
│ ├── train.py # Entry point and CLI
│ └── gdes/
│ ├── data.py # Dataset loading and DataLoader construction
│ ├── model.py # DebertaV3GDES — generator + discriminator
│ ├── trainer.py # train_one_epoch and evaluate loops
│ └── utils.py # Shared exceptions
├── tests/
│ └── test_gdes.py # Model, trainer, and config unit + integration tests
└── pyproject.toml
Roadmap
- Distributed training (DDP / FSDP)
- Publish as PyPI package
- Support additional model architectures beyond DeBERTaV3
Citation
If you use this code, please cite the original papers:
@article{he2021debertav3,
title={DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing},
author={He, Pengcheng and Liu, Jianfeng and Gao, Jianfeng and Chen, Weizhu},
journal={arXiv preprint arXiv:2111.09543},
year={2021}
}
@article{clark2020electra,
title={ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators},
author={Clark, Kevin and Luong, Minh-Thang and Le, Quoc V. and Manning, Christopher D.},
journal={arXiv preprint arXiv:2003.10555},
year={2020}
}
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rtd_gdes-0.1.4.1.tar.gz.
File metadata
- Download URL: rtd_gdes-0.1.4.1.tar.gz
- Upload date:
- Size: 13.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d5f3b0d2bd1847b1a1e3f4ab2e10ab27505b2623be4ffc3454468ee05fddc420
|
|
| MD5 |
1d9ba70316be875033c39c348f0e6731
|
|
| BLAKE2b-256 |
8ff446672f8b3d65d7f0b81b69b677131e7a0f4492c1ea616bea3f1770319de0
|
File details
Details for the file rtd_gdes-0.1.4.1-py3-none-any.whl.
File metadata
- Download URL: rtd_gdes-0.1.4.1-py3-none-any.whl
- Upload date:
- Size: 12.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
93ec19f3eae99c08a581aad36ebbf6bfb6de244991717dfe4c733c01b8c0f7ba
|
|
| MD5 |
accde2bbd4f735a288561c5ec07d8261
|
|
| BLAKE2b-256 |
abf5b0c6e7bfbde22bac51455f58dc50c7c7967d5c2f7522840fc323af0eb25e
|