Skip to main content

A PyTorch library for multi-modal image translation with diffusion bridges, GANs, and transformer backbones.

Project description

pytorch-image-translation-models

License: MIT PyPI version

A PyTorch library for multi-modal image translation with diffusion bridges, GANs, and transformer backbones.

Installation

Install from PyPI

pip install pytorch-image-translation-models

Install from source

pip install -e .

With optional dependencies:

# With training extras (accelerate, peft, datasets, tensorboard)
pip install -e ".[training]"

# With metrics extras (torchmetrics, lpips, torch-fidelity, scipy)
pip install -e ".[metrics]"

# Everything
pip install -e ".[all]"

Note: PyTorch is listed as a dependency but you may want to install a specific CUDA build first. See PyTorch — Get Started for details.

Features

Models

  • GAN generatorsUNetGenerator (encoder-decoder with skip connections), ResNetGenerator (residual blocks)
  • GAN discriminatorsPatchGANDiscriminator (Markovian patch-level classifier)
  • Diffusion bridgeI2SBUNet (ADM-style U-Net for Image-to-Image Schrödinger Bridge)

Schedulers

  • I2SBScheduler — Symmetric beta schedule with forward/reverse bridge kernels for I2SB

Pipelines

  • I2SBPipeline — End-to-end inference for I2SB models (supports "pt", "pil", "np" output)

Data

  • PairedImageDataset / UnpairedImageDataset with configurable transform pipelines

Losses

  • GANLoss (vanilla / LSGAN / hinge), VGG-based PerceptualLoss

Training

  • Pix2PixTrainer — Paired GAN training with checkpoint save/load
  • I2SBTrainer — I2SB bridge model training (in examples/i2sb/)

Metrics

  • compute_psnr, compute_ssim, compute_lpips, compute_fid

Quick Start

GAN-based translation (Pix2Pix)

import src

gen = src.UNetGenerator(in_channels=3, out_channels=3)
disc = src.PatchGANDiscriminator(in_channels=6)

from src.training import Pix2PixTrainer, TrainingConfig
config = TrainingConfig(epochs=100, device="cuda")
trainer = Pix2PixTrainer(gen, disc, config)
trainer.fit(dataloader)  # expects {"source": tensor, "target": tensor}

translator = src.ImageTranslator(gen, device="cuda")
result = translator.predict(pil_image)

Diffusion bridge translation (I2SB)

from src.models.unet import I2SBUNet, create_model
from src.schedulers import I2SBScheduler
from src.pipelines.i2sb import I2SBPipeline

# Create model and scheduler
model = create_model(
    image_size=256, in_channels=3, num_channels=128,
    num_res_blocks=2, attention_resolutions="32,16,8",
    condition_mode="concat",
)
scheduler = I2SBScheduler(interval=1000, beta_max=0.3)

# Inference pipeline
pipeline = I2SBPipeline(unet=model, scheduler=scheduler)
result = pipeline(source_tensor, nfe=20, output_type="pt")

I2SB training with task configs

from examples.i2sb.config import sar2eo_config
from examples.i2sb.trainer import I2SBTrainer

cfg = sar2eo_config(resolution=256, train_batch_size=8)
trainer = I2SBTrainer(cfg)
model = trainer.build_model()
scheduler = trainer.build_scheduler()

# Single-step loss computation
loss = I2SBTrainer.compute_training_loss(model, scheduler, source_batch, target_batch)
loss.backward()

Package Structure

src/
├── __init__.py              # Public API
├── models/
│   ├── generators.py        # UNetGenerator, ResNetGenerator
│   ├── discriminators.py    # PatchGANDiscriminator
│   └── unet/                # ADM-style U-Net for I2SB
│       ├── i2sb_unet.py     # I2SBUNet
│       └── unet_2d.py       # create_model factory
├── schedulers/
│   └── i2sb.py              # I2SBScheduler
├── pipelines/
│   └── i2sb.py              # I2SBPipeline
├── data/
│   ├── datasets.py          # PairedImageDataset, UnpairedImageDataset
│   └── transforms.py        # get_transforms, default_transforms
├── losses/
│   ├── adversarial.py       # GANLoss
│   └── perceptual.py        # PerceptualLoss
├── training/
│   └── trainer.py           # Pix2PixTrainer, TrainingConfig
├── inference/
│   └── predictor.py         # ImageTranslator
└── metrics/
    └── image_quality.py     # PSNR, SSIM, LPIPS, FID
examples/
└── i2sb/
    ├── config.py            # TaskConfig, sar2eo_config, etc.
    └── trainer.py           # I2SBTrainer

Credits

Reference papers

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_image_translation_models-0.1.1.tar.gz (49.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file pytorch_image_translation_models-0.1.1.tar.gz.

File metadata

File hashes

Hashes for pytorch_image_translation_models-0.1.1.tar.gz
Algorithm Hash digest
SHA256 51bfe40380b1bbf0bd290d3fd17d2ebc57f3861505501c6d097be8148aad6807
MD5 70ee55aac5f7d50acd4086fcb8fc654b
BLAKE2b-256 d402a2a7c24d7d531df6b959b09144f3af6d3e76ece01157a6494655a1eae1ff

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_image_translation_models-0.1.1.tar.gz:

Publisher: publish.yml on Bili-Sakura/pytorch-image-translation-models

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pytorch_image_translation_models-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_image_translation_models-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 890b6de3d308c40d2ae6f0539b4f60afd107532ba38a3ce467d6b6cc83ff582b
MD5 4185a96af1013c449a95156f13e2c368
BLAKE2b-256 515fc25e2e498fe2345861c477bfee4996dea84662f4c4d0cfd4982049f7efbe

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_image_translation_models-0.1.1-py3-none-any.whl:

Publisher: publish.yml on Bili-Sakura/pytorch-image-translation-models

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page