Multi-output evaluator TFX component for deep learning models that utilize TensorFlow and TensorFlow Extended.
Project description
MultiOutputEvaluator (TFX Component)
A lightweight and flexible TFX component for evaluating multi-output TensorFlow models. It computes per-output and global metrics and writes a TFMA-compatible JSON artifact for downstream analysis and reporting.
📦 Installation
Install from PyPI:
pip install tfx-moe
⚙️ Features
- ✅ Evaluate multi-output models (e.g., shape
(batch, num_outputs)) - ✅ Computes per-output and global metrics
- ✅ Currently supports MSE and MAE (easy to extend)
- ✅ TFMA-compatible JSON output (
ModelEvaluationartifact) - ✅ Works with TFX Transform artifacts (
TransformGraph) - ✅ Pluggable dataset input function (your data, your logic)
📚 API Reference
MultiOutputEvaluator Component
MultiOutputEvaluator(
model, # Channel[Model]
examples, # Channel[Examples]
transform_graph, # Channel[TransformGraph]
output_names, # list[str]: names of model outputs in order
metrics, # list[str]: e.g., ['mse', 'mae']
input_fn_path, # str: "pkg.mod:func" or "pkg.mod.func"
input_fn_kwargs=None, # dict: extra kwargs for input function
example_split='test' # str: which split to evaluate
)
Parameters
- model:
standard_artifacts.Model– Trained SavedModel (expects predictions of shape(batch, num_outputs)). - examples:
standard_artifacts.Examples– TFRecords location for the chosen split. - transform_graph:
standard_artifacts.TransformGraph– Loaded via tft.TFTransformOutput. - output_names (list[str]): Logical names per output index; order must correspond to your model’s output dimension.
- metrics (list[str]): Currently
['mse', 'mae']supported. - input_fn_path (str): Dotted path to your dataset builder function.
- input_fn_kwargs (dict, optional): Keyword args passed to the dataset function.
- example_split (str, default 'test'): Split folder name under
Examplesartifact (e.g.,'Split-test' if that’s how your pipeline writes splits).
🚀 Quickstart (TFX Usage)
[1] Add Component to your Pipeline and Run
from tfx_moe import MultiOutputEvaluator
# InteractiveContext Pipeline Runner
multi_output_evaluator = MultiOutputEvaluator(
model=trainer.outputs['model'],
examples=transform.outputs["transformed_examples"],
transform_graph=transform.outputs['transform_graph'],
output_names=OUTPUT_KEYS, # e.g., ["y_loc1", "y_loc2", ...]
metrics=['mse', 'mae'],
input_fn_path="my_pkg.mdata:input_fn", # Your dataset builder function
example_split='Split-test' # Match your artifact split naming
)
# Run (InteractiveContext example)
context.run(multi_output_evaluator)
# Access TFMA-like JSON results
evaluation_uri = multi_output_evaluator.outputs['evaluation'].get()[0].uri
print(evaluation_uri)
[2] Parse Results (per-output + global)
import json
import pandas as pd
import tensorflow as tf
with tf.io.gfile.GFile(evaluation_uri, 'r') as f:
data = json.load(f) # TFMA-style dict
output_names = []
mse_values = []
mae_values = []
for key, value in data['metrics'].items():
if key.endswith('>>mse'):
output_name = key.split('>>')[0] # 'per_output>>{name}>>mse' → '{name}'
output_names.append(output_name)
mse_values.append(value)
elif key.endswith('>>mae'):
mae_values.append(value)
df = pd.DataFrame({
'Output Name': output_names,
'MSE': mse_values,
'MAE': mae_values
})
last_row_index = df.index[-1]
global_mse = df["MSE"].iloc[last_row_index]
global_mae = df["MAE"].iloc[last_row_index]
df = df.drop(index=[last_row_index]) # keep per-output rows only
[3] Global Metrics
print(f'Global MSE {global_mse}')
print(f'Global MAE {global_mae}')
[4] Aggregate Across Outputs
global_mse_all_outputs = df["MSE"].mean()
global_mae_all_outputs = df["MAE"].mean()
print(f'Global MSE per All Locations {global_mse_all_outputs}')
print(f'Global MAE per All Locations {global_mae_all_outputs}')
🧠 How It Works
The Executor
- Loads
TransformGraph(tft.TFTransformOutput) and the SavedModel. - Builds an evaluation dataset via your input function (
input_fn_path). - Computes per-output and global metrics across the dataset.
- Writes a TFMA-like JSON to the
ModelEvaluationartifact.
Metrics Supported
mse,mae(easy to extend in code).
Output Format (metrics dictionary keys)
- Per-output:
per_output>>{output_name}>>{metric} - Global:
global>>{metric}
📥 Dataset Function Contract
You supply the dataset function via input_fn_path, e.g. "my_pkg.mdata:input_fn".
Expected Signature
def input_fn(file_pattern: str, tf_transform_output, **kwargs) -> tf.data.Dataset:
# returns a dataset of (features, labels)
Expected Yields:
features: a structure compatible with your model’sSavedModelsignature.labels: a Tensor shaped (batch, 1, num_outputs) so that:labels[:, 0, i]is the true label vector foroutput_names[i].
- Your model’s predictions should have shape (batch, num_outputs) so that:
predictions[:, i]aligns withoutput_names[i].
If your labels are shaped differently, update either your input function or the executor’s slicing logic to align shapes.
📤 Artifact Output
-
Writes a single JSON file under the
ModelEvaluationartifact directory:evaluation.json
-
Example keys:
metrics["per_output>>y_loc3>>mse"] = 0.0123metrics["global>>mae"] = 0.1357
📜 License
This project is licensed under the MIT License.
© 2025 Dr. Ahmed Moussa
🤝 Contributing
Pull requests are welcome.
For major changes, please open an issue first to discuss what you would like to change.
📫 Contact
For feedback, bugs, or collaboration ideas:
- GitHub: @real-ahmed-moussa
⭐️ Show Your Support
If you find this project useful, consider giving it a ⭐️ on GitHub!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tfx_moe-0.1.0.tar.gz.
File metadata
- Download URL: tfx_moe-0.1.0.tar.gz
- Upload date:
- Size: 10.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0733166d812af1bad34f610f8e9b430035efef245baf2beba96694afad4520d7
|
|
| MD5 |
7eea5b0a746bd1a18fef7b1f0b56c238
|
|
| BLAKE2b-256 |
0c5b6ad8c55e53bad66df427158214f654fc52002057daac0f1826dd2688a17d
|
File details
Details for the file tfx_moe-0.1.0-py3-none-any.whl.
File metadata
- Download URL: tfx_moe-0.1.0-py3-none-any.whl
- Upload date:
- Size: 7.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a1344b07792fa0cde5634fc2225492266af1d56f91970a0d53c127a737532526
|
|
| MD5 |
aa020809d90dcac1dc3066b2ad4a8913
|
|
| BLAKE2b-256 |
8a0a53bdaf5da41619b2d637004fab7a6e5f20489ae30654ae003710c8d48071
|