Skip to main content

Clone and prune transformer models with new tokenizers

Project description

🔄 Transformer Cloner

PyPI version Python 3.10+ License: MIT

Clone and prune transformer models with new tokenizers. Create smaller, more efficient models by mapping vocabularies, reducing dimensions, and pruning layers.

Perfect for:

  • 🌍 Language adaptation: Use a custom tokenizer optimized for your language
  • 📉 Model compression: Create smaller models for edge deployment
  • 🎓 Knowledge distillation: Generate student models from teacher models
  • 🔬 Research: Experiment with different model architectures

📦 Installation

pip install transformer-cloner

Requirements:

  • Python 3.10+
  • PyTorch 2.0+
  • Transformers 4.40+

🚀 Quick Start

Basic Cloning with New Tokenizer

from transformer_cloner import TransformerCloner, EmbeddingStrategy

# Initialize with original model and target tokenizer
cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="your-username/custom-tokenizer",
)

# Clone the model
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)

# Save the cloned model
model.save_pretrained("./cloned-model")

📖 Detailed Usage

1. Vocabulary Mapping (New Tokenizer)

When you have a new tokenizer with different vocabulary, the cloner maps each new token to the original model's embeddings:

from transformer_cloner import TransformerCloner, EmbeddingStrategy

cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="alibayram/gemma3-tr-v64k",  # Turkish tokenizer
)

# View vocabulary samples
cloner.print_vocab_samples(n=10)

# Build token ID map (this maps new tokens → original tokens)
token_map = cloner.build_token_id_map()

# Check how a specific token is mapped
info = cloner.get_token_info("merhaba")
print(info)
# {'token': 'merhaba', 'target_id': 1234, 'source_ids': [567, 890], 'source_tokens': ['mer', 'haba']}

# Clone with your chosen embedding strategy
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)
model.save_pretrained("./turkish-gemma")

2. Model Architecture Pruning

Create smaller models by reducing dimensions:

from transformer_cloner import TransformerCloner, PruningConfig, EmbeddingStrategy

cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="alibayram/gemma3-tr-v64k",
)

# Define pruning configuration
pruning_config = PruningConfig(
    hidden_size=320,           # Original: 640
    num_hidden_layers=9,       # Original: 18
    intermediate_size=1024,    # Original: 2048
    num_attention_heads=2,     # Original: 4
    num_key_value_heads=1,     # Keep GQA ratio
    head_dim=128,              # Original: 256
)

# Validate before cloning (optional)
errors = pruning_config.validate(cloner.org_model.config)
if errors:
    print("Validation errors:", errors)
else:
    # Clone with pruning
    model = cloner.clone_pruned(
        pruning_config=pruning_config,
        strategy=EmbeddingStrategy.MEAN,
    )
    model.save_pretrained("./gemma-small")

3. Vocabulary Pruning (Smaller Embedding Table)

Keep only specific tokens in the embedding table:

from transformer_cloner import TransformerCloner

# Use the same tokenizer for both
cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="google/gemma-3-270m-it",
)

# Option A: Keep first N tokens
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(vocab_size=16000)

# Option B: Keep specific token IDs
important_tokens = [0, 1, 2, 100, 200, 500, ...]  # Your list
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(keep_token_ids=important_tokens)

# Save model (tokenizer is unchanged)
model.save_pretrained("./vocab-pruned")

# Use id_mapping to convert token IDs: old_id -> new_embedding_index
print(id_mapping)  # {0: 0, 1: 1, 2: 2, 100: 3, ...}

Note: The original tokenizer is returned unchanged because modifying SentencePiece/BPE vocabularies breaks them. Use id_mapping to convert token IDs to embedding indices.

4. Combined Pruning (Vocabulary + Architecture)

from transformer_cloner import TransformerCloner, PruningConfig

cloner = TransformerCloner(
    org_model_id="google/gemma-3-270m-it",
    target_tokenizer_id="google/gemma-3-270m-it",
)

# Combine vocabulary and architecture pruning
model, tokenizer, id_mapping = cloner.clone_with_vocab_pruning(
    vocab_size=16000,
    pruning_config=PruningConfig(
        hidden_size=320,
        num_hidden_layers=9,
    ),
)

🎯 Embedding Strategies

When a target token maps to multiple source tokens, choose how to combine their embeddings:

Strategy Description Use Case
MEAN Average of all embeddings Default, balanced representation
SUM Sum of all embeddings Preserve total magnitude
FIRST First token's embedding Prefix-focused tokens
LAST Last token's embedding Suffix-focused tokens
WEIGHTED Weighted average (first tokens weighted more) Morphological priority
MAX Element-wise maximum Preserve dominant features
MIN Element-wise minimum Preserve minimal features
from transformer_cloner import EmbeddingStrategy

# Use different strategies
model = cloner.clone(strategy=EmbeddingStrategy.MEAN)
model = cloner.clone(strategy=EmbeddingStrategy.WEIGHTED)
model = cloner.clone(strategy=EmbeddingStrategy.FIRST)

⚙️ Pruning Configuration

Parameter Description Example
hidden_size Embedding dimension 768 → 512
num_hidden_layers Number of transformer layers 12 → 6
intermediate_size FFN intermediate dimension 3072 → 1536
num_attention_heads Number of attention heads 12 → 8
num_key_value_heads Number of KV heads (for GQA) 4 → 2
head_dim Dimension per attention head 64 → 32

Validation

The library automatically validates your pruning config:

pruning_config = PruningConfig(hidden_size=1000, num_attention_heads=7)
errors = pruning_config.validate(cloner.org_model.config)
# ['hidden_size (1000) must be divisible by num_attention_heads (7)']

Validation checks:

  • ✅ Dimensions don't exceed original model
  • ✅ All values are positive
  • num_attention_heads divisible by num_key_value_heads
  • hidden_size compatible with attention heads

📊 Example: Creating a Half-Size Model

Original Gemma-3-270M architecture:

hidden_size: 640
num_hidden_layers: 18
intermediate_size: 2048
num_attention_heads: 4
num_key_value_heads: 1
head_dim: 256

Half-size configuration:

pruning_config = PruningConfig(
    hidden_size=320,          # 640 / 2
    num_hidden_layers=9,      # 18 / 2
    intermediate_size=1024,   # 2048 / 2
    num_attention_heads=2,    # 4 / 2
    num_key_value_heads=1,    # Keep 1
    head_dim=128,             # 256 / 2
)

Result: ~1/8 of original parameters!


🔧 API Reference

TransformerCloner

cloner = TransformerCloner(
    org_model_id: str,         # HuggingFace model ID or local path
    target_tokenizer_id: str,  # HuggingFace tokenizer ID or local path
)

Methods:

Method Returns Description
build_token_id_map() dict[int, list[int]] Map target tokens to source tokens
clone(strategy) AutoModelForCausalLM Clone with new tokenizer
clone_with_lm_head(strategy) AutoModelForCausalLM Clone including lm_head
clone_pruned(pruning_config, strategy) AutoModelForCausalLM Clone with architecture pruning
clone_with_vocab_pruning(...) (model, tokenizer) Clone with vocabulary reduction
create_pruned_tokenizer(keep_token_ids) (tokenizer, id_mapping) Create smaller tokenizer
get_token_info(token) dict Debug token mapping
print_vocab_samples(n) None Print vocabulary samples

PruningConfig

@dataclass
class PruningConfig:
    hidden_size: Optional[int] = None
    num_hidden_layers: Optional[int] = None
    intermediate_size: Optional[int] = None
    num_attention_heads: Optional[int] = None
    num_key_value_heads: Optional[int] = None
    head_dim: Optional[int] = None

    def validate(self, original_config) -> list[str]:
        """Returns list of validation errors (empty if valid)"""

EmbeddingStrategy

class EmbeddingStrategy(Enum):
    MEAN = "mean"
    SUM = "sum"
    FIRST = "first"
    LAST = "last"
    WEIGHTED = "weighted"
    MAX = "max"
    MIN = "min"

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.


📄 License

MIT License - see LICENSE for details.


🙏 Acknowledgments

  • Built on top of 🤗 Transformers
  • Inspired by vocabulary adaptation research in multilingual NLP

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer_cloner-0.1.2.tar.gz (12.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transformer_cloner-0.1.2-py3-none-any.whl (14.0 kB view details)

Uploaded Python 3

File details

Details for the file transformer_cloner-0.1.2.tar.gz.

File metadata

  • Download URL: transformer_cloner-0.1.2.tar.gz
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for transformer_cloner-0.1.2.tar.gz
Algorithm Hash digest
SHA256 9c385dc7858591abed502ffa5167507391007ebd30501e577e4aecd7d8c16a31
MD5 5d352b1e7e66c7ef54735698e1fefd1b
BLAKE2b-256 af896c77459e99c973498b6bf0dbb5ed314bdc122efdcabbfd17795ea19c65b0

See more details on using hashes here.

File details

Details for the file transformer_cloner-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for transformer_cloner-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6afcf5fbcc0dd3ac00b2aa0e4b21a2295c4b11dbf19c291a7a344610666e15a9
MD5 b00b3a37d7fa9cefb48f7ead766a5839
BLAKE2b-256 0895d61f4ff25363d024ce0732ef0dd1fd46984dc6c6038ed8f3ef8a0e439c56

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page