Clone and prune transformer models with new tokenizers
Project description
🔄 Transformer Cloner
Clone and prune transformer models with new tokenizers. Create smaller, more efficient models by mapping vocabularies, reducing dimensions, and pruning layers.
Perfect for:
- 🌍 Language adaptation: Use a custom tokenizer optimized for your language
- 📉 Model compression: Create smaller models for edge deployment
- 🎓 Knowledge distillation: Generate student models from teacher models
- 🔬 Research: Experiment with different model architectures
📦 Installation
pip install transformer-cloner
Requirements:
- Python 3.10+
- PyTorch 2.0+
- Transformers 4.40+
🚀 Quick Start
Basic Cloning with New Tokenizer
from transformer_cloner import TransformerCloner, EmbeddingStrategy
# Initialize with original model and target tokenizer
cloner = TransformerCloner(
org_model_id="google/gemma-3-270m-it",
target_tokenizer_id="your-username/custom-tokenizer",
)
# Clone the model
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)
# Save the cloned model
model.save_pretrained("./cloned-model")
📖 Detailed Usage
1. Vocabulary Mapping (New Tokenizer)
When you have a new tokenizer with different vocabulary, the cloner maps each new token to the original model's embeddings:
from transformer_cloner import TransformerCloner, EmbeddingStrategy
cloner = TransformerCloner(
org_model_id="google/gemma-3-270m-it",
target_tokenizer_id="alibayram/gemma3-tr-v64k", # Turkish tokenizer
)
# View vocabulary samples
cloner.print_vocab_samples(n=10)
# Build token ID map (this maps new tokens → original tokens)
token_map = cloner.build_token_id_map()
# Check how a specific token is mapped
info = cloner.get_token_info("merhaba")
print(info)
# {'token': 'merhaba', 'target_id': 1234, 'source_ids': [567, 890], 'source_tokens': ['mer', 'haba']}
# Clone with your chosen embedding strategy
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)
model.save_pretrained("./turkish-gemma")
2. Model Architecture Pruning
Create smaller models by reducing dimensions:
from transformer_cloner import TransformerCloner, PruningConfig, EmbeddingStrategy
cloner = TransformerCloner(
org_model_id="google/gemma-3-270m-it",
target_tokenizer_id="alibayram/gemma3-tr-v64k",
)
# Define pruning configuration
pruning_config = PruningConfig(
hidden_size=320, # Original: 640
num_hidden_layers=9, # Original: 18
intermediate_size=1024, # Original: 2048
num_attention_heads=2, # Original: 4
num_key_value_heads=1, # Keep GQA ratio
head_dim=128, # Original: 256
)
# Validate before cloning (optional)
errors = pruning_config.validate(cloner.org_model.config)
if errors:
print("Validation errors:", errors)
else:
# Clone with pruning
model = cloner.clone_pruned(
pruning_config=pruning_config,
strategy=EmbeddingStrategy.MEAN,
)
model.save_pretrained("./gemma-small")
3. Vocabulary Pruning (Smaller Embedding Table)
Keep only specific tokens in the embedding table:
from transformer_cloner import TransformerCloner
# Use the same tokenizer for both
cloner = TransformerCloner(
org_model_id="google/gemma-3-270m-it",
target_tokenizer_id="google/gemma-3-270m-it",
)
# Option A: Keep first N tokens
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(vocab_size=16000)
# Option B: Keep specific token IDs
important_tokens = [0, 1, 2, 100, 200, 500, ...] # Your list
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(keep_token_ids=important_tokens)
# Save model (tokenizer is unchanged)
model.save_pretrained("./vocab-pruned")
# Use id_mapping to convert token IDs: old_id -> new_embedding_index
print(id_mapping) # {0: 0, 1: 1, 2: 2, 100: 3, ...}
Note: The original tokenizer is returned unchanged because modifying SentencePiece/BPE vocabularies breaks them. Use
id_mappingto convert token IDs to embedding indices.
4. Combined Pruning (Vocabulary + Architecture)
from transformer_cloner import TransformerCloner, PruningConfig
cloner = TransformerCloner(
org_model_id="google/gemma-3-270m-it",
target_tokenizer_id="google/gemma-3-270m-it",
)
# Combine vocabulary and architecture pruning
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(
vocab_size=16000,
pruning_config=PruningConfig(
hidden_size=320,
num_hidden_layers=9,
),
)
🎯 Embedding Strategies
When a target token maps to multiple source tokens, choose how to combine their embeddings:
| Strategy | Description | Use Case |
|---|---|---|
MEAN |
Average of all embeddings | Default, balanced representation |
SUM |
Sum of all embeddings | Preserve total magnitude |
FIRST |
First token's embedding | Prefix-focused tokens |
LAST |
Last token's embedding | Suffix-focused tokens |
WEIGHTED |
Weighted average (first tokens weighted more) | Morphological priority |
MAX |
Element-wise maximum | Preserve dominant features |
MIN |
Element-wise minimum | Preserve minimal features |
from transformer_cloner import EmbeddingStrategy
# Use different strategies
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)
model = cloner.clone(strategy=EmbeddingStrategy.WEIGHTED)
model = cloner.clone(strategy=EmbeddingStrategy.FIRST)
⚙️ Pruning Configuration
| Parameter | Description | Example |
|---|---|---|
hidden_size |
Embedding dimension | 768 → 512 |
num_hidden_layers |
Number of transformer layers | 12 → 6 |
intermediate_size |
FFN intermediate dimension | 3072 → 1536 |
num_attention_heads |
Number of attention heads | 12 → 8 |
num_key_value_heads |
Number of KV heads (for GQA) | 4 → 2 |
head_dim |
Dimension per attention head | 64 → 32 |
Validation
The library automatically validates your pruning config:
pruning_config = PruningConfig(hidden_size=1000, num_attention_heads=7)
errors = pruning_config.validate(cloner.org_model.config)
# ['hidden_size (1000) must be divisible by num_attention_heads (7)']
Validation checks:
- ✅ Dimensions don't exceed original model
- ✅ All values are positive
- ✅
num_attention_headsdivisible bynum_key_value_heads - ✅
hidden_sizecompatible with attention heads
📊 Example: Creating a Half-Size Model
Original Gemma-3-270M architecture:
hidden_size: 640
num_hidden_layers: 18
intermediate_size: 2048
num_attention_heads: 4
num_key_value_heads: 1
head_dim: 256
Half-size configuration:
pruning_config = PruningConfig(
hidden_size=320, # 640 / 2
num_hidden_layers=9, # 18 / 2
intermediate_size=1024, # 2048 / 2
num_attention_heads=2, # 4 / 2
num_key_value_heads=1, # Keep 1
head_dim=128, # 256 / 2
)
Result: ~1/8 of original parameters!
🔧 API Reference
TransformerCloner
cloner = TransformerCloner(
org_model_id: str, # HuggingFace model ID or local path
target_tokenizer_id: str, # HuggingFace tokenizer ID or local path
)
Methods:
| Method | Returns | Description |
|---|---|---|
build_token_id_map() |
dict[int, list[int]] |
Map target tokens to source tokens |
clone(strategy) |
AutoModelForCausalLM |
Clone with new tokenizer |
clone_with_lm_head(strategy) |
AutoModelForCausalLM |
Clone including lm_head |
clone_pruned(pruning_config, strategy) |
AutoModelForCausalLM |
Clone with architecture pruning |
clone_with_vocab_pruning(...) |
(model, tokenizer) |
Clone with vocabulary reduction |
create_pruned_tokenizer(keep_token_ids) |
(tokenizer, id_mapping) |
Create smaller tokenizer |
get_token_info(token) |
dict |
Debug token mapping |
print_vocab_samples(n) |
None |
Print vocabulary samples |
PruningConfig
@dataclass
class PruningConfig:
hidden_size: Optional[int] = None
num_hidden_layers: Optional[int] = None
intermediate_size: Optional[int] = None
num_attention_heads: Optional[int] = None
num_key_value_heads: Optional[int] = None
head_dim: Optional[int] = None
def validate(self, original_config) -> list[str]:
"""Returns list of validation errors (empty if valid)"""
EmbeddingStrategy
class EmbeddingStrategy(Enum):
MEAN = "mean"
SUM = "sum"
FIRST = "first"
LAST = "last"
WEIGHTED = "weighted"
MAX = "max"
MIN = "min"
🤝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
📄 License
MIT License - see LICENSE for details.
🙏 Acknowledgments
- Built on top of 🤗 Transformers
- Inspired by vocabulary adaptation research in multilingual NLP
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file transformer_cloner-0.1.3.tar.gz.
File metadata
- Download URL: transformer_cloner-0.1.3.tar.gz
- Upload date:
- Size: 15.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
68d0c2e36032faccc7fadc63a18284554eec9b82a30d4d49ad15df4a3b260230
|
|
| MD5 |
0bbb47cf07a45d8d26e9fe626eb657f4
|
|
| BLAKE2b-256 |
52fce96813eec2b2eb78d9828a9a29ac5728d74f9976733366c3f92b57af962c
|
File details
Details for the file transformer_cloner-0.1.3-py3-none-any.whl.
File metadata
- Download URL: transformer_cloner-0.1.3-py3-none-any.whl
- Upload date:
- Size: 14.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
23ab09d9c511951600a662e108b84740f8cec944153fd53d7a066440c282df44
|
|
| MD5 |
e8a1cecb4e4c5d4e971dd76989d0735b
|
|
| BLAKE2b-256 |
e075de7ae13a869cbcfb1ad0240df8314cf7ae2ca5551561fc0d2d0a22716d77
|