Skip to main content

TensorFlow-based 5D→1D regression trainer/tester with plotting.

Project description

fivedreg_tf (TensorFlow)

TensorFlow/Keras implementation of the 5D → 1D regressor with a simple training/testing API.

Modules

  • tf_model.py: build_tf_model(hidden_sizes, Lambda) builds a Sequential model with L2 regularisation and He/Xavier init.
  • trainer_tf.py: TrainerTF loads/validates data, trains a TF model, and saves model_tf.keras plus normalisation values. Optional early stopping, LR decay, batch size, and grad clipping.
  • tester_tf.py: TesterTF loads the saved TF model and normalisation stats to make predictions on NumPy arrays or .pkl files.
  • logger.py, utils.py: lightweight logging and decorators.
  • Synthetic data examples use the separate interpy_synth package (installed automatically).

Installation (package)

PyPI:

pip install fivedreg_tf        # pulls interpy-synth + tensorflow-cpu/macos

From source (editable):

pip install -e .

Headless environments: plotting is configured with the Agg backend, so no display is required. GPU is not required or supported; on macOS use tensorflow-macos (installed automatically via platform marker), and on Linux/Windows use tensorflow-cpu.

Usage

from fivedreg_tf.trainer_tf import TrainerTF
from fivedreg_tf.tester_tf import TesterTF
from interpy_synth import synthetic_5d_pickle
import os

out_dir = "outputs_tf"
os.makedirs(out_dir, exist_ok=True)
data_path = synthetic_5d_pickle(os.path.join(out_dir, "train.pkl"), n=1000, seed=42)

trainer = TrainerTF(
    directory=out_dir,
    hidden_sizes=[64, 32, 16],
    epochs=100,
    learning_rate=0.01,
    activation="relu",
    weight_init="auto",
    beta1=0.9,
    beta2=0.999,
    epsilon=1e-8,
    early_stop_patience=10,
    lr_decay=0.95,
    seed=42,
)
train_rmse, val_rmse = trainer.train(data_path)

tester = TesterTF(directory=out_dir)
y_pred = tester.predict([0.1, 0.2, 0.3, 0.4, 0.5])

Note: Ensure TensorFlow is installed in your environment to use this package. Training also saves plots (rmse_vs_epochs.png, ytrue_vs_ypred.png) to the directory. Metadata (tf_model_metadata.json) includes hidden sizes, Lambda, activation/init, learning rate, Adam betas/epsilon, batch/clip, epochs run, best epoch, best train/val RMSE, baseline RMSE, and final train/val R².

Performance/ops tips:

  • CPU-only build; choose modest hidden sizes/batch sizes for constrained CPUs.
  • Batch size and grad clipping can help stabilise small datasets (see tests for small-batch config).
  • Optimiser: defaults to tf.keras.optimizers.legacy.Adam when available (avoids the slower Apple Silicon path) and falls back to tf.keras.optimizers.Adam otherwise.

Hyperparameter guide (UI/API)

  • hidden_sizes: Layer widths per hidden layer. More/larger layers increase capacity and training time and can overfit small datasets.
  • Lambda: L2 regularization strength; higher shrinks weights harder to reduce overfitting but can underfit.
  • activation: ReLU default; LeakyReLU avoids dead units; tanh/sigmoid bound outputs but can slow training.
  • weight_init: Auto picks He for ReLU/LeakyReLU and Xavier for tanh/sigmoid; override to experiment.
  • epochs: Full passes over the data. More epochs can fit better but take longer and may overfit.
  • learning_rate: Step size for gradient updates. Higher learns faster but risks divergence; lower is steadier.
  • train_val_split: Fraction for training vs validation/early stopping. Smaller training splits can reduce fit quality.
  • batch_size: Samples per gradient step. Larger batches smooth updates but use more memory; blank/full-batch is allowed.
  • grad_clip: Upper bound on gradient norm to prevent exploding gradients. Lower means more aggressive clipping.
  • lr_decay: Multiplier (<1) applied per epoch to the learning rate. Leave unset to keep LR constant.
  • early_stop_patience: Stop after this many epochs without validation improvement; lower stops sooner to avoid overfitting.
  • beta1 / beta2: Adam momentum terms for first/second moments. Higher values smooth updates but react slower.
  • epsilon: Small constant for numerical stability in Adam; keep default unless debugging NaNs.
  • seed: Set for deterministic initialisation/shuffling; leave unset for nondeterministic runs.

Documentation

Full package API documentation is hosted on ReadTheDocs. Use RTD for full reference and examples.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fivedreg_tf-0.1.6.tar.gz (3.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fivedreg_tf-0.1.6-py3-none-any.whl (3.5 kB view details)

Uploaded Python 3

File details

Details for the file fivedreg_tf-0.1.6.tar.gz.

File metadata

  • Download URL: fivedreg_tf-0.1.6.tar.gz
  • Upload date:
  • Size: 3.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for fivedreg_tf-0.1.6.tar.gz
Algorithm Hash digest
SHA256 de87adc38218b9261e7bb556c07fdadb1aedfb6bae0db2d17febab99a6473d14
MD5 f84d3be61674060872a6bacb9425da18
BLAKE2b-256 60c59de6f1bd78b7fdd943b81469fc5aab09ae57be7ddcbc8b816e212590cd65

See more details on using hashes here.

File details

Details for the file fivedreg_tf-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: fivedreg_tf-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 3.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for fivedreg_tf-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 7ab7a51b28f7d5f330c6460add36c96a0f148a424257d1ac44dc0d35a562d339
MD5 b6194cb672cf74e8d42651abad4b378c
BLAKE2b-256 c2362df9b91e6f72f01c260b840b61b910dee9aa450fc90d75d6e65a787e0a3b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page