Skip to main content

Store and load ML models (XGBoost, LightGBM, CatBoost) in Treasure Data Plazma via pytd

Project description

td-model-store

Store and load ML models (XGBoost, LightGBM, CatBoost) in Treasure Data DB via pytd.

Features

  • API: Save and load models with just a few lines of code
  • Multiple Model Types: Supports XGBoost, LightGBM, and CatBoost
  • Chunked Storage: Automatically handles large models by chunking data
  • Session Management: Track model versions with session IDs
  • Model Versioning: Load specific versions or the latest model
  • Seamless Integration: Works directly with Treasure Data via pytd

Installation

Basic Installation

pip install td-model-store

Quick Start

Saving a Model

from td_model_store import save_model
import xgboost as xgb

# Train your model
model = xgb.XGBClassifier()
model.fit(X_train, y_train)

# Save to Treasure Data
metadata = save_model(
    model=model,
    database="my_database",
    table="ml_model_store",
    model_name="my_xgboost_model",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

print(f"Model saved with session_id: {metadata['session_id']}")

Loading a Model

from td_model_store import load_model

# Load the latest version of a model
model = load_model(
    database="my_database",
    table="ml_model_store",
    model_name="my_xgboost_model",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Use the model for predictions
predictions = model.predict(X_test)

API Reference

save_model()

Save an ML model to a Treasure Data table.

Parameters

  • model (required): Trained model object (XGBClassifier, XGBRegressor, LGBMClassifier, LGBMRegressor, CatBoostClassifier, or CatBoostRegressor)
  • database (required): Target TD database name
  • table (str, default="ml_model_store"): Target table name
  • model_name (str, optional): Name tag for the model. Defaults to "xgboost_model", "lightgbm_model", or "catboost_model" based on model type
  • session_id (int, optional): Unique session identifier. Defaults to current unix timestamp
  • apikey (str, required): TD API key. Can also be set via TD_API_KEY or TDX_API_KEY* env vars
  • endpoint (str, default="https://api.treasuredata.com"): TD API endpoint. Use region-specific endpoints:
    • US: https://api.treasuredata.com
    • EU: https://api.eu01.treasuredata.com
    • Japan: https://api.treasuredata.co.jp
    • Asia Pacific: https://api.ap02.treasuredata.com
  • chunk_size (int, default=130000): Max characters per chunk (must be < 131072)

Returns

Dictionary containing:

  • session_id: Session identifier
  • model_name: Model name
  • model_type: Model type (xgboost, lightgbm, or catboost)
  • model_format: Serialization format
  • num_chunks: Number of chunks
  • size_bytes: Model size in bytes

Example

metadata = save_model(
    model=my_model,
    database="production_db",
    table="models",
    model_name="fraud_detector_v1",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

load_model()

Load an ML model from a Treasure Data table.

Parameters

  • database (required): Source TD database name
  • table (str, default="ml_model_store"): Source table name
  • model_name (str, optional): Filter by model name. If not provided, loads the latest session
  • session_id (int, optional): Load a specific session. If not provided, loads the max session_id (optionally filtered by model_name)
  • apikey (str, required): TD API key. Can also be set via TD_API_KEY or TDX_API_KEY* env vars
  • endpoint (str, default="https://api.treasuredata.com"): TD API endpoint. Use region-specific endpoints:
    • US: https://api.treasuredata.com
    • EU: https://api.eu01.treasuredata.com
    • Japan: https://api.treasuredata.co.jp
    • Asia Pacific: https://api.ap02.treasuredata.com

Returns

The reconstructed ML model (XGBoost, LightGBM, or CatBoost)

Example

# Load latest version of a specific model
model = load_model(
    database="production_db",
    table="models",
    model_name="fraud_detector_v1",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load a specific session
model = load_model(
    database="production_db",
    table="models",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load the absolute latest model (any name)
model = load_model(
    database="production_db",
    table="models",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

Complete Examples

XGBoost Example

from td_model_store import save_model, load_model
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Prepare data
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = xgb.XGBClassifier(n_estimators=100, max_depth=3)
model.fit(X_train, y_train)

# Save to TD
metadata = save_model(
    model=model,
    database="ml_models",
    model_name="iris_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load from TD
loaded_model = load_model(
    database="ml_models",
    model_name="iris_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Verify
score = loaded_model.score(X_test, y_test)
print(f"Model accuracy: {score:.3f}")

LightGBM Example

from td_model_store import save_model, load_model
import lightgbm as lgb
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

# Prepare data
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = lgb.LGBMRegressor(n_estimators=100)
model.fit(X_train, y_train)

# Save to TD
metadata = save_model(
    model=model,
    database="ml_models",
    model_name="boston_price_predictor",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load from TD
loaded_model = load_model(
    database="ml_models",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Make predictions
predictions = loaded_model.predict(X_test)

CatBoost Example

from td_model_store import save_model, load_model
from catboost import CatBoostClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

# Prepare data
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = CatBoostClassifier(iterations=100, verbose=False)
model.fit(X_train, y_train)

# Save to TD
metadata = save_model(
    model=model,
    database="ml_models",
    model_name="digit_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

print(f"Saved {metadata['num_chunks']} chunks, {metadata['size_bytes']} bytes")

# Load from TD
loaded_model = load_model(
    database="ml_models",
    model_name="digit_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Make predictions
predictions = loaded_model.predict(X_test)

Configuration

API Key and Endpoint

The apikey parameter is required for all operations. Specify the correct endpoint for your Treasure Data region:

# US region (default)
save_model(
    model=model,
    database="my_database",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# EU region
save_model(
    model=model,
    database="my_database",
    apikey="your_td_api_key",
    endpoint="https://api.eu01.treasuredata.com"
)

# Japan region
save_model(
    model=model,
    database="my_database",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.co.jp"
)

# Asia Pacific region
save_model(
    model=model,
    database="my_database",
    apikey="your_td_api_key",
    endpoint="https://api.ap02.treasuredata.com"
)

Model Storage Schema

Models are stored in the following table schema:

Column Type Description
time int Unix timestamp
session_id int Unique session identifier
model_name string Model name tag
model_type string Model type (xgboost, lightgbm, catboost)
model_format string Serialization format (ubj, txt, cbm)
chunk_index int Chunk index (0-based)
total_chunks int Total number of chunks
chunk_data string Base64-encoded model data chunk
created_at string ISO timestamp

How It Works

  1. Serialization: Models are serialized to their native format:

    • XGBoost: .ubj (Universal Binary JSON)
    • LightGBM: .txt (text format)
    • CatBoost: .cbm (CatBoost binary)
  2. Chunking: Large models are split into chunks (default 130KB per chunk) to comply with Treasure Data's bulk import limits

  3. Storage: Chunks are stored as base64-encoded strings in a Treasure Data table via pytd's bulk_import

  4. Retrieval: When loading, chunks are reassembled, decoded, and deserialized back into the original model object

Requirements

  • Python >= 3.8
  • pytd >= 1.0.0
  • pandas >= 1.0.0
  • At least one of: xgboost, lightgbm, or catboost

Logging

The library uses Python's standard logging module. Enable logging to see save/load operations:

import logging
logging.basicConfig(level=logging.INFO)

Example output:

INFO:td_model_store.model_store:Saved model 'my_model' (type=xgboost) to ml_models.ml_model_store | session_id=1713628800 | chunks=5 | size=12345 bytes
INFO:td_model_store.model_loader:Loaded model 'my_model' (type=xgboost) from ml_models.ml_model_store | session_id=1713628800 | chunks=5 | size=12345 bytes

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Support

For issues and questions, please open an issue on GitHub or contact the maintainers.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

td_model_store-0.1.5.tar.gz (10.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

td_model_store-0.1.5-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file td_model_store-0.1.5.tar.gz.

File metadata

  • Download URL: td_model_store-0.1.5.tar.gz
  • Upload date:
  • Size: 10.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.9

File hashes

Hashes for td_model_store-0.1.5.tar.gz
Algorithm Hash digest
SHA256 ece8cc4888bbcc6837d74e2ac3aa43854149cdf7e58f555261f68f8d4d15ee72
MD5 69ee50baaa10304f3aad28fe3cb4f7da
BLAKE2b-256 c55d4bf81b421b707c0d9c733c08483a4f86a43bf6007d8f6b1e5f57f73275df

See more details on using hashes here.

File details

Details for the file td_model_store-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: td_model_store-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 9.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.9

File hashes

Hashes for td_model_store-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 469cf1ec7b6323608733eb70164c0a368350ff1a79f2cf6f5dbc3d968a7a6d54
MD5 a23b22165d32b2c0f617c16ccce097ac
BLAKE2b-256 a8b06690a182027b190e6ba57d7d44fa9cef25b3e6e169ef1f4f9ba6e78e2290

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page