Skip to main content

Linear Consistency Autoencoders

Reason this release was yanked:

bug: default scale factor not applied

Project description

Linear Consistency Autoencoders

This repository contains the official code and pretrained models for the paper "Learning Linearity in Audio Consistency Autoencoders via Implicit Regularization".

Paper | Demo Page (with audio examples)

About The Project

This work introduces Linear Consistency Autoencoders (Lin-CAE), a training methodology to induce linearity in audio autoencoders. By using data augmentation, we enforce homogeneity (equivariance to scalar gain) and additivity (preservation of addition) in the latent space without altering the model's architecture or loss function.

This creates a structured latent space where simple algebraic operations correspond directly to intuitive audio manipulations like mixing and volume scaling.

This repository currently provides inference code for our pretrained models and the code to reproduce the demos on our project page. Training code will be made available soon.

Usage

You can install the linear-cae package using pip (or whatever package manager you prefer, we recommend using Poetry for development):

pip install linear-cae
poetry add linear-cae  # if using Poetry

Loading a Pretrained Model

The Autoencoder class provides a from_pretrained method to load models from the Hugging Face Hub.

from linear_cae import Autoencoder
import torch

model_id = "lin-cae"  # or "m2l", "lin-cae-2"
model = Autoencoder.from_pretrained(model_id)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval().to(device)

Available model_ids are:

  • "m2l": Music2Latent retrained on our dataset as a baseline.
  • "lin-cae": Our proposed Linear Consistency Autoencoder.
  • "lin-cae-2": A second version of our Lin-CAE trained without gain annealing described in the paper.

Note: we include a default scale_factor for each model to account for differences in latent space scale. The default value scales the latents to have approximately a standard deviation of 1.0 in the MUSDB train set. This can be manually overridden when loading the model by passing a different scale_factor argument to from_pretrained. Depending on your application (or data domain), you might want to consider providing a scaling to match the standard deviation of your dataset latents. Example:

# Multiplies the latents by 20.0 after encoding and divides by 20.0 before decoding
# All done internally by the model, no need to deal with it manually.
model = Autoencoder.from_pretrained("lin-cae-2", scale_factor=20.0)

The models are trained on 44.1 kHz mono audio. encode and decode expect tensors of shape [batch_size, num_samples] at 44.1 kHz.

Encoding and Decoding Audio

Once the model is loaded, you can encode audio into a latent representation and decode it back to audio.

# `audio_tensor` is a torch.Tensor of shape [batch_size, num_samples] representing 44.1 kHz audio
audio_tensor = torch.randn(1, 44100 * 2).to(device)

z = model.encode(audio_tensor)
reconstructed_audio = model.decode(z, full_length=audio_tensor.shape[-1])

Check out the demo.ipynb notebook used to create the demos on our project page for more examples of latent space manipulations (requires the musdb dataset package).

Handling Long Audio Files and Batching

To handle long audio files without running out of memory, you can specify max_chunk_size when loading the model. Audio longer than this size will be automatically processed in overlapping chunks.

The parameter max_batch_size controls the maximum number of chunks processed in a single forward pass. If the number of chunks exceeds this value, they will be further split into smaller batches. You can adjust this parameter to match your available GPU memory.

# Set max_chunk_size to 10 seconds of audio at 44.1 kHz and max_batch_size to 8
model = Autoencoder.from_pretrained("lin-cae",
                                    max_chunk_size=44100 * 10,
                                    overlap_percentage=0.25,
                                    max_batch_size=8)
model.to(device)

# Encode a long audio file (e.g., 30 seconds)
# The chunks will be processed in max_batch_size batches internally
long_audio_tensor_1 = torch.randn(1, 44100 * 30).to(device)
long_audio_tensor_2 = torch.randn(1, 44100 * 30).to(device)
z_chunked_1 = model.encode(long_audio_tensor_1)
z_chunked_2 = model.encode(long_audio_tensor_2)


# The output z_chunked will be a 4D tensor: [batch_size, num_chunks, channels, latent_dim]
print(z_chunked_1.shape)

# Mix in the latent space
z_chunked = z_chunked_1 + z_chunked_2

# Decoding the chunked latent requires the original audio length for proper reconstruction
reconstructed_mix = model.decode(z_chunked, full_length=long_audio_tensor_1.shape[-1])

The model uses an overlap-add mechanism with a crossfade to stitch the decoded chunks back together. You can control the amount of overlap with the overlap_percentage argument during model initialization. We recommend some overlap to avoid artifacts at chunk boundaries.

Algorithm

While the full training code for CAEs will be released soon, here is a general pseudo-algorithm for how to adapt any autoencoder training loop to induce linearity using our proposed method. The core idea is to use data augmentation to implicitly teach the model the properties of linearity.

# Pseudo-algorithm for training a linear autoencoder

for each batch of audio data x:
  original_batch_size = x.shape[0]

  # 1. Create artificial mixtures
  x_roll = torch.roll(x, shifts=1, dims=0) # Circularly shift the batch
  x_mixed = x + x_roll
  x_augmented = torch.cat([x, x_mixed], dim=0) # Batch now contains original and mixed audio

  # 2. Encode all audio to get latents
  z = encoder(x_augmented)

  # For the mixed portion of the batch, create the "additive" latent
  # by summing the latents of the original unmixed sources.
  z_original = z[:original_batch_size]
  z_roll = torch.roll(z_original, shifts=1, dims=0)
  z_add = z_original + z_roll

  # Replace the encoded latents of the mixed audio with the sum of source latents
  z_final = torch.cat([z_original, z_add], dim=0)

  # 3. Apply random gains to enforce Homogeneity
  gains = sample_random_gains(x_augmented.shape[0])
  z_scaled = z_final * gains[:, None, None]
  x_scaled = x_augmented * gains[:, None]

  # 4. Standard autoencoder training step
  # The decoder receives the scaled latent and must reconstruct the scaled audio
  x_reconstructed = decoder(z_scaled)

  loss = reconstruction_loss(x_reconstructed, x_scaled)

  # Update model weights
  optimizer.step()

For Consistency Autoencoders (CAEs), as used in the paper, x_scaled is first corrupted with noise, and the decoder (a denoising model) is conditioned on z_scaled to reconstruct the clean x_scaled.

Citation

If you use our work in your research, please cite our paper:

@misc{torres2025learninglinearityaudioconsistency,
      title={Learning Linearity in Audio Consistency Autoencoders via Implicit Regularization},
      author={Bernardo Torres and Manuel Moussallam and Gabriel Meseguer-Brocal},
      year={2025},
      eprint={2510.23530},
      archivePrefix={arXiv},
      primaryClass={cs.SD},
      url={https://arxiv.org/abs/2510.23530},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

linear_cae-0.1.3.tar.gz (28.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

linear_cae-0.1.3-py3-none-any.whl (29.3 kB view details)

Uploaded Python 3

File details

Details for the file linear_cae-0.1.3.tar.gz.

File metadata

  • Download URL: linear_cae-0.1.3.tar.gz
  • Upload date:
  • Size: 28.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for linear_cae-0.1.3.tar.gz
Algorithm Hash digest
SHA256 2575115bc034650b5d210c0cee5b0f390374e95e00a9e3e7fc0f5b19b993c88c
MD5 1b99c51b6bdcec87068cb65afe13f7de
BLAKE2b-256 f75867aaea0aa62dfb4f22937eb9366d8481a92340ea4500465e808bdbfd18cd

See more details on using hashes here.

File details

Details for the file linear_cae-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: linear_cae-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 29.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for linear_cae-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 08c488c7701f65b82d4353b7ebe46c37c61dad6b3dac3cce04ee66cb6c72b646
MD5 d4adff5606389754f993be7375867a5b
BLAKE2b-256 6bba45ae8aec218bf76f2ce9363f2f6a6748ae787ddb422d67727bbd5050e59b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page