Skip to main content

McTorch, a manifold optimization library for deep learning

Project description

McTorch Lib, a manifold optimization library for deep learning

McTorch is a Python library that adds manifold optimization functionality to PyTorch.

McTorch:

  • Leverages tensor computation and GPU acceleration from PyTorch.
  • Enables optimization on manifold constrained tensors to address nonlinear optimization problems.
  • Facilitates constrained weight tensors in deep learning layers.

Sections:

More about McTorch

McTorch builds on top of PyTorch and supports all PyTorch functions in addition to Manifold optimization. This is done to ensure researchers and developers using PyTorch can easily experiment with McTorch functions. McTorch's manifold implementations and optimization methods are derived from the Matlab toolbox Manopt and the Python toolbox Pymanopt.

Using McTorch for Optimization

  1. Initialize Parameter - McTorch manifold parameters are same as PyTorch parameters (mctorch.nn.Parameter) and requires just addition of one property to parameter initialization to constrain the parameter values.
  2. Define Cost - Cost function can be any PyTorch function using the above parameter mixed with non constrained parameters.
  3. Optimize - Any optimizer from mctorch.optim can be used to optimize the cost function using same functionality as any PyTorch code.

PCA Example

import torch
import mctorch.nn as mnn
import mctorch.optim as moptim

# Random data with high variance in first two dimension
X = torch.diag(torch.FloatTensor([3,2,1])).matmul(torch.randn(3,200))

# 1. Initialize Parameter
manifold_param = mnn.Parameter(manifold=mnn.Stiefel(3,2))

# 2. Define Cost - squared reconstruction error
def cost(X, w):
    wTX = torch.matmul(w.transpose(1,0), X)
    wwTX = torch.matmul(w, wTX)
    return torch.sum((X - wwTX)**2)

# 3. Optimize
optimizer = moptim.rAdagrad(params = [manifold_param], lr=1e-2)

for epoch in range(30):
    cost_step = cost(X, manifold_param)
    print(cost_step)
    cost_step.backward()
    optimizer.step()
    optimizer.zero_grad()

Using McTorch for Deep Learning

Multi Layer Perceptron Example

import torch
import mctorch.nn as mnn
import torch.nn.functional as F

# a torch module using constrained linear layers
class ManifoldMLP(nn.Module):
    def __init__(self):
        super(ManifoldMLP, self).__init__()
        self.layer1 = mnn.rLinear(in_features=28*28, out_features=100, weight_manifold=mnn.Stiefel)
        self.layer2 = mnn.rLinear(in_features=100, out_features=100, weight_manifold=mnn.PositiveDefinite)
        self.output = mnn.rLinear(in_features=100, out_features=10, weight_manifold=mnn.Stiefel)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.log_softmax(self.output(x), dim=0)
        return x

# create module object and compute cost by applying module on inputs
mlp_module = ManifoldMLP()
cost = mlp_module(inputs)

More examples added - here

Functionality Supported

This would be an ever increasing list of features. McTorch currently supports:

Manifolds

  • Stiefel
  • Positive Definite

All manifolds support k multiplier as well.

Optimizers

  • SGD
  • Adagrad
  • ConjugateGradient

Layers

  • Linear
  • Conv1d, Conv2d, Conv3d

Installation

After installing PyTorch can be installed with python setup.py install

Linux

source activate myenv
conda install numpy setuptools
# Add LAPACK support for the GPU if needed
conda install -c pytorch magma-cuda90 # or [magma-cuda80 | magma-cuda92 | magma-cuda100 ] depending on your cuda version
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch # or cudatoolkit=10.0 | cudatoolkit=10.1 | .. depending on your cuda version
pip install mctorch-lib

Release and Contribution

McTorch is currently under development and any contributions, suggestions and feature requests are welcome. We'd closely follow PyTorch stable versions to keep the base updated and will have our own versions for other additions.

McTorch is released under the open source 3-clause BSD License.

Team

Reference

Please cite [1] if you found this code useful.

McTorch, a manifold optimization library for deep learning

[1] M. Meghawanshi, P. Jawanpuria, A. Kunchukuttan, H. Kasai, and B. Mishra, McTorch, a manifold optimization library for deep learning

@techreport{meghwanshi2018mctorch,
  title={McTorch, a manifold optimization library for deep learning},
  author={Meghwanshi, Mayank and Jawanpuria, Pratik and Kunchukuttan, Anoop and Kasai, Hiroyuki and Mishra, Bamdev},
  institution={arXiv preprint arXiv:1810.01811},
  year={2018}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mctorch-lib-0.1.0.tar.gz (15.6 kB view details)

Uploaded Source

Built Distribution

mctorch_lib-0.1.0-py3-none-any.whl (19.5 kB view details)

Uploaded Python 3

File details

Details for the file mctorch-lib-0.1.0.tar.gz.

File metadata

  • Download URL: mctorch-lib-0.1.0.tar.gz
  • Upload date:
  • Size: 15.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for mctorch-lib-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3b8de42cce48ec79183a7b254030f9b9ae569e2cd6e8624ff5f68962422af238
MD5 c4dcdf539cc9a05c9b744ea308b83f44
BLAKE2b-256 a1a95157f9af38aa4117104a3cbdd10cdc3adbfd2493a13a6efad821f8ff8838

See more details on using hashes here.

File details

Details for the file mctorch_lib-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mctorch_lib-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for mctorch_lib-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 413750ed561223d4b08bbc97d63832d37bc6ff34fec41548ddd9cc5f393c0d64
MD5 f01c6e96a32e887178068515933acb65
BLAKE2b-256 3d87e78f240558083d16f69adeb89b92f1cf871aa1ec1eb58bb1cd047e3b9e5d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page