WavEncoder - PyTorch backed audio encoder
Project description
WavEncoder
WavEncoder is a Python library for encoding audio signals, transforms for audio augmentation, and training audio classification models with PyTorch backend.
Package Contents
Layers | Models | Transforms | Trainer and utils |
---|---|---|---|
|
|
|
|
Wav Models to be added
- wav2vec [1]
- wav2vec2 [2]
- SincNet [3]
- PASE [4]
- MockingJay [5]
- RawNet [6]
- GaborNet [7]
- LEAF [8]
- CNN-1D
- CNN-LSTM
- CNN-LSTM-Attn
Check the Demo Colab Notebook.
Installation
Use the package manager pip to install wavencoder.
pip install wavencoder
Usage
Import pretrained encoder, baseline models and classifiers
import torch
import wavencoder
x = torch.randn(1, 16000) # [1, 16000]
encoder = wavencoder.models.Wav2Vec(pretrained=True)
z = encoder(x) # [1, 512, 98]
classifier = wavencoder.models.LSTM_Attn_Classifier(512, 64, 2,
return_attn_weights=True,
attn_type='soft')
y_hat, attn_weights = classifier(z) # [1, 2], [1, 98]
Use wavencoder with PyTorch Sequential or class modules
import torch
import torch.nn as nn
import wavencoder
model = nn.Sequential(
wavencoder.models.Wav2Vec(),
wavencoder.models.LSTM_Attn_Classifier(512, 64, 2,
return_attn_weights=True,
attn_type='soft')
)
x = torch.randn(1, 16000) # [1, 16000]
y_hat, attn_weights = model(x) # [1, 2], [1, 98]
import torch
import torch.nn as nn
import wavencoder
class AudioClassifier(nn.Module):
def __init__(self):
super(AudioClassifier, self).__init__()
self.encoder = wavencoder.models.Wav2Vec(pretrained=True)
self.classifier = nn.Linear(512, 2)
def forward(self, x):
z = self.encoder(x)
z = torch.mean(z, dim=2)
out = self.classifier(z)
return out
model = AudioClassifier()
x = torch.randn(1, 16000) # [1, 16000]
y_hat = model(x) # [1, 2]
Train the encoder-classifier models
from wavencoder.models import Wav2Vec, LSTM_Attn_Classifier
from wavencoder.trainer import train, test_evaluate_classifier, test_predict_classifier
model = nn.Sequential(
Wav2Vec(pretrained=False),
LSTM_Attn_Classifier(512, 64, 2)
)
trainloader = ...
valloader = ...
testloader = ...
trained_model, train_dict = train(model, trainloader, valloader, n_epochs=20)
test_prediction_dict = test_predict_classifier(trained_model, testloader)
Add Transforms to your DataLoader for Augmentation/Processing the wav signal
from wavencoder.transforms import Compose, AdditiveNoise, SpeedChange, Clipping, PadCrop, Reverberation
audio, _ = torchaudio.load('test.wav')
transforms = Compose([
AdditiveNoise('path-to-noise-folder', p=0.5, snr_levels=[5, 10, 15], p=0.5),
SpeedChange(factor_range=(-0.5, 0.0), p=0.5),
Clipping(p=0.5),
PadCrop(48000, crop_position='random', pad_position='random')
])
transformed_audio = transforms(audio)
Contributing
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Please make sure to update tests as appropriate.
License
Reference
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
wavencoder-0.1.3.tar.gz
(25.0 kB
view details)
Built Distribution
File details
Details for the file wavencoder-0.1.3.tar.gz
.
File metadata
- Download URL: wavencoder-0.1.3.tar.gz
- Upload date:
- Size: 25.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3a5423f5ff0d7688d6aca26e0ac4881d20752006f674bf83c1ed6948e7ea17d8 |
|
MD5 | 64e041edac40892f76865583c34e2579 |
|
BLAKE2b-256 | eb6c456ee619121088aba0f3c17ad0b154af75d7e3c77f487e62efd5d310da26 |
File details
Details for the file wavencoder-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: wavencoder-0.1.3-py3-none-any.whl
- Upload date:
- Size: 30.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c921bf59f71054e7c14f9598c5aabb042e49d3511421af58e059c097a59646c7 |
|
MD5 | b99a24434a9a4ed3543f3736969226e1 |
|
BLAKE2b-256 | 933ee6e32a706b4080ced5bdcb73c644b128174153b88521d98cf38df6600dfe |