Skip to main content

Maximize memory utilization with PyTorch.

Project description

torch-max-mem

Tests Cookiecutter template from @cthoyt PyPI PyPI - Python Version PyPI - License Documentation Status Code style: black

This package provides decorators for memory utilization maximization with PyTorch and CUDA by starting with a maximum parameter size and applying successive halving until no more out-of-memory exception occurs.

💪 Getting Started

Assume you have a function for batched computation of nearest neighbors using brute-force distance calculation.

import torch

def knn(x, y, batch_size, k: int = 3):
    return torch.cat(
        [
            torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
            for start in range(0, x.shape[0], batch_size)
        ],
        dim=0,
    )

With torch_max_mem you can decorate this function to reduce the batch size until no more out-of-memory error occurs.

import torch
from torch_max_mem import MemoryUtilizationMaximizer


@MemoryUtilizationMaximizer()
def knn(x, y, batch_size, k: int = 3):
    return torch.cat(
        [
            torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
            for start in range(0, x.shape[0], batch_size)
        ],
        dim=0,
    )

In the code, you can now always pass the largest sensible batch size, e.g.,

x = torch.rand(100, 100, device="cuda")
y = torch.rand(200, 100, device="cuda")
knn(x, y, batch_size=x.shape[0])

🚀 Installation

The most recent release can be installed from PyPI with:

$ pip install torch_max_mem

The most recent code and data can be installed directly from GitHub with:

$ pip install git+https://github.com/mberr/torch-max-mem.git

To install in development mode, use the following:

$ git clone git+https://github.com/mberr/torch-max-mem.git
$ cd torch-max-mem
$ pip install -e .

👐 Contributing

Contributions, whether filing an issue, making a pull request, or forking, are appreciated. See CONTRIBUTING.md for more information on getting involved.

👋 Attribution

Parts of the logic have been developed with Laurent Vermue for PyKEEN.

⚖️ License

The code in this package is licensed under the MIT License.

🍪 Cookiecutter

This package was created with @audreyfeldroy's cookiecutter package using @cthoyt's cookiecutter-snekpack template.

🛠️ For Developers

See developer instrutions

The final section of the README is for if you want to get involved by making a code contribution.

🥼 Testing

After cloning the repository and installing tox with pip install tox, the unit tests in the tests/ folder can be run reproducibly with:

$ tox

Additionally, these tests are automatically re-run with each commit in a GitHub Action.

📖 Building the Documentation

$ tox -e docs

📦 Making a Release

After installing the package in development mode and installing tox with pip install tox, the commands for making a new release are contained within the finish environment in tox.ini. Run the following from the shell:

$ tox -e finish

This script does the following:

  1. Uses Bump2Version to switch the version number in the setup.cfg and src/torch_max_mem/version.py to not have the -dev suffix
  2. Packages the code in both a tar archive and a wheel
  3. Uploads to PyPI using twine. Be sure to have a .pypirc file configured to avoid the need for manual input at this step
  4. Push to GitHub. You'll need to make a release going with the commit where the version was bumped.
  5. Bump the version to the next patch. If you made big changes and want to bump the version by minor, you can use tox -e bumpversion minor after.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_max_mem-0.0.2.tar.gz (14.1 kB view details)

Uploaded Source

Built Distribution

torch_max_mem-0.0.2-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_max_mem-0.0.2.tar.gz.

File metadata

  • Download URL: torch_max_mem-0.0.2.tar.gz
  • Upload date:
  • Size: 14.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.10

File hashes

Hashes for torch_max_mem-0.0.2.tar.gz
Algorithm Hash digest
SHA256 9ed946f15aed04571f1a0eee1e828019b10efc2ae41937740cd8eefe9abfbe0a
MD5 a3f8013b666d2ec46d39630e1ec9f759
BLAKE2b-256 510d36b88da015134cebf98d7280270fd087e6ed2049be2cb131b1fd4d6ee852

See more details on using hashes here.

File details

Details for the file torch_max_mem-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_max_mem-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7e58cc6ffbfcf103a3b2eb3ada23e5783ead96b4d2b798b7b1932f1d01742de3
MD5 d9d9f07681435cacdc53bc86b5e129f5
BLAKE2b-256 05e08556e282f6d2a11f9a3d89002cb67dfc8857c26d4d820e3e7dbbd6c1be31

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page