Skip to main content

Train a speech-to-speech model using your own language model

Project description

🦙🎤 Llama-Jarvis

Lint Status Tests Status contributions welcome

alt text Train a speech-to-speech model using your own language model. Currently based on the Seamless Model, but plan to support more models in the future.

This model is based on speech-to-speech models such as Llama-Omni. However, it aims to take advantage of the joint speech-text embeddings of the Seamless Model.

This code is very much a work in progress. Any and all contributions are welcome!

Why this Library?

This library aims to make speech-to-speech models more compatible with the HuggingFace ecosystem, rather than requiring you to modify your models and datasets to work with a new library. This allows us to take advantage of things like the HuggingFace Trainer.

Getting Started

NOTE For some of the below, you may have to first log in to HuggingFace to gain access to the gated models (especially Llama models).

Running Locally

This code is not yet available via PyPi (I am hesitant to release it without thoroughly testing the code). Thus, to try it locally, please run

git clone https://github.com/johnsutor/llama-jarvis
cd llama-jarvis 
pip install -e . 

Phase One Loss

The example code will return the phase one loss (i.e., when training the first phase of Llama-Omni)

from llama_jarvis.model import JarvisModel, JarvisConfig, JarvisProcessor

BASE_LLM = "meta-llama/Llama-3.2-1B"
SEAMLESS_MODEL = "facebook/hf-seamless-m4t-medium"
LANGUAGE = "eng"

jarvis_config = JarvisConfig(
    BASE_LLM,
    SEAMLESS_MODEL
)
jarvis_model = JarvisModel(jarvis_config)
jarvis_processor = JarvisProcessor(
    BASE_LLM,
    SEAMLESS_MODEL
)

inputs = processor(
    instruction=["You are a language model who should respond to my speech"],
    text=["What is two plus two?"],
    label=["Two plus two is four"],
    src_lang=LANGUAGE,
    return_tensors="pt",
    padding=True
)

outputs = model.forward(
    **inputs,
    tgt_lang=LANGUAGE
)

print(output.loss)

Phase One Two

The example code will return the phase two loss (i.e., when training the second phase of Llama-Omni)

from llama_jarvis.model import JarvisModel, JarvisConfig, JarvisProcessor

BASE_LLM = "meta-llama/Llama-3.2-1B"
SEAMLESS_MODEL = "facebook/hf-seamless-m4t-medium"
LANGUAGE = "eng"

jarvis_config = JarvisConfig(
    BASE_LLM,
    SEAMLESS_MODEL
)
jarvis_model = JarvisModel(jarvis_config)
jarvis_processor = JarvisProcessor(
    BASE_LLM,
    SEAMLESS_MODEL
)

inputs = processor(
    instruction=["You are a language model who should respond to my speech"],
    text=["What is two plus two?"],
    label=["Two plus two is four"],
    src_lang=LANGUAGE,
    return_tensors="pt",
    padding=True
)

outputs = model.forward(
    **inputs,
    tgt_lang=LANGUAGE,
    train_phase=2
)

print(output.loss)

Roadmap

  • Release the code on PyPi
  • Train a baseline model using Llama 3.2 1B and Seamless Medium
  • Provide training example code
  • Fully document the code
  • Create an inference script for the model
  • Write thorough tests for the code, and test with a multitude of open-source models

Other Cool Libraries

We take a lot of inspiration from some other nice open-source libraries out there. Shoutout to

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llama_jarvis-0.1.0.tar.gz (110.6 kB view details)

Uploaded Source

Built Distribution

llama_jarvis-0.1.0-py3-none-any.whl (13.7 kB view details)

Uploaded Python 3

File details

Details for the file llama_jarvis-0.1.0.tar.gz.

File metadata

  • Download URL: llama_jarvis-0.1.0.tar.gz
  • Upload date:
  • Size: 110.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for llama_jarvis-0.1.0.tar.gz
Algorithm Hash digest
SHA256 ba157a5dfed9dec02c6f73e14c70dfe99187d76a7e9187b68e1d0b303b72c110
MD5 17567f5bd70e8110dc951826e6256eba
BLAKE2b-256 eb85ca0bc953b342855eb6f2dfb2cf2b5e7b65b898242df617fd90a0710a0669

See more details on using hashes here.

File details

Details for the file llama_jarvis-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for llama_jarvis-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1d060dc0b6a41d68eac42fea2c9fa3b1174b90ae40df91b39c2dd7ddc221b2e6
MD5 99783667c1febf17daff75f07eb7265a
BLAKE2b-256 75656ff8b96c166864380239df2d3a6c913d4706e8f10e7704d039ae6889c3f6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page