Skip to main content

HS TasNet

Project description

HS-TasNet

Implementation of HS-TasNet, "Real-time Low-latency Music Source Separation using Hybrid Spectrogram-TasNet", proposed by the research team at L-Acoustics

Install

$ pip install HS-TasNet

Usage

import torch
from hs_tasnet import HSTasNet

model = HSTasNet()

audio = torch.randn(1, 2, 204800) # ~5 seconds of stereo

separated_audios, _ = model(audio)

assert separated_audios.shape == (1, 4, 2, 204800) # second dimension is the separated tracks

With the Trainer

# model

from hs_tasnet import HSTasNet, Trainer

model = HSTasNet()

# trainer

trainer = Trainer(
    model,
    dataset = None,               # add your in-house Dataset
    concat_musdb_dataset = True,  # concat the musdb dataset automatically
    batch_size = 2,
    max_steps = 2,
    cpu = True,
)

trainer()

# after much training
# inferencing

model.sounddevice_stream(
    duration_seconds = 2,
    return_reduced_sources = [0, 2]
)

# or from the exponentially smoothed model (in the trainer)

trainer.ema_model.sounddevice_stream(...)

# or you can load from a specific checkpoint

model.load('./checkpoints/path.to.desired.ckpt.pt')
model.sounddevice_stream(...)

# to load an HS-TasNet from any of the saved checkpoints, without having to save its hyperparameters, just run

model = HSTasNet.init_and_load_from('./checkpoints/path.to.desired.ckpt.pt')

Training script

First make sure dependencies are there by running

$ sh scripts/install.sh

Then make sure uv is installed

$ pip install uv

Finally run the following to train a newly initialized model on a small subset of MusDB, and make sure the loss goes down

$ uv run train.py

For distributed training, you just need to run accelerate config first, courtesy of accelerate from 🤗 but single machine is fine too

Experiment tracking

To enable online experiment monitoring / tracking, you need to have wandb installed and logged in

$ pip install wandb && wandb login

Then

$ uv run train.py --use-wandb

To wipe the previous checkpoints and evaluated results, append --clear-folders

Alternative RNNs

The architecture defaults to using PyTorch's LSTM (or GRU), but you can easily substitute it for any other module by passing an rnn_klass to the HSTasNet constructor, as long as it adheres to a specific interface (read alternative_rnns.py)

For example, to use the minGRU architecture:

import torch
from hs_tasnet import HSTasNet
from hs_tasnet.alternative_rnns import minGRUWrapper

model = HSTasNet(rnn_klass = minGRUWrapper)

audio = torch.randn(1, 2, 204800)
separated_audios, _ = model(audio)

Test

$ uv pip install '.[test]' --system

Then

$ pytest tests

Sponsors

This open sourced work is sponsored by Sweet Spot

Citations

@misc{venkatesh2024realtimelowlatencymusicsource,
    title    = {Real-time Low-latency Music Source Separation using Hybrid Spectrogram-TasNet},
    author   = {Satvik Venkatesh and Arthur Benilov and Philip Coleman and Frederic Roskam},
    year     = {2024},
    eprint   = {2402.17701},
    archivePrefix = {arXiv},
    primaryClass = {eess.AS},
    url      = {https://arxiv.org/abs/2402.17701},
}
@inproceedings{Feng2024WereRA,
    title   = {Were RNNs All We Needed?},
    author  = {Leo Feng and Frederick Tung and Mohamed Osama Ahmed and Yoshua Bengio and Hossein Hajimirsadegh},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273025630}
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hs_tasnet-0.3.2.tar.gz (19.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hs_tasnet-0.3.2-py3-none-any.whl (20.0 kB view details)

Uploaded Python 3

File details

Details for the file hs_tasnet-0.3.2.tar.gz.

File metadata

  • Download URL: hs_tasnet-0.3.2.tar.gz
  • Upload date:
  • Size: 19.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for hs_tasnet-0.3.2.tar.gz
Algorithm Hash digest
SHA256 3dc3baaa337b5dfa60ab612e02fe9762378ee55f3e5e9ca84c318f8152310a07
MD5 fd121cf9d4a0482d5a6290d06c3c12e0
BLAKE2b-256 52cb7a666fb9767dcab7a65a9052d468e8a9809fa309ff4f9eeb6a2d39879ac5

See more details on using hashes here.

File details

Details for the file hs_tasnet-0.3.2-py3-none-any.whl.

File metadata

  • Download URL: hs_tasnet-0.3.2-py3-none-any.whl
  • Upload date:
  • Size: 20.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for hs_tasnet-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 94161119ec57dc91269d8aba5ad1cccfa86960550e8e30ccb81e99f1fb69ba11
MD5 00a0020ef9f6d30f40e2bad8c067e60e
BLAKE2b-256 48c1dcb5a7752a65d02161931c557bd7733af4c5d0cf34016c6958dbf381e352

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page