Skip to main content

A single-module library for [describe functionality briefly]

Project description

Orient Express

A library to accelerate model deployments to Vertex AI directly from colab notebooks

train-resized

Installation

pip install orient_express

Example

Train Model

Train a regular model. In the example below, it's xgboost model, trained on the Titanic dataset.

# Import necessary libraries
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer

# Load the Titanic dataset
data = sns.load_dataset('titanic').dropna(subset=['survived'])  # Dropping rows with missing target labels

# Select features and target
X = data[['pclass', 'sex', 'age', 'sibsp', 'parch', 'fare', 'embarked']]
y = data['survived']

# Define preprocessing for numeric columns (impute missing values and scale features)
numeric_features = ['age', 'fare', 'sibsp', 'parch']
numeric_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())
])

# Define preprocessing for categorical columns (impute missing values and one-hot encode)
categorical_features = ['pclass', 'sex', 'embarked']
categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])

# Combine preprocessing steps
preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features)
    ])

# Create a pipeline that first transforms the data, then trains an XGBoost model
model = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss'))
])

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the model
model.fit(X_train, y_train)

Upload Model To Model Registry

model_wrapper  = ModelExpress(model=model,
                             project_name='my-project-name',
                             region='us-central1',
                             bucket_name='my-artifacts-bucket',
                             model_name='titanic')
model_wrapper.upload()

Local Inference (Without Online Prediction Endpoint)

The following code will download the last model from the model registry and run the inference locally.

# create input dataframe
titanic_data = {
    "pclass": [1],          # Passenger class (1st, 2nd, 3rd)
    "sex": ["female"],      # Gender
    "age": [29],            # Age
    "sibsp": [0],           # Number of siblings/spouses aboard
    "parch": [0],           # Number of parents/children aboard
    "fare": [100.0],        # Ticket fare
    "embarked": ["S"]       # Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
}
input_df = pd.DataFrame(titanic_data)

# init the model wrapper
model_wrapper  = ModelExpress(project_name='my-project-name',
                             region='us-central1',
                             model_name='titanic')

# Run inference locally
# It will download the most recent version from the model registry automatically
model_wrapper.local_predict(input_df)

Pin Model Version

In many cases, the pipeline should be pinned to a specific model version so the model can only be updated explicitly. Just pass a model_version parameter when instantiating the ModelExpress wrapper.

# init the model wrapper
model_wrapper  = ModelExpress(project_name='my-project-name',
                             region='us-central1',
                             model_name='titanic',
                             model_version=11)

Remote Inference (With Online Prediction Endpoint)

Make sure the model is deployed:

model_wrapper  = ModelExpress(model=model,
                             project_name='my-project-name',
                             region='us-central1',
                             bucket_name='my-artifacts-bucket',
                             model_name='titanic')

# upload the version to the registry and deploy it to the endpoint
model_wrapper.deploy()

Run inference with remote_predict method. It will make a remote call to the endpoint without fetching the model locally.

titanic_data = {
    "pclass": [1],             # Passenger class (1st, 2nd, 3rd)
    "sex": ["female"],         # Gender
    "age": [29],               # Age
    "sibsp": [0],              # Number of siblings/spouses aboard
    "parch": [0],              # Number of parents/children aboard
    "fare": [100.0],           # Ticket fare
    "embarked": ["S"]          # Port of Embarkation (C = Cherbourg, Q = Queenstown, S = Southampton)
}
df = pd.DataFrame(titanic_data)

model_wrapper.remote_predict(df)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

orient_express-0.1.2.tar.gz (3.1 kB view details)

Uploaded Source

Built Distribution

orient_express-0.1.2-py3-none-any.whl (3.0 kB view details)

Uploaded Python 3

File details

Details for the file orient_express-0.1.2.tar.gz.

File metadata

  • Download URL: orient_express-0.1.2.tar.gz
  • Upload date:
  • Size: 3.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for orient_express-0.1.2.tar.gz
Algorithm Hash digest
SHA256 1d84561e90869c2b38faed0fa2405243fe11668c8df78a63845d553924061bab
MD5 2769d51967f4d4d8f4a3d8b0dd3fe307
BLAKE2b-256 62fe481538b9a86c56d79b6512058f07ad332529e673520b74e33b6631ac191f

See more details on using hashes here.

File details

Details for the file orient_express-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for orient_express-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 732c002571eb7e1d5f03512b7063699216c9c08692c3799a9a941f0ca8701356
MD5 ac3140a376993f09c894e33fc8344e01
BLAKE2b-256 cecea98299afc409e6fee50a3e5aba3f6cae9b0c2a61486ce566e0cce0e24a2d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page