Skip to main content

TensorFlow-based 5D→1D regression trainer/tester with plotting.

Project description

fivedreg_tf (TensorFlow)

TensorFlow/Keras implementation of the 5D → 1D regressor with a simple training/testing API.

Modules

  • tf_model.py: build_tf_model(hidden_sizes, Lambda) builds a Sequential model with L2 regularisation and He/Xavier init.
  • trainer_tf.py: TrainerTF loads/validates data, trains a TF model, and saves model_tf.keras plus normalisation values. Optional early stopping, LR decay, batch size, and grad clipping.
  • tester_tf.py: TesterTF loads the saved TF model and normalisation stats to make predictions on NumPy arrays or .pkl files.
  • logger.py, utils.py: lightweight logging and decorators.
  • Synthetic data examples use the separate interpy_synth package (installed automatically).

Installation (package)

From this backend/fivedreg_tf directory:

pip install -r requirements.lock  # pinned CPU-only deps
pip install .

Headless environments: plotting is configured with the Agg backend, so no display is required. GPU is not required or supported; the package depends on tensorflow-cpu. For reproducibility, install via the pinned requirements.lock in backend/.

Docker (whole app):

cd ../..
./scripts/docker_build.sh
./scripts/docker_up.sh   # backend on :8000 (includes TF if built with fivedreg_tf)

Usage

from fivedreg_tf.trainer_tf import TrainerTF
from fivedreg_tf.tester_tf import TesterTF
from interpy_synth import synthetic_5d_pickle
import os

out_dir = "outputs_tf"
os.makedirs(out_dir, exist_ok=True)
data_path = synthetic_5d_pickle(os.path.join(out_dir, "train.pkl"), n=1000, seed=42)

trainer = TrainerTF(
    directory=out_dir,
    hidden_sizes=[64, 32, 16],
    epochs=100,
    learning_rate=0.01,
    early_stop_patience=10,
    lr_decay=0.95,
    seed=42,
)
train_rmse, val_rmse = trainer.train(data_path)

tester = TesterTF(directory=out_dir)
y_pred = tester.predict([0.1, 0.2, 0.3, 0.4, 0.5])

Note: Ensure TensorFlow is installed in your environment to use this package. Training also saves plots (rmse_vs_epochs.png, ytrue_vs_ypred.png) to the directory. Metadata (tf_model_metadata.json) includes hidden sizes, Lambda, epochs run, best epoch, best train/val RMSE, baseline RMSE, and final train/val R².

Performance/ops tips:

  • CPU-only build; choose modest hidden sizes/batch sizes for constrained CPUs.
  • Batch size and grad clipping can help stabilise small datasets (see tests for small-batch config).
  • Use requirements.lock for reproducibility; mount outputs_tf via Docker volumes in production.

FastAPI usage

  • /train supports model_type=tf to train and save TF artifacts into backend/outputs_tf/ (including TF plots) when running the API.
  • /predict accepts model_type=tf to run predictions using the TF model.
  • /artifacts/{filename} serves TF artifacts (model_tf.keras, normalisation_values_tf.npz, tf_model_metadata.json) as well as NumPy ones.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fivedreg_tf-0.1.1.tar.gz (2.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fivedreg_tf-0.1.1-py3-none-any.whl (2.7 kB view details)

Uploaded Python 3

File details

Details for the file fivedreg_tf-0.1.1.tar.gz.

File metadata

  • Download URL: fivedreg_tf-0.1.1.tar.gz
  • Upload date:
  • Size: 2.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.16

File hashes

Hashes for fivedreg_tf-0.1.1.tar.gz
Algorithm Hash digest
SHA256 71d49567b5a20c701808cc70c6fae7eadaaf0213d832b40792a9d55d2f4fdd86
MD5 3ae33faf4dae2f22f89ce3cc1c0cb0c8
BLAKE2b-256 6cfd6b7f6fc3355fca59665fb7f404d6d87a788b958e0e62d50cf48ce07b6eb0

See more details on using hashes here.

File details

Details for the file fivedreg_tf-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: fivedreg_tf-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 2.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.16

File hashes

Hashes for fivedreg_tf-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 be1676cfc8342eceec176ae66ac2289c729666163a29f80b638e9466d39217c1
MD5 b45f6852d8e3ec1dedf2f394d47d5694
BLAKE2b-256 45df309925ae437c376b23b3d9fd85cc090cd692afc48baa1a06be56be0742c4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page