LLM classification toolkit
Project description
from future import annotations import os, re, logging from typing import Iterable, Union, Dict, List, Tuple import numpy as np import torch, pandas as pd from tqdm import tqdm from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
LOGGER = logging.getLogger(name) all = ["load_model_and_tokenizer", "embed_dataset", "embed"]
try PEFT lazily so base use still works without the dependency
try: from peft import PeftModel, PeftConfig _PEFT_AVAILABLE = True except Exception: PeftModel = None # type: ignore PeftConfig = None # type: ignore _PEFT_AVAILABLE = False
skip pure punctuation/symbol tokens
_SKIP_PUNCT_RE = re.compile(r"^[^A-Za-z0-9]+$")
skip pandas duplicate columns: optional non-alnum → dot → digits
_SKIP_DUP_RE = re.compile(r"^[^A-Za-z0-9]*.[0-9]+$")
helpers
def _is_peft_model(m: torch.nn.Module) -> bool: """Return True if m looks like a PEFT-wrapped model.""" return _PEFT_AVAILABLE and isinstance(m, PeftModel) # type: ignore
def _unwrap(model: torch.nn.Module) -> torch.nn.Module: """ Return the underlying base model if wrapped in DataParallel/DDP/PEFT. """ # DataParallel/DDP first if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): model = model.module
# Unwrap PEFT to the base HF model
if _is_peft_model(model):
get_base = getattr(model, "get_base_model", None)
if callable(get_base):
model = get_base()
elif hasattr(model, "base_model"):
model = model.base_model # type: ignore
return model
def _primary_device(model: torch.nn.Module) -> torch.device: """Device where inputs should live (first param/buffer → safe).""" m = _unwrap(model) for t in m.parameters(recurse=True): return t.device for t in m.buffers(recurse=True): return t.device return torch.device("cpu")
def _get_blocks(model: torch.nn.Module): """ Locate transformer blocks across common HF architectures. Works whether the incoming model is PEFT-wrapped or not. """ m = _unwrap(model) if hasattr(m, "layers"): return m.layers if hasattr(m, "model") and hasattr(m.model, "layers"): return m.model.layers if hasattr(m, "gpt_neox") and hasattr(m.gpt_neox, "layers"): return m.gpt_neox.layers if hasattr(m, "encoder") and hasattr(m.encoder, "layer"): return m.encoder.layer if hasattr(m, "decoder") and hasattr(m.decoder, "layers"): return m.decoder.layers raise AttributeError("Cannot locate transformer layers on the (unwrapped) model.")
def _get_attn_and_mlp(block: torch.nn.Module) -> Tuple[torch.nn.Module, torch.nn.Module]: """Return (attention, mlp) sub-modules regardless of naming.""" attn_names = ("self_attn", "attn", "attention", "self_attention") mlp_names = ("mlp", "ffn", "feed_forward", "feedforward") attn = next((getattr(block, n) for n in attn_names if hasattr(block, n)), None) mlp = next((getattr(block, n) for n in mlp_names if hasattr(block, n)), None) if attn is None or mlp is None: raise AttributeError(f"Cannot find attn/mlp in {type(block).name}") return attn, mlp
load
def load_model_and_tokenizer( model_name: str = "Qwen/Qwen3-0.6B", *, quantization: str | None = None, # "4bit" | "8bit" | None device_map: str | None = "auto", lora_adapter: str | None = None, merge_lora: bool = False, attn_implementation: str | None = None # e.g., "eager" for Gemma2 ): LOGGER.info("Loading model %s", model_name)
# Resolve base/task from adapter if provided
if lora_adapter is not None:
if not _PEFT_AVAILABLE:
raise RuntimeError(
"lora_adapter was provided but 'peft' is not installed. "
"Install with: pip install peft"
)
peft_cfg = PeftConfig.from_pretrained(lora_adapter)
base_name = peft_cfg.base_model_name_or_path or model_name
task_type = str(peft_cfg.task_type).lower()
if attn_implementation is None:
attn_implementation = "eager" # Gemma2 safe default
else:
base_name = model_name
task_type = "causal_lm" # sensible default for LLMs
tokenizer = AutoTokenizer.from_pretrained(base_name)
if getattr(tokenizer, "pad_token", None) is None and hasattr(tokenizer, "eos_token"):
tokenizer.pad_token = tokenizer.eos_token
use_causal = (lora_adapter is not None) or ("causal" in task_type)
loader = AutoModelForCausalLM if use_causal else AutoModel
# Build kwargs selectively
kw = {}
if device_map is not None:
kw["device_map"] = device_map
if attn_implementation is not None:
kw["attn_implementation"] = attn_implementation
if quantization == "4bit":
kw["load_in_4bit"] = True
elif quantization == "8bit":
kw["load_in_8bit"] = True
model = loader.from_pretrained(base_name, **kw)
# Make sure hidden states are returned
if getattr(model.config, "output_hidden_states", False) is not True:
model.config.output_hidden_states = True
# Attach LoRA if provided
if lora_adapter is not None:
model = PeftModel.from_pretrained(model, lora_adapter)
if merge_lora and hasattr(model, "merge_and_unload"):
model = model.merge_and_unload()
model.eval()
return model, tokenizer
hooks
def _register_block_hooks(model) -> tuple[ List[torch.utils.hooks.RemovableHandle], Dict[int, torch.Tensor], Dict[int, torch.Tensor], ]: """Attach hooks capturing attention & MLP outputs.""" attn_cache: Dict[int, torch.Tensor] = {} mlp_cache: Dict[int, torch.Tensor] = {} handles: List[torch.utils.hooks.RemovableHandle] = []
for idx, block in enumerate(_get_blocks(model)):
attn_sub, mlp_sub = _get_attn_and_mlp(block)
def make_hook(cache, i):
def hook(_mod, _inp, out):
t = out[0] if isinstance(out, tuple) else out
cache[i] = t.detach().cpu().squeeze(0)
return hook
handles.append(attn_sub.register_forward_hook(make_hook(attn_cache, idx)))
handles.append(mlp_sub .register_forward_hook(make_hook(mlp_cache , idx)))
return handles, attn_cache, mlp_cache
token helpers
def _find_target_token(tokens: List[str]) -> tuple[int, str]: """Last meaningful token (skip eos, punct, pandas dup suffix).""" for i in range(len(tokens) - 1, -1, -1): tok = tokens[i] if "end_of_turn" in tok or _SKIP_PUNCT_RE.match(tok) or _SKIP_DUP_RE.match(tok): continue return i, tok return len(tokens) - 1, tokens[-1]
def _pooled(t: torch.Tensor, idxs: List[int], method: str) -> torch.Tensor:
"""
Return a 1-D pooled vector from a 2-D [seq_len, hidden] tensor.
idxs are indices of valid tokens (already filtered).
"""
if not idxs: # degenerate: fall back to all
idxs = list(range(t.size(0)))
if method == "last":
return t[idxs[-1]]
elif method == "first":
return t[idxs[0]]
elif method == "mean":
return t[idxs].mean(dim=0)
elif method == "max":
return t[idxs].max(dim=0).values
else:
raise ValueError(f"Unknown pooling method '{method}'")
def _pool_key(method: str, tokens: List[str], valid_idxs: List[int]) -> str: """ Column / dict key to use for a given pooling method. • "last" → last valid token string • "first" → first valid token string • others → method name itself """ if method == "last": return tokens[valid_idxs[-1]] if method == "first": return tokens[valid_idxs[0]] return method
main functions
def embed_dataset( data: Union[pd.DataFrame, str, Iterable[str]], *, input_col: str | None = None, model_name: str = "Qwen/Qwen3-0.6B", output_dir: str = "embeddings", layers: List[int] | None = None, parts: List[str] | None = None, pooling: Union[str, List[str]] = "last", eos_token: str | None = None, device: str | None = "auto", filter_non_text: bool = False, # LoRA lora_adapter: str | None = None, merge_lora: bool = False, # NEW: quantization passthrough quantization: str | None = None, # "4bit" | "8bit" | None ): """ Extract embeddings with multiple pooling strategies and save CSVs. If filter_non_text is True, skip pure-punct/symbol tokens, pandas-duplicate suffixes, and any token containing the provided eos_token string. Otherwise include all tokens.
LoRA:
- Provide `lora_adapter` (HF repo id or local path) to load and apply an adapter for inference.
- Set `merge_lora=True` to bake adapters into the base weights (optional).
Quantization:
- Set `quantization` to "4bit" or "8bit" to load quantized base weights for low-VRAM inference.
"""
# ------------- load / setup -------------
if isinstance(data, (str, os.PathLike)):
df = pd.read_csv(data)
elif isinstance(data, pd.DataFrame):
df = data.copy()
else:
df = pd.DataFrame({"__input__": list(data)})
input_col = "__input__"
if input_col is None:
raise ValueError("input_col required when data is DataFrame")
if isinstance(pooling, str):
pooling = [pooling]
model, tokenizer = load_model_and_tokenizer(
model_name,
device_map=device,
lora_adapter=lora_adapter,
merge_lora=merge_lora,
quantization=quantization, # <-- pass through
)
n_layers = len(_get_blocks(model))
layers = layers or list(range(n_layers))
parts = parts or ["rs", "attn", "mlp"]
handles, attn_cache, mlp_cache = _register_block_hooks(model)
try:
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Embedding"):
prompt = row[input_col]
enc = tokenizer(prompt, return_tensors="pt")
dev = _primary_device(model)
enc = {k: v.to(dev) for k, v in enc.items()}
attn_cache.clear(); mlp_cache.clear()
with torch.no_grad():
out = model(**enc, output_hidden_states=True)
tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0].tolist())
if filter_non_text:
# guard eos_token is None
valid_idxs = [
i for i, tok in enumerate(tokens)
if not _SKIP_PUNCT_RE.match(tok)
and not _SKIP_DUP_RE.match(tok)
and (eos_token is None or eos_token not in tok)
]
else:
valid_idxs = list(range(len(tokens)))
row_id = row.get("id", f"row_{idx}")
for layer in layers:
if not (0 <= layer < n_layers):
raise ValueError(f"Layer {layer} out of 0-{n_layers-1}")
seq_rs = out.hidden_states[layer + 1][0].cpu()
seq_attn = attn_cache[layer].cpu()
seq_mlp = mlp_cache[layer].cpu()
for method in pooling:
key_name = _pool_key(method, tokens, valid_idxs)
vecs = {
"rs": _pooled(seq_rs, valid_idxs, method),
"attn": _pooled(seq_attn, valid_idxs, method),
"mlp": _pooled(seq_mlp, valid_idxs, method),
}
for part in parts:
path_dir = os.path.join(output_dir, part)
os.makedirs(path_dir, exist_ok=True)
path = os.path.join(path_dir, f"{row_id}_L{layer:02d}.csv")
pd.DataFrame(
vecs[part].numpy().reshape(-1, 1),
columns=[key_name]
).to_csv(path, index=False)
finally:
for h in handles: h.remove()
def embed( inputs: Union[str, List[str]], *, model_name: str = "Qwen/Qwen3-0.6B", layers: List[int] | None = None, parts: List[str] | None = None, pooling: Union[str, List[str]] = "last", eos_token: str | None = None, device: str | None = "auto", filter_non_text: bool = False, # LoRA lora_adapter: str | None = None, merge_lora: bool = False, # NEW: quantization passthrough quantization: str | None = None, # "4bit" | "8bit" | None ) -> Dict[str, Dict[int, Dict[str, Dict[str, np.ndarray]]]]: """ Return embeddings with multiple pooling strategies.
If filter_non_text is True, skip punct/symbol tokens, pandas-duplicate suffixes,
and any tokens containing eos_token; otherwise include all tokens.
LoRA:
- Provide `lora_adapter` (HF repo id or local path) to load and apply an adapter for inference.
- Set `merge_lora=True` to bake adapters into the base weights (optional).
Quantization:
- Set `quantization` to "4bit" or "8bit" to load quantized base weights for low-VRAM inference.
"""
if isinstance(inputs, str):
inputs = [inputs]
if isinstance(pooling, str):
pooling = [pooling]
model, tokenizer = load_model_and_tokenizer(
model_name,
device_map=device,
lora_adapter=lora_adapter,
merge_lora=merge_lora,
quantization=quantization, # <-- pass through
)
n_layers = len(_get_blocks(model))
layers = layers or list(range(n_layers))
parts = parts or ["rs", "attn", "mlp"]
handles, attn_cache, mlp_cache = _register_block_hooks(model)
embeddings: Dict[str, Dict[int, Dict[str, Dict[str, np.ndarray]]]] = {}
try:
for s in tqdm(inputs, desc="Embedding"):
prompt = s
enc = tokenizer(prompt, return_tensors="pt")
dev = _primary_device(model)
enc = {k: v.to(dev) for k, v in enc.items()}
attn_cache.clear(); mlp_cache.clear()
with torch.no_grad():
out = model(**enc, output_hidden_states=True)
tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0].tolist())
if filter_non_text:
valid_idxs = [
i for i, tok in enumerate(tokens)
if not _SKIP_PUNCT_RE.match(tok)
and not _SKIP_DUP_RE.match(tok)
and (eos_token is None or eos_token not in tok)
]
else:
valid_idxs = list(range(len(tokens)))
embeddings[s] = {}
for layer in layers:
seq_rs = out.hidden_states[layer + 1][0].cpu()
seq_attn = attn_cache[layer].cpu()
seq_mlp = mlp_cache[layer].cpu()
part_pool: Dict[str, Dict[str, np.ndarray]] = {p: {} for p in parts}
for method in pooling:
key_name = _pool_key(method, tokens, valid_idxs)
part_pool["rs"][key_name] = _pooled(seq_rs, valid_idxs, method).numpy()
part_pool["attn"][key_name] = _pooled(seq_attn, valid_idxs, method).numpy()
part_pool["mlp"][key_name] = _pooled(seq_mlp, valid_idxs, method).numpy()
embeddings[s][layer] = part_pool
finally:
for h in handles: h.remove()
return embeddings
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pingkit-0.4.7.tar.gz.
File metadata
- Download URL: pingkit-0.4.7.tar.gz
- Upload date:
- Size: 27.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
49e40f515844fc548811277ffe40003795e0d8d6371076ccdc06ee60f8532245
|
|
| MD5 |
9f535c10384d718c3e89b2796e8a2880
|
|
| BLAKE2b-256 |
a46d61855c4a7da487e4da86b6617a44b578604681d8c89d5c51cae7ccf7bb46
|
File details
Details for the file pingkit-0.4.7-py3-none-any.whl.
File metadata
- Download URL: pingkit-0.4.7-py3-none-any.whl
- Upload date:
- Size: 29.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
39f6d0dbd5debcd35008f5f5b456133959d02ad1dc83ac3c19fb3f3cf02f3085
|
|
| MD5 |
63d31c3e475b31232d3ad36dc5ae4e98
|
|
| BLAKE2b-256 |
13db30eb29cf9cdd67d416b77f9266166c0e8271e5a21fa1a26da6d392ea41e4
|