Skip to main content

A comprehensive parameter tracking and storage system for TensorFlow models using Zarr

Project description

ParamLake

A comprehensive solution for tracking, storing, and analyzing deep learning model parameters, gradients, and activations during training. ParamLake uses advanced storage technologies for efficient management of model data.

Features

  • Minimal Code Changes: Simply add a decorator to your training function
  • Automatic Gradient Capture: Multiple methods for tracking gradients without manual instrumentation
  • Comprehensive Data Collection: Capture trainable weights, non-trainable variables, gradients, and activations
  • Optimized Storage: Specialized chunk sizes and compression strategies for different tensor types
  • Efficient Analysis: Tools for analyzing and visualizing model parameters and gradients
  • Transactional Storage: Supports Icechunk for cloud-native transactional tensor storage
  • Framework Agnostic Design: Core schema designed to work across TensorFlow, PyTorch, and JAX (TensorFlow implementation provided)
  • Flexible Configuration: YAML-based configuration for customizing what and how data is collected
  • Production Ready: Optimized for minimal training overhead while providing comprehensive parameter tracking

Installation

# Basic installation
pip install paramlake

# With Icechunk support (for transactional cloud storage)
pip install paramlake icechunk

# With visualization support
pip install paramlake matplotlib

Quick Start

import tensorflow as tf
from paramlake import paramlake

# 1. Optional: Configure via YAML
# config.yaml:
# capture_frequency: 1  # every epoch
# compression:
#   algorithm: blosc_zstd
#   level: 3
# output_path: "model_data.zarr"
# gradients:
#   enabled: true
#   auto_tracking: true
#   track_method: "auto"  # Can be "auto", "train_step", "optimizer", or "callback"

# 2. Add the decorator to your training function
@paramlake(config="config.yaml")  # or inline config: @paramlake(capture_frequency=5)
def train_model():
    # Define and train your model as usual
    model = tf.keras.Sequential([...])
    model.compile(...)
    model.fit(...)

# 3. Call your training function - ParamLake will automatically log parameters and gradients
train_model()

# 4. Analyze the data
from paramlake import ZarrModelAnalyzer

analyzer = ZarrModelAnalyzer("model_data.zarr")
analyzer.plot_weight_evolution("dense_1/kernel")

# 5. Analyze gradient behavior
analyzer.plot_gradient_norm_by_layer()
gradient_stats = analyzer.analyze_gradient_statistics()
print(f"Gradient coverage: {gradient_stats['summary']['gradient_coverage']:.2%}")

Automatic Gradient Capture

ParamLake provides multiple methods to automatically capture gradients during training:

# Configure gradient tracking method in the decorator
@paramlake(
    gradients={
        "enabled": True,
        "auto_tracking": True,
        "track_method": "auto"  # Automatically select the best method
    }
)
def train_model():
    model = create_model()
    model.compile(...)
    model.fit(...)
    return model

# Alternatively, use a configuration file
@paramlake(config="paramlake_config.yaml")
def train_model():
    # ParamLake will handle gradient tracking based on config file settings
    model = create_model()
    model.compile(...)
    model.fit(...)
    return model

The available gradient tracking methods are:

  • "auto": Automatically detect and use the best method for the model
  • "train_step": Override the model's train_step method
  • "optimizer": Override the optimizer's apply_gradients method
  • "callback": Use a callback-based approach with GradientTape

Configuration Options

ParamLake can be configured through a YAML file or by passing parameters directly to the decorator:

# Basic options
output_path: "model_data.zarr"  # Where to store the dataset
capture_frequency: 5  # Capture every 5 steps/epochs
capture_gradients: true  # Whether to capture gradients
capture_activations: false  # Whether to capture activations

# Gradient options
gradients:
  enabled: true
  auto_tracking: true
  track_method: "auto"  # "auto", "train_step", "optimizer", or "callback"

# Layer filtering
include_layers: ["dense*", "conv*"]  # Only include layers matching patterns
exclude_layers: ["batch_normalization*"]  # Exclude specific layers

# Storage optimization
compression:
  algorithm: blosc_zstd  # Compression algorithm: blosc, zstd, lz4, etc.
  level: 3  # Compression level (higher = more compression but slower)
  shuffle: true  # Whether to shuffle data before compression

# Gradient-specific compression
gradient_compression:
  algorithm: blosc_zstd
  level: 5  # Higher compression for gradients
  shuffle: true

# Chunking strategy
chunking:
  time_dimension: 1  # Number of time steps per chunk
  spatial_dimensions: auto  # Automatic sizing based on tensor shape
  target_chunk_size: 1048576  # Target chunk size in bytes (1MB)
  gradient_chunk_size: 524288  # Smaller chunks for gradients (512KB)

Cloud Storage with Icechunk

ParamLake supports Icechunk, a transactional storage engine for tensor data designed for cloud object storage. This provides:

  • Transactional Consistency: Prevent data corruption when multiple processes write to the store
  • Version Control: Track model parameters across different training runs with branches and tags
  • Time Travel: Go back to previous states of model parameters for comparison
  • Cloud Optimization: Optimized for S3, GCS, and Azure blob storage

Using ParamLake with Icechunk

import tensorflow as tf
from paramlake import paramlake

# Configure S3 storage with Icechunk backend
@paramlake(
    storage_backend="icechunk",
    storage_type="s3",
    bucket="paramlake",
    prefix="mnist_training",
    region="us-east-1",
    create_repo=True,
    icechunk={
        "commit_frequency": 5,  # Commit changes every 5 epochs
        "tag_snapshots": True,  # Create tags for snapshots
    },
    capture_frequency=1,
    gradients={
        "enabled": True,
        "auto_tracking": True
    }
)
def train_model():
    # Train your model as usual
    model = tf.keras.Sequential([...])
    model.compile(...)
    model.fit(...)
    return model

# Analyze data with IcechunkModelAnalyzer
from paramlake import IcechunkModelAnalyzer

analyzer = IcechunkModelAnalyzer({
    "type": "s3",
    "bucket": "paramlake",
    "prefix": "mnist_training"
})

# Analyze snapshots, compare runs, gradient behavior, etc.
analyzer.plot_weight_evolution("dense/kernel")
analyzer.plot_gradient_norm_by_layer()
gradient_stats = analyzer.analyze_gradient_statistics()

Analyzing the Data

ParamLake provides utilities for analyzing the collected data:

from paramlake import ZarrModelAnalyzer

analyzer = ZarrModelAnalyzer("model_data.zarr")

# Get layer statistics over time
stats = analyzer.get_layer_stats("dense_1/kernel")

# Plot weight evolution
analyzer.plot_weight_evolution("dense_1/kernel")

# Analyze gradients
gradient_stats = analyzer.analyze_gradient_statistics()
for layer_name, layer_stats in gradient_stats["layer_stats"].items():
    for tensor_name, tensor_stats in layer_stats.items():
        print(f"{layer_name}/{tensor_name}:")
        print(f"  Mean gradient magnitude: {tensor_stats['mean_abs']:.6f}")
        print(f"  Max gradient magnitude: {tensor_stats['max']:.6f}")
        print(f"  Zero fraction: {tensor_stats['zero_fraction']:.2%}")

# Plot gradient norms
analyzer.plot_gradient_norm_by_layer()

# Compare two training runs
analyzer.compare_runs("run1.zarr", "run2.zarr")

For Icechunk storage, use the IcechunkModelAnalyzer:

from paramlake import IcechunkModelAnalyzer

# Analyze S3 storage
analyzer = IcechunkModelAnalyzer({
    "type": "s3", 
    "bucket": "my-bucket", 
    "prefix": "my-training-run"
})

# Get training history (snapshots)
history = analyzer.get_training_history()

# Compare snapshots
analyzer.plot_snapshot_comparison(
    other_snapshot_id="H5CCPE350FJV69V9D0HG",
    layer_name="dense/kernel"
)

# Analyze gradients across snapshots
for snapshot in history[:3]:  # Look at the latest 3 snapshots
    temp_analyzer = IcechunkModelAnalyzer({
        "type": "s3", 
        "bucket": "my-bucket", 
        "prefix": "my-training-run"
    }, snapshot_id=snapshot["id"])
    
    # Get gradient statistics
    grad_stats = temp_analyzer.analyze_gradient_statistics()
    print(f"Snapshot {snapshot['id']}: Gradient coverage {grad_stats['summary']['gradient_coverage']:.2%}")

Extensibility

ParamLake is designed to be framework-agnostic. While the current implementation focuses on TensorFlow, the schema and storage mechanism are designed to support other frameworks like PyTorch and JAX.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

paramlake-0.1.1.tar.gz (73.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

paramlake-0.1.1-py3-none-any.whl (59.5 kB view details)

Uploaded Python 3

File details

Details for the file paramlake-0.1.1.tar.gz.

File metadata

  • Download URL: paramlake-0.1.1.tar.gz
  • Upload date:
  • Size: 73.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.14

File hashes

Hashes for paramlake-0.1.1.tar.gz
Algorithm Hash digest
SHA256 edb14fffda1358704923dcf5f301bbcef485e64be373135d8cffa62e3b29b626
MD5 00e1fa4dcac57b98d0d29a42cfe3a9b3
BLAKE2b-256 f8035d41cd6d0b065f669225ca276e76d5ce94cf2ca6a5935e33a001a56fb1b9

See more details on using hashes here.

File details

Details for the file paramlake-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: paramlake-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 59.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.14

File hashes

Hashes for paramlake-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 77a8c9b8d753a8c6c698097bae57b228ebe6189bf71a352d241b94517eae76e2
MD5 e690f3a3d24ce4ee8208446ffab214e4
BLAKE2b-256 e2299bb07a255dd5a38b1ff869f442cf062ab021076a698f7240fabe70cfdf4a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page