Skip to main content

LVSM - Pytorch

Project description

LVSM - Pytorch

Implementation of LVSM, SOTA Large View Synthesis with Minimal 3d Inductive Bias, from Adobe Research

We will focus only on the Decoder-only architecture in this repository.

This paper lines up with another from ICLR 2025

Install

$ pip install lvsm-pytorch

Usage

import torch
from lvsm_pytorch import LVSM

rays = torch.randn(2, 4, 6, 256, 256)
images = torch.randn(2, 4, 3, 256, 256)

target_rays = torch.randn(2, 6, 256, 256)
target_images = torch.randn(2, 3, 256, 256)

model = LVSM(
    dim = 512,
    max_image_size = 256,
    patch_size = 32,
    depth = 2,
)

loss = model(
    input_images = images,
    input_rays = rays,
    target_rays = target_rays,
    target_images = target_images
)

loss.backward()

# after much training

pred_images = model(
    input_images = images,
    input_rays = rays,
    target_rays = target_rays,
) # (2, 3, 256, 256)

assert pred_images.shape == target_images.shape

Or from the raw camera intrinsic / extrinsics (please submit an issue or pull request if you see an error. new to view synthesis and out of my depths here)

import torch
from lvsm_pytorch import LVSM, CameraWrapper

input_intrinsic_rotation = torch.randn(2, 4, 3, 3)
input_extrinsic_rotation = torch.randn(2, 4, 3, 3)
input_translation = torch.randn(2, 4, 3)
input_uniform_points = torch.randn(2, 4, 3, 256, 256)

target_intrinsic_rotation = torch.randn(2, 3, 3)
target_extrinsic_rotation = torch.randn(2, 3, 3)
target_translation = torch.randn(2, 3)
target_uniform_points = torch.randn(2, 3, 256, 256)

images = torch.randn(2, 4, 4, 256, 256)
target_images = torch.randn(2, 4, 256, 256)

lvsm = LVSM(
    dim = 512,
    max_image_size = 256,
    patch_size = 32,
    channels = 4,
    depth = 2,
)

model = CameraWrapper(lvsm)

loss = model(
    input_intrinsic_rotation = input_intrinsic_rotation,
    input_extrinsic_rotation = input_extrinsic_rotation,
    input_translation = input_translation,
    input_uniform_points = input_uniform_points,
    target_intrinsic_rotation = target_intrinsic_rotation,
    target_extrinsic_rotation = target_extrinsic_rotation,
    target_translation = target_translation,
    target_uniform_points = target_uniform_points,
    input_images = images,
    target_images = target_images,
)

loss.backward()

# after much training

pred_target_images = model(
    input_intrinsic_rotation = input_intrinsic_rotation,
    input_extrinsic_rotation = input_extrinsic_rotation,
    input_translation = input_translation,
    input_uniform_points = input_uniform_points,
    target_intrinsic_rotation = target_intrinsic_rotation,
    target_extrinsic_rotation = target_extrinsic_rotation,
    target_translation = target_translation,
    target_uniform_points = target_uniform_points,
    input_images = images,
)

For an improvised self-supervised learning using masked autoencoder for reconstructing images and plucker rays, just import MAE first and wrap your LVSM instance. Then pass in your images and rays

import torch

from lvsm_pytorch import (
    LVSM,
    MAE
)

rays = torch.randn(2, 4, 6, 256, 256)
images = torch.randn(2, 4, 4, 256, 256)

lvsm = LVSM(
    dim = 512,
    max_image_size = 256,
    patch_size = 32,
    channels = 4,
    depth = 2,
    dropout_input_ray_prob = 0.5
)

mae = MAE(
    lvsm = lvsm,
    frac_masked = 0.5,                  # 1 in 2 image/ray pair to be masked out. minimum set to 1
    frac_images_to_ray_masked = 0.5,    # for a given image/ray pair that is masked, the proportion of images being masked vs rays (1. would be only images masked, 0. would be only rays masked). they cannot be both masked
    image_to_ray_loss_weight = 1.       # you can weigh the image recon oss differently than ray recon loss
)

ssl_loss = mae(
    images,
    rays
)

ssl_loss.backward()

# do the above in a loop on a huge amount of data

Above with camera in/extrsinsics

import torch

from lvsm_pytorch.lvsm import (
    LVSM,
    MAE,
    MAECameraWrapper
)

intrinsic_rotation = torch.randn(2, 4, 3, 3)
extrinsic_rotation = torch.randn(2, 4, 3, 3)
translation = torch.randn(2, 4, 3)
uniform_points = torch.randn(2, 4, 3, 256, 256)

images = torch.randn(2, 4, 4, 256, 256)

lvsm = LVSM(
    dim = 512,
    max_image_size = 256,
    patch_size = 32,
    channels = 4,
    depth = 2,
)

mae = MAE(lvsm)

model = MAECameraWrapper(mae)

loss = model(
    intrinsic_rotation = intrinsic_rotation,
    extrinsic_rotation = extrinsic_rotation,
    translation = translation,
    uniform_points = uniform_points,
    images = images,
)

loss.backward()

Citations

@inproceedings{Jin2024LVSMAL,
    title   = {LVSM: A Large View Synthesis Model with Minimal 3D Inductive Bias},
    author  = {Haian Jin and Hanwen Jiang and Hao Tan and Kai Zhang and Sai Bi and Tianyuan Zhang and Fujun Luan and Noah Snavely and Zexiang Xu},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273507016}
}
@article{Zhang2024CamerasAR,
    title     = {Cameras as Rays: Pose Estimation via Ray Diffusion},
    author    = {Jason Y. Zhang and Amy Lin and Moneish Kumar and Tzu-Hsuan Yang and Deva Ramanan and Shubham Tulsiani},
    journal   = {ArXiv},
    year      = {2024},
    volume    = {abs/2402.14817},
    url       = {https://api.semanticscholar.org/CorpusID:267782978}
}
@misc{he2021masked,
    title   = {Masked Autoencoders Are Scalable Vision Learners}, 
    author  = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
    year    = {2021},
    eprint  = {2111.06377},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lvsm_pytorch-0.1.6.tar.gz (1.5 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lvsm_pytorch-0.1.6-py3-none-any.whl (10.1 kB view details)

Uploaded Python 3

File details

Details for the file lvsm_pytorch-0.1.6.tar.gz.

File metadata

  • Download URL: lvsm_pytorch-0.1.6.tar.gz
  • Upload date:
  • Size: 1.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for lvsm_pytorch-0.1.6.tar.gz
Algorithm Hash digest
SHA256 5f688ce960ba8bde5c6c59fada790b1524df661ca9a11ac7f1606ec4c601b461
MD5 e4cd3e2b07a38ad5e76ad05b61e5f55a
BLAKE2b-256 b83ccac861fbafb94656d8610cae1dd394311ed778254ad6b43832b61a04a4d5

See more details on using hashes here.

File details

Details for the file lvsm_pytorch-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: lvsm_pytorch-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 10.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for lvsm_pytorch-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 a1ab6ce32c540401b665886b99d40fc59d42e0c2669a40616464b6ae5b71c2dd
MD5 75780082085216da11f0d64128c890af
BLAKE2b-256 344996c63e34d6d3cd3735c80bfa81afbda216582235682aa0b9897d246645b6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page