Skip to main content

Paper - Pytorch

Project description

[Paper Implementation] AN EVOLVED UNIVERSAL TRANSFORMER MEMORY

Join our Discord Subscribe on YouTube Connect on LinkedIn Follow on X.com

An open source implementation of the paper: "AN EVOLVED UNIVERSAL TRANSFORMER MEMORY"

Abstract:

Prior methods propose to offset the escalating costs of modern foundation models by dropping specific parts of their contexts with hand-designed rules, while attempting to preserve their original performance. We overcome this trade-off with Neural Attention Memory Models (NAMMs), introducing a learned network for memory management that improves both the performance and efficiency of transformers. We evolve NAMMs atop pre-trained transformers to provide different latent contexts focusing on the most relevant information for individual layers and attention heads. NAMMs are universally applicable to any model using selfattention as they condition exclusively on the values in the produced attention matrices. Learning NAMMs on a small set of problems, we achieve substantial performance improvements across multiple long-context benchmarks while cutting the model’s input contexts up to a fraction of the original sizes. We show the generality of our conditioning enables zero-shot transfer of NAMMs trained only on language to entirely new transformer architectures even across input modalities, with their benefits carrying over to vision and reinforcement learning.

Install

$ pip3 install -U 

Usage

def create_sample_inputs(
    batch_size: int = 2,
    seq_len: int = 1024,
    n_queries: int = 512,
    d_model: int = 256,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
    """Create sample inputs for NAMM testing.
    
    Args:
        batch_size: Batch size
        seq_len: Sequence length (number of tokens in KV cache)
        n_queries: Number of recent queries
        d_model: Model dimension
        device: Device to create tensors on
    
    Returns:
        Tuple of (kv_cache, attention_matrix)
    """
    logger.info(f"Creating sample inputs on device: {device}")
    
    # Create sample KV cache
    # In practice, these would be the key and value tensors from transformer layers
    kv_cache = {
        "key": torch.randn(batch_size, seq_len, d_model, device=device),
        "value": torch.randn(batch_size, seq_len, d_model, device=device)
    }
    
    # Create sample attention matrix
    # In practice, this would be the recent attention scores from transformer layers
    attention_matrix = torch.randn(batch_size, seq_len, n_queries, device=device)
    
    # Apply softmax to make it look like real attention scores
    attention_matrix = torch.softmax(attention_matrix, dim=1)
    
    logger.info(
        f"Created inputs - KV cache size: {kv_cache['key'].shape}, "
        f"Attention matrix size: {attention_matrix.shape}"
    )
    
    return kv_cache, attention_matrix

def main():
    """Main function demonstrating NAMM usage."""
    # Setup logging
    logger.remove()
    logger.add(lambda msg: print(msg, flush=True), colorize=True, level="INFO")
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Create NAMM instance with custom config
    config = NAMMConfig(
        update_interval=256,  # More frequent updates for demonstration
        stride_size=16,
        window_size=64,
        d_model=256,
        n_head=4,
        gamma=0.95,
        dropout=0.1
    )
    
    namm = create_namm(config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    namm = namm.to(device)
    
    logger.info(f"Created NAMM model on device: {device}")
    
    # Create sample inputs
    kv_cache, attention_matrix = create_sample_inputs(
        batch_size=2,
        seq_len=1024,
        n_queries=512,
        d_model=config.d_model,
        device=device
    )
    
    # Simulate multiple steps of processing
    n_steps = 1000
    retention_stats = []
    
    logger.info(f"Starting simulation for {n_steps} steps")
    
    for step in range(n_steps):
        # Process the KV cache
        updated_cache, _ = namm(kv_cache, attention_matrix)
        
        # Every few steps, evaluate retention
        if step % 100 == 0:
            stats = namm.evaluate_retention(kv_cache, attention_matrix)
            if stats:  # Only store if we got stats (remember NAMM only updates every update_interval)
                retention_stats.append(stats)
                logger.info(
                    f"Step {step}: Retention rate = {stats['retention_rate']:.2%}, "
                    f"Mean score = {stats['mean_score']:.3f}"
                )
        
        # Update KV cache and attention matrix for next step
        if updated_cache:  # If NAMM made updates
            kv_cache = updated_cache
            # Create new attention matrix for reduced sequence length
            _, new_seq_len, _ = kv_cache['key'].shape
            attention_matrix = torch.randn(
                2, new_seq_len, 512, device=device
            )
            attention_matrix = torch.softmax(attention_matrix, dim=1)
    
    # Print final statistics
    if retention_stats:
        avg_retention = sum(s['retention_rate'] for s in retention_stats) / len(retention_stats)
        logger.info(f"Average retention rate over simulation: {avg_retention:.2%}")

if __name__ == "__main__":
    main()

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

open_namm-0.0.1.tar.gz (8.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

open_namm-0.0.1-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file open_namm-0.0.1.tar.gz.

File metadata

  • Download URL: open_namm-0.0.1.tar.gz
  • Upload date:
  • Size: 8.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.12.6 Darwin/23.3.0

File hashes

Hashes for open_namm-0.0.1.tar.gz
Algorithm Hash digest
SHA256 34b0da7998e994dbfcae90517d9f9e075890ae796c932f301a29ab712e956cd5
MD5 bf285345de2f7aa1dc3b840f71d099bb
BLAKE2b-256 f3c0226fd7e82377fbeff36dfd8d96d5a4bc7052dad1344028c850819bf8183b

See more details on using hashes here.

File details

Details for the file open_namm-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: open_namm-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.12.6 Darwin/23.3.0

File hashes

Hashes for open_namm-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 72cd739d500b11c53bcf5f58cfe749d49d8aa7898364b2ffc27531451f0d69d4
MD5 021b61cbfff2a0a32bfd0d9ff372c8db
BLAKE2b-256 f0b4037d94e9fc2cdee3024afe8158e77d12ae8a67b40a3ec9c3c4b9b8e1726d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page