Skip to main content

Store and load ML models (XGBoost, LightGBM, CatBoost) in Treasure Data Plazma via pytd

Project description

td-model-store

Store and load ML models (XGBoost, LightGBM, CatBoost) in Treasure Data Plazma via pytd.

Features

  • API: Save and load models with just a few lines of code
  • Multiple Model Types: Supports XGBoost, LightGBM, and CatBoost
  • Chunked Storage: Automatically handles large models by chunking data
  • Session Management: Track model versions with session IDs
  • Model Versioning: Load specific versions or the latest model
  • Seamless Integration: Works directly with Treasure Data via pytd

Installation

Basic Installation

pip install td-model-store

With Specific ML Framework

# For XGBoost
pip install td-model-store[xgboost]

# For LightGBM
pip install td-model-store[lightgbm]

# For CatBoost
pip install td-model-store[catboost]

# Install all frameworks
pip install td-model-store[all]

Quick Start

Environment Setup

First, set your Treasure Data credentials:

export TD_API_KEY="your_api_key_here"

Or pass them directly in your code (see examples below).

Saving a Model

from td_model_store import save_model
import xgboost as xgb

# Train your model
model = xgb.XGBClassifier()
model.fit(X_train, y_train)

# Save to Treasure Data (using environment variable)
metadata = save_model(
    model=model,
    database="my_database",
    table="ml_model_store",
    model_name="my_xgboost_model"
)

# Or pass credentials directly
metadata = save_model(
    model=model,
    database="my_database",
    table="ml_model_store",
    model_name="my_xgboost_model",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"  # or your TD endpoint
)

print(f"Model saved with session_id: {metadata['session_id']}")

Loading a Model

from td_model_store import load_model

# Load the latest version of a model (using environment variable)
model = load_model(
    database="my_database",
    table="ml_model_store",
    model_name="my_xgboost_model"
)

# Or pass credentials directly
model = load_model(
    database="my_database",
    table="ml_model_store",
    model_name="my_xgboost_model",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"  # or your TD endpoint
)

# Use the model for predictions
predictions = model.predict(X_test)

API Reference

save_model()

Save an ML model to a Treasure Data table.

Parameters

  • model (required): Trained model object (XGBClassifier, XGBRegressor, LGBMClassifier, LGBMRegressor, CatBoostClassifier, or CatBoostRegressor)
  • database (required): Target TD database name
  • table (str, default="ml_model_store"): Target table name
  • model_name (str, optional): Name tag for the model. Defaults to "xgboost_model", "lightgbm_model", or "catboost_model" based on model type
  • session_id (int, optional): Unique session identifier. Defaults to current unix timestamp
  • apikey (str, required): TD API key. Can also be set via TD_API_KEY or TDX_API_KEY* env vars
  • endpoint (str, default="https://api.treasuredata.com"): TD API endpoint. Use region-specific endpoints:
    • US: https://api.treasuredata.com
    • EU: https://api.eu01.treasuredata.com
    • Japan: https://api.treasuredata.co.jp
    • Asia Pacific: https://api.ap02.treasuredata.com
  • chunk_size (int, default=130000): Max characters per chunk (must be < 131072)

Returns

Dictionary containing:

  • session_id: Session identifier
  • model_name: Model name
  • model_type: Model type (xgboost, lightgbm, or catboost)
  • model_format: Serialization format
  • num_chunks: Number of chunks
  • size_bytes: Model size in bytes

Example

metadata = save_model(
    model=my_model,
    database="production_db",
    table="models",
    model_name="fraud_detector_v1",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

load_model()

Load an ML model from a Treasure Data table.

Parameters

  • database (required): Source TD database name
  • table (str, default="ml_model_store"): Source table name
  • model_name (str, optional): Filter by model name. If not provided, loads the latest session
  • session_id (int, optional): Load a specific session. If not provided, loads the max session_id (optionally filtered by model_name)
  • apikey (str, required): TD API key. Can also be set via TD_API_KEY or TDX_API_KEY* env vars
  • endpoint (str, default="https://api.treasuredata.com"): TD API endpoint. Use region-specific endpoints:
    • US: https://api.treasuredata.com
    • EU: https://api.eu01.treasuredata.com
    • Japan: https://api.treasuredata.co.jp
    • Asia Pacific: https://api.ap02.treasuredata.com

Returns

The reconstructed ML model (XGBoost, LightGBM, or CatBoost)

Example

# Load latest version of a specific model
model = load_model(
    database="production_db",
    table="models",
    model_name="fraud_detector_v1",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load a specific session
model = load_model(
    database="production_db",
    table="models",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load the absolute latest model (any name)
model = load_model(
    database="production_db",
    table="models",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

Complete Examples

XGBoost Example

from td_model_store import save_model, load_model
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Prepare data
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = xgb.XGBClassifier(n_estimators=100, max_depth=3)
model.fit(X_train, y_train)

# Save to TD
metadata = save_model(
    model=model,
    database="ml_models",
    model_name="iris_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load from TD
loaded_model = load_model(
    database="ml_models",
    model_name="iris_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Verify
score = loaded_model.score(X_test, y_test)
print(f"Model accuracy: {score:.3f}")

LightGBM Example

from td_model_store import save_model, load_model
import lightgbm as lgb
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

# Prepare data
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = lgb.LGBMRegressor(n_estimators=100)
model.fit(X_train, y_train)

# Save to TD
metadata = save_model(
    model=model,
    database="ml_models",
    model_name="boston_price_predictor",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Load from TD
loaded_model = load_model(
    database="ml_models",
    session_id=20240420001,
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Make predictions
predictions = loaded_model.predict(X_test)

CatBoost Example

from td_model_store import save_model, load_model
from catboost import CatBoostClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

# Prepare data
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = CatBoostClassifier(iterations=100, verbose=False)
model.fit(X_train, y_train)

# Save to TD
metadata = save_model(
    model=model,
    database="ml_models",
    model_name="digit_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

print(f"Saved {metadata['num_chunks']} chunks, {metadata['size_bytes']} bytes")

# Load from TD
loaded_model = load_model(
    database="ml_models",
    model_name="digit_classifier",
    apikey="your_td_api_key",
    endpoint="https://api.treasuredata.com"
)

# Make predictions
predictions = loaded_model.predict(X_test)

Configuration

API Key

You can provide your Treasure Data API key in two ways:

Option 1: Environment Variable (Recommended)

export TD_API_KEY="your_api_key_here"

Then use the library without passing the apikey parameter:

save_model(model=model, database="db")
load_model(database="db", model_name="my_model")

Option 2: Direct Parameter

save_model(model=model, database="db", apikey="your_api_key")
load_model(database="db", model_name="my_model", apikey="your_api_key")

Endpoint (Regional)

Specify the correct endpoint for your Treasure Data region:

# US region (default)
endpoint = "https://api.treasuredata.com"

# EU region
endpoint = "https://api.eu01.treasuredata.com"

# Japan region
endpoint = "https://api.treasuredata.co.jp"

# Asia Pacific region
endpoint = "https://api.ap02.treasuredata.com"

Example with region:

save_model(
    model=model,
    database="db",
    apikey="your_api_key",
    endpoint="https://api.eu01.treasuredata.com"  # EU region
)

Model Storage Schema

Models are stored in the following table schema:

Column Type Description
time int Unix timestamp
session_id int Unique session identifier
model_name string Model name tag
model_type string Model type (xgboost, lightgbm, catboost)
model_format string Serialization format (ubj, txt, cbm)
chunk_index int Chunk index (0-based)
total_chunks int Total number of chunks
chunk_data string Base64-encoded model data chunk
created_at string ISO timestamp

How It Works

  1. Serialization: Models are serialized to their native format:

    • XGBoost: .ubj (Universal Binary JSON)
    • LightGBM: .txt (text format)
    • CatBoost: .cbm (CatBoost binary)
  2. Chunking: Large models are split into chunks (default 130KB per chunk) to comply with Treasure Data's bulk import limits

  3. Storage: Chunks are stored as base64-encoded strings in a Treasure Data table via pytd's bulk_import

  4. Retrieval: When loading, chunks are reassembled, decoded, and deserialized back into the original model object

Requirements

  • Python >= 3.8
  • pytd >= 1.0.0
  • pandas >= 1.0.0
  • At least one of: xgboost, lightgbm, or catboost

Logging

The library uses Python's standard logging module. Enable logging to see save/load operations:

import logging
logging.basicConfig(level=logging.INFO)

Example output:

INFO:td_model_store.model_store:Saved model 'my_model' (type=xgboost) to ml_models.ml_model_store | session_id=1713628800 | chunks=5 | size=12345 bytes
INFO:td_model_store.model_loader:Loaded model 'my_model' (type=xgboost) from ml_models.ml_model_store | session_id=1713628800 | chunks=5 | size=12345 bytes

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Support

For issues and questions, please open an issue on GitHub or contact the maintainers.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

td_model_store-0.1.2.tar.gz (11.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

td_model_store-0.1.2-py3-none-any.whl (9.7 kB view details)

Uploaded Python 3

File details

Details for the file td_model_store-0.1.2.tar.gz.

File metadata

  • Download URL: td_model_store-0.1.2.tar.gz
  • Upload date:
  • Size: 11.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.9

File hashes

Hashes for td_model_store-0.1.2.tar.gz
Algorithm Hash digest
SHA256 7c4dcee22d5eef8cf3f67421fd03004fcf4783cd46235bf7fb02efeef952950e
MD5 ee237016b1f7846b23da0099a69f7875
BLAKE2b-256 50e992d8cfe3fa117c84b1da9e6917154e9ffa925752531e0907d35eba14f9de

See more details on using hashes here.

File details

Details for the file td_model_store-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: td_model_store-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 9.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.9

File hashes

Hashes for td_model_store-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7166d2b6b6057a219367b48375d723ff8042ddd72558c601157c9ea89f3e4478
MD5 fc60c75d76f061a43939fb733cf55758
BLAKE2b-256 6d6f8f7812c77551f68b352f72c645d43732702776c018a1c57462614ed7da9b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page