Skip to main content

Distributed PyTorch compilation cache for Baseten - Environment-aware, lock-free compilation cache management

Project description

https://www.notion.so/ml-infra/mega-base-cache-24291d247273805b8e20fe26677b7b0f

B10 Transfer

PyTorch compilation cache for Baseten deployments.

Usage

Synchronous Operations (Blocking)

import b10_transfer

# Inside model.load() function
def load():
    # Load cache before torch.compile()
    status = b10_transfer.load_compile_cache()

    # ...

    # Your model compilation
    model = torch.compile(model)
    # Warm up the model with dummy prompts, and arguments that would be typically used in your requests (e.g resolutions)
    dummy_input = "What is the capital of France?"
    model(dummy_input)

    # ...

    # Save cache after compilation
    if status != b10_transfer.LoadStatus.SUCCESS:
        b10_transfer.save_compile_cache()

Asynchronous Operations (Non-blocking)

import b10_transfer

def load_with_async_cache():
    # Start async cache load (returns immediately with operation ID)
    operation_id = b10_transfer.load_compile_cache_async()
    
    # Check status periodically
    while not b10_transfer.is_transfer_complete(operation_id):
        status = b10_transfer.get_transfer_status(operation_id)
        print(f"Cache load status: {status.status}")
        time.sleep(1)
    
    # Get final status
    final_status = b10_transfer.get_transfer_status(operation_id)
    if final_status.status == b10_transfer.AsyncTransferStatus.SUCCESS:
        print("Cache loaded successfully!")
    
    # Your model compilation...
    model = torch.compile(model)
    
    # Async save
    save_op_id = b10_transfer.save_compile_cache_async()
    
    # You can continue with other work while save happens in background
    # Or wait for completion if needed
    b10_transfer.wait_for_completion(save_op_id, timeout=300)  # 5 minute timeout

# With progress callback
def on_progress(operation_id: str):
    status = b10_transfer.get_transfer_status(operation_id)
    print(f"Transfer {operation_id}: {status.status}")

operation_id = b10_transfer.load_compile_cache_async(progress_callback=on_progress)

Generic Async Operations

You can also use the generic async system for custom transfer operations:

import b10_transfer
from pathlib import Path

def my_custom_callback(source: Path, dest: Path):
    # Your custom transfer logic here
    # This could be any file operation, compression, etc.
    shutil.copy2(source, dest)

# Start a generic async transfer
operation_id = b10_transfer.start_transfer_async(
    source=Path("/source/file.txt"),
    dest=Path("/dest/file.txt"),
    callback=my_custom_callback,
    operation_name="custom_file_copy",
    monitor_local=True,
    monitor_b10fs=False
)

# Use the same progress tracking as torch cache operations
b10_transfer.wait_for_completion(operation_id)

Configuration

Configure via environment variables:

# Cache directories
export TORCH_CACHE_DIR="/tmp/torchinductor_root"      # Default
export B10FS_CACHE_DIR="/cache/model/compile_cache"   # Default  
export LOCAL_WORK_DIR="/app"                          # Default

# Cache limits
export MAX_CACHE_SIZE_MB="1024"                       # 1GB default

How It Works

Environment-Specific Caching

The library automatically creates unique cache keys based on your environment:

torch-2.1.0_cuda-12.1_cc-8.6_triton-2.1.0 → cache_a1b2c3d4e5f6.latest.tar.gz
torch-2.0.1_cuda-11.8_cc-7.5_triton-2.0.1 → cache_x9y8z7w6v5u4.latest.tar.gz
torch-2.1.0_cpu_triton-none                → cache_m1n2o3p4q5r6.latest.tar.gz

Components used:

  • PyTorch version (e.g., torch-2.1.0)
  • CUDA version (e.g., cuda-12.1 or cpu)
  • GPU compute capability (e.g., cc-8.6 for A100)
  • Triton version (e.g., triton-2.1.0 or triton-none)

Cache Workflow

  1. Load Phase (startup): Generate environment key, check for matching cache in B10FS, extract to local directory
  2. Save Phase (after compilation): Create archive, atomic copy to B10FS with environment-specific filename

Lock-Free Race Prevention

Uses journal pattern with atomic filesystem operations for parallel-safe cache saves.

API Reference

Synchronous Functions

  • load_compile_cache() -> LoadStatus: Load cache from B10FS for current environment
  • save_compile_cache() -> SaveStatus: Save cache to B10FS with environment-specific filename
  • clear_local_cache() -> bool: Clear local cache directory
  • get_cache_info() -> Dict[str, Any]: Get cache status information for current environment
  • list_available_caches() -> Dict[str, Any]: List all cache files with environment details

Generic Asynchronous Functions

  • start_transfer_async(source, dest, callback, operation_name, **kwargs) -> str: Start any async transfer operation
  • get_transfer_status(operation_id: str) -> TransferProgress: Get current status of async operation
  • is_transfer_complete(operation_id: str) -> bool: Check if async operation has completed
  • wait_for_completion(operation_id: str, timeout=None) -> bool: Wait for async operation to complete
  • cancel_transfer(operation_id: str) -> bool: Attempt to cancel running operation
  • list_active_transfers() -> Dict[str, TransferProgress]: Get all active transfer operations

Torch Cache Async Functions

  • load_compile_cache_async(progress_callback=None) -> str: Start async cache load, returns operation ID
  • save_compile_cache_async(progress_callback=None) -> str: Start async cache save, returns operation ID

Status Enums

  • LoadStatus: SUCCESS, ERROR, DOES_NOT_EXIST, SKIPPED
  • SaveStatus: SUCCESS, ERROR, SKIPPED
  • AsyncTransferStatus: NOT_STARTED, IN_PROGRESS, SUCCESS, ERROR, INTERRUPTED, CANCELLED

Data Classes

  • TransferProgress: Contains operation_id, status, started_at, completed_at, error_message

Exceptions

  • CacheError: Base exception for cache operations
  • CacheValidationError: Path validation or compatibility check failed
  • CacheOperationInterrupted: Operation interrupted due to insufficient disk space

Performance Impact

Debugging

Enable debug logging:

import logging
logging.getLogger('b10_tcache').setLevel(logging.DEBUG)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

b10_transfer-0.0.2.tar.gz (25.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

b10_transfer-0.0.2-py3-none-any.whl (30.4 kB view details)

Uploaded Python 3

File details

Details for the file b10_transfer-0.0.2.tar.gz.

File metadata

  • Download URL: b10_transfer-0.0.2.tar.gz
  • Upload date:
  • Size: 25.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.4 CPython/3.11.13 Linux/6.5.0-45-generic

File hashes

Hashes for b10_transfer-0.0.2.tar.gz
Algorithm Hash digest
SHA256 3059e93c29431cabfee34996b1411bc7dd68df516fd4553548ab3f0fb4a1d836
MD5 1a21d7d7da4fc4a0e02cd922d6f0d03c
BLAKE2b-256 311c3a73f3b9477b84f66c3eb746542b7d88e7a9f83bf2c2e6a8cc3a792c662a

See more details on using hashes here.

File details

Details for the file b10_transfer-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: b10_transfer-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 30.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.4 CPython/3.11.13 Linux/6.5.0-45-generic

File hashes

Hashes for b10_transfer-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 cca61df4074b6ba99f62f98f198b7eb68b9c246bf990f4563c9171d72b9b688c
MD5 debe23b2712269a5685fb226e5ea48de
BLAKE2b-256 7c65f829c70a355d39abe945f95c416d6ea7fa8b1bf5344e61cc3b739e6b550b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page