Skip to main content

ELECTRA-style RTD pretraining with Gradient-Disentangled Embedding Sharing (GDES)

Project description

GDES-DeBERTaV3

ELECTRA-style training with Gradient-Disentangled Embedding Sharing (GDES) in native PyTorch using HuggingFace Transformers.

Overview

GDES-DeBERTaV3 implements the replaced token detection (RTD) pretraining objective from ELECTRA with the gradient-disentangled embedding sharing technique introduced in DeBERTaV3. Rather than relying on a separate generator and discriminator network, GDES shares embeddings between the two while disentangling their gradient flows — enabling more parameter-efficient training without the instability of naive weight tying.

The training loop performs two forward passes per step:

  1. Generator pass — predict masked tokens via MLM
  2. Discriminator pass — classify each token as original or replaced, with embedding gradients frozen to disentangle the two objectives

Installation

git clone https://github.com/rsyue/rtd-gdes
cd rtd-gdes
pip install .

If you intend to contribute or modify the source, an editable install is recommended so changes are reflected immediately without reinstalling:

pip install -e .

Requirements

  • Python ≥ 3.10
  • PyTorch ≥ 2.0 (CUDA recommended, ROCm builds supported)
  • Transformers
  • Datasets
  • Safetensors
  • scikit-learn
  • tqdm

Quick Start

python -m rtd_gdes.train \
  --model microsoft/deberta-v3-base \
  --lambda_disc 0.5 \
  --batch_size 8 \
  --epochs 5 \
  --learning_rate 2e-5 \
  --weight_decay 0.01 \
  --gamma 0.9 \
  --output_dir ./checkpoints \
  --bf16

Usage

CLI Arguments

Argument Flag Type Default Description
--model -m str microsoft/deberta-v3-base Pretrained model to train with RTD + GDES
--lambda_disc -ld float 0.5 Lambda coefficient scaling the discriminator loss
--batch_size -bs int 8 Batch size for training and evaluation
--epochs -ep int 5 Number of training epochs
--learning_rate -lr float 2e-5 Learning rate for AdamW
--weight_decay -wd float 0.01 Weight decay for AdamW
--gamma -g float 0.9 Gamma for exponential LR scheduler
--dataset str imdb HuggingFace dataset name
--output_dir -o str None Directory to save the model and tokenizer. The model is saved under <output_dir>/<save_name>/. Defaults to ./<save_name>/
--fp16 flag False Enable FP16 mixed precision
--bf16 flag False Enable BF16 mixed precision
--compile -c flag False Run torch.compile with max-autotune mode

Training Details

The script trains on the IMDB unsupervised split by default, with a configurable 90/10 train/eval split. The dataset can be changed via --dataset. The combined loss is computed as:

$$\mathcal{L} = \mathcal{L}{\text{gen}} + \lambda \cdot \mathcal{L}{\text{disc}}$$

where $\mathcal{L}{\text{gen}}$ is the standard MLM cross-entropy loss and $\mathcal{L}{\text{disc}}$ is binary cross-entropy over token-level replaced/original predictions.

Evaluation reports discriminator loss, accuracy, and F1 score on the held-out set.

Saved Outputs

After training, the model and tokenizer are saved under <output_dir>/<save_name>/ where save_name is derived from the model id (e.g. deberta_v3_base_gdes). If --output_dir is not specified, the model is saved to ./<save_name>/ in the current working directory.

Development

Install with dev dependencies:

pip install -e ".[dev]"

Run the test suite:

pytest tests/ -v --cov=rtd_gdes

Project Structure

rtd-gdes/
├── src/
│   └── rtd_gdes/
│       ├── config.py          # TrainConfig dataclass — all hyperparameter defaults
│       ├── train.py           # Entry point and CLI
│       └── gdes/
│           ├── data.py        # Dataset loading and DataLoader construction
│           ├── model.py       # DebertaV3GDES — generator + discriminator
│           ├── trainer.py     # train_one_epoch and evaluate loops
│           └── utils.py       # Shared exceptions
├── tests/
│   └── test_gdes.py           # Model, trainer, and config unit + integration tests
└── pyproject.toml

Roadmap

  • Distributed training (DDP / FSDP)
  • Publish as PyPI package
  • Support additional model architectures beyond DeBERTaV3

Citation

If you use this code, please cite the original papers:

@article{he2021debertav3,
  title={DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing},
  author={He, Pengcheng and Liu, Jianfeng and Gao, Jianfeng and Chen, Weizhu},
  journal={arXiv preprint arXiv:2111.09543},
  year={2021}
}

@article{clark2020electra,
  title={ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators},
  author={Clark, Kevin and Luong, Minh-Thang and Le, Quoc V. and Manning, Christopher D.},
  journal={arXiv preprint arXiv:2003.10555},
  year={2020}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rtd_gdes-0.1.4.1.tar.gz (13.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rtd_gdes-0.1.4.1-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file rtd_gdes-0.1.4.1.tar.gz.

File metadata

  • Download URL: rtd_gdes-0.1.4.1.tar.gz
  • Upload date:
  • Size: 13.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for rtd_gdes-0.1.4.1.tar.gz
Algorithm Hash digest
SHA256 d5f3b0d2bd1847b1a1e3f4ab2e10ab27505b2623be4ffc3454468ee05fddc420
MD5 1d9ba70316be875033c39c348f0e6731
BLAKE2b-256 8ff446672f8b3d65d7f0b81b69b677131e7a0f4492c1ea616bea3f1770319de0

See more details on using hashes here.

File details

Details for the file rtd_gdes-0.1.4.1-py3-none-any.whl.

File metadata

  • Download URL: rtd_gdes-0.1.4.1-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for rtd_gdes-0.1.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 93ec19f3eae99c08a581aad36ebbf6bfb6de244991717dfe4c733c01b8c0f7ba
MD5 accde2bbd4f735a288561c5ec07d8261
BLAKE2b-256 abf5b0c6e7bfbde22bac51455f58dc50c7c7967d5c2f7522840fc323af0eb25e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page