Skip to main content

Discrete Distribution Network

Project description

Discrete Distribution Network

Exploration into Discrete Distribution Network, by Lei Yang out of Beijing

Besides the split-and-prune, may also throw in an option for crossover (mixing of top 2 nodes to replace the pruned)

Install

$ pip install discrete-distribution-network

Usage

import torch
from discrete_distribution_network import DDN

ddn = DDN(
    dim = 32,
    image_size = 256
)

images = torch.randn(2, 3, 256, 256)

loss = ddn(images)
loss.backward()

# after much training

sampled = ddn.sample(batch_size = 1)

assert sampled.shape == (1, 3, 256, 256)

The proposed GuidedSampler in the paper

import torch
from discrete_distribution_network import GuidedSampler

sampler = GuidedSampler(
    dim = 16,              # feature dimension
    dim_query = 3,         # the query image dimension
    codebook_size = 10,    # the codebook size K in the paper, which is K separate projections of the features, which is then measured distance wise to the query image guide
)

features = torch.randn(20, 16, 32, 32)
query_image = torch.randn(20, 3, 32, 32)

out, codes, commit_loss = sampler(features, query_image)

# (20, 3, 32, 32), (20,), ()

assert torch.allclose(sampler.forward_for_codes(features, codes), out, atol = 1e-5)

# after optimizer step, this needs to be called
# there is also a helper function by the same name that can take in your module and will invoke all of the guided samplers

sampler.split_and_prune_()

Oxford flowers

Install uv, which will probably become the default in the near future

$ pip install uv

Then

$ uv run train_oxford_flowers.py

Citations

@misc{yang2025discretedistributionnetworks,
    title   = {Discrete Distribution Networks}, 
    author  = {Lei Yang},
    year    = {2025},
    eprint  = {2401.00036},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV},
    url     = {https://arxiv.org/abs/2401.00036}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

discrete_distribution_network-0.2.3.tar.gz (677.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

discrete_distribution_network-0.2.3-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file discrete_distribution_network-0.2.3.tar.gz.

File metadata

File hashes

Hashes for discrete_distribution_network-0.2.3.tar.gz
Algorithm Hash digest
SHA256 f3c05a61297866fd6a31595c20c1c1f8ff8f8ac99cc848835ee82b24e4f26b64
MD5 18f6877c57e94ee93db51fd5267cd273
BLAKE2b-256 ed1302d8db9845653d13c28223a7936008358133873aef7e5081545fd9f6c1b7

See more details on using hashes here.

File details

Details for the file discrete_distribution_network-0.2.3-py3-none-any.whl.

File metadata

File hashes

Hashes for discrete_distribution_network-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c712f9d5c287740a9b93181c5dc01859ef413b017e9e0b47d7fc122bef0ccecc
MD5 25d5cafc6c4f5e5adc8187112dd966c6
BLAKE2b-256 e77031590b4138e4a5e8ab4b5b3e3bfa5d31273c56097ae7a2f7050171bd890a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page