Skip to main content

Sparse Selective Hyper-Connections for stable and efficient deep residual learning

Project description

SHC: Sparse Selective Hyper-Connections

PyPI version Python 3.9+ PyTorch 2.0+ License: MIT Tests codecov

A PyTorch implementation of Sparse Selective Hyper-Connections for stable and efficient deep residual learning.

Overview

SHC replaces traditional residual connections with sparse mixtures of orthogonal routing matrices, providing:

  • Bounded spectral norm (ρ ≤ 1): Guarantees training stability
  • 16× faster routing: Via closed-form Cayley transform (vs. Sinkhorn iteration)
  • 3.3× KV cache reduction: Through learned low-rank factorization
  • O(L) inference: Optional SSM distillation for linear-time generation

Installation

# Clone repository
git clone https://github.com/your-org/shc.git
cd shc

# Create virtual environment
python -m venv venv
source venv/bin/activate  # Linux/Mac
# or: venv\Scripts\activate  # Windows

# Install dependencies
pip install -r shc/requirements.txt

# Install package in development mode
pip install -e .

Quick Start

Basic Usage

from shc.models import SHCTransformer, get_config

# Create model with predefined config
config = get_config('500m')  # Options: '500m', '1b', '3b', '7b'
model = SHCTransformer(config)

# Forward pass
import torch
input_ids = torch.randint(0, 32000, (2, 512))
logits = model(input_ids)

# Generate text
output = model.generate(
    input_ids[:, :10],  # prompt
    max_new_tokens=100,
    temperature=0.7,
    top_p=0.9,
)

Training

# Single GPU
python -m shc.scripts.train --model_size 500m --output_dir ./output

# Multi-GPU with DDP
torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 3b \
    --batch_size 32 \
    --learning_rate 3e-4

# FSDP for 7B+ models (memory efficient)
torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 7b \
    --use_fsdp \
    --mixed_precision bf16

Evaluation

# Run benchmarks
python -m shc.scripts.evaluate \
    --model_path ./output/final \
    --benchmarks bbh gsm8k mmlu

# Efficiency profiling
python -m shc.scripts.evaluate \
    --model_path ./output/final \
    --profile \
    --analyze_routing

SSM Distillation

from shc.models import SHCTransformer, SSMStudent
from shc.training import DistillationTrainer, DistillationConfig

# Load trained teacher
teacher = SHCTransformer.from_pretrained('path/to/teacher')

# Create student matching teacher dimensions
student = SSMStudent.from_teacher_config(teacher.config)

# Distill
config = DistillationConfig(max_steps=10000)
trainer = DistillationTrainer(teacher, student, config, train_loader)
trainer.train()

# Student generates with O(1) per step (no KV cache!)
output = student.generate(input_ids, max_new_tokens=100)

Architecture

Core Components

Component Description Reference
CayleyTransform Closed-form orthogonal matrix: Q = (I-A)(I+A)⁻¹ Eq. 9
SparseOrthogonalMixture H^res = Σ αᵢ(x)·Qᵢ with ρ ≤ 1 Eq. 7, Prop. 1
FactorizedKVCache Low-rank compression: x̄ ≈ UV^T Eq. 14
AdaptiveRankSelector Gumbel-Softmax rank selection Eq. 16
SHCBlock Complete block with triple routing Algorithm 1

Model Configurations

Size Hidden Layers Heads Parameters
500M 1024 24 16 ~500M
1B 2048 24 16 ~1B
3B 2560 32 32 ~3B
7B 4096 32 32 ~7B

Project Structure

shc/
├── __init__.py              # Package init
├── requirements.txt         # Dependencies
├── configs/
│   └── config.py           # Model/training configs
├── layers/
│   ├── cayley.py           # Cayley transform
│   ├── sparse_mixture.py   # Sparse orthogonal routing
│   ├── factorized_cache.py # KV cache compression
│   └── adaptive_rank.py    # Rank selection
├── blocks/
│   ├── attention.py        # Multi-head attention + RoPE
│   ├── feedforward.py      # SwiGLU FFN
│   └── shc_block.py        # Complete SHC block
├── models/
│   ├── embeddings.py       # Token/positional embeddings
│   ├── transformer.py      # SHCTransformer
│   └── ssm_student.py      # SSM for O(L) inference
├── training/
│   ├── distributed.py      # DDP/FSDP utilities
│   ├── optimizer.py        # Adam + cosine scheduler
│   ├── trainer.py          # Training loop
│   └── distillation.py     # Teacher→student distillation
├── data/
│   ├── dataset.py          # Dataset classes
│   └── dataloader.py       # Distributed data loading
├── evaluation/
│   ├── metrics.py          # PPL, accuracy, F1, BLEU
│   ├── benchmarks.py       # BBH, GSM8K, MMLU
│   └── profiler.py         # Efficiency profiling
└── scripts/
    ├── train.py            # Training CLI
    └── evaluate.py         # Evaluation CLI

Key Features

1. Stable Training via Orthogonal Routing

# Spectral norm is bounded by construction
routing = SparseOrthogonalMixture(n=4, k=2, hidden_dim=768)
H_res = routing(x)
spectral_norm = routing.get_spectral_norm(x)  # Always ≤ 1.0

2. Efficient Multi-GPU Training

from shc.training import setup_distributed, wrap_model_ddp

# Automatic distributed setup
rank, local_rank, world_size = setup_distributed()

# DDP for data parallelism
model = wrap_model_ddp(model)

# Or FSDP for memory efficiency
model = wrap_model_fsdp(model, mixed_precision=True)

3. Memory-Efficient KV Cache

from shc.layers import FactorizedKVCache

cache = FactorizedKVCache(n=4, d=768, r=1)
compressed = cache.compress(x_bar)       # 4×768 → 1 scalar
reconstructed = cache.decompress(compressed)  # 99% accurate

4. Routing Analysis

from shc.evaluation import RoutingAnalyzer

analyzer = RoutingAnalyzer(model)
analyzer.analyze_batch(input_ids)
stats = analyzer.get_summary()
# {'spectral_norms': {'mean': 0.98, 'max': 1.0}, 'mixing_entropy': {...}}

Benchmarks

Target performance (from paper):

Benchmark SHC MHC DenseRes
BBH (23 tasks) 42.3% 42.1% 40.8%
GSM8K 28.7% 28.5% 27.2%
MMLU (5-shot) 45.2% 45.0% 44.1%

Efficiency gains:

  • 16× speedup in routing computation
  • 3.3× reduction in KV cache memory
  • <1% overhead vs baseline Transformer

Requirements

  • Python 3.9+
  • PyTorch 2.0+
  • CUDA 11.8+ (for GPU training)

See shc/requirements.txt for full dependencies.

Citation

@article{shc2026,
  title={Sparse Selective Hyper-Connections: A Unified Framework for 
         Stable and Efficient Deep Residual Learning},
  author={...},
  journal={...},
  year={2026}
}

License

MIT License - see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparse_hyper_connections-0.1.0.tar.gz (76.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sparse_hyper_connections-0.1.0-py3-none-any.whl (81.3 kB view details)

Uploaded Python 3

File details

Details for the file sparse_hyper_connections-0.1.0.tar.gz.

File metadata

File hashes

Hashes for sparse_hyper_connections-0.1.0.tar.gz
Algorithm Hash digest
SHA256 cc2c4f179fbf77b7ec396869395592a9d56600abb8493e6286fd53fd4984d8c9
MD5 e8e6cb26cfc363d247ef11eb1773a7f5
BLAKE2b-256 0d6aa043d697ad05716500e28611e579c41f6ee0950e73759309726238908112

See more details on using hashes here.

File details

Details for the file sparse_hyper_connections-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for sparse_hyper_connections-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cfdced7352d71f7e76da98316533180f94241c095497a3ed42f7efa0f13ca74b
MD5 68b36be76b7b16f8fce0ad907c4e3055
BLAKE2b-256 d9e085cf055d9254f5b7823deb1f6ed8b0d0b1389e1330d61b6b72fae052a3c8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page