Construct neural network models and training workflows by structcast package.
Project description
StructCast-Model
StructCast-Model is a configuration-driven toolkit that generates PyTorch models and training workflows from YAML templates. Built on top of StructCast, it lets you describe model architecture, optimizer logic, dataset configuration, and training orchestration declaratively — then generates runnable Python code from those descriptions.
The current implementation focuses on PyTorch. JAX and TensorFlow support is planned (see Roadmap), but all active CLI commands, code generators, and tests target the PyTorch stack.
Table of Contents
- StructCast-Model
- Table of Contents
- What This Project Does
- Installation
- Project Structure
- Core Workflow
- StructCast Pattern Basics
- Command Guide
- Configuration Examples
- API Reference:
base_trainer.py - API Reference:
trainer.py - Minimal End-to-End Example
- Development
- Roadmap
What This Project Does
- Generate model code — Produce PyTorch
nn.Moduleclasses from YAML layer templates. - Generate training code — Produce backward-pass, optimizer, and scheduler orchestration classes from YAML templates.
- Format reusable templates — Render parameterized YAML templates into concrete runtime configurations.
- Inspect model complexity — Compute FLOPs and parameter counts with
ptflopsandcalflops. - Train end-to-end — Run training with Automatic Mixed Precision (AMP), timm datasets, Exponential Moving Average (EMA), optional
torch.compile, and MLflow experiment logging.
Installation
StructCast-Model is installed with uv and exposes the scm CLI entry point.
uv sync --extra torch-cu130 --extra mlflow --extra flops
Each extra installs a group of optional dependencies:
| Extra | What it provides |
|---|---|
torch-cu130 |
PyTorch and torchvision with CUDA 13.0 support |
mlflow |
Experiment tracking for scm torch train |
flops |
Both ptflops and calflops for model complexity inspection |
Omit any extra you do not need. For example, uv sync --extra torch-cu130 is sufficient if you only want to generate and train models without FLOPs analysis or MLflow logging.
Project Structure
structcast-model/
├── cfg/
│ ├── backwards/ # backward, optimizer, scheduler templates
│ ├── datasets/ # reusable dataset/dataloader templates
│ ├── losses/ # loss module templates
│ ├── metrics/ # metric module templates
│ ├── models/ # model architecture templates
│ └── others/ # compile and EMA presets
├── src/structcast_model/
│ ├── builders/ # generic and torch-specific code generators
│ ├── commands/ # Typer CLI entry points
│ ├── torch/ # trainer, layers, optimizer helpers
│ ├── utils/ # shared helpers
│ └── base_trainer.py
├── tests/ # CLI, builder, trainer, and layer tests
└── README.md
The main package areas are:
| Directory | Purpose |
|---|---|
builders/ |
Converts validated YAML templates into intermediate representations, then renders Python source code. |
commands/ |
Exposes the scm CLI (built with Typer). |
torch/ |
Runtime utilities used by the CLI and available for direct Python usage — training steps, trackers, timm wrappers, optimizer helpers. |
cfg/ |
Declarative source of truth: YAML templates for models, backward logic, datasets, and runtime presets. |
Core Workflow
The repository follows a repeatable five-step workflow:
- Write or reuse YAML templates under
cfg/. - Render templates with
scm formatand-p/--parameteroverrides to produce concrete configuration files. - Generate Python source files for the model, loss, metric, and backward logic using
scm torch create. - Instantiate those generated modules at runtime through StructCast object patterns (see StructCast Pattern Basics).
- Train through
scm torch train, which wires together datasets, models, losses, metrics, optimizer logic, AMP, EMA, and MLflow.
YAML templates ---> scm format / scm torch create ---> Generated .py files
|
StructCast patterns <-------------------------------------------+
|
v
scm torch train ---> MLflow logs + model checkpoints
StructCast Pattern Basics
This repository relies heavily on StructCast object patterns to bridge generated source files and runtime commands. The minimum syntax you need to read the CLI examples is:
| Alias | Meaning | Example |
|---|---|---|
_obj_ |
Chain multiple construction steps | [_obj_, ..., ...] |
_addr_ |
Import a class or function by dotted path | {_addr_: torch.nn.ReLU} |
_file_ |
Load the symbol from a local Python file | {_addr_: model.Model, _file_: model.py} |
_call_ |
Invoke the current callable | _call_ or {_call_: {out_features: 10}} |
_bind_ |
Partially apply arguments | {_bind_: {lr: 0.001}} |
_attr_ |
Access an attribute or method | {_attr_: model_validate} |
Example:
[_obj_, {_addr_: model.Model, _file_: model.py}, _call_]
This pattern does the following:
- Import
Modelfrom the local filemodel.py. - Call
Model()with no arguments and return the instance.
This pattern is the bridge between generated source files and runtime commands like ptflops, calflops, and train. For full documentation on StructCast patterns, see the StructCast README.
Command Guide
1. Format Templates
Use scm format to render a parameterized YAML template (such as cfg/datasets/default_timm.yaml) into a concrete configuration file.
scm format cfg/datasets/default_timm.yaml \
-o dataset_train.yaml \
-p 'DEFAULT: {training: true, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, label_smoothing: 0.1, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
scm format cfg/datasets/default_timm.yaml \
-o dataset_valid.yaml \
-p 'DEFAULT: {training: false, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
What this does:
- Loads the YAML template.
- Merges any repeated
-p/--parametergroups into a single parameter set. - Renders Jinja-based sections within the template.
- Writes the resolved YAML to
-o/--output(or prints to stdout if-ois omitted).
2. Generate a Model Class
Generate a Python nn.Module from a YAML layer template (such as cfg/models/ConvNeXtV2.yaml).
scm torch create model cfg/models/ConvNeXtV2.yaml
scm torch create model cfg/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: femto}'
scm torch create model cfg/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: femto}' -o model.py
Useful options:
-p/--parameter: override template parameters-c/--classname: set the generated class name, defaultModel--no-structured-output: force tuple-like return behavior instead of a structured output mapping-s/--sublayer: generate a named sublayer from the template instead of the root model-o/--output: output file path; if omitted, defaults to the snake-cased class name in the current directory (e.g.,model.pyfor the default class nameModel)
The ConvNeXtV2 template uses Jinja parameter groups to switch between backbone variants such as atto, femto, tiny, and base.
3. Generate Loss, Metric, and Backward Classes
Losses and metrics use the same scm torch create model command because they are also layer graphs.
scm torch create model cfg/losses/cls.yaml -c Loss -o loss.py
scm torch create model cfg/metrics/topk.yaml -c Metric -o metric.py
scm torch create backward cfg/backwards/ConvNeXtV2.yaml -p 'DEFAULT: {epochs: 5}' -o backward.py
The scm torch create backward command turns a backward template into a class that manages:
- optimizer construction
- optional gradient scaler creation
- optional gradient clipping
- optional gradient accumulation
- optimizer stepping and zeroing
- learning-rate and parameter-group inspection helpers
4. Inspect FLOPs and Parameters
Once a model has been generated, you can instantiate it from a StructCast pattern and measure its computational complexity.
scm torch ptflops '[_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
--backend pytorch
scm torch calflops '[_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]'
What these commands do internally:
- Instantiate the model from the
_obj_pattern. - Create dummy tensors from the
-s/--shapespecification. - Run one initialization forward pass via
initial_model(...). - Pass the initialized model to
ptflopsorcalflopsfor complexity analysis.
5. Train a Generated Model
Below is the complete training command from the included ConvNeXtV2 example.
scm torch train \
'model: [_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
-d cuda \
--ema cfg/others/ema.yaml \
-l '[_obj_, {_addr_: loss.Loss, _file_: loss.py}, _call_]' \
-m '[_obj_, {_addr_: metric.Metric, _file_: metric.py}, _call_]' \
-b '[_obj_, {_addr_: backward.Backward, _file_: backward.py}]' \
-c cfg/others/compile_default.yaml \
-e 5 \
-t dataset_train.yaml \
-v dataset_valid.yaml \
-f 1 \
-LC ce_loss \
-LC val_ce_loss \
-HC acc1 \
-HC val_acc1 \
-HC acc5 \
-HC val_acc5 \
-SC val_acc1 \
--matmul-precision high \
-E Test \
-A model.py \
-A cfg/others/ema.yaml \
-A loss.py \
-A metric.py \
-A backward.py \
-A cfg/others/compile_default.yaml \
-A dataset_train.yaml \
-A dataset_valid.yaml
Key arguments:
- positional model patterns: one or more named model definitions
-s/--shape: dummy input shapes used for model initialization-d/--device:cpuorcuda--ema: boolean, YAML file, or inline dict fortimm.utils.ModelEmaV3-l/--loss: StructCast pattern for the loss module-m/--metric: StructCast pattern for the metric module-b/--backward: StructCast pattern for the backward class-c/--compile: boolean, YAML file, or inline dict fortorch.compile-t/--training-dataset: training dataset pattern or rendered dataset YAML-v/--validation-dataset: validation dataset pattern or rendered dataset YAML-LC/--lower-criterion: criteria where lower is better-HC/--higher-criterion: criteria where higher is better-SC/--save-criterion: criteria that should trigger best-model saving-E/--experiment: MLflow experiment name-A/--log-artifacts: artifacts to store in MLflow
What the train command does internally:
- Instantiates datasets and determines their lengths.
- Initializes models with optional dummy-input forward passes.
- Instantiates loss, metric, backward, compile, and EMA objects.
- Builds a
TorchTrackerfrom the declared output names. - Creates a
TorchTrainerwith training and validation step objects. - Logs metrics, arguments, model states, optimizer states, gradient scaler states, and best checkpoints to MLflow.
Configuration Examples
The cfg/ directory contains working YAML templates that demonstrate each part of the workflow.
cfg/models/ConvNeXtV2.yaml
Demonstrates the model-building style used throughout the project:
- parameter groups for multiple backbone sizes
- nested user-defined layers such as
Backbone,Stem,DownSample, andBlock - Jinja-driven layer expansion
- separate training and inference flow support
- structured outputs such as
{cls: torch.tensor(...), ...}
cfg/backwards/ConvNeXtV2.yaml
Demonstrates how backward logic is configured declaratively:
MIXED_PRECISIONfortorch.amp.GradScalerMIXED_PRECISION_TYPEfor autocast dtypeACCUMULATE_GRADIENTSfor delayed optimizer updates- optimizer creation through
structcast_model.torch.optimizers.create_with_scheduler - optional gradient clipping via
timm.utils.clip_grad.dispatch_clip_grad
cfg/datasets/default_timm.yaml
Formats directly into a TimmDataLoaderWrapper.model_validate(...) pattern. Covers:
- timm dataset construction
- timm dataloader construction
- device and prefetch settings
- mixup and cutmix options
- train or validation split generation from one template
API Reference: base_trainer.py
src/structcast_model/base_trainer.py provides the framework-agnostic training loop, state management, and callback system. Concrete trainers such as TorchTrainer build on top of these abstractions.
Utility functions
get_dataset(dataset)
Resolves a DatasetLike or a zero-argument callable into an actual iterable. This allows lazy dataset construction.
get_dataset_size(dataset)
Returns the number of batches. Uses __len__ when available, otherwise iterates to count.
invoke_callback(callbacks, info, *args, **models)
Iterates over a callback list and calls each entry with info and keyword model arguments.
Protocols
Forward
Called once per batch during training or validation. Accepts an inputs dictionary and keyword model arguments; returns a dict[str, Any] of named outputs and criteria.
Backward
Called once per training step. Receives the step index and criterion keyword arguments; returns True when the optimizer has stepped, False when gradients are being accumulated.
Callback and BestCallback
Lifecycle hooks called with (info: BaseInfo, **models). BestCallback additionally receives target: str and best: float arguments.
InferenceWrapper
Applied to models before each validation epoch. Returns a remapped model dictionary, e.g., swapping a trained model for its EMA copy.
State and callbacks
BaseInfo
Dataclass holding mutable training state:
step— total training steps takenupdate— optimizer update countepoch— current epoch numberhistory— per-epoch log dictionarieslogs(epoch=None)— returns the log dict for the current (or given) epoch
Callbacks
Dataclass holding callback lists for each lifecycle hook:
on_update— after each optimizer updateon_training_begin/on_training_endon_training_step_begin/on_training_step_endon_validation_begin/on_validation_endon_validation_step_begin/on_validation_step_endon_epoch_begin/on_epoch_end
When add_global_callbacks=True (the default), entries from GLOBAL_CALLBACKS are copied into each list at construction time.
GLOBAL_CALLBACKS
A shared Callbacks[Any] instance. Callbacks registered here are automatically picked up by every newly created trainer.
Core classes
BaseTrainer
The main training loop driver. Inherits both BaseInfo and Callbacks.
Required fields: training_step (Forward), backward (Backward), tracker (callable returning dict[str, float]).
Optional fields: validation_step, inference_wrapper, training_prefix (default ""), validation_prefix (default "val_").
Key methods:
train(dataset, **models)— runs one training epoch, returns the final step logsevaluate(dataset, **models)— runs one validation epoch, returns the final step logsfit(epochs, training_dataset, validation_dataset=None, start_epoch=1, validation_frequency=1, **models)— runs the full loop and returns the complete history dictsync()— optional synchronization hook, no-op by default (overridden inTorchTrainer)
trainer = MyTrainer(
training_step=my_forward,
backward=my_backward,
tracker=my_tracker,
validation_step=my_val_forward,
)
history = trainer.fit(
epochs=10,
training_dataset=train_loader,
validation_dataset=val_loader,
model=model,
)
BestCriterion
A callable that monitors a log key and fires on_best callbacks whenever a new best is found. Attach it to on_epoch_end or on_validation_end.
checkpoint = BestCriterion(
target="val_acc1",
mode="max",
on_best=[save_checkpoint],
)
trainer.on_epoch_end.append(checkpoint)
Fields: target (str), mode ("min" or "max", default "min"), on_best (list of BestCallback).
API Reference: trainer.py
src/structcast_model/torch/trainer.py contains the PyTorch-specific runtime layer.
Utility functions
create_torch_inputs(shape)
Creates dummy float32 tensors from tuple, list, or dict shape descriptions. Used for model initialization and FLOPs inspection.
get_torch_device(device=None)
Returns the runtime device. Selects cuda when available and requested, otherwise falls back to cpu.
initial_model(model, shapes=None, compile_fn=None)
Walks a module or nested module structure, optionally builds dummy inputs, runs a forward pass, and applies a compile function to each module. Returns:
(initialized_model, inputs, outputs)
get_autocast(mixed_precision_type, device)
Returns a context manager for automatic mixed precision:
contextlib.suppresswhen AMP is disabled.- A configured
torch.autocast(...)partial when AMP is enabled.
Step objects
TrainingStep
TrainingStep chains one or more models, updates a shared output dictionary, computes losses, and optionally computes metrics.
step = TrainingStep(
models=["model"],
losses=loss_module,
metrics=metric_module,
autocast=get_autocast("bfloat16", "cuda"),
)
criteria = step({"image": image, "label": label}, model=model)
ValidationStep
Same interface as TrainingStep, but always executes under torch.no_grad().
Tracking and orchestration
TorchTracker
Wraps CriteriaTracker instances for losses and metrics, resets them through global callbacks, and returns float-valued logs suitable for history storage and MLflow logging.
tracker = TorchTracker.from_criteria(["ce_loss"], ["acc1", "acc5"])
logs = tracker(ce_loss=loss_tensor, acc1=acc1_tensor, acc5=acc5_tensor)
TorchTrainer
TorchTrainer extends the generic BaseTrainer with PyTorch-specific synchronization.
trainer = TorchTrainer(
device="cuda",
training_step=TrainingStep(models=["model"], losses=loss_module, metrics=metric_module),
validation_step=ValidationStep(models=["model"], losses=loss_module, metrics=metric_module),
backward=backward,
tracker=tracker,
)
history = trainer.fit(
epochs=5,
training_dataset=train_loader,
validation_dataset=valid_loader,
model=model,
)
timm integrations
TimmDatasetWrapper
Holds validated dataset configuration and lazily calls timm.data.create_dataset(...).
TimmDataLoaderWrapper
Builds a timm dataloader with support for:
- Prefetching
- Channels-last memory format conversion
- Mixup and cutmix data augmentation
- Train/validation-specific augmentation settings
- Distributed device initialization
- Optional
FlexSpecoutput remapping
The dataset template at cfg/datasets/default_timm.yaml formats into this wrapper.
TimmEmaWrapper
Creates and updates timm.utils.ModelEmaV3 instances and swaps them into inference-time evaluation when appropriate.
Minimal End-to-End Example
uv sync --extra torch-cu130 --extra mlflow --extra flops
scm torch create model cfg/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: femto}' -o model.py
scm torch create model cfg/losses/cls.yaml -c Loss -o loss.py
scm torch create model cfg/metrics/topk.yaml -c Metric -o metric.py
scm torch create backward cfg/backwards/ConvNeXtV2.yaml -p 'DEFAULT: {epochs: 5}' -o backward.py
scm format cfg/datasets/default_timm.yaml \
-o dataset_train.yaml \
-p 'DEFAULT: {training: true, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, label_smoothing: 0.1, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
scm format cfg/datasets/default_timm.yaml \
-o dataset_valid.yaml \
-p 'DEFAULT: {training: false, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
scm torch train \
'model: [_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
-d cuda \
--ema cfg/others/ema.yaml \
-l '[_obj_, {_addr_: loss.Loss, _file_: loss.py}, _call_]' \
-m '[_obj_, {_addr_: metric.Metric, _file_: metric.py}, _call_]' \
-b '[_obj_, {_addr_: backward.Backward, _file_: backward.py}]' \
-c cfg/others/compile_default.yaml \
-e 5 \
-t dataset_train.yaml \
-v dataset_valid.yaml \
-f 1 \
-LC ce_loss \
-LC val_ce_loss \
-HC acc1 \
-HC val_acc1 \
-HC acc5 \
-HC val_acc5 \
-SC val_acc1 \
--matmul-precision high \
-E Test
Development
Set up the development environment with:
uv sync --extra torch-cpu --dev --group tox
Run the test suite:
pytest
Run static type checks:
mypy src
mypy tests
Run linting and formatting:
ruff check src tests
ruff format src tests
The repository includes tests for:
- CLI behavior
- Builder code generation
- Schema validation
- Trainer utilities
- timm dataset and dataloader wrappers
- Custom torch layers
Roadmap
- PyTorch model construction from YAML configuration files
- PyTorch training workflow generation from YAML configuration files
- JAX model construction from YAML configuration files
- JAX training workflow generation from YAML configuration files
- TensorFlow model construction from YAML configuration files
- TensorFlow training workflow generation from YAML configuration files
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file structcast_model-1.1.0-py3-none-any.whl.
File metadata
- Download URL: structcast_model-1.1.0-py3-none-any.whl
- Upload date:
- Size: 61.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
336a0d667052b2f243b05108765768a3aac01b56c6c466bd9fde6211f43e6fd4
|
|
| MD5 |
6a1d038667995e13fa694d5f54ec7792
|
|
| BLAKE2b-256 |
5eed1bd12549b7365c6d96440804d608f09328668e38f34bcd2f462132c52936
|
Provenance
The following attestation bundles were made for structcast_model-1.1.0-py3-none-any.whl:
Publisher:
ci.yml on f6ra07nk14/structcast-model
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
structcast_model-1.1.0-py3-none-any.whl -
Subject digest:
336a0d667052b2f243b05108765768a3aac01b56c6c466bd9fde6211f43e6fd4 - Sigstore transparency entry: 1092031560
- Sigstore integration time:
-
Permalink:
f6ra07nk14/structcast-model@0cba0f8800f2f902a72194cbe3880a14359181df -
Branch / Tag:
refs/heads/main - Owner: https://github.com/f6ra07nk14
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@0cba0f8800f2f902a72194cbe3880a14359181df -
Trigger Event:
push
-
Statement type: