Skip to main content

No project description provided

Project description

Gradient Cache Contrastive Learning

Gradient Cache Contrastive Learning is a technique for unlimitedly scaling contrastive learning batch far beyond GPU/TPU memory constraint in Computer Vision. This means training that used to take heavy hardware, e.g. 8 V100 GPU, can be done on a single GPU. In addition, Gradient Cache allow users to replace big RAM GPU/TPU with much more cost efficient high FLOP low RAM systems. It is an adopted version of the paper Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup for the SimCLR, SupCon and SelfCon losses.

Installation

First install Pytorch. To install grad_cache_con_learning, run the following:

pip install grad-cache-con-learning

Usage

Gradient caching functionalities are implemented in GradCache class.

Initialization

The class's __init__ method defines the cache and has a functional parameter loss_fn to flexibly set your loss function.

grad_cache.GradCache(  
  model: nn.Module,  
  chunk_size: int,
  loss_fn: Callable[..., Tensor],
  loss_type: str = "SupCon",
  fp16: bool = False,
  scaler: GradScaler = None, 
)

model - The encoder model to be updated with with the Gradient Cache.

chunk_size - An integer indicating chunk size. This controls the sub-batch size to run forward-backward pass and should be set based on available GPU memory. A value too small will leave the GPU under utilized.

loss_fn - A loss function that takes representation tensors. It should compute the loss of the model based on the representations. The options are grad_cache_con_learning.losses.SupConLoss for SimCLR and SupCon, grad_cache_con_learning.losses.ConLoss for SelfCon.

loss_type - The loss type: 'SimCLR', 'SupCon' or 'SelfCon'.

fp16 - If True, run mixed precision training, which requires scaler to also be set.

scaler - A GradScaler object for automatic mixed precision training.

Cache Gradient Step

To run a cached gradient computatoin step, call cache_step function,

cache_step(  
  model_input,
  model_input: Tensor,
  labels: Tensor = None,  
  no_sync_except_last: bool = False,  
  **loss_kwargs  
)

Run a single gradient cache step. Upon function return, updates are computed for each model in self.models with gradient populated on the weights, as if the model_inputs are run as a huge single batch on sufficiently large hardware. Calling an GradCache object with __call__ will also invoke this function.

model_input - Tensor which is the input for the encoder model. labels - Tensor which contains the true labels for training. For SimCLR we do not provide labels.

no_sync_except_last - If True, under distributed setup, for each model, only trigger gradient reduction across processes for the last sub-batch's forward-backward pass. This could come in handy when dealing with a) large model, and/or b) non trivial number of sub-batches.

loss_kwargs - Additional keyword arguments to the loss function loss_fn.

Return - loss, the current steps loss scaler tensor (detached from the graph).

Example Usage with Contastive Losses (SimCLR, SupCon, SelfCon)

You need to preserve the original training procedure from the methods - SimCLR and SupCon, SelfCon. It works only with the original methods.

SupCon Example

First, you need to initialize the GradCache object,

from grad_cache_con_learning import GradCache
from grad_cache_con_learning.losses import SupConLoss
...
loss_fn = SupConLoss()
gc = GradCache(
  model=model, 
  chunk_sizes=2, 
  loss_fn=loss_fn,
  loss_type="SupCon"
)
...

Only replace:

...
optimizer.zero_grad()
features = model(images)
f1, f2 = torch.split(features, [batch_size, batch_size], dim=0)
features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
loss = criterion(features, y)
loss.backward()
optimizer.step()
...

with the following:

...
optimizer.zero_grad()
gc(x, y)
optimizer.step()
...

SimCLR Example

First, you need to initialize the GradCache object,

from grad_cache_con_learning import GradCache
from grad_cache_con_learning.losses import SupConLoss
...
loss_fn = SupConLoss()
gc = GradCache(
  model=model, 
  chunk_size=2, 
  loss_fn=loss_fn,
  loss_type="SimCLR" 
)
...

Only replace:

...
optimizer.zero_grad()
features = model(images)
f1, f2 = torch.split(features, [batch_size, batch_size], dim=0)
features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
loss = criterion(features)
loss.backward()
optimizer.step()
...

with the following:

...
optimizer.zero_grad()
gc(images)
optimizer.step()
...

SelfCon Example

First, you need to initialize the GradCache object,

from grad_cache_con_learning import GradCache
from grad_cache_con_learning.losses import ConLoss
...
loss_fn = ConLoss()
gc = GradCache(
  model=model, 
  chunk_size=2, 
  loss_fn=loss_fn,
  loss_type="SelfCon"
)
...

Only replace:

...
optimizer.zero_grad()
features = model(images)
f1, f2 = features
features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1)
loss = criterion(features, labels)
loss.backward()
optimizer.step()
...

with the following:

...
optimizer.zero_grad()
gc(images, labels)
optimizer.step()
...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

grad_cache_con_learning-1.0.2-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file grad_cache_con_learning-1.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for grad_cache_con_learning-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 19eff731dd2a5f280e46474bd57d596667f631c34c8b908db8a26205b075661c
MD5 94975bbaf071da39abfeeb2a33ff3fd1
BLAKE2b-256 785ddb24113a060edc794eab8efeac0440eaed8e7723b11f0d2d08600c6f62fc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page