Embedding-based few-shot learning library for keyword spotting and related tasks
Project description
fewshotkit
fewshotkit là thư viện Python cho bài toán few-shot learning dựa trên embedding (đặc biệt phù hợp với Keyword Spotting).
Mục tiêu của thư viện:
- Nhẹ, dễ dùng, API gần giống sklearn.
- Không chứa encoder (bạn tự tạo embedding từ pipeline khác).
- Chạy offline, deterministic, tập trung vào suy luận few-shot.
1. fewshotkit giải quyết bài toán gì?
Bạn đã có embedding vector (NumPy) và muốn:
- Huấn luyện nhanh từ ít mẫu support (
fit). - Dự đoán class cho query (
predict). - Lấy xác suất cho từng class (
predict_proba). - Đánh giá bằng episodic few-shot (
evaluate_n_way_k_shot).
fewshotkit không trích xuất đặc trưng từ audio thô, chỉ làm việc trên embedding đầu vào.
2. Cài đặt
Cài đặt cơ bản
cd /home/ngocan/Projects/KWS_Project/Fewshot
python3 -m venv .venv
source .venv/bin/activate
pip install -e .
Cài đặt cho phát triển (test/lint/benchmark)
pip install -e .[dev]
3. Kiểm tra nhanh môi trường
Đảm bảo python và pip cùng một virtualenv:
which python
python -c "import sys; print(sys.executable)"
Cả 2 lệnh nên trỏ về:
/home/ngocan/Projects/KWS_Project/Fewshot/.venv/bin/python
4. Định dạng dữ liệu đầu vào
Input chuẩn cho model
X_support:np.ndarray, shape(N, D)y_support:np.ndarray, shape(N,)X_query:np.ndarray, shape(M, D)
Output
predict(X_query): shape(M,)predict_proba(X_query): shape(M, num_classes)
Lưu ý
D(embedding dimension) phải nhất quán giữa support/query.y_supportcó thể là số hoặc chuỗi.- Nếu shape sai, thư viện sẽ raise
ValueErrorrõ ràng.
5. API chính
Import public API:
from fewshotkit import (
ProtoNet,
SiameseKNN,
MatchingNet,
evaluate_n_way_k_shot,
)
Tất cả model đều hỗ trợ:
fit(X, y)predict(X)predict_proba(X)score(X, y)add_class(label, embeddings)remove_class(label)
6. Dùng từng model
6.1 ProtoNet (khuyên dùng để bắt đầu)
import numpy as np
from fewshotkit import ProtoNet
x_support = np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0], [1.1, 1.0]])
y_support = np.array(["off", "off", "on", "on"], dtype=object)
x_query = np.array([[0.05, 0.0], [1.05, 1.0]])
model = ProtoNet(metric="cosine") # hoặc metric="euclidean"
model.fit(x_support, y_support)
pred = model.predict(x_query)
proba = model.predict_proba(x_query)
Cơ chế: tính centroid cho từng class rồi so sánh query với centroid.
6.2 SiameseKNN
from fewshotkit import SiameseKNN
model = SiameseKNN(metric="euclidean", k=3)
model.fit(x_support, y_support)
pred = model.predict(x_query)
Cơ chế: so similarity tới toàn bộ support, voting theo top-k.
6.3 MatchingNet
from fewshotkit import MatchingNet
model = MatchingNet(metric="cosine")
model.fit(x_support, y_support)
pred = model.predict(x_query)
Cơ chế: attention softmax giữa query và support để tính xác suất lớp.
7. Unknown detection (phát hiện ngoài tập lớp)
Bạn có thể set ngưỡng để trả về nhãn unknown khi độ tương đồng thấp:
from fewshotkit import ProtoNet
model = ProtoNet(metric="cosine", threshold=0.8, unknown_label="unknown")
model.fit(x_support, y_support)
pred = model.predict(x_query)
Nếu score tốt nhất < threshold thì output là unknown_label.
8. Thêm/Xóa class động
model.add_class("maybe", np.array([[0.5, 0.5], [0.6, 0.5]]))
model.remove_class("maybe")
Hữu ích khi cần enroll class mới mà không train lại pipeline encoder.
9. Đánh giá episodic few-shot
import numpy as np
from fewshotkit import ProtoNet, evaluate_n_way_k_shot
x = np.random.randn(300, 64)
y = np.repeat(np.arange(10), 30)
model = ProtoNet(metric="euclidean")
result = evaluate_n_way_k_shot(
model,
x,
y,
n_way=5,
k_shot=5,
n_query=10,
n_episodes=100,
random_state=42,
)
print(result["accuracy_mean"], result["accuracy_std"])
Kết quả trả về gồm:
accuracy_meanaccuracy_stdepisode_accuracies- metadata episode (
n_way,k_shot,n_query,n_episodes,random_state)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fewshotkit-0.0.1.tar.gz.
File metadata
- Download URL: fewshotkit-0.0.1.tar.gz
- Upload date:
- Size: 14.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c8bf3140e2cc29198a2f97f5dbc2167a375324b58398cf63aed39bf109e902d1
|
|
| MD5 |
65934d963fa84ed048a125ed5a56a108
|
|
| BLAKE2b-256 |
6505d11dd936b30c9adf567aef7d3282802ba0e5de6d3815561c784250be4cce
|
File details
Details for the file fewshotkit-0.0.1-py3-none-any.whl.
File metadata
- Download URL: fewshotkit-0.0.1-py3-none-any.whl
- Upload date:
- Size: 15.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
95e6b8352988e18e907a2f58b65e53f212cc5e995cda3b4a3e8cd14b7534aa4a
|
|
| MD5 |
2b83f87988b558647dbacd78a540231f
|
|
| BLAKE2b-256 |
0c6b2111b3ef7690a2169cce569f903481c75357287a490006640b216b114ce8
|