Sparse Selective Hyper-Connections for stable and efficient deep residual learning
Project description
SHC: Sparse Selective Hyper-Connections
A PyTorch implementation of Sparse Selective Hyper-Connections for stable and efficient deep residual learning.
Overview
SHC replaces traditional residual connections with sparse mixtures of orthogonal routing matrices, providing:
- Bounded spectral norm (ρ ≤ 1): Guarantees training stability
- 16× faster routing: Via closed-form Cayley transform (vs. Sinkhorn iteration)
- 3.3× KV cache reduction: Through learned low-rank factorization
- O(L) inference: Optional SSM distillation for linear-time generation
Installation
# Clone repository
git clone https://github.com/your-org/shc.git
cd shc
# Create virtual environment
python -m venv venv
source venv/bin/activate # Linux/Mac
# or: venv\Scripts\activate # Windows
# Install dependencies
pip install -r shc/requirements.txt
# Install package in development mode
pip install -e .
Quick Start
Basic Usage
from shc.models import SHCTransformer, get_config
# Create model with predefined config
config = get_config('500m') # Options: '500m', '1b', '3b', '7b'
model = SHCTransformer(config)
# Forward pass
import torch
input_ids = torch.randint(0, 32000, (2, 512))
logits = model(input_ids)
# Generate text
output = model.generate(
input_ids[:, :10], # prompt
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
)
Training
# Single GPU
python -m shc.scripts.train --model_size 500m --output_dir ./output
# Multi-GPU with DDP
torchrun --nproc_per_node=8 -m shc.scripts.train \
--model_size 3b \
--batch_size 32 \
--learning_rate 3e-4
# FSDP for 7B+ models (memory efficient)
torchrun --nproc_per_node=8 -m shc.scripts.train \
--model_size 7b \
--use_fsdp \
--mixed_precision bf16
Evaluation
# Run benchmarks
python -m shc.scripts.evaluate \
--model_path ./output/final \
--benchmarks bbh gsm8k mmlu
# Efficiency profiling
python -m shc.scripts.evaluate \
--model_path ./output/final \
--profile \
--analyze_routing
SSM Distillation
from shc.models import SHCTransformer, SSMStudent
from shc.training import DistillationTrainer, DistillationConfig
# Load trained teacher
teacher = SHCTransformer.from_pretrained('path/to/teacher')
# Create student matching teacher dimensions
student = SSMStudent.from_teacher_config(teacher.config)
# Distill
config = DistillationConfig(max_steps=10000)
trainer = DistillationTrainer(teacher, student, config, train_loader)
trainer.train()
# Student generates with O(1) per step (no KV cache!)
output = student.generate(input_ids, max_new_tokens=100)
Architecture
Core Components
| Component | Description | Reference |
|---|---|---|
CayleyTransform |
Closed-form orthogonal matrix: Q = (I-A)(I+A)⁻¹ | Eq. 9 |
SparseOrthogonalMixture |
H^res = Σ αᵢ(x)·Qᵢ with ρ ≤ 1 | Eq. 7, Prop. 1 |
FactorizedKVCache |
Low-rank compression: x̄ ≈ UV^T | Eq. 14 |
AdaptiveRankSelector |
Gumbel-Softmax rank selection | Eq. 16 |
SHCBlock |
Complete block with triple routing | Algorithm 1 |
Model Configurations
| Size | Hidden | Layers | Heads | Parameters |
|---|---|---|---|---|
| 500M | 1024 | 24 | 16 | ~500M |
| 1B | 2048 | 24 | 16 | ~1B |
| 3B | 2560 | 32 | 32 | ~3B |
| 7B | 4096 | 32 | 32 | ~7B |
Project Structure
shc/
├── __init__.py # Package init
├── requirements.txt # Dependencies
├── configs/
│ └── config.py # Model/training configs
├── layers/
│ ├── cayley.py # Cayley transform
│ ├── sparse_mixture.py # Sparse orthogonal routing
│ ├── factorized_cache.py # KV cache compression
│ └── adaptive_rank.py # Rank selection
├── blocks/
│ ├── attention.py # Multi-head attention + RoPE
│ ├── feedforward.py # SwiGLU FFN
│ └── shc_block.py # Complete SHC block
├── models/
│ ├── embeddings.py # Token/positional embeddings
│ ├── transformer.py # SHCTransformer
│ └── ssm_student.py # SSM for O(L) inference
├── training/
│ ├── distributed.py # DDP/FSDP utilities
│ ├── optimizer.py # Adam + cosine scheduler
│ ├── trainer.py # Training loop
│ └── distillation.py # Teacher→student distillation
├── data/
│ ├── dataset.py # Dataset classes
│ └── dataloader.py # Distributed data loading
├── evaluation/
│ ├── metrics.py # PPL, accuracy, F1, BLEU
│ ├── benchmarks.py # BBH, GSM8K, MMLU
│ └── profiler.py # Efficiency profiling
└── scripts/
├── train.py # Training CLI
└── evaluate.py # Evaluation CLI
Key Features
1. Stable Training via Orthogonal Routing
# Spectral norm is bounded by construction
routing = SparseOrthogonalMixture(n=4, k=2, hidden_dim=768)
H_res = routing(x)
spectral_norm = routing.get_spectral_norm(x) # Always ≤ 1.0
2. Efficient Multi-GPU Training
from shc.training import setup_distributed, wrap_model_ddp
# Automatic distributed setup
rank, local_rank, world_size = setup_distributed()
# DDP for data parallelism
model = wrap_model_ddp(model)
# Or FSDP for memory efficiency
model = wrap_model_fsdp(model, mixed_precision=True)
3. Memory-Efficient KV Cache
from shc.layers import FactorizedKVCache
cache = FactorizedKVCache(n=4, d=768, r=1)
compressed = cache.compress(x_bar) # 4×768 → 1 scalar
reconstructed = cache.decompress(compressed) # 99% accurate
4. Routing Analysis
from shc.evaluation import RoutingAnalyzer
analyzer = RoutingAnalyzer(model)
analyzer.analyze_batch(input_ids)
stats = analyzer.get_summary()
# {'spectral_norms': {'mean': 0.98, 'max': 1.0}, 'mixing_entropy': {...}}
Benchmarks
Target performance (from paper):
| Benchmark | SHC | MHC | DenseRes |
|---|---|---|---|
| BBH (23 tasks) | 42.3% | 42.1% | 40.8% |
| GSM8K | 28.7% | 28.5% | 27.2% |
| MMLU (5-shot) | 45.2% | 45.0% | 44.1% |
Efficiency gains:
- 16× speedup in routing computation
- 3.3× reduction in KV cache memory
- <1% overhead vs baseline Transformer
Requirements
- Python 3.9+
- PyTorch 2.0+
- CUDA 11.8+ (for GPU training)
See shc/requirements.txt for full dependencies.
Citation
@article{shc2026,
title={Sparse Selective Hyper-Connections: A Unified Framework for
Stable and Efficient Deep Residual Learning},
author={...},
journal={...},
year={2026}
}
License
MIT License - see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sparse_hyper_connections-0.1.0.tar.gz.
File metadata
- Download URL: sparse_hyper_connections-0.1.0.tar.gz
- Upload date:
- Size: 76.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cc2c4f179fbf77b7ec396869395592a9d56600abb8493e6286fd53fd4984d8c9
|
|
| MD5 |
e8e6cb26cfc363d247ef11eb1773a7f5
|
|
| BLAKE2b-256 |
0d6aa043d697ad05716500e28611e579c41f6ee0950e73759309726238908112
|
File details
Details for the file sparse_hyper_connections-0.1.0-py3-none-any.whl.
File metadata
- Download URL: sparse_hyper_connections-0.1.0-py3-none-any.whl
- Upload date:
- Size: 81.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cfdced7352d71f7e76da98316533180f94241c095497a3ed42f7efa0f13ca74b
|
|
| MD5 |
68b36be76b7b16f8fce0ad907c4e3055
|
|
| BLAKE2b-256 |
d9e085cf055d9254f5b7823deb1f6ed8b0d0b1389e1330d61b6b72fae052a3c8
|