Atlas artifact bundle SDK — pushes a PyTorch model + validation data to MLflow in the format atlas (TRT/Triton optimizer) consumes.
Project description
ml_atlas_sdk
Pushes a PyTorch model + validation data to MLflow in the format atlas (the TensorRT/Triton optimizer) consumes.
Why
Atlas needs 9 specific input artifacts on GCS in a strict format. Without this SDK, data scientists hand-write ONNX exports, val_inputs.pt, metadata.json, etc. — and silently get them wrong (FP16/GPU-generated reference, bad constant-folding, missing dynamic axes). Failures only surface inside atlas hours later.
ml_atlas_sdk consolidates all 9 artifacts behind one push_to_mlflow(bundle, run)
call, validates the model+data locally before upload, and runs an
ORT round-trip self-check that catches export bugs before they ship.
Install
pip install ml_atlas_sdk
# optional extras
pip install ml_atlas_sdk[architecture] # LLM-driven architecture.md
pip install ml_atlas_sdk[gcs] # for atlas reverse-side resolver
LLM architecture autogen (
architecture_doc="auto", the default) requires both the[architecture]extra andANTHROPIC_API_KEYin the environment whenpush_to_mlflowruns. If either is missing, the SDK silently falls back to a structural+ONNX-summary markdown (PyTorchrepr+ op histogram) — it never raises.
Minimal usage
import mlflow
import torch
from ml_atlas_sdk import (
AtlasArtifactBundle,
InputBinding,
InputSchema,
ModelMetadata,
push_to_mlflow,
)
bundle = AtlasArtifactBundle(
model_name="my_model",
model=my_model, # nn.Module
metadata=ModelMetadata(
model_name="my_model",
input_names=["input__0", "input__1"],
output_names=["output__0"],
input_shapes={"input__0": [-1, 4, 8], "input__1": [-1, 16]},
output_shape={"output__0": [-1, 3]},
dynamic_axes={ # paste from your torch.onnx.export call
"input__0": {0: "batch_size"},
"input__1": {0: "batch_size"},
"output__0": {0: "batch_size"},
},
onnx_opset=17,
torch_version=torch.__version__,
architecture="MyModel",
task_type="multitask_ranking",
),
input_schema=InputSchema(inputs=[ # MANDATORY — predator-shaped feature bindings
InputBinding(name="input__0", features=["ONLINE_FEATURE|user:age_bucket"]),
InputBinding(name="input__1", features=["OFFLINE_FEATURE|user_segments"]),
]),
val_inputs_df=val_in_df, # one column per ONNX input
export_batch_size=64, # rows to slice for the ONNX trace sample
)
mlflow.set_tracking_uri("...")
mlflow.set_experiment("ranking_models")
with mlflow.start_run() as run:
run_id = push_to_mlflow(bundle, run)
That's it. The SDK:
- forces
model.eval().cpu().float()+torch.no_grad()for export - generates canonical val_outputs by running the model on val_inputs
- exports ONNX with proper dynamic_axes
- runs ORT round-trip self-check (with 5-step escalation if it fails)
- uploads the required atlas artifacts under
atlas/in the run (includinginput_schema.json— the predator-shaped feature bindings, structurally validated) plus any optional ones the bundle supplied (traced.pt,calibration_inputs.pt,val_outputs_user.pt,drift_report.json,architecture.md) + logs all params/tags atlas needs
See PLAN.md for the full design.
Full usage
Every field on AtlasArtifactBundle (and the nested ModelMetadata) shown
with realistic values, including a non-batch dynamic axis (seq_len). Use
this when you want INT8 calibration, drift comparison against your own
reference outputs, or to override SDK defaults.
from pathlib import Path
import mlflow
import pandas as pd
import torch
from ml_atlas_sdk import (
AtlasArtifactBundle,
InputBinding,
InputSchema,
ModelMetadata,
push_to_mlflow,
)
# ── 1. Metadata: ONNX-side schema.
# In input_shapes / output_shape: use -1 for every dynamic dim and
# positive ints for static dims.
# In dynamic_axes: paste your torch.onnx.export `dynamic_axes` dict
# verbatim — same format, {io_name: {axis_idx: axis_name}}. Static
# axes are NOT listed here; axis 0 (batch) MUST be listed for every
# input AND every output.
metadata = ModelMetadata(
model_name="my_model",
input_names=["tokens", "dense"],
output_names=["logits"],
input_shapes={
"tokens": [-1, -1], # axes 0 and 1 both dynamic; names below
"dense": [-1, 16], # axis 1 static (16)
},
output_shape={
"logits": [-1, 3], # axis 1 static (3)
},
dynamic_axes={
"tokens": {0: "batch_size", 1: "seq_len"},
"dense": {0: "batch_size"},
"logits": {0: "batch_size"},
},
onnx_opset=17, # must be >= 17 (LayerNorm/SDPA precision)
torch_version=torch.__version__,
architecture="MyModel",
task_type="multitask_ranking",
)
# ── 2. Input schema: predator-shaped feature bindings (MANDATORY).
# One binding per ONNX input; each feature is "TYPE|name" where
# TYPE is one of ALLOWED_FEATURE_TYPES. Validation is structural —
# the SDK does not reach out to OFS / offline APIs.
input_schema = InputSchema(inputs=[
InputBinding(name="tokens", features=["ONLINE_FEATURE|user:tokens"]),
InputBinding(name="dense", features=["OFFLINE_FEATURE|user_dense"]),
])
# Equivalent accepted forms:
# input_schema = {"inputs": [...]} # dict
# input_schema = Path("input_schema.json") # file path on disk
# input_schema = "input_schema.json" # str path on disk
# ── 3. Validation data: one column per ONNX input, ndarray cells.
val_in_df: pd.DataFrame = ... # cols: "tokens", "dense"
# Optional: your own reference outputs. SDK still generates canonical
# val_outputs internally; yours is used ONLY for the drift report
# (val_outputs_user.pt is uploaded for debugging, never as the atlas artifact).
val_out_df: pd.DataFrame = ... # cols: "logits"
# Optional: calibration sample for INT8. >= 1000 rows recommended.
calib_in_df: pd.DataFrame = ... # cols: "tokens", "dense"
bundle = AtlasArtifactBundle(
model_name="my_model",
model=my_model, # nn.Module — SDK forces .eval().cpu().float()
metadata=metadata,
input_schema=input_schema, # mandatory; see step 2
val_inputs_df=val_in_df,
val_outputs_df=val_out_df, # optional; drives drift report
export_batch_size=64, # rows to slice for the ONNX trace sample (default 64)
calibration_inputs_df=calib_in_df, # optional; needed only for INT8 path in atlas
architecture_doc="auto", # "auto" (LLM-generated), a str/Path, or None to skip
skip_sdk_validations=False, # see Bypass section below
drift_tolerance_mse=1e-3, # raises drift error if user vs canonical mse exceeds this
drift_tolerance_max_abs=5e-2, # raises drift error if max|user-canonical| exceeds this
bundle_schema_version="1",
)
mlflow.set_tracking_uri("...")
mlflow.set_experiment("ranking_models")
with mlflow.start_run() as run:
run_id = push_to_mlflow(bundle, run)
Notes:
input_schemais mandatory. It mirrors the predator-side feature schema ({inputs: [{name, features: ["TYPE|name"]}]}) and is written pass-through asatlas/input_schema.jsonin the MLflow run. Accepted forms on the bundle:InputSchemainstance, adictof the same shape, or aPath/strpointing to a JSON file on disk (auto-loaded). Valid feature-type prefixes are exported asALLOWED_FEATURE_TYPES:ONLINE_FEATURE,PARENT_ONLINE_FEATURE,OFFLINE_FEATURE,PARENT_OFFLINE_FEATURE,RTP_FEATURE,PARENT_RTP_FEATURE,DEFAULT_FEATURE,PARENT_DEFAULT_FEATURE,MODEL_FEATURE,CALIBRATION,PCTR_CALIBRATION,PCVR_CALIBRATION. Validation is structural only — the SDK does not call OFS or offline APIs.dynamic_axesis the same dict you'd pass totorch.onnx.export— copy it in verbatim. Every input AND every output must have axis 0 listed (batch is always dynamic for atlas). Non-batch dynamic axes are named here too.- Canonical value lists for non-batch dynamic axes (e.g.
seq_len = [64, 128, 256]) AND the deployment-time batch list are supplied as part of the atlas trigger-run request, not on the bundle — the bundle only declares which axes are dynamic and what they're called. export_batch_sizeis only the SDK-side trace sample size — how many rows ofval_inputs_dfget sliced when runningtorch.onnx.export. It has nothing to do with deployment batch sizes. Default 64 works for most models; bump it if the trace needs more variance, drop it for tiny models / sparse data. Ifval_inputs_dfhas fewer rows, all rows are used.val_outputs_dfis optional; the SDK always generates the canonicalval_outputs.ptitself by running the FP32-CPU model undertorch.no_grad()onval_inputs_df. If you supply yours, it's saved asval_outputs_user.ptfor debugging and compared against the canonical (mse / max_abs) — that's the drift report.calibration_inputs_dfis optional; omit it unless atlas will run the INT8 quantization path for this model.architecture_doc="auto"is the default. To get LLM-generated docs, install withpip install ml_atlas_sdk[architecture]and setANTHROPIC_API_KEYin the environment. If the extra is missing, the API call fails, or the env var is unset, the SDK silently falls back to a structural+ONNX-summary markdown (PyTorchrepr+ op histogram). Pass astr/Pathto supply your own markdown, orNoneto skip the file entirely.skip_sdk_validationsacceptsTrue(skip all),False(skip none), or a list/set of the check names documented in the Validations section below.
Validations
Every check below runs by default during push_to_mlflow. Pass the check's
name in skip_sdk_validations to skip it. Atlas re-validates everything on its
side regardless — skipping here only changes whether you get fast local
feedback or slower atlas-side feedback after the upload.
| Check | What it does | Skip when | If you skip |
|---|---|---|---|
df_shape_dtype |
Confirms every cell in val_inputs_df / val_outputs_df / calibration_inputs_df has shape [d for d in metadata.input_shapes[name] if d != -1] and the right dtype (mirrors atlas validate_val_data). |
You've already verified DF shapes match metadata and are debugging a different failure. | Wrong-shape cells slip through; atlas rejects the bundle after upload. |
nan_inf |
Hard-fails on NaN/Inf in val_inputs_df and in the canonical val_outputs the SDK generates; warns (doesn't fail) on NaN/Inf in user-supplied val_outputs_df. |
Input data is known to contain NaN/Inf sentinels that your model handles correctly. | Bad-numerics inputs/outputs ship and TRT engine builds may NaN-out in prod. |
calib_sample_count |
Warns if calibration_inputs_df has < 1000 rows (recommended floor for INT8 quantizer stability). |
You've intentionally supplied a small calibration sample (e.g. for a tiny model or test path). | INT8 quantization may pick poor scale/zero-point and drop accuracy. |
forward_scan |
Greps inspect.getsource(model.forward) for tracing footguns: .item(), .tolist(), int(tensor.x(...)), range(tensor.size(...)), if x.training, .cpu().numpy(), if tensor.any()/.all(). Warns per match. |
Forward method genuinely uses one of these (e.g. eval-only .item() outside the traced path). |
You lose the early heads-up about Python-side control flow that bakes constants into ONNX. |
roundtrip |
Loads the freshly-exported ONNX in ORT (CPU), runs val_inputs through it, and compares to the canonical PyTorch outputs (rtol=1e-2, atol=5e-2). On failure, the 5-step escalation loop walks opset bumps + const-fold off + dynamo before giving up. |
The roundtrip flakes on a known-tolerated op and you've already verified accuracy off-band. | Export bugs (FP16 contamination, opset-decomposition drift, broken constant-folding) ship silently. |
sample_invariance |
Exports the model twice on disjoint val_input slices, hashes both ONNX byte streams after stripping initializer values. Different hashes → tracing was sample-dependent; warns naming the divergent op type. | Known harmless per-sample variation (rare). | A Python-control-flow bug in forward produces sample-dependent ONNX with no warning. |
onnx_cross_check |
Post-export: onnx.checker.check_model passes, ONNX input/output names equal metadata.input_names/output_names, every I/O has dim_param at axis 0 (dynamic batch), opset version is exactly metadata.onnx_opset (or higher if escalation bumped it). Mirrors atlas validate_onnx. |
Investigating exporter internals where the cross-check is the symptom, not the cause. | Schema mismatches between your declared metadata and the actual ONNX slip through. |
drift |
Computes mse, max_abs_err, mean_abs_err, max_pct_err, mean_pct_err, and (for last-axis size ≥ 8) cosine_sim between user-supplied val_outputs_df and the SDK's canonical outputs. Fails if mse > drift_tolerance_mse or max_abs > drift_tolerance_max_abs. |
Tolerances are intentionally tight for diagnostics and you accept the noise. | A discrepancy between the outputs you think your model produces and what it actually produces ships unnoticed. |
Bypass for emergency pushes
bundle = AtlasArtifactBundle(..., skip_sdk_validations=True) # all
bundle = AtlasArtifactBundle(..., skip_sdk_validations={"roundtrip"}) # granular
Atlas re-validates on its side regardless; bypass only saves DS local
feedback time. See PLAN.md §Bypass validations.
Layout
ml_atlas_sdk/
models.py AtlasArtifactBundle, ModelMetadata, InputSchema, InputBinding
errors.py AtlasBundleError
materialize.py orchestrator (14-step deterministic order)
validators/
df_check.py DataFrame schema check (mirrors atlas validate_val_data)
onnx_check.py post-export ONNX cross-check (mirrors atlas validate_onnx)
roundtrip.py ORT round-trip vs canonical PyTorch outputs
exporters/
df_to_tensor.py DataFrame -> dict[name, Tensor]
onnx_export.py export-env pin + variance-aware sample + dynamic_axes
escalation.py 5-attempt round-trip escalation loop
forward_scan.py inspect.getsource footgun warnings
sample_invariance.py two-sample export-hash divergence probe
drift.py DS-supplied vs canonical metrics (mse, max_abs, ...)
architecture/
autogen.py LLM-driven architecture.md
inspector.py PyTorch + ONNX graph signal extraction
cache.py ~/.cache/ml-atlas-sdk/architecture/
upload/
mlflow_upload.py push_to_mlflow public entrypoint
tests/
smoke_e2e.py happy-path e2e against file:// MLflow tracker
smoke_bypass_drift.py bypass + drift report paths
fixtures/tiny_model.py
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ml_atlas_sdk-0.2.0.tar.gz.
File metadata
- Download URL: ml_atlas_sdk-0.2.0.tar.gz
- Upload date:
- Size: 51.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bb8b62ab26dedb45b7e5cfc91ede54835bf72481cec2efbef40f2ecd636fd392
|
|
| MD5 |
598d2ce03177ad25215dc713c36abf6d
|
|
| BLAKE2b-256 |
0482104c21e9439be8250fc51b38e2fc09b35b4aa4ad2529737608961ceec27a
|
File details
Details for the file ml_atlas_sdk-0.2.0-py3-none-any.whl.
File metadata
- Download URL: ml_atlas_sdk-0.2.0-py3-none-any.whl
- Upload date:
- Size: 41.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
913cf9e6985c5ec11b4ae7653c93dcd847acf1e363a12e5bcaadb2145c44fbc4
|
|
| MD5 |
79863dd998b7bacef59939baed395ca3
|
|
| BLAKE2b-256 |
7e33490c5351785916617d17ef3ffe706ac45f2bf46f898274c13ed7493856dd
|