Skip to main content

Clone and prune transformer models with new tokenizers

Project description

🔄 Transformer Cloner

PyPI version Python 3.10+ License: MIT

Clone and prune transformer models with new tokenizers. Create smaller, more efficient models by mapping vocabularies, reducing dimensions, and pruning layers.

Perfect for:

  • 🌍 Language adaptation: Use a custom tokenizer optimized for your language
  • 📉 Model compression: Create smaller models for edge deployment
  • 🎓 Knowledge distillation: Generate student models from teacher models
  • 🔬 Research: Experiment with different model architectures

📦 Installation

pip install transformer-cloner

Requirements:

  • Python 3.10+
  • PyTorch 2.0+
  • Transformers 4.40+

🚀 Quick Start

Basic Cloning with New Tokenizer

from transformer_cloner import TransformerCloner, EmbeddingStrategy

# Initialize with original model and target tokenizer
cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="your-username/custom-tokenizer",
)

# Clone the model
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)

# Save the cloned model
model.save_pretrained("./cloned-model")

📖 Detailed Usage

1. Vocabulary Mapping (New Tokenizer)

When you have a new tokenizer with different vocabulary, the cloner maps each new token to the original model's embeddings:

from transformer_cloner import TransformerCloner, EmbeddingStrategy

cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="alibayram/gemma3-tr-v64k",  # Turkish tokenizer
)

# View vocabulary samples
cloner.print_vocab_samples(n=10)

# Build token ID map (this maps new tokens → original tokens)
token_map = cloner.build_token_id_map()

# Check how a specific token is mapped
info = cloner.get_token_info("merhaba")
print(info)
# {'token': 'merhaba', 'target_id': 1234, 'source_ids': [567, 890], 'source_tokens': ['mer', 'haba']}

# Clone with your chosen embedding strategy
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)
model.save_pretrained("./turkish-gemma")

2. Model Architecture Pruning

Create smaller models by reducing dimensions:

from transformer_cloner import TransformerCloner, PruningConfig, EmbeddingStrategy

cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="alibayram/gemma3-tr-v64k",
)

# Define pruning configuration
pruning_config = PruningConfig(
    hidden_size=320,           # Original: 640
    num_hidden_layers=9,       # Original: 18
    intermediate_size=1024,    # Original: 2048
    num_attention_heads=2,     # Original: 4
    num_key_value_heads=1,     # Keep GQA ratio
    head_dim=128,              # Original: 256
)

# Validate before cloning (optional)
errors = pruning_config.validate(cloner.org_model.config)
if errors:
    print("Validation errors:", errors)
else:
    # Clone with pruning
    model = cloner.clone_pruned(
        pruning_config=pruning_config,
        strategy=EmbeddingStrategy.MEAN,
    )
    model.save_pretrained("./gemma-small")

3. Vocabulary Pruning (Smaller Embedding Table)

Keep only specific tokens in the embedding table:

from transformer_cloner import TransformerCloner

# Use the same tokenizer for both
cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="google/gemma-3-270m-it",
)

# Option A: Keep first N tokens
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(vocab_size=16000)

# Option B: Keep specific token IDs
important_tokens = [0, 1, 2, 100, 200, 500, ...]  # Your list
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(keep_token_ids=important_tokens)

# Save model (tokenizer is unchanged)
model.save_pretrained("./vocab-pruned")

# Use id_mapping to convert token IDs: old_id -> new_embedding_index
print(id_mapping)  # {0: 0, 1: 1, 2: 2, 100: 3, ...}

Note: The original tokenizer is returned unchanged because modifying SentencePiece/BPE vocabularies breaks them. Use id_mapping to convert token IDs to embedding indices.

4. Combined Pruning (Vocabulary + Architecture)

from transformer_cloner import TransformerCloner, PruningConfig

cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="google/gemma-3-270m-it",
)

# Combine vocabulary and architecture pruning
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(
    vocab_size=16000,
    pruning_config=PruningConfig(
        hidden_size=320,
        num_hidden_layers=9,
    ),
)

🎯 Embedding Strategies

When a target token maps to multiple source tokens, choose how to combine their embeddings:

Strategy Description Use Case
MEAN Average of all embeddings Default, balanced representation
SUM Sum of all embeddings Preserve total magnitude
FIRST First token's embedding Prefix-focused tokens
LAST Last token's embedding Suffix-focused tokens
WEIGHTED Weighted average (first tokens weighted more) Morphological priority
MAX Element-wise maximum Preserve dominant features
MIN Element-wise minimum Preserve minimal features
from transformer_cloner import EmbeddingStrategy

# Use different strategies
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)
model = cloner.clone(strategy=EmbeddingStrategy.WEIGHTED)
model = cloner.clone(strategy=EmbeddingStrategy.FIRST)

⚙️ Pruning Configuration

Parameter Description Example
hidden_size Embedding dimension 768 → 512
num_hidden_layers Number of transformer layers 12 → 6
intermediate_size FFN intermediate dimension 3072 → 1536
num_attention_heads Number of attention heads 12 → 8
num_key_value_heads Number of KV heads (for GQA) 4 → 2
head_dim Dimension per attention head 64 → 32

Validation

The library automatically validates your pruning config:

pruning_config = PruningConfig(hidden_size=1000, num_attention_heads=7)
errors = pruning_config.validate(cloner.org_model.config)
# ['hidden_size (1000) must be divisible by num_attention_heads (7)']

Validation checks:

  • ✅ Dimensions don't exceed original model
  • ✅ All values are positive
  • num_attention_heads divisible by num_key_value_heads
  • hidden_size compatible with attention heads

📊 Example: Creating a Half-Size Model

Original Gemma-3-270M architecture:

hidden_size: 640
num_hidden_layers: 18
intermediate_size: 2048
num_attention_heads: 4
num_key_value_heads: 1
head_dim: 256

Half-size configuration:

pruning_config = PruningConfig(
    hidden_size=320,          # 640 / 2
    num_hidden_layers=9,      # 18 / 2
    intermediate_size=1024,   # 2048 / 2
    num_attention_heads=2,    # 4 / 2
    num_key_value_heads=1,    # Keep 1
    head_dim=128,             # 256 / 2
)

Result: ~1/8 of original parameters!


🔧 API Reference

TransformerCloner

cloner = TransformerCloner(
    org_model_id: str,         # HuggingFace model ID or local path
    target_tokenizer_id: str,  # HuggingFace tokenizer ID or local path
)

Methods:

Method Returns Description
build_token_id_map() dict[int, list[int]] Map target tokens to source tokens
clone(strategy) AutoModelForCausalLM Clone with new tokenizer
clone_with_lm_head(strategy) AutoModelForCausalLM Clone including lm_head
clone_pruned(pruning_config, strategy) AutoModelForCausalLM Clone with architecture pruning
clone_with_vocab_pruning(...) (model, tokenizer) Clone with vocabulary reduction
create_pruned_tokenizer(keep_token_ids) (tokenizer, id_mapping) Create smaller tokenizer
get_token_info(token) dict Debug token mapping
print_vocab_samples(n) None Print vocabulary samples

PruningConfig

@dataclass
class PruningConfig:
    hidden_size: Optional[int] = None
    num_hidden_layers: Optional[int] = None
    intermediate_size: Optional[int] = None
    num_attention_heads: Optional[int] = None
    num_key_value_heads: Optional[int] = None
    head_dim: Optional[int] = None

    def validate(self, original_config) -> list[str]:
        """Returns list of validation errors (empty if valid)"""

EmbeddingStrategy

class EmbeddingStrategy(Enum):
    MEAN = "mean"
    SUM = "sum"
    FIRST = "first"
    LAST = "last"
    WEIGHTED = "weighted"
    MAX = "max"
    MIN = "min"

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.


📄 License

MIT License - see LICENSE for details.


🙏 Acknowledgments

  • Built on top of 🤗 Transformers
  • Inspired by vocabulary adaptation research in multilingual NLP

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer_cloner-0.1.3.tar.gz (15.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transformer_cloner-0.1.3-py3-none-any.whl (14.0 kB view details)

Uploaded Python 3

File details

Details for the file transformer_cloner-0.1.3.tar.gz.

File metadata

  • Download URL: transformer_cloner-0.1.3.tar.gz
  • Upload date:
  • Size: 15.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for transformer_cloner-0.1.3.tar.gz
Algorithm Hash digest
SHA256 68d0c2e36032faccc7fadc63a18284554eec9b82a30d4d49ad15df4a3b260230
MD5 0bbb47cf07a45d8d26e9fe626eb657f4
BLAKE2b-256 52fce96813eec2b2eb78d9828a9a29ac5728d74f9976733366c3f92b57af962c

See more details on using hashes here.

File details

Details for the file transformer_cloner-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for transformer_cloner-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 23ab09d9c511951600a662e108b84740f8cec944153fd53d7a066440c282df44
MD5 e8a1cecb4e4c5d4e971dd76989d0735b
BLAKE2b-256 e075de7ae13a869cbcfb1ad0240df8314cf7ae2ca5551561fc0d2d0a22716d77

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page