A package for compressing PyTorch model checkpoints using the LC-Checkpoint method
Project description
LC-Checkpoint
LC-Checkpoint is a Python package that implements the LC-Checkpoint method for compressing and checkpointing PyTorch models during training.
Installation
You can install LC-Checkpoint using pip:
pip install lc_checkpoint
Usage
To use LC-Checkpoint in your PyTorch training script, you can follow these steps:
-
Import the LC-Checkpoint module:
from lc_checkpoint import main
-
Initialize the LC-Checkpoint method with your PyTorch model, optimizer, loss function, and other hyperparameters:
model = model optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) criterion = nn.CrossEntropyLoss() checkpoint_dir = 'checkpoints/lc-checkpoint' num_buckets = 5 num_bits = 32 prev_state_dict = net.state_dict() lc.initialize(model, optimizer, criterion, checkpoint_dir, num_buckets, num_bits)
-
Use the LC-Checkpoint method in your training loop:
for epoch in range(epochs): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # Load the previous checkpoints if exist try: # Find the latest checkpoint file lc_checkpoint_files = glob.glob(os.path.join('checkpoints/lc-checkpoint', 'lc_checkpoint_epoch*.pt')) latest_checkpoint_file = max(lc_checkpoint_files, key=os.path.getctime) prev_state_dict, epoch_loaded = lc.load_checkpoint(latest_checkpoint_file) print('Restored latest checkpoint:', latest_checkpoint_file) latest_checkpoint_file_size = os.path.getsize(latest_checkpoint_file) latest_checkpoint_file_size_kb = latest_checkpoint_file_size / 1024 print('Latest checkpoint file size:', latest_checkpoint_file_size_kb, 'KB') restore_time = time.time() - start_time total_restore_time += restore_time print('Time taken to restore checkpoint:', restore_time) start_time = time.time() # reset the start time print('-' * 50) except: pass # Get the inputs and labels inputs, labels = data # Zero the parameter gradients optimizer.zero_grad() # Forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() new_state_dict = net.state_dict() δt = np.concatenate([tensor.numpy().flatten() for tensor in new_state_dict.values()]) # convert each tensor to a numpy array and concatenate them prev_state_dict = new_state_dict # Save the checkpoint compressed_data = lc.compress_data(δt, num_bits=num_bits, k=num_buckets) save_start_time = time.time() # record the start time lc.save_checkpoint('checkpoint.pt', compressed_data, epoch, i) save_time = time.time() - save_start_time # calculate the time taken to save the checkpoint # Print statistics running_loss += loss.item() if i % 1 == 0: # print every 1 mini-batches print('[Epoch: %d, Iteration: %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 1)) running_loss = 0.0 print('Time taken to save checkpoint:', save_time)
API Reference
lc_checkpoint.initialize(model, optimizer, criterion, checkpoint_dir, num_buckets, num_bits)
Initializes the LC-Checkpoint method with the given PyTorch model, optimizer, loss function, checkpoint directory, number of buckets, and number of bits.
lc_checkpoint.compress_data(δt, num_bits=num_bits, k=num_buckets, treshold=True)
Compresses the model parameters and returns the compressed data.
lc_checkpoint.decode_data(encoded)
Decodes the compressed data and returns the original model parameters.
lc_checkpoint.save_checkpoint(filename, compressed_data, epoch, iteration)
Saves the compressed data to a file with the given filename, epoch, and iteration.
lc_checkpoint.load_checkpoint(filename)
Loads the compressed data from a file with the given filename.
lc_checkpoint.calculate_compression_rate(prev_state_dict, num_bits=num_bits, num_buckets=num_buckets)
Calculates the compression rate of the LC-Checkpoint method based on the previous state dictionary and the current number of bits and buckets.
License
LC-Checkpoint is licensed under the MIT License. See the LICENSE file for more information.
Acknowledgements
LC-Checkpoint is based on paper "On Efficient Constructions of Checkpoints" authored by Yu Chen, Zhenming Liu, Bin Ren, Xin Jin.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for lc_checkpoint-0.2.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 154a7d7c20ea3b1b198b0d9034fa8dec18db459fe9859b02d200ed49900900d0 |
|
MD5 | aebd99470b70775bc4710697f3867c23 |
|
BLAKE2b-256 | 5d255a6dca987b8d2a54070a82e987bd3dc9da82deee0af6a858af75e86f7c87 |