Skip to main content

Atlas artifact bundle SDK — pushes a PyTorch model + validation data to MLflow in the format atlas (TRT/Triton optimizer) consumes.

Project description

ml_atlas_sdk

Pushes a PyTorch model + validation data to MLflow in the format atlas (the TensorRT/Triton optimizer) consumes.

Why

Atlas needs 9 specific input artifacts on GCS in a strict format. Without this SDK, data scientists hand-write ONNX exports, val_inputs.pt, metadata.json, etc. — and silently get them wrong (FP16/GPU-generated reference, bad constant-folding, missing dynamic axes). Failures only surface inside atlas hours later.

ml_atlas_sdk consolidates all 9 artifacts behind one push_to_mlflow(bundle, run) call, validates the model+data locally before upload, and runs an ORT round-trip self-check that catches export bugs before they ship.

Install

pip install ml_atlas_sdk
# optional extras
pip install ml_atlas_sdk[architecture]   # LLM-driven architecture.md
pip install ml_atlas_sdk[gcs]            # for atlas reverse-side resolver

LLM architecture autogen (architecture_doc="auto", the default) requires both the [architecture] extra and ANTHROPIC_API_KEY in the environment when push_to_mlflow runs. If either is missing, the SDK silently falls back to a structural+ONNX-summary markdown (PyTorch repr + op histogram) — it never raises.

Minimal usage

import mlflow
import torch
from ml_atlas_sdk import AtlasArtifactBundle, ModelMetadata, push_to_mlflow

bundle = AtlasArtifactBundle(
    model_name="my_model",
    model=my_model,                                  # nn.Module
    metadata=ModelMetadata(
        model_name="my_model",
        input_names=["input__0", "input__1"],
        output_names=["output__0"],
        input_shapes={"input__0": [-1, 4, 8], "input__1": [-1, 16]},
        output_shape={"output__0": [-1, 3]},
        dynamic_axes={                               # paste from your torch.onnx.export call
            "input__0": {0: "batch_size"},
            "input__1": {0: "batch_size"},
            "output__0": {0: "batch_size"},
        },
        onnx_opset=17,
        torch_version=torch.__version__,
        architecture="MyModel",
        task_type="multitask_ranking",
    ),
    val_inputs_df=val_in_df,                         # one column per ONNX input
    export_batch_size=64,                            # rows to slice for the ONNX trace sample
)

mlflow.set_tracking_uri("...")
mlflow.set_experiment("ranking_models")
with mlflow.start_run() as run:
    run_id = push_to_mlflow(bundle, run)

That's it. The SDK:

  • forces model.eval().cpu().float() + torch.no_grad() for export
  • generates canonical val_outputs by running the model on val_inputs
  • exports ONNX with proper dynamic_axes
  • runs ORT round-trip self-check (with 5-step escalation if it fails)
  • uploads 8 files under atlas/ in the run + logs all params/tags atlas needs

See PLAN.md for the full design.

Full usage

Every field on AtlasArtifactBundle (and the nested ModelMetadata) shown with realistic values, including a non-batch dynamic axis (seq_len). Use this when you want INT8 calibration, drift comparison against your own reference outputs, or to override SDK defaults.

from pathlib import Path

import mlflow
import pandas as pd
import torch
from ml_atlas_sdk import AtlasArtifactBundle, ModelMetadata, push_to_mlflow

# ── 1. Metadata: ONNX-side schema.
#       In input_shapes / output_shape: use -1 for every dynamic dim and
#       positive ints for static dims.
#       In dynamic_axes: paste your torch.onnx.export `dynamic_axes` dict
#       verbatim — same format, {io_name: {axis_idx: axis_name}}. Static
#       axes are NOT listed here; axis 0 (batch) MUST be listed for every
#       input AND every output.
metadata = ModelMetadata(
    model_name="my_model",
    input_names=["tokens", "dense"],
    output_names=["logits"],
    input_shapes={
        "tokens": [-1, -1],         # axes 0 and 1 both dynamic; names below
        "dense":  [-1, 16],         # axis 1 static (16)
    },
    output_shape={
        "logits": [-1, 3],          # axis 1 static (3)
    },
    dynamic_axes={
        "tokens": {0: "batch_size", 1: "seq_len"},
        "dense":  {0: "batch_size"},
        "logits": {0: "batch_size"},
    },
    onnx_opset=17,                  # must be >= 17 (LayerNorm/SDPA precision)
    torch_version=torch.__version__,
    architecture="MyModel",
    task_type="multitask_ranking",
)

# ── 2. Validation data: one column per ONNX input, ndarray cells.
val_in_df: pd.DataFrame = ...        # cols: "tokens", "dense"

# Optional: your own reference outputs. SDK still generates canonical
# val_outputs internally; yours is used ONLY for the drift report
# (val_outputs_user.pt is uploaded for debugging, never as the atlas artifact).
val_out_df: pd.DataFrame = ...       # cols: "logits"

# Optional: calibration sample for INT8. >= 1000 rows recommended.
calib_in_df: pd.DataFrame = ...      # cols: "tokens", "dense"

bundle = AtlasArtifactBundle(
    model_name="my_model",
    model=my_model,                                  # nn.Module — SDK forces .eval().cpu().float()
    metadata=metadata,
    val_inputs_df=val_in_df,
    val_outputs_df=val_out_df,                       # optional; drives drift report
    export_batch_size=64,                            # rows to slice for the ONNX trace sample (default 64)
    calibration_inputs_df=calib_in_df,               # optional; needed only for INT8 path in atlas
    architecture_doc="auto",                         # "auto" (LLM-generated), a str/Path, or None to skip
    skip_sdk_validations=False,                      # see Bypass section below
    drift_tolerance_mse=1e-3,                        # raises drift error if user vs canonical mse exceeds this
    drift_tolerance_max_abs=5e-2,                    # raises drift error if max|user-canonical| exceeds this
    bundle_schema_version="1",
)

mlflow.set_tracking_uri("...")
mlflow.set_experiment("ranking_models")
with mlflow.start_run() as run:
    run_id = push_to_mlflow(bundle, run)

Notes:

  • dynamic_axes is the same dict you'd pass to torch.onnx.export — copy it in verbatim. Every input AND every output must have axis 0 listed (batch is always dynamic for atlas). Non-batch dynamic axes are named here too.
  • Canonical value lists for non-batch dynamic axes (e.g. seq_len = [64, 128, 256]) AND the deployment-time batch list are supplied as part of the atlas trigger-run request, not on the bundle — the bundle only declares which axes are dynamic and what they're called.
  • export_batch_size is only the SDK-side trace sample size — how many rows of val_inputs_df get sliced when running torch.onnx.export. It has nothing to do with deployment batch sizes. Default 64 works for most models; bump it if the trace needs more variance, drop it for tiny models / sparse data. If val_inputs_df has fewer rows, all rows are used.
  • val_outputs_df is optional; the SDK always generates the canonical val_outputs.pt itself by running the FP32-CPU model under torch.no_grad() on val_inputs_df. If you supply yours, it's saved as val_outputs_user.pt for debugging and compared against the canonical (mse / max_abs) — that's the drift report.
  • calibration_inputs_df is optional; omit it unless atlas will run the INT8 quantization path for this model.
  • architecture_doc="auto" is the default. To get LLM-generated docs, install with pip install ml_atlas_sdk[architecture] and set ANTHROPIC_API_KEY in the environment. If the extra is missing, the API call fails, or the env var is unset, the SDK silently falls back to a structural+ONNX-summary markdown (PyTorch repr + op histogram). Pass a str/Path to supply your own markdown, or None to skip the file entirely.
  • skip_sdk_validations accepts True (skip all), False (skip none), or a list/set of the check names documented in the Validations section below.

Validations

Every check below runs by default during push_to_mlflow. Pass the check's name in skip_sdk_validations to skip it. Atlas re-validates everything on its side regardless — skipping here only changes whether you get fast local feedback or slower atlas-side feedback after the upload.

Check What it does Skip when If you skip
df_shape_dtype Confirms every cell in val_inputs_df / val_outputs_df / calibration_inputs_df has shape [d for d in metadata.input_shapes[name] if d != -1] and the right dtype (mirrors atlas validate_val_data). You've already verified DF shapes match metadata and are debugging a different failure. Wrong-shape cells slip through; atlas rejects the bundle after upload.
nan_inf Hard-fails on NaN/Inf in val_inputs_df and in the canonical val_outputs the SDK generates; warns (doesn't fail) on NaN/Inf in user-supplied val_outputs_df. Input data is known to contain NaN/Inf sentinels that your model handles correctly. Bad-numerics inputs/outputs ship and TRT engine builds may NaN-out in prod.
calib_sample_count Warns if calibration_inputs_df has < 1000 rows (recommended floor for INT8 quantizer stability). You've intentionally supplied a small calibration sample (e.g. for a tiny model or test path). INT8 quantization may pick poor scale/zero-point and drop accuracy.
forward_scan Greps inspect.getsource(model.forward) for tracing footguns: .item(), .tolist(), int(tensor.x(...)), range(tensor.size(...)), if x.training, .cpu().numpy(), if tensor.any()/.all(). Warns per match. Forward method genuinely uses one of these (e.g. eval-only .item() outside the traced path). You lose the early heads-up about Python-side control flow that bakes constants into ONNX.
roundtrip Loads the freshly-exported ONNX in ORT (CPU), runs val_inputs through it, and compares to the canonical PyTorch outputs (rtol=1e-2, atol=5e-2). On failure, the 5-step escalation loop walks opset bumps + const-fold off + dynamo before giving up. The roundtrip flakes on a known-tolerated op and you've already verified accuracy off-band. Export bugs (FP16 contamination, opset-decomposition drift, broken constant-folding) ship silently.
sample_invariance Exports the model twice on disjoint val_input slices, hashes both ONNX byte streams after stripping initializer values. Different hashes → tracing was sample-dependent; warns naming the divergent op type. Known harmless per-sample variation (rare). A Python-control-flow bug in forward produces sample-dependent ONNX with no warning.
onnx_cross_check Post-export: onnx.checker.check_model passes, ONNX input/output names equal metadata.input_names/output_names, every I/O has dim_param at axis 0 (dynamic batch), opset version is exactly metadata.onnx_opset (or higher if escalation bumped it). Mirrors atlas validate_onnx. Investigating exporter internals where the cross-check is the symptom, not the cause. Schema mismatches between your declared metadata and the actual ONNX slip through.
drift Computes mse, max_abs_err, mean_abs_err, max_pct_err, mean_pct_err, and (for last-axis size ≥ 8) cosine_sim between user-supplied val_outputs_df and the SDK's canonical outputs. Fails if mse > drift_tolerance_mse or max_abs > drift_tolerance_max_abs. Tolerances are intentionally tight for diagnostics and you accept the noise. A discrepancy between the outputs you think your model produces and what it actually produces ships unnoticed.

Bypass for emergency pushes

bundle = AtlasArtifactBundle(..., skip_sdk_validations=True)              # all
bundle = AtlasArtifactBundle(..., skip_sdk_validations={"roundtrip"})     # granular

Atlas re-validates on its side regardless; bypass only saves DS local feedback time. See PLAN.md §Bypass validations.

Layout

ml_atlas_sdk/
  models.py            AtlasArtifactBundle, ModelMetadata
  errors.py            AtlasBundleError
  materialize.py       orchestrator (14-step deterministic order)
  validators/
    df_check.py        DataFrame schema check (mirrors atlas validate_val_data)
    onnx_check.py      post-export ONNX cross-check (mirrors atlas validate_onnx)
    roundtrip.py       ORT round-trip vs canonical PyTorch outputs
  exporters/
    df_to_tensor.py    DataFrame -> dict[name, Tensor]
    onnx_export.py     export-env pin + variance-aware sample + dynamic_axes
    escalation.py      5-attempt round-trip escalation loop
    forward_scan.py    inspect.getsource footgun warnings
    sample_invariance.py  two-sample export-hash divergence probe
    drift.py           DS-supplied vs canonical metrics (mse, max_abs, ...)
  architecture/
    autogen.py         LLM-driven architecture.md
    inspector.py       PyTorch + ONNX graph signal extraction
    cache.py           ~/.cache/ml-atlas-sdk/architecture/
  upload/
    mlflow_upload.py   push_to_mlflow public entrypoint
tests/
  smoke_e2e.py         happy-path e2e against file:// MLflow tracker
  smoke_bypass_drift.py  bypass + drift report paths
  fixtures/tiny_model.py

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ml_atlas_sdk-0.1.0.tar.gz (47.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ml_atlas_sdk-0.1.0-py3-none-any.whl (39.5 kB view details)

Uploaded Python 3

File details

Details for the file ml_atlas_sdk-0.1.0.tar.gz.

File metadata

  • Download URL: ml_atlas_sdk-0.1.0.tar.gz
  • Upload date:
  • Size: 47.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for ml_atlas_sdk-0.1.0.tar.gz
Algorithm Hash digest
SHA256 aa0fcd83042ca8b7b51d7a284643d46281c58b71aa4d1011449da956c39f04e0
MD5 c1b86a6b002e11f4c847c49240a59441
BLAKE2b-256 8f23fa5b594c663709ea7cd143a1cf65074eb3f026ac7def8a19663833294d24

See more details on using hashes here.

File details

Details for the file ml_atlas_sdk-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: ml_atlas_sdk-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 39.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for ml_atlas_sdk-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2be16d0fe2cc0545261e7e593362b928daf6ab0c1368f0cecbbebc9a689f0314
MD5 5e62e8b30d71e42a1a00ee255a1fdc33
BLAKE2b-256 ed00c32344a495236d6dbe591e2eeaa93eb062cfe281bdf3e1f17efcadbf9da4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page