Skip to main content

Coconut in Pytorch

Project description

🥥 Coconut

Implementation of Coconut, proposed by the paper Training Large Language Models to Reason in a Continuous Latent Space out of FAIR, in Pytorch

Architecture wise, the closest work to the one proposed here would be RMT, where the memory tokens there could serve as the continuous latent tokens. Both directions are worth exploring

Install

$ pip install coconut-pytorch

Usage

import torch
from coconut_pytorch import Coconut

model = Coconut(
    num_reasoning_steps = 3,
    num_latents_per_step = 1,
    transformer = dict(
        num_tokens = 256,
        dim = 512,
        depth = 6
    )
)

prompt = torch.randint(0, 256, (2, 1024))
answer = torch.randint(0, 256, (2, 64))

loss = model(prompt, answer)
loss.backward()

# after much training

answer = model.generate(prompt, max_length = 64) # (2, 64)

Citation

@inproceedings{Hao2024TrainingLL,
    title   = {Training Large Language Models to Reason in a Continuous Latent Space},
    author  = {Shibo Hao and Sainbayar Sukhbaatar and DiJia Su and Xian Li and Zhiting Hu and Jason Weston and Yuandong Tian},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:274610816}
}
@article{Burtsev2021MultiStreamT,
    title   = {Multi-Stream Transformers},
    author  = {Mikhail S. Burtsev and Anna Rumshisky},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2107.10342},
    url     = {https://api.semanticscholar.org/CorpusID:236171087}
}
@article{Zhu2024HyperConnections,
    title   = {Hyper-Connections},
    author  = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2409.19606},
    url     = {https://api.semanticscholar.org/CorpusID:272987528}
}
@inproceedings{Zhou2024ValueRL,
    title   = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
    author  = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273532030}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

coconut_pytorch-0.0.32.tar.gz (131.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

coconut_pytorch-0.0.32-py3-none-any.whl (13.5 kB view details)

Uploaded Python 3

File details

Details for the file coconut_pytorch-0.0.32.tar.gz.

File metadata

  • Download URL: coconut_pytorch-0.0.32.tar.gz
  • Upload date:
  • Size: 131.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for coconut_pytorch-0.0.32.tar.gz
Algorithm Hash digest
SHA256 1ba32ef07c878e22463cb342f2ab2af614c925f3e4c3c07bc97371251a3cee21
MD5 3f5150275af7e20286076ed6fedebb7e
BLAKE2b-256 b09afc6fb6f84b1667770b9fdb6c047bbebb45021c026dacdd3b0ad73253bee3

See more details on using hashes here.

File details

Details for the file coconut_pytorch-0.0.32-py3-none-any.whl.

File metadata

File hashes

Hashes for coconut_pytorch-0.0.32-py3-none-any.whl
Algorithm Hash digest
SHA256 a05f8a9d224e6ea15110d440b362396e11047a48bc751e2b82e976b87ff076ac
MD5 fbc1d2c24b74e8cdeec4a14d4f245bf5
BLAKE2b-256 b9cb93322861b3dd6faf75b2c3dda969ddfc718c0be5d40eb6fb20692b9357dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page