Skip to main content

Python package allowing for the exploration of latent spaces of generative models through Riemannian geometry.

Project description

Latent Geometry

Master's thesis python package allowing for the exploration of latent spaces of generative models through Riemannian geometry.

By employing a pull-back metric from the observation space one can reveal nuanced geometrical structures in hidden spaces. The framework is agnostic to the automatic differetiation backend e.g. PyTorch, TensorFlow. It works even with custom, hand-made differentiable mappings.

CI - Test PyPI Latest Release PyPI Downloads PyPI - Python Version License - GPLv3

Installation

pip install latent-geometry

Usage

Geodesics

import numpy as np
import torch.nn as nn # just for the sake of example

your_neural_net: nn.Module = YourPyTorchNet(latent_dim=8)

# create Mapping
from latent_geometry.mapping import Mapping, TorchModelMapping

mapping: Mapping = TorchModelMapping(
    model=your_neural_net,
    in_shape=(8,), # dimensionality of the domain w/o batch size
    out_shape=(3, 32, 32), # dimensionality of the co-domain w/o batch size
    batch_size=batch_size,
    call_fn=your_neural_net.forward, # optional
)

# define your favourite metric for the observation space
from latent_geometry.metric import EuclideaMetric

ambient_metric = EuclideanMetric()

# create the manifold spanned by your latent space with the pulled-back ambient metric
from latent_geometry.manifold import LatentManifold

latent_manifold = LatentManifold(
    mapping=latent_mapping,
    ambient_metric=ambient_metric,
)

# calculate the geodesic starting from z_0 with velocity v_0
from latent_geometry.path import ManifoldPath

z_0 = np.zeros(8)
v_0 = np.ones_like(z_0)

geodesic: ManifoldPath = latent_manifold.geodesic(z=z_0, velocity_vec=v_0)
# geodesic(0) == z_0

# calculate the the shortest path between z_a and z_b
z_a = np.zeros(8)
z_b = np.zeros(8) + 3

shortest_path: ManifoldPath = shortest_path.geodesic(z_a=z_a, z_b=z_b)
# shortest_path(0.0), shortest_path(1.0) == z_a, z_b

If your mapping is not based on PyTorch you need to implement one of two possible interfaces

# you can implement first and second derivative of output wrt. to input
from latent_geometry.mapping import DerivativeMapping

class YourMappingWrapper(DerivativeMapping):
    def __init__(self, your_mapping: Callable) -> None: ...
    def jacobian(self, zs: np.ndarray) -> np.ndarray: ...
    def second_derivative(self, zs: np.ndarray) -> np.ndarray: ...

# or follow the other interface (for speed-up purpuses)
from latent_geometry.mapping import MatrixMapping

class YourMappingWrapper(MatrixMapping):
    def __init__(self, your_mapping: Callable) -> None: ...
    def jacobian(self, zs: np.ndarray) -> np.ndarray: ...
    def metric_matrix_derivative(
        self, zs: np.ndarray, ambient_metric_matrices: np.ndarray
    ) -> np.ndarray: ...

Riemannian optimizer

only for PyTorch right now

import torch
import torch.nn as nn
from latent_geometry.optim import TorchMetric, InputGDOptimizer

your_neural_net: nn.Module = YourPyTorchNet()
loss_fn = lambda x: x.mean()
example_input = torch.zeros(8, requires_grad=True)

optimizer = InputDGOptimizer(
    param=example_input,
    metric=TorchMetric(mapping=your_neural_net),
    lr=0.001,
    gradient_type="geometric" # may also be "standard", "retractive"
)

for _ in range(1_000):
    optimizer.zero_grad()
    x = your_neural_net(example_input)
    loss = loss_fn(x)

    loss.backward()
    optimizer.step()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

latent-geometry-1.1.0.tar.gz (34.7 kB view hashes)

Uploaded Source

Built Distribution

latent_geometry-1.1.0-py3-none-any.whl (37.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page