A PyTorch helper library designed to save only the changes in a fine-tuned base model.
Project description
pytorch-diff-checkpoint
pytorch-diff-checkpoint
is a simple library designed to efficiently save only the modified parameters of a fine-tuned base model. This tool is particularly advantageous in scenarios where minimizing storage usage is crucial, as it ensures that only the altered parameters are stored.
It checks if a parameter is different by the "requires_grad" attribute and if the first element of the parameter is different.
This automatically handles saving parameters changes by the optimizer or statistics like batch norm.
Installation
poetry add pytorch-diff-checkpoint
Usage
import torch
from torch.nn import Module
from diff_checkpoint import DiffCheckpoint
class SimpleModel(Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = torch.nn.Linear(10, 10)
self.bn1 = torch.nn.BatchNorm1d(10)
self.fc2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.bn1(x)
x = self.fc2(x)
return x
model = SimpleModel()
# Create a DiffCheckpoint from the base model
diff_checkpoint = DiffCheckpoint.from_base_model(model)
# Train
# ...
# Save the differential checkpoint
diff_checkpoint.save(model, 'diff_checkpoint.pth')
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file pytorch_diff_checkpoint-1.1.0-py3-none-any.whl
.
File metadata
- Download URL: pytorch_diff_checkpoint-1.1.0-py3-none-any.whl
- Upload date:
- Size: 3.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.3 CPython/3.8.8 Linux/6.5.0-35-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 01f3b4c11c8c0208fbb359c15ca32da5df4e7dcea2e032b05dca3d1b164e7170 |
|
MD5 | 4307f613827591b9492f95b03b4f2320 |
|
BLAKE2b-256 | a4a21e61f40e054210b047e48c04043df8a2a425de53a7b80901f4b22037f4b6 |