Skip to main content

A library for making PyTorch models streamable

Project description

TorchStream

TorchStream is a library to help ML developers stream PyTorch models without retraining nor rewriting them, in order to reduce their latency or use them in live applications.

TorchStream comes with a website of live examples.

Installation

Install as a package (any OS, CUDA optional):

(uv) pip install torchstream-lib

Install as a project, to run the streamlit examples yourself (set "--extra cpu" for a cpu only install):

git clone https://github.com/CorentinJ/TorchStream
cd TorchStream
uv run --group demos streamlit run examples --extra cuda

If you don't have uv yet:

# On Windows:
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
# On Linux
curl -LsSf https://astral.sh/uv/install.sh | sh

# Alternatively, on any platform if you have pip installed you can do
pip install -U uv

Overview

TorchStream offers a set of tools to help you stream complex neural networks and other sequence to sequence transforms.

The example below requires cloning the project and installing the demos dependencies (uv sync --group demos). It streams BigVGAN, a state of the art neural vocoder:

import logging

import librosa
import torch

from examples.resources.bigvgan.bigvgan import BigVGAN
from examples.resources.bigvgan.meldataset import get_mel_spectrogram
from torchstream import SeqSpec, SlidingWindowStream, find_sliding_window_params

logging.basicConfig(level=logging.INFO)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x").eval().to(device)
model.remove_weight_norm()

# Get a sample mel spectrogram input
wave, sample_rate = librosa.load(librosa.ex("libri1"), sr=model.h.sampling_rate)
mel = get_mel_spectrogram(torch.from_numpy(wave).unsqueeze(0), model.h).to(device)

# Specify the model's input format: a melspectrogram
in_spec = SeqSpec(1, model.h.num_mels, -1, device=device)
# Output format: an audio waveform
out_spec = SeqSpec(1, 1, -1, device=device)

# Use TorchStream's solver to find the sliding window parameters of BigVGAN
sli_params = find_sliding_window_params(
    trsfm=model,
    in_spec=in_spec,
    out_spec=out_spec,
    max_in_out_seq_size=1_000_000,
)[0]

# Perform streaming inference
stream = SlidingWindowStream(model, sli_params, in_spec, out_spec)
for audio_chunk in stream.forward_in_chunks_iter(mel, chunk_size=80):
    print(f"Got a {tuple(audio_chunk.shapes[0])} shaped audio chunk")

Disclaimer

TorchStream is developed by myself. It is not affiliated with, endorsed by, or sponsored by the PyTorch team or Meta.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchstream_lib-1.0.2.tar.gz (57.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchstream_lib-1.0.2-py3-none-any.whl (63.2 kB view details)

Uploaded Python 3

File details

Details for the file torchstream_lib-1.0.2.tar.gz.

File metadata

  • Download URL: torchstream_lib-1.0.2.tar.gz
  • Upload date:
  • Size: 57.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.5

File hashes

Hashes for torchstream_lib-1.0.2.tar.gz
Algorithm Hash digest
SHA256 a0f78101f8b810943a2b2aa929e36c91eff25f99c5b1d9a8e5ca521d2fe814f4
MD5 915aad4463c188bb84c38bdb200353cc
BLAKE2b-256 28a477b3623233366194bd19bad6daaa5e955c877ee865786a0b1077287b3099

See more details on using hashes here.

File details

Details for the file torchstream_lib-1.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torchstream_lib-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d8814dc3272f457c177c673e541ffdaddb91cfd1a73af661b6fe5f7c377667a1
MD5 42656bde5298f6f5400fd659d0bd29b6
BLAKE2b-256 c13ec759e546570c05bd28d9e8c533743818785133b1a100791c1e323bd806f5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page