Skip to main content

Explorations into Transformer Language Model with Adversarial Loss

Project description

Language model with adversarial loss

Explorations into adversarial losses on top of autoregressive loss for language modeling

I have tried this in the past, when GANs were still dominant. But at the time I was either too inexperienced or the research not there. Either way could not get it working. Will give it another shot in the next few weeks, mainly to see if an adversarial system could benefit world modeling

Usage

import torch

from transformer_lm_gan import (
    LanguageModelGenerator,
    Discriminator,
    GAN,
)

gan = GAN(
    strategy = 'gumbel_one_hot', # or 'rotate' for rotation trick, may try combination of two if both fails in experiments
    generator = dict(
        num_tokens = 256,
        dim = 512,
        depth = 6,
        dim_head = 64,
        heads = 8,
        max_seq_len = 1024
    ),
    discriminator = dict(
        num_tokens = 256,
        dim = 512,
        depth = 2,
        dim_head = 64,
        heads = 9,
        max_seq_len = 1024
    )
).cuda()

seq = torch.randint(0, 256, (2, 1024)).cuda()

discr_loss = gan.discriminate_forward(seq)
discr_loss.backward()

gen_loss = gan.generate_forward(seq)
gen_loss.backward()

Citations

@inproceedings{Huang2025TheGI,
    title   = {The GAN is dead; long live the GAN! A Modern GAN Baseline},
    author  = {Yiwen Huang and Aaron Gokaslan and Volodymyr Kuleshov and James Tompkin},
    year    = {2025},
    url     = {https://api.semanticscholar.org/CorpusID:275405495}
}
@article{Fifty2024Restructuring,
    title   = {Restructuring Vector Quantization with the Rotation Trick},
    author  = {Christopher Fifty, Ronald G. Junkins, Dennis Duan, Aniketh Iyengar, Jerry W. Liu, Ehsan Amid, Sebastian Thrun, Christopher Ré},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2410.06424},
    url     = {https://api.semanticscholar.org/CorpusID:273229218}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer_lm_gan-0.0.4.tar.gz (36.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transformer_lm_gan-0.0.4-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file transformer_lm_gan-0.0.4.tar.gz.

File metadata

  • Download URL: transformer_lm_gan-0.0.4.tar.gz
  • Upload date:
  • Size: 36.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for transformer_lm_gan-0.0.4.tar.gz
Algorithm Hash digest
SHA256 af1d80a696cf9dbe60aa5abc56ede035ec9858b657db037815d659877b430187
MD5 f8f6877a75191a028ec34cd6b91ac9d8
BLAKE2b-256 06120066e18b93fca3cb27e709683f5f7540e9e30869b823844aa435a0f7d81a

See more details on using hashes here.

File details

Details for the file transformer_lm_gan-0.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for transformer_lm_gan-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 8570544e2a5d32bb53a03d182e44df6dbce4be699e661ad06e3599ff7b9b363c
MD5 4deedb89df8aa352a6a3221d86803268
BLAKE2b-256 4278db6dc0143b8669e604ddf3613400cc80ee5ed0034d9bb67e7b2372ab1abf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page