Skip to main content

Embedding-based few-shot learning library for keyword spotting and related tasks

Project description

fewshotkit

fewshotkit là thư viện Python cho bài toán few-shot learning dựa trên embedding (đặc biệt phù hợp với Keyword Spotting).

Mục tiêu của thư viện:

  • Nhẹ, dễ dùng, API gần giống sklearn.
  • Không chứa encoder (bạn tự tạo embedding từ pipeline khác).
  • Chạy offline, deterministic, tập trung vào suy luận few-shot.

1. fewshotkit giải quyết bài toán gì?

Bạn đã có embedding vector (NumPy) và muốn:

  • Huấn luyện nhanh từ ít mẫu support (fit).
  • Dự đoán class cho query (predict).
  • Lấy xác suất cho từng class (predict_proba).
  • Đánh giá bằng episodic few-shot (evaluate_n_way_k_shot).

fewshotkit không trích xuất đặc trưng từ audio thô, chỉ làm việc trên embedding đầu vào.


2. Cài đặt

Cài đặt cơ bản

cd /home/ngocan/Projects/KWS_Project/Fewshot
python3 -m venv .venv
source .venv/bin/activate
pip install -e .

Cài đặt cho phát triển (test/lint/benchmark)

pip install -e .[dev]

3. Kiểm tra nhanh môi trường

Đảm bảo pythonpip cùng một virtualenv:

which python
python -c "import sys; print(sys.executable)"

Cả 2 lệnh nên trỏ về: /home/ngocan/Projects/KWS_Project/Fewshot/.venv/bin/python


4. Định dạng dữ liệu đầu vào

Input chuẩn cho model

  • X_support: np.ndarray, shape (N, D)
  • y_support: np.ndarray, shape (N,)
  • X_query: np.ndarray, shape (M, D)

Output

  • predict(X_query): shape (M,)
  • predict_proba(X_query): shape (M, num_classes)

Lưu ý

  • D (embedding dimension) phải nhất quán giữa support/query.
  • y_support có thể là số hoặc chuỗi.
  • Nếu shape sai, thư viện sẽ raise ValueError rõ ràng.

5. API chính

Import public API:

from fewshotkit import (
    ProtoNet,
    SiameseKNN,
    MatchingNet,
    evaluate_n_way_k_shot,
)

Tất cả model đều hỗ trợ:

  • fit(X, y)
  • predict(X)
  • predict_proba(X)
  • score(X, y)
  • add_class(label, embeddings)
  • remove_class(label)

6. Dùng từng model

6.1 ProtoNet (khuyên dùng để bắt đầu)

import numpy as np
from fewshotkit import ProtoNet

x_support = np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0], [1.1, 1.0]])
y_support = np.array(["off", "off", "on", "on"], dtype=object)
x_query = np.array([[0.05, 0.0], [1.05, 1.0]])

model = ProtoNet(metric="cosine")  # hoặc metric="euclidean"
model.fit(x_support, y_support)

pred = model.predict(x_query)
proba = model.predict_proba(x_query)

Cơ chế: tính centroid cho từng class rồi so sánh query với centroid.

6.2 SiameseKNN

from fewshotkit import SiameseKNN

model = SiameseKNN(metric="euclidean", k=3)
model.fit(x_support, y_support)
pred = model.predict(x_query)

Cơ chế: so similarity tới toàn bộ support, voting theo top-k.

6.3 MatchingNet

from fewshotkit import MatchingNet

model = MatchingNet(metric="cosine")
model.fit(x_support, y_support)
pred = model.predict(x_query)

Cơ chế: attention softmax giữa query và support để tính xác suất lớp.


7. Unknown detection (phát hiện ngoài tập lớp)

Bạn có thể set ngưỡng để trả về nhãn unknown khi độ tương đồng thấp:

from fewshotkit import ProtoNet

model = ProtoNet(metric="cosine", threshold=0.8, unknown_label="unknown")
model.fit(x_support, y_support)
pred = model.predict(x_query)

Nếu score tốt nhất < threshold thì output là unknown_label.


8. Thêm/Xóa class động

model.add_class("maybe", np.array([[0.5, 0.5], [0.6, 0.5]]))
model.remove_class("maybe")

Hữu ích khi cần enroll class mới mà không train lại pipeline encoder.


9. Đánh giá episodic few-shot

import numpy as np
from fewshotkit import ProtoNet, evaluate_n_way_k_shot

x = np.random.randn(300, 64)
y = np.repeat(np.arange(10), 30)

model = ProtoNet(metric="euclidean")
result = evaluate_n_way_k_shot(
    model,
    x,
    y,
    n_way=5,
    k_shot=5,
    n_query=10,
    n_episodes=100,
    random_state=42,
)

print(result["accuracy_mean"], result["accuracy_std"])

Kết quả trả về gồm:

  • accuracy_mean
  • accuracy_std
  • episode_accuracies
  • metadata episode (n_way, k_shot, n_query, n_episodes, random_state)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fewshotkit-0.0.1.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fewshotkit-0.0.1-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file fewshotkit-0.0.1.tar.gz.

File metadata

  • Download URL: fewshotkit-0.0.1.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for fewshotkit-0.0.1.tar.gz
Algorithm Hash digest
SHA256 c8bf3140e2cc29198a2f97f5dbc2167a375324b58398cf63aed39bf109e902d1
MD5 65934d963fa84ed048a125ed5a56a108
BLAKE2b-256 6505d11dd936b30c9adf567aef7d3282802ba0e5de6d3815561c784250be4cce

See more details on using hashes here.

File details

Details for the file fewshotkit-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: fewshotkit-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 15.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for fewshotkit-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 95e6b8352988e18e907a2f58b65e53f212cc5e995cda3b4a3e8cd14b7534aa4a
MD5 2b83f87988b558647dbacd78a540231f
BLAKE2b-256 0c6b2111b3ef7690a2169cce569f903481c75357287a490006640b216b114ce8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page