Construct neural network models and training workflows by structcast package.
Project description
StructCast-Model
StructCast-Model is a configuration-driven toolkit that generates PyTorch, Flax (JAX), and Keras models — plus PyTorch training workflows — from YAML templates. Built on top of StructCast, it lets you describe model architecture, optimizer logic, dataset configuration, and training orchestration declaratively — then generates runnable Python code from those descriptions.
Model code generation is available for all three frameworks. Training workflow generation and the full training CLI (scm torch train) are currently PyTorch-only; Flax and Keras training support is planned (see Roadmap).
Table of Contents
- StructCast-Model
- Table of Contents
- What This Project Does
- Installation
- Project Structure
- Core Workflow
- StructCast Pattern Basics
- Command Guide
- Distributed Training with
torchrun - Configuration Examples
- Schema Reference
- API Reference:
base_trainer.py - API Reference:
trainer.py - Minimal End-to-End Example
- Development
- Migration Notes
- Roadmap
What This Project Does
- Generate model code — Produce PyTorch
nn.Module, Flaxnnx.Module, and KerasLayerclasses from YAML layer templates. - Generate training code — Produce backward-pass, optimizer, and scheduler orchestration classes from YAML templates (PyTorch only).
- Format reusable templates — Render parameterized YAML templates into concrete runtime configurations.
- Inspect model complexity — Compute FLOPs and parameter counts with
ptflopsandcalflops(PyTorch only). - Measure inference time — Benchmark average forward-pass latency of generated models across all three frameworks via
scm [torch/flax/keras] time. - Train end-to-end — Run PyTorch training with Automatic Mixed Precision (AMP), timm datasets, optional
torch.compile, and MLflow experiment logging.
Installation
StructCast-Model is installed with uv and exposes the scm CLI entry point.
uv sync --extra torch-cu130 --extra mlflow --extra flops
Each extra installs a group of optional dependencies. Pick the extras that match your target framework and accelerator.
PyTorch
| Extra | What it provides |
|---|---|
torch-cpu |
PyTorch and torchvision (CPU only) |
torch-cu118 |
PyTorch and torchvision with CUDA 11.8 support |
torch-cu126 |
PyTorch and torchvision with CUDA 12.6 support |
torch-cu128 |
PyTorch and torchvision with CUDA 12.8 support |
torch-cu130 |
PyTorch and torchvision with CUDA 13.0 support |
JAX / Flax
| Extra | What it provides |
|---|---|
jax-cpu |
JAX and Flax (CPU only) |
jax-cu12 |
JAX and Flax with CUDA 12 support |
jax-cu13 |
JAX and Flax with CUDA 13 support |
TensorFlow
| Extra | What it provides |
|---|---|
tf-cpu |
TensorFlow (CPU only) |
tf-cu12 |
TensorFlow with CUDA 12 support |
Keras (multi-backend)
Keras runs on top of JAX, PyTorch, or TensorFlow. Choose the extra that matches your preferred backend:
| Extra | Backend + accelerator |
|---|---|
keras-jax-cpu |
Keras with JAX (CPU) |
keras-jax-cu12 |
Keras with JAX (CUDA 12) |
keras-jax-cu13 |
Keras with JAX (CUDA 13) |
keras-torch-cpu |
Keras with PyTorch (CPU) |
keras-torch-cu118 |
Keras with PyTorch (CUDA 11.8) |
keras-torch-cu126 |
Keras with PyTorch (CUDA 12.6) |
keras-torch-cu128 |
Keras with PyTorch (CUDA 12.8) |
keras-torch-cu130 |
Keras with PyTorch (CUDA 13.0) |
keras-tf-cpu |
Keras with TensorFlow (CPU) |
keras-tf-cu12 |
Keras with TensorFlow (CUDA 12) |
Bundles
| Extra | What it provides |
|---|---|
all-cpu |
JAX + Flax, PyTorch + torchvision + timm, TensorFlow, and Keras — all CPU-only |
all-cuda |
Same as all-cpu but with CUDA acceleration for every backend |
Tools
| Extra | What it provides |
|---|---|
ptflops |
ptflops for model complexity inspection |
calflops |
calflops and Transformers for model complexity inspection |
flops |
Both ptflops and calflops |
mlflow |
MLflow experiment tracking for scm torch train |
Omit any extra you do not need. For example, uv sync --extra torch-cu130 is sufficient if you only want to generate and train PyTorch models without FLOPs analysis or MLflow logging. To work with all three model frameworks on CPU:
uv sync --extra all-cpu
Project Structure
structcast-model/
├── cfg/
│ ├── torch/
│ │ ├── backwards/ # backward, optimizer, scheduler templates
│ │ ├── datasets/ # reusable dataset/dataloader templates
│ │ ├── losses/ # loss module templates
│ │ ├── metrics/ # metric module templates
│ │ ├── models/ # model architecture templates
│ │ └── others/ # misc templates (e.g. for `torch.compile` options)
│ ├── flax/
│ │ └── models/ # Flax model architecture templates
│ └── keras/
│ └── models/ # Keras model architecture templates
├── src/structcast_model/
│ ├── builders/ # generic and framework-specific code generators
│ ├── commands/ # Typer CLI entry points
│ ├── torch/ # trainer, layers, optimizer helpers
│ ├── flax/ # Flax layers and inference utilities
│ ├── keras/ # Keras layers and inference utilities
│ ├── utils/ # shared helpers
│ └── base_trainer.py
├── tests/ # CLI, builder, trainer, and layer tests
└── README.md
The main package areas are:
| Directory | Purpose |
|---|---|
builders/ |
Converts validated YAML templates into intermediate representations, then renders Python source code for PyTorch, Flax, and Keras. |
commands/ |
Exposes the scm CLI (built with Typer) with torch, flax, and keras sub-commands. |
torch/ |
Runtime utilities used by the CLI and available for direct Python usage — training steps, trackers, timm wrappers, optimizer helpers. |
flax/ |
Flax-specific layers (e.g. GlobalResponseNorm) and JAX inference helpers. |
keras/ |
Keras-specific layers (e.g. GlobalResponseNormalization) and backend-agnostic inference helpers. |
cfg/torch/ |
Declarative source of truth: YAML templates for PyTorch models, backward logic, datasets, and runtime presets. |
cfg/flax/ |
YAML templates for Flax model architectures. |
cfg/keras/ |
YAML templates for Keras model architectures. |
Core Workflow
The repository follows a repeatable workflow:
- Write or reuse YAML templates under
cfg/[torch/flax/keras]/. - Render templates with
scm formatand-p/--parameteroverrides to produce concrete configuration files. - Generate Python source files for the model (and, for PyTorch, loss, metric, and backward logic) using
scm [torch/flax/keras] create. - Instantiate those generated modules at runtime through StructCast object patterns (see StructCast Pattern Basics).
- Benchmark inference latency with
scm [torch/flax/keras] time. - (PyTorch only) Train through
scm torch train, which wires together datasets, models, losses, metrics, optimizer logic, AMP, and MLflow.
YAML templates ---> scm format / scm [torch/flax/keras] create ---> Generated .py files
|
StructCast patterns <--------------------------------------------------------+
|
v
scm [torch/flax/keras] time ---> Inference benchmarks
scm torch train ---> MLflow logs + model checkpoints
StructCast Pattern Basics
This repository relies heavily on StructCast object patterns to bridge generated source files and runtime commands. The minimum syntax you need to read the CLI examples is:
| Alias | Meaning | Example |
|---|---|---|
_obj_ |
Chain multiple construction steps | [_obj_, ..., ...] |
_addr_ |
Import a class or function by dotted path | {_addr_: torch.nn.ReLU} |
_file_ |
Load the symbol from a local Python file | {_addr_: model.Model, _file_: model.py} |
_call_ |
Invoke the current callable | _call_ or {_call_: {out_features: 10}} |
_bind_ |
Partially apply arguments | {_bind_: {lr: 0.001}} |
_attr_ |
Access an attribute or method | {_attr_: model_validate} |
Example:
[_obj_, {_addr_: model.Model, _file_: model.py}, _call_]
This pattern does the following:
- Import
Modelfrom the local filemodel.py. - Call
Model()with no arguments and return the instance.
This pattern is the bridge between generated source files and runtime commands like ptflops, calflops, and train. For full documentation on StructCast patterns, see the StructCast README.
Command Guide
1. Format Templates
Use scm format to render a parameterized YAML template (such as cfg/torch/datasets/default_timm.yaml) into a concrete configuration file.
scm format cfg/torch/datasets/default_timm.yaml \
-o dataset_train.yaml \
-p 'DEFAULT: {training: true, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, label_smoothing: 0.1, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
scm format cfg/torch/datasets/default_timm.yaml \
-o dataset_valid.yaml \
-p 'DEFAULT: {training: false, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
What this does:
- Loads the YAML template.
- Merges any repeated
-p/--parametergroups into a single parameter set. - Renders Jinja-based sections within the template.
- Writes the resolved YAML to
-o/--output(or prints to stdout if-ois omitted).
2. Generate a Model Class
Each framework has its own create model command that reads a YAML layer template and generates a framework-native module.
PyTorch
Generate a PyTorch nn.Module from a YAML layer template (such as cfg/torch/models/ConvNeXtV2.yaml).
scm torch create model cfg/torch/models/ConvNeXtV2.yaml
scm torch create model cfg/torch/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: femto}'
scm torch create model cfg/torch/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: atto}' -o torch_model.py
Flax
Generate a Flax nnx.Module from a YAML layer template (such as cfg/flax/models/ConvNeXtV2.yaml).
scm flax create model cfg/flax/models/ConvNeXtV2.yaml
scm flax create model cfg/flax/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: femto}'
scm flax create model cfg/flax/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: atto}' -o flax_model.py
Keras
Generate a Keras Layer from a YAML layer template (such as cfg/keras/models/ConvNeXtV2.yaml).
scm keras create model cfg/keras/models/ConvNeXtV2.yaml
scm keras create model cfg/keras/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: femto}'
scm keras create model cfg/keras/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: atto}' -o keras_model.py
Common options
All three commands share the same options:
-p/--parameter: override template parameters-c/--classname: set the generated class name, defaultModel--no-structured-output: force tuple-like return behavior instead of a structured output mapping-s/--sublayer: generate a named sublayer from the template instead of the root model-o/--output: output file path; if omitted, defaults to the snake-cased class name in the current directory (e.g.,model.pyfor the default class nameModel)
The ConvNeXtV2 template uses Jinja parameter groups to switch between backbone variants such as atto, femto, tiny, and base.
3. Generate Loss, Metric, and Backward Classes
Losses and metrics use the same scm torch create model command because they are also layer graphs.
scm torch create model cfg/torch/losses/cls.yaml -c Loss -o loss.py
scm torch create model cfg/torch/metrics/topk.yaml -c Metric -o metric.py
scm torch create backward cfg/torch/backwards/ConvNeXtV2.yaml -p 'DEFAULT: {epochs: 5}' -o backward.py
The scm torch create backward command turns a backward template into a class that manages:
- a training-time execution graph (
FLOW) and an inference-time execution graph (INFERENCE_FLOW) per backward entry - inline layer instantiation (loss layers, metric layers, and arbitrary modules can be defined directly in the flow)
- one or more backward entries, each with its own optimizer and trainable layers — enabling multi-optimizer training (e.g., GAN generator + discriminator)
- optimizer construction via StructCast patterns
- optional gradient scaler creation (
MIXED_PRECISION) - optional gradient clipping (
CLIP) - optional gradient accumulation (
ACCUMULATE_GRADIENTS) - optimizer stepping, zeroing, and automatic train/eval mode switching
- learning-rate and parameter-group inspection helpers
For example, a CycleGAN backward template defines three backward entries — one for the generator pair and one for each discriminator — each with its own flow, optimizer, and trainable layers:
scm torch create backward cfg/torch/backwards/CycleGAN.yaml -o backward.py
4. Inspect FLOPs and Parameters
Once a model has been generated, you can instantiate it from a StructCast pattern and measure its computational complexity.
scm torch ptflops '[_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
--backend pytorch
scm torch calflops '[_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]'
What these commands do internally:
- Instantiate the model from the
_obj_pattern. - Create dummy tensors from the
-s/--shapespecification. - Run one initialization forward pass via
initial_model(...). - Pass the initialized model to
ptflopsorcalflopsfor complexity analysis.
5. Measure Inference Time
Use scm [torch/flax/keras] time to benchmark the average forward-pass latency of a generated model. All three frameworks share the same basic options:
| Option | Description |
|---|---|
| positional pattern | StructCast object pattern to instantiate the model |
-s/--shape |
Input tensor shapes, e.g. 'image: [3, 224, 224]' |
-d/--device |
Computation device (cpu, cuda, gpu:0, …) |
-c/--compile |
Compile the model before measurement (true, YAML path, or dict) |
--training-mode |
Measure in training mode instead of evaluation mode |
-w/--warmup-runs |
Number of warmup iterations (default: 2) |
-t/--times |
Number of timed iterations (default: 10) |
-b/--batch-size |
Batch size for dummy inputs (default: 1) |
PyTorch
scm torch create model cfg/torch/models/ConvNeXtV2.yaml \
-p 'DEFAULT: {backbone: atto}' -o torch_model.py
scm torch time \
'[_obj_, {_addr_: model.Model, _file_: torch_model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
-c cfg/torch/others/compile_default.yaml \
-d cuda
PyTorch-specific option: --matmul-precision (highest, high, medium) controls torch.set_float32_matmul_precision.
Flax
scm flax create model cfg/flax/models/ConvNeXtV2.yaml \
-p 'DEFAULT: {backbone: atto}' -o flax_model.py
scm flax time \
'[_obj_, {_addr_: model.Model, _file_: flax_model.py}, {_call_: {rngs: [_obj_, _addr_: flax.nnx.Rngs, _call_: {params: 0, dropout: 1}]}}]' \
-s 'image: [224, 224, 3]' \
-c true \
-d gpu:0
Flax-specific option: --training-mode-kwargs lets you override the keyword arguments passed to nnx.view when --training-mode is set (e.g. '{deterministic: false, use_running_average: false}').
Note: Flax uses channel-last tensor layout. The shape
'image: [224, 224, 3]'corresponds to H × W × C.
Keras
scm keras create model cfg/keras/models/ConvNeXtV2.yaml \
-p 'DEFAULT: {backbone: atto}' -o keras_model.py
# Keras with JAX backend may need NVIDIA shared libraries on the path
export LD_LIBRARY_PATH=$(find .venv -name "*.so*" | grep nvidia | xargs dirname | sort -u | paste -d ":" -s -)
scm keras time \
'[_obj_, {_addr_: model.Model, _file_: keras_model.py}, _call_]' \
-s 'image: [224, 224, 3]' \
-c true \
-d gpu:0
Compilation for Keras uses keras.Model.compile. The --compile/-c option accepts true/false, a YAML file path, or an inline dict of keyword arguments.
Note: Keras also uses channel-last layout by default. The shape
'image: [224, 224, 3]'corresponds to H × W × C.
6. Train a Generated Model
Below is the complete training command from the included ConvNeXtV2 example.
scm torch train \
'model: [_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
-d cuda \
-L '[_obj_, {_addr_: loss.Loss, _file_: loss.py}, _call_]' \
-M '[_obj_, {_addr_: metric.Metric, _file_: metric.py}, _call_]' \
-B '[_obj_, {_addr_: backward.Backward, _file_: backward.py}]' \
-c cfg/torch/others/compile_default.yaml \
-e 5 \
-T dataset_train.yaml \
-V dataset_valid.yaml \
-f 1 \
-LC ce_loss \
-LC val_ce_loss \
-HC acc1 \
-HC val_acc1 \
-HC acc5 \
-HC val_acc5 \
-SC val_acc1 \
--matmul-precision high \
-E Test \
-A model.py \
-A loss.py \
-A metric.py \
-A backward.py \
-A cfg/torch/others/compile_default.yaml \
-A dataset_train.yaml \
-A dataset_valid.yaml
Key arguments:
- positional model patterns: one or more named model definitions
-s/--shape: dummy input shapes used for model initialization-d/--device:cpuorcuda-L/--loss: StructCast pattern for the loss module-M/--metric: StructCast pattern for the metric module-B/--backward: StructCast pattern for the backward class-c/--compile: boolean, YAML file, or inline dict fortorch.compile-T/--training-dataset: training dataset pattern or rendered dataset YAML-V/--validation-dataset: validation dataset pattern or rendered dataset YAML-LC/--lower-criterion: criteria where lower is better-HC/--higher-criterion: criteria where higher is better-SC/--save-criterion: criteria that should trigger best-model saving-E/--experiment: MLflow experiment name-A/--log-artifacts: artifacts to store in MLflow
What the train command does internally:
- Instantiates datasets and determines their lengths.
- Initializes models with optional dummy-input forward passes.
- Instantiates loss, metric, backward, and compile objects.
- Builds a
TorchTrackerfrom the declared output names. - Creates a
TorchTrainerwith training and validation step objects. - Logs metrics, arguments, model states, optimizer states, gradient scaler states, and best checkpoints to MLflow.
Distributed Training with torchrun
scm torch train supports multi-GPU and multi-node distributed data parallel (DDP) training out of the box via torchrun. No changes to your generated code, YAML templates, or dataset configurations are required — the same scm torch train command works for both single-GPU and distributed training.
⚠️ SyncBatchNorm Warning
When using multi-GPU training with
DistributedDataParallel,scm torch traindoes not automatically convertBatchNormlayers toSyncBatchNorm. StandardBatchNormcomputes statistics per-GPU, which can cause inconsistent behavior across ranks — especially with small per-GPU batch sizes. If your model containsBatchNormlayers and you are training with DDP, consider applyingtorch.nn.SyncBatchNorm.convert_sync_batchnorm(model)to the model before wrapping it withDistributedDataParallel. This conversion must happen in user code or in the model definition; the CLI will not perform it for you.
How It Works
When launched through torchrun, the environment variables RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDR, and MASTER_PORT are set automatically. scm torch train detects these and enables distributed mode:
- Process group initialization — The NCCL backend is initialized via
torch.distributed.init_process_group. - Per-rank device assignment — Each process is assigned to
cuda:<LOCAL_RANK>. - DDP model wrapping — All models are wrapped with
DistributedDataParallel. - Distributed data loading —
TimmDataLoaderWrapperautomatically creates aDistributedSamplerwhen a distributed environment is detected. The sampler'sset_epoch()is called each epoch for proper shuffling. - Metric synchronization —
TorchTrackerusesall_reduceto average loss and metric values across all ranks. - Rank-0 logging — MLflow logging, progress bars, and checkpoint saving are performed only on rank 0.
- Gradient sync optimization — During gradient accumulation steps, DDP gradient synchronization is disabled to reduce communication overhead.
- Cleanup —
torch.distributed.destroy_process_group()is called when training finishes.
Single-Node Multi-GPU
To train on all GPUs of a single machine, prefix your scm torch train command with torchrun:
# Use all available GPUs on the current machine
torchrun --nproc_per_node=gpu \
-m structcast_model.commands.main \
torch train \
'model: [_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
-d cuda \
-L '[_obj_, {_addr_: loss.Loss, _file_: loss.py}, _call_]' \
-M '[_obj_, {_addr_: metric.Metric, _file_: metric.py}, _call_]' \
-B '[_obj_, {_addr_: backward.Backward, _file_: backward.py}]' \
-c cfg/torch/others/compile_default.yaml \
-e 5 \
-T dataset_train.yaml \
-V dataset_valid.yaml \
-f 1 \
-LC ce_loss -LC val_ce_loss \
-HC acc1 -HC val_acc1 -HC acc5 -HC val_acc5 \
-SC val_acc1 \
--matmul-precision high \
-E Test
Or specify an exact GPU count:
# Use exactly 4 GPUs
torchrun --nproc_per_node=4 \
-m structcast_model.commands.main \
torch train ...
Note:
torchrunlaunches the training script as a Python module (-m structcast_model.commands.main) rather than through thescmentry point. This is becausetorchrunrequires a module or script path, not a console script wrapper.
Multi-Node Training
For training across multiple machines, provide the node topology to torchrun on each node:
# On node 0 (master)
torchrun \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=0 \
--master_addr=192.168.1.100 \
--master_port=29500 \
-m structcast_model.commands.main \
torch train ...
# On node 1
torchrun \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=1 \
--master_addr=192.168.1.100 \
--master_port=29500 \
-m structcast_model.commands.main \
torch train ...
This creates 8 total processes (4 GPUs × 2 nodes) training with DDP.
torchrun parameters:
| Parameter | Description |
|---|---|
--nproc_per_node |
Number of processes per node. Use gpu for all available GPUs. |
--nnodes |
Total number of nodes. Defaults to 1 for single-node training. |
--node_rank |
Rank of the current node (0-indexed). |
--master_addr |
IP address of the master node. |
--master_port |
Port for inter-node communication. |
scm torch train distributed-related options:
| Option | Description |
|---|---|
--dist-backend |
Distributed backend (nccl, gloo). Auto-selected if omitted. Also settable via DIST_BACKEND env var. |
--dist-url |
URL for distributed setup. Defaults to env://. Also settable via DIST_URL env var. |
--ci |
Disables tqdm progress bars — useful in cluster job logs. |
Dataset Configuration
Dataset YAML files do not need per-rank customization. A single device: cuda value in the dataset configuration works for all ranks — TimmDataLoaderWrapper internally resolves it to the correct cuda:<LOCAL_RANK> device for each process.
# The same dataset YAML works for single-GPU and distributed training
scm format cfg/torch/datasets/default_timm.yaml \
-o dataset_train.yaml \
-p 'DEFAULT: {training: true, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, label_smoothing: 0.1, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
Tip: The
batch_sizein the dataset template is the per-GPU batch size. With 4 GPUs andbatch_size: 32, the effective global batch size is 128.
Distributed Training Notes
- Seed reproducibility — Each rank's random seed is offset by
global_rankto ensure different data augmentation across processes while remaining reproducible. - Learning rate scaling — When scaling to multiple GPUs, consider adjusting the learning rate. A common practice is linear scaling: multiply the base learning rate by the number of GPUs. This must be configured in the backward template or optimizer settings —
scm torch traindoes not scale the learning rate automatically. - SyncBatchNorm —
scm torch traindoes not automatically convertBatchNormlayers toSyncBatchNorm. If your model usesBatchNormand you are training with DDP, consider applyingtorch.nn.SyncBatchNorm.convert_sync_batchnorm(model)in the model definition. See the SyncBatchNorm warning for details. torch.compileand DDP — When both--compileand DDP are active,torch.compileis applied before DDP wrapping.- Checkpoint saving — Only rank 0 saves checkpoints and logs to MLflow. When resuming from a checkpoint in a distributed setting, all ranks load the same checkpoint.
Configuration Examples
The cfg/ directory contains working YAML templates that demonstrate each part of the workflow. Templates are organized by framework under cfg/torch/, cfg/flax/, and cfg/keras/.
PyTorch
cfg/torch/models/ConvNeXtV2.yaml
Demonstrates the model-building style used throughout the project:
- parameter groups for multiple backbone sizes
- nested user-defined layers such as
Backbone,Stem,DownSample, andBlock - Jinja-driven layer expansion
- separate training and inference flow support
- structured outputs such as
{cls: torch.tensor(...), ...}
cfg/torch/backwards/ConvNeXtV2.yaml
Demonstrates how backward logic is configured declaratively for a single-optimizer workflow:
MIXED_PRECISIONfortorch.amp.GradScalerMIXED_PRECISION_TYPEfor autocast dtypeACCUMULATE_GRADIENTSfor delayed optimizer updates- single
BACKWARDSentry withFLOWcontaining model forward pass, loss, and metric computation inline - separate
INFERENCE_FLOWfor evaluation without gradient tracking - optimizer creation through
structcast_model.torch.optimizers.create_with_scheduler - optional gradient clipping via
timm.utils.clip_grad.dispatch_clip_grad
cfg/torch/backwards/CycleGAN.yaml
Demonstrates multi-optimizer backward logic for GAN-style training:
- three
BACKWARDSentries: one for the generator pair (G_AB,G_BA) and one for each discriminator (D_A,D_B) - each entry defines its own
FLOWwith inline loss layers (L1Loss,MSELoss) and computed expressions - each entry has a dedicated
OPTIMIZERwith independent learning-rate scheduler TRAINABLE_LAYERSspecifies which models each optimizer manages- the generated backward class automatically handles train/eval mode switching per backward entry
OUTPUTSaggregates all tracked values (generator loss, GAN loss, cycle loss, identity loss, discriminator losses)
cfg/torch/models/CycleGAN_generator.yaml and CycleGAN_discriminator.yaml
Pair of model templates for the CycleGAN architecture:
- Generator — uses
ResidualBlock,DownBlock, andUpBlocksublayers with reflection padding, instance normalization, and Jinja-driven residual block expansion (n_residual_blocksparameter) - Discriminator — uses a
DiscriminatorBlocksublayer with conditional instance normalization controlled by anormalizeparameter - both templates use
LazyConv2dfor automatic input channel inference
cfg/torch/datasets/default_timm.yaml
Formats directly into a TimmDataLoaderWrapper.model_validate(...) pattern. Covers:
- timm dataset construction
- timm dataloader construction
- device and prefetch settings
- mixup and cutmix options
- train or validation split generation from one template
Flax
cfg/flax/models/ConvNeXtV2.yaml
Generates a Flax nnx.Module equivalent of the PyTorch ConvNeXtV2 model. The template mirrors the same parameter groups (atto through huge) and uses GlobalResponseNorm as a custom Flax layer. Key differences from the PyTorch variant:
- uses channel-last tensor layout (H × W × C)
- constructor accepts a
rngs: flax.nnx.Rngsargument for parameter initialization __call__propagates atrainingflag to sub-modules
Keras
cfg/keras/models/ConvNeXtV2.yaml
Generates a Keras Layer equivalent of the ConvNeXtV2 model. Shares the same backbone parameter groups and uses GlobalResponseNormalization as a custom Keras layer. Key differences:
- uses channel-last tensor layout (H × W × C)
- follows the Keras
call(self, ..., *, training=None, **kwargs)convention - runs on any Keras backend (JAX, PyTorch, or TensorFlow)
Schema Reference
All configuration templates under cfg/ follow a shared schema that controls how YAML files are parsed, rendered, and validated by the code generators. This section explains every top-level key and sub-key that appears in these templates.
Template Parameters
Every YAML template may begin with an optional top-level PARAMETERS block that declares named sets of values consumed by the Jinja rendering engine.
PARAMETERS
The top-level container for all template variable groups. Any key nested inside PARAMETERS (other than DEFAULT and SHARED) is treated as a named group that can be selected at render time.
PARAMETERS:
DEFAULT:
backbone: atto
SHARED:
drop_path_rate: 0.0
num_classes: 1000
atto:
dims: [40, 80, 160, 320]
depths: [2, 2, 6, 2]
femto:
dims: [48, 96, 192, 384]
depths: [2, 2, 6, 2]
DEFAULT
Defines the default template variables. These values are active when no named group is selected and can be overridden at the command line with -p 'DEFAULT: {key: value}'.
DEFAULT:
backbone: atto
epochs: 300
lr: 4.0e-3
SHARED
Defines variables that are merged into every named group (including DEFAULT). Use SHARED for constants that apply to all backbone or variant choices.
SHARED:
stem_kernel_size: 4
kernel_size: 7
norm_eps: 1.0e-6
Named groups
Any key in PARAMETERS that is not DEFAULT or SHARED is a named parameter group — for example atto, femto, tiny, or base. A named group is activated via _jinja_group_ and its variables (merged with SHARED) replace the template variables for that rendering scope.
atto:
dims: [40, 80, 160, 320]
depths: [2, 2, 6, 2]
femto:
dims: [48, 96, 192, 384]
depths: [2, 2, 6, 2]
_jinja_yaml_
Embeds an inline Jinja template that is rendered and merged back into the surrounding YAML. The rendered result must itself be valid YAML. _jinja_yaml_ blocks are evaluated with the currently active template variables and can emit any number of sibling YAML keys or list entries.
_jinja_yaml_: |-
{% if accumulate_gradients is none %}
ACCUMULATE_GRADIENTS: null
{% else %}
ACCUMULATE_GRADIENTS: {{accumulate_gradients}}
{% endif %}
Inside a _jinja_yaml_ block you can also use standard Jinja control structures ({% for %}, {% if %}, {% set %}, etc.) as well as the custom filter cumsum (provided by structcast_model.builders.jinja_filters).
_jinja_group_
Selects a named parameter group from PARAMETERS, merging its values (together with SHARED) into the template variable scope for the enclosing block. _jinja_group_ must appear alongside a _jinja_yaml_ sibling that consumes the newly activated variables.
- _jinja_group_: {{backbone}}
_jinja_yaml_: |-
- [_, cls, head, [_obj_, {_addr_: torch.nn.LazyLinear}, {_call_: {out_features: {{num_classes}}}}]]
When backbone resolves to atto, the atto group from PARAMETERS (merged with SHARED) becomes the local variable scope for the inner _jinja_yaml_ block.
Model Template Schema
The following keys appear in model configuration files such as cfg/torch/models/ConvNeXtV2.yaml. Each top-level key that is not PARAMETERS or a Jinja directive defines either the root model (using the reserved keys below) or a named sublayer (an arbitrary key whose value follows the same schema).
IMPORTS
Additional Python imports to inject at the top of the generated file. Accepts a dict mapping module names to lists of names to import, or an empty dict {} when no extra imports are needed.
IMPORTS: {}
# or
IMPORTS:
torch.nn: [Module, Linear]
my_package.utils: null # imports the entire module
INPUTS
Ordered list of tensor names that the generated forward() method accepts as keyword arguments. These names correspond to the first element of each FLOW entry and to the keys in the inputs dict passed at runtime.
INPUTS: [image]
OUTPUTS
Ordered list of tensor names produced by the generated forward() method. When STRUCTURED_OUTPUT is true, these names become the keys of the returned dict; otherwise, they determine the order of the returned tuple.
OUTPUTS: [cls]
# or, for a multi-output model:
OUTPUTS: [feat1, feat2, feat3, feat4]
STRUCTURED_OUTPUT
Controls the return type of the generated forward() method.
| Value | Behavior |
|---|---|
true |
Returns {"cls": tensor, ...} — a dict keyed by the names in OUTPUTS. |
false (default) |
Returns a plain tuple in the order of OUTPUTS. |
STRUCTURED_OUTPUT: true
FLOW and INFERENCE_FLOW
FLOW is the training-time execution graph: an ordered list of LayerBehavior entries (see FLOW entry format below) that describes how tensors are routed through the model's submodules.
INFERENCE_FLOW is an optional alternative graph used only during inference — for example, to skip DropPath or other training-only layers. When INFERENCE_FLOW is absent, inference uses FLOW unchanged. Both fields must produce the same INPUTS and OUTPUTS.
FLOW:
- [image, {feature: feat4}, backbone, {TYPE: Backbone}]
- [feature, _, [_obj_, {_addr_: torch.nn.AdaptiveAvgPool2d}, {_call_: {output_size: 1}}]]
- [_, cls, head, [_obj_, {_addr_: torch.nn.LazyLinear}, {_call_: {out_features: 1000}}]]
# DropPath sublayer uses a simpler inference path
DropPath:
FLOW: [[inp, out, [_obj_, {_addr_: timm.layers.DropPath}, {_call_: {drop_prob: 0.1}}]]]
INFERENCE_FLOW: [[inp, out]]
FLOW entry format
Each entry in FLOW or INFERENCE_FLOW is a LayerBehavior — a list of 2 to 4 elements:
[INPUTS, OUTPUTS]
[INPUTS, OUTPUTS, NAME_or_LAYER]
[INPUTS, OUTPUTS, NAME, LAYER]
| Position | Field | Description |
|---|---|---|
| 0 | INPUTS |
Input variable name(s) for this step. A plain string (image, feat1) reads a named tensor from the current scope. Use _ to pass the previous step's output forward. A nested list [[a, b]] collects tensors from multiple sources (e.g., for residual additions). |
| 1 | OUTPUTS |
Output variable name(s) produced by this step. Use _ for intermediate values that need not be named. A dict {alias: real_name} renames the output in the current scope. |
| 2 | NAME |
(optional) A unique identifier for the generated submodule attribute. Auto-generated when omitted. Must be a valid Python identifier. |
| 2 or 3 | LAYER |
(optional) The layer definition — either a StructCast ObjectPattern (e.g., [_obj_, {_addr_: torch.nn.ReLU}, _call_]) or a UserLayer dict (see TYPE, PARAM, and CFG). |
FLOW:
- [image, {feature: feat4}, backbone, {TYPE: Backbone}]
- [feature, _, [_obj_, {_addr_: torch.nn.AdaptiveAvgPool2d}, {_call_: {output_size: 1}}]]
- [_, _, [_obj_, {_addr_: torch.nn.Flatten}, _call_]]
- [_, cls, head, [_obj_, {_addr_: torch.nn.LazyLinear}, {_call_: {out_features: 1000}}]]
NAME
NAME appears in two contexts:
- As the third element of a
FLOWentry — sets the Python attribute name of the generated submodule (e.g.,"block0","head"). Must be a valid Python identifier. - As a key in a
BACKWARDSentry — sets the generated attribute name for that backward pass and its optimizer.
# In FLOW:
- [feat1, feat1, "block0", {TYPE: Block, PARAM: {DEFAULT: {fout: 40}}}]
# In BACKWARDS:
BACKWARDS:
- NAME: optimizer
LOSS: ce_loss
TRAINABLE_LAYERS: [model]
OPTIMIZER: [_obj_, ...]
LAYER
The fourth (or third) element of a FLOW entry. Defines how the submodule for this step is constructed. Two forms are accepted:
-
StructCast
ObjectPattern— an[_obj_, ...]list that constructs a standard PyTorch module:[_obj_, {_addr_: torch.nn.LazyConv2d}, {_call_: {out_channels: 40, kernel_size: 4, stride: 4}}]
-
UserLayerdict — references a sublayer defined elsewhere in the same file (viaTYPE) or in an external file (viaCFG):{TYPE: Backbone} {TYPE: Block, PARAM: {DEFAULT: {fout: 40, drop_path: 0.0}}} {CFG: cfg/torch/models/my_sublayer.yaml, TYPE: MySublayer}
TYPE, PARAM, and CFG
These three keys form the UserLayer dict that activates a named sublayer:
| Key | Type | Description |
|---|---|---|
TYPE |
str |
Name of a sublayer defined as a top-level key in the same YAML file (e.g., Backbone, Block, Stem). The code generator expands it into a nested nn.Module subclass. |
PARAM |
PARAMETERS dict |
Template variable overrides passed when rendering the sublayer. Uses the same DEFAULT / SHARED / named-group structure as the top-level PARAMETERS block. |
CFG |
file path | Path to an external YAML file that defines the sublayer. Allows sublayer reuse across multiple model templates. When CFG is set, TYPE selects the sublayer name within that file. |
# References Backbone sublayer defined in the same file, no parameter overrides
- [image, {feature: feat4}, backbone, {TYPE: Backbone}]
# References Block sublayer with per-instance parameter overrides
- [feat1, feat1, "block0", {TYPE: Block, PARAM: {DEFAULT: {fout: 40, drop_path: 0.0}}}]
Backward Template Schema
The following keys appear in backward configuration files such as cfg/torch/backwards/ConvNeXtV2.yaml and cfg/torch/backwards/CycleGAN.yaml.
IMPORTS
Same format as in the model schema. Injects additional Python imports into the generated backward file.
IMPORTS: {}
INPUTS and OUTPUTS
INPUTS lists the tensor names the generated backward class expects as keyword arguments during training and inference. OUTPUTS lists the tensor names produced by the backward flow. Both default to [], which instructs the code generator to infer them automatically from the BACKWARDS entries' FLOW definitions.
INPUTS: [] # auto-inferred from BACKWARDS[*].FLOW
OUTPUTS: [loss_G, loss_GAN, loss_cycle, loss_identity, loss_D_A, loss_D_B, fake_A, fake_B]
MIXED_PRECISION
Controls torch.amp.GradScaler for automatic mixed-precision training.
| Value | Behavior |
|---|---|
false (default) |
AMP disabled; no GradScaler is created. |
true |
AMP enabled with default GradScaler settings. |
dict |
AMP enabled; the dict is forwarded as keyword arguments to torch.amp.GradScaler(...). |
MIXED_PRECISION:
init_scale: "eval: 2.0**16"
growth_factor: 2.0
backoff_factor: 0.5
growth_interval: 2000
enabled: True
MIXED_PRECISION_TYPE
The dtype forwarded to torch.autocast when mixed precision is enabled. Accepts "bfloat16" or "float16". Has no effect when MIXED_PRECISION is false.
MIXED_PRECISION_TYPE: bfloat16
ACCUMULATE_GRADIENTS
The number of forward–backward steps to accumulate before calling the optimizer. Set to null to disable accumulation (optimizer steps every batch). When set to a positive integer n, optimizer.step() and optimizer.zero_grad() are called once every n batches.
ACCUMULATE_GRADIENTS: null # disabled
ACCUMULATE_GRADIENTS: 4 # accumulate over 4 steps
BACKWARDS
An ordered list of BackwardBehavior entries. Each entry defines one backward pass — i.e., one loss to differentiate, one optimizer to update, and its own execution graph. Multiple entries enable multi-optimizer training (e.g., GAN-style training where generator and discriminator optimizers are stepped independently).
During code generation, each entry's trainable layers are automatically set to training mode before its flow executes, and set back to eval mode after the optimizer step.
# Single-optimizer example (classification)
BACKWARDS:
- NAME: optimizer
LOSS: ce_loss
TRAINABLE_LAYERS: [model]
OPTIMIZER: [_obj_, ...]
CLIP: null
EXTRA: {}
FLOW:
- [image, cls, model]
- [{target: label, input: cls}, ce_loss, cross_entropy_loss, [_obj_, ...]]
INFERENCE_FLOW:
- [image, cls, model]
- [{target: label, input: cls}, ce_loss, cross_entropy_loss]
# Multi-optimizer example (GAN)
BACKWARDS:
- NAME: optimizer_G
LOSS: loss_G
TRAINABLE_LAYERS: [G_AB, G_BA]
OPTIMIZER: [_obj_, ...]
FLOW: [...] # generator forward + loss computation
INFERENCE_FLOW: [...] # inference-only flow
- NAME: optimizer_D_A
LOSS: loss_D_A
TRAINABLE_LAYERS: [D_A]
OPTIMIZER: [_obj_, ...]
FLOW: [...] # discriminator A forward + loss computation
- NAME: optimizer_D_B
LOSS: loss_D_B
TRAINABLE_LAYERS: [D_B]
OPTIMIZER: [_obj_, ...]
FLOW: [...] # discriminator B forward + loss computation
LOSSES and TRAINABLE_LAYERS
Both fields default to [], which instructs the code generator to infer their values automatically from the BACKWARDS entries.
| Key | Type | Description |
|---|---|---|
LOSSES |
list[str] |
Explicit list of loss key names that the generated backward class tracks. Auto-inferred from BACKWARDS[*].LOSS when left as []. |
TRAINABLE_LAYERS |
list[str] |
Explicit list of trainable model names the generated backward class expects as constructor arguments. Auto-inferred from BACKWARDS[*].TRAINABLE_LAYERS when left as []. |
LOSSES: [] # auto-inferred
TRAINABLE_LAYERS: [] # auto-inferred
BACKWARDS entry keys
Each entry in BACKWARDS is a BackwardBehavior with the following fields:
| Key | Type | Description |
|---|---|---|
NAME |
str |
Optional identifier for this backward pass. Used as the generated attribute name for the optimizer. Must be a valid Python identifier. |
LOSS |
str |
The loss key (produced by the FLOW) that this backward pass differentiates. |
TRAINABLE_LAYERS |
list[str] |
Model names whose parameters this optimizer manages. Each value must match a model passed to the backward class constructor. |
FLOW |
list |
Training-time execution graph for this backward entry. Uses the same entry format as model FLOW (see FLOW entry format), plus support for "eval: ..." expressions and inline layer instantiation via StructCast patterns. |
INFERENCE_FLOW |
list |
Optional inference-time execution graph. When absent, FLOW is used for inference as well. |
OPTIMIZER |
StructCast pattern | A StructCast ObjectPattern that constructs the optimizer (and optionally its learning-rate scheduler). Commonly uses structcast_model.torch.optimizers.create_with_scheduler with _bind_ to pass optimizer_kwargs and scheduler_kwargs. |
CLIP |
StructCast pattern or null |
Optional gradient-clipping callable. When non-null, the pattern is bound once and called before each optimizer step with the parameters identified by TRAINABLE_LAYERS. Set to null to disable gradient clipping. |
EXTRA |
dict |
Extra keyword arguments forwarded to the backward/optimizer logic. Default is {}. |
BACKWARDS:
- NAME: optimizer
LOSS: ce_loss
TRAINABLE_LAYERS: [model]
OPTIMIZER:
- _obj_
- _addr_: structcast_model.torch.optimizers.create_with_scheduler
- _bind_:
optimizer_kwargs:
opt: adamw
lr: 4.0e-3
weight_decay: 0.001
scheduler_kwargs:
name: cosine
num_epochs: 300
CLIP:
- _obj_
- _addr_: timm.utils.clip_grad.dispatch_clip_grad
- _bind_: {value: 1.0, mode: norm, norm_type: 2.0}
EXTRA: {}
FLOW:
- [image, cls, model]
- [{target: label, input: cls}, ce_loss, cross_entropy_loss, [_obj_, _addr_: torch.nn.CrossEntropyLoss, _call_]]
INFERENCE_FLOW:
- [image, cls, model]
- [{target: label, input: cls}, ce_loss, cross_entropy_loss]
API Reference: base_trainer.py
src/structcast_model/base_trainer.py provides the framework-agnostic training loop, state management, and callback system. Concrete trainers such as TorchTrainer build on top of these abstractions.
Utility functions
get_dataset(dataset)
Resolves a DatasetLike or a zero-argument callable into an actual iterable. This allows lazy dataset construction.
get_dataset_size(dataset)
Returns the number of batches. Uses __len__ when available, otherwise iterates to count.
invoke_callback(callbacks, info, *args, **models)
Iterates over a callback list and calls each entry with info and keyword model arguments.
Protocols
Forward
Called once per batch during training or validation. Accepts an inputs dictionary and keyword model arguments; returns a dict[str, Any] of named outputs and criteria.
Backward
Called once per training step. Receives the step index and criterion keyword arguments; returns True when the optimizer has stepped, False when gradients are being accumulated.
Callback and BestCallback
Lifecycle hooks called with (info: BaseInfo, **models). BestCallback additionally receives target: str and best: float arguments.
State and callbacks
BaseInfo
Dataclass holding mutable training state:
step— total training steps takenupdate— optimizer update countepoch— current epoch numberhistory— per-epoch log dictionarieslogs(epoch=None)— returns the log dict for the current (or given) epoch
Callbacks
Dataclass holding callback lists for each lifecycle hook:
on_update— after each optimizer updateon_training_begin/on_training_endon_training_step_begin/on_training_step_endon_validation_begin/on_validation_endon_validation_step_begin/on_validation_step_endon_epoch_begin/on_epoch_end
When add_global_callbacks=True (the default), entries from GLOBAL_CALLBACKS are copied into each list at construction time.
GLOBAL_CALLBACKS
A shared Callbacks[Any] instance. Callbacks registered here are automatically picked up by every newly created trainer.
Core classes
BaseTrainer
The main training loop driver. Inherits both BaseInfo and Callbacks.
Required fields: training_step (Forward), backward (Backward), tracker (callable returning dict[str, float]).
Optional fields: validation_step, training_prefix (default ""), validation_prefix (default "val_").
Key methods:
train(dataset, **models)— runs one training epoch, returns the final step logsevaluate(dataset, **models)— runs one validation epoch, returns the final step logsfit(epochs, training_dataset, validation_dataset=None, start_epoch=1, validation_frequency=1, **models)— runs the full loop and returns the complete history dictsync()— optional synchronization hook, no-op by default (overridden inTorchTrainer)
trainer = MyTrainer(
training_step=my_forward,
backward=my_backward,
tracker=my_tracker,
validation_step=my_val_forward,
)
history = trainer.fit(
epochs=10,
training_dataset=train_loader,
validation_dataset=val_loader,
model=model,
)
BestCriterion
A callable that monitors a log key and fires on_best callbacks whenever a new best is found. Attach it to on_epoch_end or on_validation_end.
checkpoint = BestCriterion(
target="val_acc1",
mode="max",
on_best=[save_checkpoint],
)
trainer.on_epoch_end.append(checkpoint)
Fields: target (str), mode ("min" or "max", default "min"), on_best (list of BestCallback).
API Reference: trainer.py
src/structcast_model/torch/trainer.py contains the PyTorch-specific runtime layer.
Utility functions
create_torch_inputs(shape)
Creates dummy float32 tensors from tuple, list, or dict shape descriptions. Used for model initialization and FLOPs inspection.
get_torch_device(device=None)
Returns the runtime device. Selects cuda when available and requested, otherwise falls back to cpu.
initial_model(model, shapes=None, compile_fn=None)
Walks a module or nested module structure, optionally builds dummy inputs, runs a forward pass, and applies a compile function to each module. Returns:
(initialized_model, inputs, outputs)
get_autocast(mixed_precision_type, device)
Returns a context manager for automatic mixed precision:
contextlib.suppresswhen AMP is disabled.- A configured
torch.autocast(...)partial when AMP is enabled.
Step objects
TrainingStep
TrainingStep chains one or more models, updates a shared output dictionary, computes losses, and optionally computes metrics.
step = TrainingStep(
models=["model"],
losses=loss_module,
metrics=metric_module,
autocast=get_autocast("bfloat16", "cuda"),
)
criteria = step({"image": image, "label": label}, model=model)
ValidationStep
Same interface as TrainingStep, but always executes under torch.no_grad().
Tracking and orchestration
TorchTracker
Wraps CriteriaTracker instances for losses and metrics, resets them through global callbacks, and returns float-valued logs suitable for history storage and MLflow logging.
tracker = TorchTracker.from_criteria(["ce_loss"], ["acc1", "acc5"])
logs = tracker(ce_loss=loss_tensor, acc1=acc1_tensor, acc5=acc5_tensor)
TorchTrainer
TorchTrainer extends the generic BaseTrainer with PyTorch-specific synchronization.
trainer = TorchTrainer(
device="cuda",
training_step=TrainingStep(models=["model"], losses=loss_module, metrics=metric_module),
validation_step=ValidationStep(models=["model"], losses=loss_module, metrics=metric_module),
backward=backward,
tracker=tracker,
)
history = trainer.fit(
epochs=5,
training_dataset=train_loader,
validation_dataset=valid_loader,
model=model,
)
timm integrations
TimmDatasetWrapper
Holds validated dataset configuration and lazily calls timm.data.create_dataset(...).
TimmDataLoaderWrapper
Builds a timm dataloader with support for:
- Prefetching
- Channels-last memory format conversion
- Mixup and cutmix data augmentation
- Train/validation-specific augmentation settings
- Distributed device initialization
- Optional
FlexSpecoutput remapping
The dataset template at cfg/torch/datasets/default_timm.yaml formats into this wrapper.
Minimal End-to-End Example
uv sync --extra torch-cu130 --extra mlflow --extra flops
scm torch create model cfg/torch/models/ConvNeXtV2.yaml -p 'DEFAULT: {backbone: femto}' -o model.py
scm torch create model cfg/torch/losses/cls.yaml -c Loss -o loss.py
scm torch create model cfg/torch/metrics/topk.yaml -c Metric -o metric.py
scm torch create backward cfg/torch/backwards/ConvNeXtV2.yaml -p 'DEFAULT: {epochs: 5}' -o backward.py
scm format cfg/torch/datasets/default_timm.yaml \
-o dataset_train.yaml \
-p 'DEFAULT: {training: true, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, label_smoothing: 0.1, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
scm format cfg/torch/datasets/default_timm.yaml \
-o dataset_valid.yaml \
-p 'DEFAULT: {training: false, epochs: 5, batch_size: 32, dataset: torch/cifar100, num_classes: 100, input_size: [3, 224, 224], image_dtype: bfloat16, download: true}'
scm torch train \
'model: [_obj_, {_addr_: model.Model, _file_: model.py}, _call_]' \
-s 'image: [3, 224, 224]' \
-d cuda \
-L '[_obj_, {_addr_: loss.Loss, _file_: loss.py}, _call_]' \
-M '[_obj_, {_addr_: metric.Metric, _file_: metric.py}, _call_]' \
-B '[_obj_, {_addr_: backward.Backward, _file_: backward.py}]' \
-c cfg/torch/others/compile_default.yaml \
-e 5 \
-T dataset_train.yaml \
-V dataset_valid.yaml \
-f 1 \
-LC ce_loss \
-LC val_ce_loss \
-HC acc1 \
-HC val_acc1 \
-HC acc5 \
-HC val_acc5 \
-SC val_acc1 \
--matmul-precision high \
-E Test
Development
Set up the development environment with:
uv sync --extra torch-cpu --dev --group tox
Run the test suite:
pytest
Run static type checks:
mypy src
mypy tests
Run linting and formatting:
ruff check src tests
ruff format src tests
Run all checks in parallel with:
tox run-parallel --parallel all
The repository includes tests for:
- CLI behavior
- Builder code generation
- Schema validation
- Trainer utilities
- timm dataset and dataloader wrappers
- Custom torch layers
Migration Notes
Upgrading from v1.x
The following breaking changes were introduced by the backward-template restructure for multi-optimizer GAN training support:
- EMA support removed —
TimmEmaWrapper, thecfg/torch/others/ema.yamlconfiguration, and allInferenceWrapper-based EMA integration incmd_torch.pyandtorch/trainer.pyhave been removed. If your training workflow relied on built-in EMA, you will need to manage EMA externally. - Backward template schema restructured — The
BACKWARDSkey now expects a list ofBackwardBehaviorentries (each with its ownNAME,LOSS,TRAINABLE_LAYERS,OPTIMIZER,FLOW, and optionalINFERENCE_FLOW). Previous single-optimizer backward configurations must be wrapped in a single-entry list. trainer.fit()signature simplified — Unused model arguments were removed from thefit()method. Update any custom callers accordingly.
Roadmap
- PyTorch model construction from YAML configuration files
- PyTorch training workflow generation from YAML configuration files
- JAX (Flax) model construction from YAML configuration files
- JAX (Flax) training workflow generation from YAML configuration files
- Keras model construction from YAML configuration files
- Keras training workflow generation from YAML configuration files
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file structcast_model-2.0.0-py3-none-any.whl.
File metadata
- Download URL: structcast_model-2.0.0-py3-none-any.whl
- Upload date:
- Size: 87.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
65f001ef4c2176b5afafa7c3882adf12d36a8db5a7d005ef916a51ae034c20cd
|
|
| MD5 |
11e680c49493bf66b10bdb82dac06372
|
|
| BLAKE2b-256 |
ee46f27652ae434d2bacab1dc381f6282d9b7f1cd359fae0df47d19711bac03d
|
Provenance
The following attestation bundles were made for structcast_model-2.0.0-py3-none-any.whl:
Publisher:
ci.yml on f6ra07nk14/structcast-model
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
structcast_model-2.0.0-py3-none-any.whl -
Subject digest:
65f001ef4c2176b5afafa7c3882adf12d36a8db5a7d005ef916a51ae034c20cd - Sigstore transparency entry: 1280500859
- Sigstore integration time:
-
Permalink:
f6ra07nk14/structcast-model@7dd60538ff48172f3978c65702d30e265d096b42 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/f6ra07nk14
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@7dd60538ff48172f3978c65702d30e265d096b42 -
Trigger Event:
push
-
Statement type: