Skip to main content

Package for independent vector analysis in torch

Project description

A package for blind source separation and beamforming in pytorch .

  • supports many BSS and beamforming methods

  • supports memory efficient gradient computation for training neural source models

  • supports batched computations

  • can run on GPU via pytorch

Author

Quick Start

This guide assumes anaconda is installed:

# get code and install environment
git clone <torchiva_repo>
cd torchiva
conda env create -f environment.yml
conda activate torchiva
pip install -e .

cd ./examples
export PYTHONPATH="/path/to/torchiva":$PYTHONPATH"

# BSS example
# algorithm can be selected from tiss, auxiva_ip, auxiva_ip2, and five
python ./example.py PATH_TO_DATASET ALGORITHM

Separation using Pre-trained Model

We provide pre-trained model at — hugging face link —. The model is trained with WSJ1-mix dataset with the same configuration as ./configs/tiss.json. You can easily try separation with the pre-trained model:

# download model parameters from hugging face

# Separation
python ./example_dnn.py ./configs/tiss.json PATH_TO_DATASET PATH_TO_MODEL_PARAMS

Training

We provide some simple training scripts. We support training of T-ISS, MWF, MVDR, GEV:

cd examples

# install some modules necessary for training
pip install -r requirements.txt

# training
python train.py PATH_TO_CONFIG PATH_TO_DATASET

Note that our example scripts assumes using WSJ1-mix dataset. If you want to use other datasets, please change the script in the part that loads audios.

Test your trained model with checkpoint from epoch 128:

# python ./test.py --dataset ../wsj1_6ch --n_fft 2048 --hop 512 --n_iter 40 --iss-hparams checkpoints/tiss_delay1tap5_2ch/lightning_logs/version_0/hparams.yaml --epoch 128 --test

Export the trained model for later use:

python ./export_model.py ../trained_models/tiss checkpoints/tiss_delay1tap5_2ch/lightning_logs/version_0 128 146 148 138 122 116 112 108 104 97

Run the example script using the exported model:

python ./example_dnn.py ../wsj1_6ch ../trained_models/tiss -m 2 -r 100

License

2022 (c) Robin Scheibler, Kohei Saijo, LINE Corporation.

All of this code is released under MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchiva-0.1.0.tar.gz (32.6 kB view details)

Uploaded Source

File details

Details for the file torchiva-0.1.0.tar.gz.

File metadata

  • Download URL: torchiva-0.1.0.tar.gz
  • Upload date:
  • Size: 32.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.8

File hashes

Hashes for torchiva-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0ef5c4ff249bfd146ff833da419c6d30287f96a7cab540924e9d5ad2fd526e23
MD5 a403193e704f783ff8dacf6b00f5fef9
BLAKE2b-256 f0aa6daf402dec40806563032eba4fca304359fc013672bacee53d00a0c211f9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page