Skip to main content

example community package

Project description

Swarmauri Logo

PyPI - Downloads Hits PyPI - Python Version PyPI - License PyPI - swarmauri_embedding_mlm


Swarmauri Embedding MLM

Trainable embedding provider that fine-tunes a Hugging Face masked language model (MLM) end-to-end so Swarmauri agents can produce contextual document vectors without leaving the framework.

Features

  • Wraps any Hugging Face masked language model (embedding_name) behind the Swarmauri EmbeddingBase interface.
  • Supports optional vocabulary expansion via add_new_tokens before fine-tuning to capture domain-specific terminology.
  • Handles end-to-end fine-tuning with masking, AdamW optimization, and GPU/CPU selection based on availability.
  • Exposes pooling utilities (transform, infer_vector) that average the last hidden state to yield dense vectors ready for downstream retrieval or clustering.
  • Provides save_model/load_model helpers so trained weights and tokenizers can be persisted and reloaded across workers.

Prerequisites

  • Python 3.10 or newer.
  • PyTorch with CUDA support if you plan to train on GPU (the class falls back to CPU automatically).
  • Access to the Hugging Face model hub for downloading embedding_name. Set HF_HOME, proxies, or tokens if your environment requires authentication.
  • Enough disk space to cache the chosen MLM (e.g., bert-base-uncased ~420 MB).

Installation

# pip
pip install swarmauri_embedding_mlm

# poetry
poetry add swarmauri_embedding_mlm

# uv (pyproject-based projects)
uv add swarmauri_embedding_mlm

Quickstart: Fine-tune and Embed Documents

from swarmauri_embedding_mlm import MlmEmbedding

docs = [
    "Swarmauri SDK ships modular agents.",
    "Masked language models produce contextual embeddings.",
]

embedder = MlmEmbedding(
    embedding_name="distilbert-base-uncased",
    batch_size=16,
    learning_rate=3e-5,
)

# One epoch of MLM fine-tuning on your corpus
embedder.fit(docs)

# Generate vectors for downstream tasks
vectors = embedder.transform([
    "Agents coordinate through shared memory",
    "Fine-tuning improves domain recall",
])

for v in vectors:
    print(len(v.value), v.value[:4])  # dimension and preview

# Single-text inference helper
query_vector = embedder.infer_vector("How do masked models compute embeddings?")

Expanding the Vocabulary

Set add_new_tokens=True to capture domain-specific terms before training. New tokens are identified via simple whitespace tokenization and appended to the tokenizer before the first epoch.

from swarmauri_embedding_mlm import MlmEmbedding

domain_docs = [
    "Neo4j graph embeddings power fraud detection",
    "Qdrant supports hybrid sparse-dense search",
]

embedder = MlmEmbedding(add_new_tokens=True)
embedder.fit(domain_docs)

# Inspect the tokenizer to confirm additions
print(f"Vocabulary size: {len(embedder.extract_features())}")

Persisting and Reloading Models

from pathlib import Path
from swarmauri_embedding_mlm import MlmEmbedding

save_dir = Path("models/mlm-distilbert")

embedder = MlmEmbedding()
embedder.fit(["short corpus", "to warm up the model"])
embedder.save_model(save_dir.as_posix())

# Later or on another machine
restored = MlmEmbedding()
restored.load_model(save_dir.as_posix())

embedding = restored.infer_vector("Reuse the trained weights instantly")

Operational Tips

  • Batch and sequence length drive GPU memory usage; reduce batch_size or max_length in tokenizer calls when running on constrained hardware.
  • fit_transform runs a full fine-tuning pass and immediately returns embeddings—useful for one-off adaptation jobs.
  • When training on large corpora, stream documents from a generator, chunk them, or wrap the .fit call in your own epoch loop.
  • Run extract_features() to audit the tokenizer vocabulary (helpful when debugging domain token coverage).
  • Combine the generated vectors with Swarmauri vector stores (Redis, Qdrant, etc.) to build end-to-end retrieval pipelines.

Want to help?

If you want to contribute to swarmauri-sdk, read up on our guidelines for contributing that will help you get started.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

swarmauri_embedding_mlm-0.8.2.dev3.tar.gz (9.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

swarmauri_embedding_mlm-0.8.2.dev3-py3-none-any.whl (11.0 kB view details)

Uploaded Python 3

File details

Details for the file swarmauri_embedding_mlm-0.8.2.dev3.tar.gz.

File metadata

  • Download URL: swarmauri_embedding_mlm-0.8.2.dev3.tar.gz
  • Upload date:
  • Size: 9.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.3 {"installer":{"name":"uv","version":"0.10.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for swarmauri_embedding_mlm-0.8.2.dev3.tar.gz
Algorithm Hash digest
SHA256 0ef9fb06142b6b67fb6bca432b19afc1f440133f1f56733f7593f49e453c7510
MD5 d12a70a1a4004d0bbaa1a2a521d26214
BLAKE2b-256 805d2c30869f454691f03ec5363583d5c483b8bde92c94ec1ee4836c1fd6af80

See more details on using hashes here.

File details

Details for the file swarmauri_embedding_mlm-0.8.2.dev3-py3-none-any.whl.

File metadata

  • Download URL: swarmauri_embedding_mlm-0.8.2.dev3-py3-none-any.whl
  • Upload date:
  • Size: 11.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.3 {"installer":{"name":"uv","version":"0.10.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for swarmauri_embedding_mlm-0.8.2.dev3-py3-none-any.whl
Algorithm Hash digest
SHA256 e7550c2df2bcbc9ae11e35e4cb6e8ba974b552510f240e3e14dc87ae70fd3db5
MD5 61c30cf873f47421b6e87d6bd2c8b6d5
BLAKE2b-256 9b3e2b96bf79f45ce84fc3c6e2a867e1275ddca8121cec199f6b8c471c95df5c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page