Skip to main content

Sparse Selective Hyper-Connections for stable and efficient deep residual learning

Project description

SHC: Sparse Selective Hyper-Connections

PyPI version Python 3.9+ PyTorch 2.0+ License: MIT Tests Documentation

A PyTorch implementation of Sparse Selective Hyper-Connections for stable and efficient deep residual learning.

Overview

SHC replaces traditional residual connections with sparse mixtures of orthogonal routing matrices, providing:

  • Bounded spectral norm (ρ ≤ 1): Guarantees training stability
  • 16× faster routing: Via closed-form Cayley transform (vs. Sinkhorn iteration)
  • 3.3× KV cache reduction: Through learned low-rank factorization
  • O(L) inference: Optional SSM distillation for linear-time generation

Installation

# Clone repository
git clone https://github.com/rahvis/shc.git
cd shc

# Create virtual environment
python -m venv venv
source venv/bin/activate  # Linux/Mac
# or: venv\Scripts\activate  # Windows

# Install dependencies
pip install -r shc/requirements.txt

# Install package in development mode
pip install -e .

Quick Start

Basic Usage

from shc.models import SHCTransformer, get_config

# Create model with predefined config
config = get_config('500m')  # Options: '500m', '1b', '3b', '7b'
model = SHCTransformer(config)

# Forward pass
import torch
input_ids = torch.randint(0, 32000, (2, 512))
logits = model(input_ids)

# Generate text
output = model.generate(
    input_ids[:, :10],  # prompt
    max_new_tokens=100,
    temperature=0.7,
    top_p=0.9,
)

Training

# Single GPU
python -m shc.scripts.train --model_size 500m --output_dir ./output

# Multi-GPU with DDP
torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 3b \
    --batch_size 32 \
    --learning_rate 3e-4

# FSDP for 7B+ models (memory efficient)
torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 7b \
    --use_fsdp \
    --mixed_precision bf16

Evaluation

# Run benchmarks
python -m shc.scripts.evaluate \
    --model_path ./output/final \
    --benchmarks bbh gsm8k mmlu

# Efficiency profiling
python -m shc.scripts.evaluate \
    --model_path ./output/final \
    --profile \
    --analyze_routing

SSM Distillation

from shc.models import SHCTransformer, SSMStudent
from shc.training import DistillationTrainer, DistillationConfig

# Load trained teacher
teacher = SHCTransformer.from_pretrained('path/to/teacher')

# Create student matching teacher dimensions
student = SSMStudent.from_teacher_config(teacher.config)

# Distill
config = DistillationConfig(max_steps=10000)
trainer = DistillationTrainer(teacher, student, config, train_loader)
trainer.train()

# Student generates with O(1) per step (no KV cache!)
output = student.generate(input_ids, max_new_tokens=100)

Architecture

Core Components

Component Description Reference
CayleyTransform Closed-form orthogonal matrix: Q = (I-A)(I+A)⁻¹ Eq. 9
SparseOrthogonalMixture H^res = Σ αᵢ(x)·Qᵢ with ρ ≤ 1 Eq. 7, Prop. 1
FactorizedKVCache Low-rank compression: x̄ ≈ UV^T Eq. 14
AdaptiveRankSelector Gumbel-Softmax rank selection Eq. 16
SHCBlock Complete block with triple routing Algorithm 1

Model Configurations

Size Hidden Layers Heads Parameters
500M 1024 24 16 ~500M
1B 2048 24 16 ~1B
3B 2560 32 32 ~3B
7B 4096 32 32 ~7B

Project Structure

shc/
├── __init__.py              # Package init
├── requirements.txt         # Dependencies
├── configs/
│   └── config.py           # Model/training configs
├── layers/
│   ├── cayley.py           # Cayley transform
│   ├── sparse_mixture.py   # Sparse orthogonal routing
│   ├── factorized_cache.py # KV cache compression
│   └── adaptive_rank.py    # Rank selection
├── blocks/
│   ├── attention.py        # Multi-head attention + RoPE
│   ├── feedforward.py      # SwiGLU FFN
│   └── shc_block.py        # Complete SHC block
├── models/
│   ├── embeddings.py       # Token/positional embeddings
│   ├── transformer.py      # SHCTransformer
│   └── ssm_student.py      # SSM for O(L) inference
├── training/
│   ├── distributed.py      # DDP/FSDP utilities
│   ├── optimizer.py        # Adam + cosine scheduler
│   ├── trainer.py          # Training loop
│   └── distillation.py     # Teacher→student distillation
├── data/
│   ├── dataset.py          # Dataset classes
│   └── dataloader.py       # Distributed data loading
├── evaluation/
│   ├── metrics.py          # PPL, accuracy, F1, BLEU
│   ├── benchmarks.py       # BBH, GSM8K, MMLU
│   └── profiler.py         # Efficiency profiling
└── scripts/
    ├── train.py            # Training CLI
    └── evaluate.py         # Evaluation CLI

Key Features

1. Stable Training via Orthogonal Routing

# Spectral norm is bounded by construction
routing = SparseOrthogonalMixture(n=4, k=2, hidden_dim=768)
H_res = routing(x)
spectral_norm = routing.get_spectral_norm(x)  # Always ≤ 1.0

2. Efficient Multi-GPU Training

from shc.training import setup_distributed, wrap_model_ddp

# Automatic distributed setup
rank, local_rank, world_size = setup_distributed()

# DDP for data parallelism
model = wrap_model_ddp(model)

# Or FSDP for memory efficiency
model = wrap_model_fsdp(model, mixed_precision=True)

3. Memory-Efficient KV Cache

from shc.layers import FactorizedKVCache

cache = FactorizedKVCache(n=4, d=768, r=1)
compressed = cache.compress(x_bar)       # 4×768 → 1 scalar
reconstructed = cache.decompress(compressed)  # 99% accurate

4. Routing Analysis

from shc.evaluation import RoutingAnalyzer

analyzer = RoutingAnalyzer(model)
analyzer.analyze_batch(input_ids)
stats = analyzer.get_summary()
# {'spectral_norms': {'mean': 0.98, 'max': 1.0}, 'mixing_entropy': {...}}

Benchmarks

Target performance (from paper):

Benchmark SHC MHC DenseRes
BBH (23 tasks) 42.3% 42.1% 40.8%
GSM8K 28.7% 28.5% 27.2%
MMLU (5-shot) 45.2% 45.0% 44.1%

Efficiency gains:

  • 16× speedup in routing computation
  • 3.3× reduction in KV cache memory
  • <1% overhead vs baseline Transformer

Requirements

  • Python 3.9+
  • PyTorch 2.0+
  • CUDA 11.8+ (for GPU training)

See shc/requirements.txt for full dependencies.

Citation

@article{shc2026,
  title={Sparse Selective Hyper-Connections: A Unified Framework for 
         Stable and Efficient Deep Residual Learning},
  author={...},
  journal={...},
  year={2026}
}

License

MIT License - see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparse_hyper_connections-0.1.2.tar.gz (76.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sparse_hyper_connections-0.1.2-py3-none-any.whl (81.3 kB view details)

Uploaded Python 3

File details

Details for the file sparse_hyper_connections-0.1.2.tar.gz.

File metadata

File hashes

Hashes for sparse_hyper_connections-0.1.2.tar.gz
Algorithm Hash digest
SHA256 3d05ecb75bf4f1b7045ac0fdaaf80c52352e268c6affa116543c2b4ce9d8a893
MD5 99c129bef640fd27976a687ae96a1f80
BLAKE2b-256 319d5f9dea06df07204bf4e705763f83f1c16ba10a0dfa832f56cfee7291941e

See more details on using hashes here.

File details

Details for the file sparse_hyper_connections-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for sparse_hyper_connections-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7b36ebe35e041c6a2e8ea760d3c0d48eb4dd9a1f2b5844adb9447896f39536fa
MD5 6464103da316678f0d0aa1edda84e51b
BLAKE2b-256 a9cdd149131b80d533e8333d95c19d6d21185dbf5cf17fa5a0b077481ff8530d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page