Skip to main content

Torch-Utils, a library containing all necessary and new DL development utilities using PyTorch.

Project description

torch-utils

This repository contains useful functions and classes for Deep Learning engineers using PyTorch.

Installation

You can install this package using pip. The name of the package in PyPI is pytorch-utilities:

pip install pytorch-utilities

Cosine Annealing with Linear Warmup Learning Rate

Using this scheduler is as simple as using a default PyTorch scheduler.

Example usage:

import torch
from torch.optim import AdamW
from torchutils.schedulers import CosineAnnealingLinearWarmup


# Initialize your model and dataloader
# model = ...
# dataloader = ...
# loss_fn = ...

# Initialize the optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=0.0005)
scheduler = CosineAnnealingLinearWarmup(optimizer, warmup_epochs=5, max_epochs=100)

# If you want to step the scheduler after each iteration (batch), adjust the warmup_epochs and max_epochs accordingly
# scheduler = CosineAnnealingLinearWarmup(optimizer, warmup_epochs=5 * len(dataloader), max_epochs=100 * len(dataloader))

# Training loop
for epoch in range(100):
    for inputs, targets in dataloader:
        optimizer.zero_grad()
  
        # Forward pass
        outputs = model(inputs)
  
        # Compute loss
        loss = loss_fn(outputs, targets)
  
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # If you want to step the scheduler after each iteration (batch), uncomment the following line
        # scheduler.step()
  
    # If you're stepping the scheduler after each epoch, do it here
    scheduler.step()

Layer-wise Learning Rate Decay

Using layerwise_lrd, you can set different learning rates for different layers in your model, from the first layer to the last in an ascending order. This is a widely used fine-tuning technique in Deep Vision models that ensures the model keeps most of its learned parameters in the first layers, as the features extracted in these layers are usually low level such as edges and shapes which are beneficial in most image domains and do not need much of a change.

Currently, only ViT models implemented in timm, or with layer names like the ones implemented in it are supported.

Example Usage:

from torchutils.schedulers import layerwise_lrd


# Load the model
model = timm.create_model('vit_base_patch14_dinov2.lvd142m', num_classes=1000)

# Fetch model's parameter groups (in place of `model.parameters()`)
param_groups = layerwise_lrd(
    model,
    weight_decay=0.05,
    no_weight_decay_list=model.no_weight_decay(),
    layer_decay=0.75,
)

# Set the optimizer
optimizer = torch.optim.AdamW(param_groups, lr=0.001)

# Rest of your training code as usual
# ...

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_utilities-0.0.2.tar.gz (176.9 kB view details)

Uploaded Source

Built Distribution

pytorch_utilities-0.0.2-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_utilities-0.0.2.tar.gz.

File metadata

  • Download URL: pytorch_utilities-0.0.2.tar.gz
  • Upload date:
  • Size: 176.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.5

File hashes

Hashes for pytorch_utilities-0.0.2.tar.gz
Algorithm Hash digest
SHA256 f51cf898a2939ee9115d1885783f5c76d4a4bfac1ee25b17ae13c0c7e8e8ba89
MD5 cc7cc7251441b17b26d712754d148e10
BLAKE2b-256 ee0b00ba029ac441f7b4fc5a0a75ca51dedd57d6b54dff108cf8cdc1a7058bae

See more details on using hashes here.

File details

Details for the file pytorch_utilities-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_utilities-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 95a64d531d770cd783aa066da42370a9a3e7539510817aab9919db5cbd9bbdcb
MD5 f9d0bf0bf84b3092b2431fa96923e67b
BLAKE2b-256 6efecfdd73b26d193dccf85133c3bc1a3e3ee84525bbc7201edeb99d910de905

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page