Skip to main content

LLM classification toolkit

Project description

from future import annotations import os, re, logging from typing import Iterable, Union, Dict, List, Tuple import numpy as np import torch, pandas as pd from tqdm import tqdm from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer

LOGGER = logging.getLogger(name) all = ["load_model_and_tokenizer", "embed_dataset", "embed"]

try PEFT lazily so base use still works without the dependency

try: from peft import PeftModel, PeftConfig _PEFT_AVAILABLE = True except Exception: PeftModel = None # type: ignore PeftConfig = None # type: ignore _PEFT_AVAILABLE = False

skip pure punctuation/symbol tokens

_SKIP_PUNCT_RE = re.compile(r"^[^A-Za-z0-9]+$")

skip pandas duplicate columns: optional non-alnum → dot → digits

_SKIP_DUP_RE = re.compile(r"^[^A-Za-z0-9]*.[0-9]+$")

helpers

def _is_peft_model(m: torch.nn.Module) -> bool: """Return True if m looks like a PEFT-wrapped model.""" return _PEFT_AVAILABLE and isinstance(m, PeftModel) # type: ignore

def _unwrap(model: torch.nn.Module) -> torch.nn.Module: """ Return the underlying base model if wrapped in DataParallel/DDP/PEFT. """ # DataParallel/DDP first if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): model = model.module

# Unwrap PEFT to the base HF model
if _is_peft_model(model):
    get_base = getattr(model, "get_base_model", None)
    if callable(get_base):
        model = get_base()
    elif hasattr(model, "base_model"):
        model = model.base_model  # type: ignore

return model

def _primary_device(model: torch.nn.Module) -> torch.device: """Device where inputs should live (first param/buffer → safe).""" m = _unwrap(model) for t in m.parameters(recurse=True): return t.device for t in m.buffers(recurse=True): return t.device return torch.device("cpu")

def _get_blocks(model: torch.nn.Module): """ Locate transformer blocks across common HF architectures. Works whether the incoming model is PEFT-wrapped or not. """ m = _unwrap(model) if hasattr(m, "layers"): return m.layers if hasattr(m, "model") and hasattr(m.model, "layers"): return m.model.layers if hasattr(m, "gpt_neox") and hasattr(m.gpt_neox, "layers"): return m.gpt_neox.layers if hasattr(m, "encoder") and hasattr(m.encoder, "layer"): return m.encoder.layer if hasattr(m, "decoder") and hasattr(m.decoder, "layers"): return m.decoder.layers raise AttributeError("Cannot locate transformer layers on the (unwrapped) model.")

def _get_attn_and_mlp(block: torch.nn.Module) -> Tuple[torch.nn.Module, torch.nn.Module]: """Return (attention, mlp) sub-modules regardless of naming.""" attn_names = ("self_attn", "attn", "attention", "self_attention") mlp_names = ("mlp", "ffn", "feed_forward", "feedforward") attn = next((getattr(block, n) for n in attn_names if hasattr(block, n)), None) mlp = next((getattr(block, n) for n in mlp_names if hasattr(block, n)), None) if attn is None or mlp is None: raise AttributeError(f"Cannot find attn/mlp in {type(block).name}") return attn, mlp

load

def load_model_and_tokenizer( model_name: str = "Qwen/Qwen3-0.6B", *, quantization: str | None = None, # "4bit" | "8bit" | None device_map: str | None = "auto", lora_adapter: str | None = None, merge_lora: bool = False, attn_implementation: str | None = None # e.g., "eager" for Gemma2 ): LOGGER.info("Loading model %s", model_name)

# Resolve base/task from adapter if provided
if lora_adapter is not None:
    if not _PEFT_AVAILABLE:
        raise RuntimeError(
            "lora_adapter was provided but 'peft' is not installed. "
            "Install with: pip install peft"
        )
    peft_cfg = PeftConfig.from_pretrained(lora_adapter)
    base_name = peft_cfg.base_model_name_or_path or model_name
    task_type = str(peft_cfg.task_type).lower()
    if attn_implementation is None:
        attn_implementation = "eager"   # Gemma2 safe default
else:
    base_name = model_name
    task_type = "causal_lm"  # sensible default for LLMs

tokenizer = AutoTokenizer.from_pretrained(base_name)
if getattr(tokenizer, "pad_token", None) is None and hasattr(tokenizer, "eos_token"):
    tokenizer.pad_token = tokenizer.eos_token

use_causal = (lora_adapter is not None) or ("causal" in task_type)
loader = AutoModelForCausalLM if use_causal else AutoModel

# Build kwargs selectively
kw = {}
if device_map is not None:
    kw["device_map"] = device_map
if attn_implementation is not None:
    kw["attn_implementation"] = attn_implementation
if quantization == "4bit":
    kw["load_in_4bit"] = True
elif quantization == "8bit":
    kw["load_in_8bit"] = True

model = loader.from_pretrained(base_name, **kw)

# Make sure hidden states are returned
if getattr(model.config, "output_hidden_states", False) is not True:
    model.config.output_hidden_states = True

# Attach LoRA if provided
if lora_adapter is not None:
    model = PeftModel.from_pretrained(model, lora_adapter)
    if merge_lora and hasattr(model, "merge_and_unload"):
        model = model.merge_and_unload()

model.eval()
return model, tokenizer

hooks

def _register_block_hooks(model) -> tuple[ List[torch.utils.hooks.RemovableHandle], Dict[int, torch.Tensor], Dict[int, torch.Tensor], ]: """Attach hooks capturing attention & MLP outputs.""" attn_cache: Dict[int, torch.Tensor] = {} mlp_cache: Dict[int, torch.Tensor] = {} handles: List[torch.utils.hooks.RemovableHandle] = []

for idx, block in enumerate(_get_blocks(model)):
    attn_sub, mlp_sub = _get_attn_and_mlp(block)

    def make_hook(cache, i):
        def hook(_mod, _inp, out):
            t = out[0] if isinstance(out, tuple) else out
            cache[i] = t.detach().cpu().squeeze(0)
        return hook

    handles.append(attn_sub.register_forward_hook(make_hook(attn_cache, idx)))
    handles.append(mlp_sub .register_forward_hook(make_hook(mlp_cache , idx)))

return handles, attn_cache, mlp_cache

token helpers

def _find_target_token(tokens: List[str]) -> tuple[int, str]: """Last meaningful token (skip eos, punct, pandas dup suffix).""" for i in range(len(tokens) - 1, -1, -1): tok = tokens[i] if "end_of_turn" in tok or _SKIP_PUNCT_RE.match(tok) or _SKIP_DUP_RE.match(tok): continue return i, tok return len(tokens) - 1, tokens[-1]

def _pooled(t: torch.Tensor, idxs: List[int], method: str) -> torch.Tensor: """ Return a 1-D pooled vector from a 2-D [seq_len, hidden] tensor. idxs are indices of valid tokens (already filtered). """ if not idxs: # degenerate: fall back to all idxs = list(range(t.size(0)))

if method == "last":
    return t[idxs[-1]]
elif method == "first":
    return t[idxs[0]]
elif method == "mean":
    return t[idxs].mean(dim=0)
elif method == "max":
    return t[idxs].max(dim=0).values
else:
    raise ValueError(f"Unknown pooling method '{method}'")

def _pool_key(method: str, tokens: List[str], valid_idxs: List[int]) -> str: """ Column / dict key to use for a given pooling method. • "last" → last valid token string • "first" → first valid token string • others → method name itself """ if method == "last": return tokens[valid_idxs[-1]] if method == "first": return tokens[valid_idxs[0]] return method

main functions

def embed_dataset( data: Union[pd.DataFrame, str, Iterable[str]], *, input_col: str | None = None, model_name: str = "Qwen/Qwen3-0.6B", output_dir: str = "embeddings", layers: List[int] | None = None, parts: List[str] | None = None, pooling: Union[str, List[str]] = "last", eos_token: str | None = None, device: str | None = "auto", filter_non_text: bool = False, # LoRA lora_adapter: str | None = None, merge_lora: bool = False, # NEW: quantization passthrough quantization: str | None = None, # "4bit" | "8bit" | None ): """ Extract embeddings with multiple pooling strategies and save CSVs. If filter_non_text is True, skip pure-punct/symbol tokens, pandas-duplicate suffixes, and any token containing the provided eos_token string. Otherwise include all tokens.

LoRA:
- Provide `lora_adapter` (HF repo id or local path) to load and apply an adapter for inference.
- Set `merge_lora=True` to bake adapters into the base weights (optional).

Quantization:
- Set `quantization` to "4bit" or "8bit" to load quantized base weights for low-VRAM inference.
"""
# ------------- load / setup -------------
if isinstance(data, (str, os.PathLike)):
    df = pd.read_csv(data)
elif isinstance(data, pd.DataFrame):
    df = data.copy()
else:
    df = pd.DataFrame({"__input__": list(data)})
    input_col = "__input__"
if input_col is None:
    raise ValueError("input_col required when data is DataFrame")

if isinstance(pooling, str):
    pooling = [pooling]

model, tokenizer = load_model_and_tokenizer(
    model_name,
    device_map=device,
    lora_adapter=lora_adapter,
    merge_lora=merge_lora,
    quantization=quantization,  # <-- pass through
)
n_layers = len(_get_blocks(model))
layers   = layers or list(range(n_layers))
parts    = parts  or ["rs", "attn", "mlp"]

handles, attn_cache, mlp_cache = _register_block_hooks(model)

try:
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Embedding"):
        prompt = row[input_col]
        enc    = tokenizer(prompt, return_tensors="pt")
        dev    = _primary_device(model)
        enc    = {k: v.to(dev) for k, v in enc.items()}

        attn_cache.clear(); mlp_cache.clear()
        with torch.no_grad():
            out = model(**enc, output_hidden_states=True)

        tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0].tolist())

        if filter_non_text:
            # guard eos_token is None
            valid_idxs = [
                i for i, tok in enumerate(tokens)
                if not _SKIP_PUNCT_RE.match(tok)
                   and not _SKIP_DUP_RE.match(tok)
                   and (eos_token is None or eos_token not in tok)
            ]
        else:
            valid_idxs = list(range(len(tokens)))

        row_id = row.get("id", f"row_{idx}")

        for layer in layers:
            if not (0 <= layer < n_layers):
                raise ValueError(f"Layer {layer} out of 0-{n_layers-1}")

            seq_rs   = out.hidden_states[layer + 1][0].cpu()
            seq_attn = attn_cache[layer].cpu()
            seq_mlp  = mlp_cache[layer].cpu()

            for method in pooling:
                key_name = _pool_key(method, tokens, valid_idxs)
                vecs = {
                    "rs":   _pooled(seq_rs,   valid_idxs, method),
                    "attn": _pooled(seq_attn, valid_idxs, method),
                    "mlp":  _pooled(seq_mlp,  valid_idxs, method),
                }
                for part in parts:
                    path_dir = os.path.join(output_dir, part)
                    os.makedirs(path_dir, exist_ok=True)
                    path = os.path.join(path_dir, f"{row_id}_L{layer:02d}.csv")
                    pd.DataFrame(
                        vecs[part].numpy().reshape(-1, 1),
                        columns=[key_name]
                    ).to_csv(path, index=False)
finally:
    for h in handles: h.remove()

def embed( inputs: Union[str, List[str]], *, model_name: str = "Qwen/Qwen3-0.6B", layers: List[int] | None = None, parts: List[str] | None = None, pooling: Union[str, List[str]] = "last", eos_token: str | None = None, device: str | None = "auto", filter_non_text: bool = False, # LoRA lora_adapter: str | None = None, merge_lora: bool = False, # NEW: quantization passthrough quantization: str | None = None, # "4bit" | "8bit" | None ) -> Dict[str, Dict[int, Dict[str, Dict[str, np.ndarray]]]]: """ Return embeddings with multiple pooling strategies.

If filter_non_text is True, skip punct/symbol tokens, pandas-duplicate suffixes,
and any tokens containing eos_token; otherwise include all tokens.

LoRA:
- Provide `lora_adapter` (HF repo id or local path) to load and apply an adapter for inference.
- Set `merge_lora=True` to bake adapters into the base weights (optional).

Quantization:
- Set `quantization` to "4bit" or "8bit" to load quantized base weights for low-VRAM inference.
"""
if isinstance(inputs, str):
    inputs = [inputs]
if isinstance(pooling, str):
    pooling = [pooling]

model, tokenizer = load_model_and_tokenizer(
    model_name,
    device_map=device,
    lora_adapter=lora_adapter,
    merge_lora=merge_lora,
    quantization=quantization,  # <-- pass through
)
n_layers = len(_get_blocks(model))
layers   = layers or list(range(n_layers))
parts    = parts  or ["rs", "attn", "mlp"]

handles, attn_cache, mlp_cache = _register_block_hooks(model)
embeddings: Dict[str, Dict[int, Dict[str, Dict[str, np.ndarray]]]] = {}

try:
    for s in tqdm(inputs, desc="Embedding"):
        prompt = s
        enc    = tokenizer(prompt, return_tensors="pt")
        dev    = _primary_device(model)
        enc    = {k: v.to(dev) for k, v in enc.items()}

        attn_cache.clear(); mlp_cache.clear()
        with torch.no_grad():
            out = model(**enc, output_hidden_states=True)

        tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0].tolist())

        if filter_non_text:
            valid_idxs = [
                i for i, tok in enumerate(tokens)
                if not _SKIP_PUNCT_RE.match(tok)
                   and not _SKIP_DUP_RE.match(tok)
                   and (eos_token is None or eos_token not in tok)
            ]
        else:
            valid_idxs = list(range(len(tokens)))

        embeddings[s] = {}
        for layer in layers:
            seq_rs   = out.hidden_states[layer + 1][0].cpu()
            seq_attn = attn_cache[layer].cpu()
            seq_mlp  = mlp_cache[layer].cpu()

            part_pool: Dict[str, Dict[str, np.ndarray]] = {p: {} for p in parts}
            for method in pooling:
                key_name = _pool_key(method, tokens, valid_idxs)
                part_pool["rs"][key_name]   = _pooled(seq_rs,   valid_idxs, method).numpy()
                part_pool["attn"][key_name] = _pooled(seq_attn, valid_idxs, method).numpy()
                part_pool["mlp"][key_name]  = _pooled(seq_mlp,  valid_idxs, method).numpy()
            embeddings[s][layer] = part_pool
finally:
    for h in handles: h.remove()

return embeddings

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pingkit-0.4.8.tar.gz (27.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pingkit-0.4.8-py3-none-any.whl (29.5 kB view details)

Uploaded Python 3

File details

Details for the file pingkit-0.4.8.tar.gz.

File metadata

  • Download URL: pingkit-0.4.8.tar.gz
  • Upload date:
  • Size: 27.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for pingkit-0.4.8.tar.gz
Algorithm Hash digest
SHA256 53c84136c5579b5606d56534c9e49c820755da5d703cab783e5d22907d740d5a
MD5 49fbdfa2f42e9423b1a58606e900143f
BLAKE2b-256 0fa0c410cf3ab11738dc3b0be017bb9e39a507c42782463d73849e3e06297550

See more details on using hashes here.

File details

Details for the file pingkit-0.4.8-py3-none-any.whl.

File metadata

  • Download URL: pingkit-0.4.8-py3-none-any.whl
  • Upload date:
  • Size: 29.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for pingkit-0.4.8-py3-none-any.whl
Algorithm Hash digest
SHA256 8092a6349c194a5ce260dbeb63ddbf16f3f3671445bceaa75a9f77de535fee04
MD5 4ad62b53513c785d8e50c04a0869731c
BLAKE2b-256 6dd02ca11fa84b0f7175e9c2fe4566435ffe599c49631cb3e6d2f2bd830bf324

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page