Efficient and Automatic Rematerialization for Pytorch training
Project description
Rockmate
Warning: Currently, Rockmate relies on Gurobi to solve the Integer Linear Programming model.
Given a module and a sample (i.e. example input for it) and a memory budget,
Rockmate
builds a new torch.nn.Module
with equal forward and backward results while
keeping the memory peak under the given budget.
Backward pass updates original model parameters.
The model and sample should be on the GPU device.
Complete example
import torch
from rockmate import Rockmate
from torchvision.models import resnet101
device = torch.device("cuda")
model = resnet101().to(device)
x = torch.randn([100, 3, 128, 128]).to(device)
m_budget = 2 * 1024**3
rkMod = Rockmate(model, x, m_budget)
loss = rkMod(x).mean()
loss.backward()
rkMod.backward()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
rockmate-1.0.1.tar.gz
(64.2 kB
view hashes)