Skip to main content

A PyTorch library for multi-modal image translation with diffusion bridges, GANs, and transformer backbones.

Project description

pytorch-image-translation-models

License: MIT

A PyTorch library for multi-modal image translation with diffusion bridges, GANs, and transformer backbones.

Installation

Install from source

pip install -e .

With optional dependencies:

# With training extras (accelerate, peft, datasets, tensorboard)
pip install -e ".[training]"

# With metrics extras (torchmetrics, lpips, torch-fidelity, scipy)
pip install -e ".[metrics]"

# Everything
pip install -e ".[all]"

Note: PyTorch is listed as a dependency but you may want to install a specific CUDA build first. See PyTorch — Get Started for details.

Features

Models

  • GAN generatorsUNetGenerator (encoder-decoder with skip connections), ResNetGenerator (residual blocks)
  • GAN discriminatorsPatchGANDiscriminator (Markovian patch-level classifier)
  • Diffusion bridgeI2SBUNet (ADM-style U-Net for Image-to-Image Schrödinger Bridge)

Schedulers

  • I2SBScheduler — Symmetric beta schedule with forward/reverse bridge kernels for I2SB

Pipelines

  • I2SBPipeline — End-to-end inference for I2SB models (supports "pt", "pil", "np" output)

Data

  • PairedImageDataset / UnpairedImageDataset with configurable transform pipelines

Losses

  • GANLoss (vanilla / LSGAN / hinge), VGG-based PerceptualLoss

Training

  • Pix2PixTrainer — Paired GAN training with checkpoint save/load
  • I2SBTrainer — I2SB bridge model training (in examples/i2sb/)

Metrics

  • compute_psnr, compute_ssim, compute_lpips, compute_fid

Quick Start

GAN-based translation (Pix2Pix)

import src

gen = src.UNetGenerator(in_channels=3, out_channels=3)
disc = src.PatchGANDiscriminator(in_channels=6)

from src.training import Pix2PixTrainer, TrainingConfig
config = TrainingConfig(epochs=100, device="cuda")
trainer = Pix2PixTrainer(gen, disc, config)
trainer.fit(dataloader)  # expects {"source": tensor, "target": tensor}

translator = src.ImageTranslator(gen, device="cuda")
result = translator.predict(pil_image)

Diffusion bridge translation (I2SB)

from src.models.unet import I2SBUNet, create_model
from src.schedulers import I2SBScheduler
from src.pipelines.i2sb import I2SBPipeline

# Create model and scheduler
model = create_model(
    image_size=256, in_channels=3, num_channels=128,
    num_res_blocks=2, attention_resolutions="32,16,8",
    condition_mode="concat",
)
scheduler = I2SBScheduler(interval=1000, beta_max=0.3)

# Inference pipeline
pipeline = I2SBPipeline(unet=model, scheduler=scheduler)
result = pipeline(source_tensor, nfe=20, output_type="pt")

I2SB training with task configs

from examples.i2sb.config import sar2eo_config
from examples.i2sb.trainer import I2SBTrainer

cfg = sar2eo_config(resolution=256, train_batch_size=8)
trainer = I2SBTrainer(cfg)
model = trainer.build_model()
scheduler = trainer.build_scheduler()

# Single-step loss computation
loss = I2SBTrainer.compute_training_loss(model, scheduler, source_batch, target_batch)
loss.backward()

Package Structure

src/
├── __init__.py              # Public API
├── models/
│   ├── generators.py        # UNetGenerator, ResNetGenerator
│   ├── discriminators.py    # PatchGANDiscriminator
│   └── unet/                # ADM-style U-Net for I2SB
│       ├── i2sb_unet.py     # I2SBUNet
│       └── unet_2d.py       # create_model factory
├── schedulers/
│   └── i2sb.py              # I2SBScheduler
├── pipelines/
│   └── i2sb.py              # I2SBPipeline
├── data/
│   ├── datasets.py          # PairedImageDataset, UnpairedImageDataset
│   └── transforms.py        # get_transforms, default_transforms
├── losses/
│   ├── adversarial.py       # GANLoss
│   └── perceptual.py        # PerceptualLoss
├── training/
│   └── trainer.py           # Pix2PixTrainer, TrainingConfig
├── inference/
│   └── predictor.py         # ImageTranslator
└── metrics/
    └── image_quality.py     # PSNR, SSIM, LPIPS, FID
examples/
└── i2sb/
    ├── config.py            # TaskConfig, sar2eo_config, etc.
    └── trainer.py           # I2SBTrainer

Credits

Reference papers

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_image_translation_models-0.1.0.tar.gz (34.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file pytorch_image_translation_models-0.1.0.tar.gz.

File metadata

File hashes

Hashes for pytorch_image_translation_models-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4b251eb51ab5ace5b0a59a99c8d5e08fd965e6313754d0982e9d9d2442924b6f
MD5 21414e58615367092bc73c642be6bc7f
BLAKE2b-256 7477730a6040a7e01a78b3956eb760b9631891a6e955acf89c87a490e069bf40

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_image_translation_models-0.1.0.tar.gz:

Publisher: publish.yml on Bili-Sakura/pytorch-image-translation-models

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pytorch_image_translation_models-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_image_translation_models-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f4b65e14a918865e6496c41839bf532a6f94a02d10c3c8de4dde75455f64a001
MD5 03a719fa02c35b95e139d239ca64ed57
BLAKE2b-256 08a9b1c222666be8ff1c64b703be8841b641084532da0083775e4766bd2510cf

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_image_translation_models-0.1.0-py3-none-any.whl:

Publisher: publish.yml on Bili-Sakura/pytorch-image-translation-models

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page