Skip to main content

Spline Based Transformer

Project description

Spline-Based Transformer

Implementation of the proposed Spline-Based Transformer (paper) from Disney Research

This is basically a transformer based autoencoder, but they cleverly use a set of latent tokens, where that set of tokens are the (high dimensional) control points for a spline.

Install

$ pip install spline-based-transformer

Usage

import torch
from spline_based_transformer import SplineBasedTransformer

model = SplineBasedTransformer(
    dim = 512,
    enc_depth = 6,
    dec_depth = 6
)

data = torch.randn(1, 1024, 512)

loss = model(data, return_loss = True)
loss.backward()

# after much training

recon, control_points = model(data, return_latents = True)
assert data.shape == recon.shape

# mess with the control points, which should preserve continuity better

control_points += 1

controlled_recon = model.decode_from_latents(control_points, num_times = 1024)
assert controlled_recon.shape == data.shape

For an example of an image autoencoder

import torch

from spline_based_transformer import (
    SplineBasedTransformer,
    ImageAutoencoderWrapper
)

model = ImageAutoencoderWrapper(
    image_size = 256,
    patch_size = 32,
    spline_transformer = SplineBasedTransformer(
        dim = 512,
        enc_depth = 6,
        dec_depth = 6
    )
)

images = torch.randn(2, 3, 256, 256)

loss = model(images, return_loss = True)
loss.backward()

# after much training

recon_images, control_points = model(images, return_latents = True)
assert images.shape == recon_images.shape

# changing the control points

control_points += 1

controlled_recon_images = model.decode_from_latents(control_points)

assert controlled_recon_images.shape == images.shape

Citations

@misc{Chandran2024,
    author  = {Prashanth Chandran, Agon Serifi, Markus Gross, Moritz Bächer},
    url     = {https://la.disneyresearch.com/publication/spline-based-transformers/}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spline_based_transformer-0.0.14.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file spline_based_transformer-0.0.14.tar.gz.

File metadata

File hashes

Hashes for spline_based_transformer-0.0.14.tar.gz
Algorithm Hash digest
SHA256 f023c41004b47c54639ca61288ba7fef979e19f9b076f06e618b22a6aaca86a2
MD5 4b054206355453a4adb4868860d09733
BLAKE2b-256 4376487dfa448da9ad2a987abf4e7823ca328cbca7cc2e20154029f6e5f48d9d

See more details on using hashes here.

File details

Details for the file spline_based_transformer-0.0.14-py3-none-any.whl.

File metadata

File hashes

Hashes for spline_based_transformer-0.0.14-py3-none-any.whl
Algorithm Hash digest
SHA256 be0f28c8254c07cd1b47d250b973c8b868c8584485eb8fbeebb9095f2182bf69
MD5 bf9f1fb86a83b860f874bfa91eeb7193
BLAKE2b-256 ba1aedd7f34fa04836c0c27a8594092d26a7a0088f67add25a3e39d09d1aac07

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page