Skip to main content

A library for making PyTorch models streamable

Project description

TorchStream

TorchStream is a library to help ML developers stream PyTorch models without retraining nor rewriting them, in order to reduce their latency or use them in live applications.

TorchStream comes with a website of live examples.

Installation

Install as a package (any OS, CUDA optional):

(uv) pip install torchstream-lib

Install as a project, to run the streamlit examples yourself (some examples require CUDA):

git clone https://github.com/CorentinJ/TorchStream
cd TorchStream
uv run --group demos streamlit run examples

If you don't have uv yet:

# On Windows:
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
# On Linux
curl -LsSf https://astral.sh/uv/install.sh | sh

# Alternatively, on any platform if you have pip installed you can do
pip install -U uv

Overview

TorchStream offers a set of tools to help you stream complex neural networks and other sequence to sequence transforms.

The example below requires cloning the project and installing the demos dependencies (uv sync --group demos). It streams BigVGAN, a state of the art neural vocoder:

import logging

import librosa
import torch

from examples.resources.bigvgan.bigvgan import BigVGAN
from examples.resources.bigvgan.meldataset import get_mel_spectrogram
from torchstream import SeqSpec, SlidingWindowStream, find_sliding_window_params

logging.basicConfig(level=logging.INFO)

device = "cuda"
model = BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x").eval().to(device)
model.remove_weight_norm()

# Get a sample mel spectrogram input
wave, sample_rate = librosa.load(librosa.ex("libri1"), sr=model.h.sampling_rate)
mel = get_mel_spectrogram(torch.from_numpy(wave).unsqueeze(0), model.h).to(device)

# Specify the model's input format: a melspectrogram
in_spec = SeqSpec(1, model.h.num_mels, -1, device=device)
# Output format: an audio waveform
out_spec = SeqSpec(1, 1, -1, device=device)

# Use TorchStream's solver to find the sliding window parameters of BigVGAN
sli_params = find_sliding_window_params(
    trsfm=model,
    in_spec=in_spec,
    out_spec=out_spec,
    max_in_out_seq_size=1_000_000,
)[0]

# Perform streaming inference
stream = SlidingWindowStream(model, sli_params, in_spec, out_spec)
for audio_chunk in stream.forward_in_chunks_iter(mel, chunk_size=80):
    print(f"Got a {tuple(audio_chunk.shapes[0])} shaped audio chunk")

Disclaimer

TorchStream is developed by myself. It is not affiliated with, endorsed by, or sponsored by the PyTorch team or Meta.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchstream_lib-1.0.1.tar.gz (70.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchstream_lib-1.0.1-py3-none-any.whl (78.0 kB view details)

Uploaded Python 3

File details

Details for the file torchstream_lib-1.0.1.tar.gz.

File metadata

  • Download URL: torchstream_lib-1.0.1.tar.gz
  • Upload date:
  • Size: 70.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.5

File hashes

Hashes for torchstream_lib-1.0.1.tar.gz
Algorithm Hash digest
SHA256 71049ca5035cdc03693249683bef9766978e94f28ae2baff25ade50caf8b7253
MD5 59cee93d6e5ee16d88653e8586c7d26c
BLAKE2b-256 ac1032d75cf3d68558d104af3e8c5b43687fdb6adda45331e309173e5e7e65c4

See more details on using hashes here.

File details

Details for the file torchstream_lib-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torchstream_lib-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9f4dd07ffbf765f67bb8f41051eb4a37b23882c02416e7699c48a9cae5361212
MD5 3f5faae0ec1fba2b8fe3a3ac822c8ba2
BLAKE2b-256 241dc1cfb97108a3eea55edef8a251547e1929e460fac719f5111a711e8ffa0e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page