Skip to main content

A Judicious, Unified and extendable framework for multi-paradigm Machine Learning research, powered by jutility and PyTorch.

Project description

JUML

A Judicious, Unified and extendable framework for multi-paradigm Machine Learning research, powered by jutility and PyTorch.

Judicious [adjective]: having or showing reason and good judgment in making decisions

Contents

Installation

JUML is available as a Python package on PyPI, and can be installed using pip with the following commands:

python -m pip install -U pip
python -m pip install -U juml-toolkit

Alternatively, JUML can be installed in "editable mode" from the GitHub repository:

git clone https://github.com/jakelevi1996/juml.git
python -m pip install -U pip
python -m pip install -e ./juml

JUML depends on PyTorch. The installation instructions for PyTorch depend on which (if any) CUDA version is available, so PyTorch won't be automatically installed by pip when installing JUML. Instead, please install PyTorch following the official PyTorch installation instructions.

Overview

The JUML framework defines 6 fundamental classes (and several example subclasses), available in the juml.base namespace module, which are expected to be subclassed in downstream projects:

Coming soon: juml.base.Environment for RL

Usage examples

The JUML framework is designed to be extended in downstream research projects, but nontheless contains enough built-in functionality to run some simple ML experiments and visualise the results from the command line (without writing any Python code). The following subsections demonstrate:

  1. The built-in functionality of JUML
  2. A simple example demonstrating how to extend JUML with a new model and dataset

JUML can be used with GPUs, however all of the below usage examples run purely on CPUs (using multiple processes for the hyperparameter sweeps) in under a minute (not including downloading datasets). These commands can be run on GPUs by specifying --devices.

Out of the box

Help interface

juml -h
juml train -h
juml sweep -h
juml profile -h

Train a model

juml train --model Mlp --model.Mlp.embedder Flatten --model.Mlp.embedder.Flatten.n 3 --dataset Mnist --trainer.BpSp.epochs 3
cli: Mnist()
cli: Flatten(n=3)
cli: Identity()
cli: Mlp(depth=3, embedder=Flatten(num_params=0), hidden_dim=100, input_shape=[1, 28, 28], output_shape=[10], pooler=Identity(num_params=0))
cli: CrossEntropy()
cli: Adam(lr=0.001, params=<generator object Module.parameters at 0x750d7f1f2ea0>)
Time        | Epoch      | Batch      | Batch loss | Train metric | Test metric
----------- | ---------- | ---------- | ---------- | ------------ | ------------
0.0002s     |          0 |            |            |      0.12013 |      0.11540
1.7009s     |          0 |          0 |    2.33147 |              |
2.0003s     |          0 |         91 |    0.47548 |              |
3.0021s     |          0 |        398 |    0.24203 |              |
3.6553s     |          0 |        599 |    0.22922 |              |
3.6566s     |          1 |            |            |      0.95568 |      0.95320
5.3384s     |          1 |          0 |    0.21946 |              |
6.0018s     |          1 |        200 |    0.07385 |              |
7.0005s     |          1 |        500 |    0.19840 |              |
7.8268s     |          1 |        599 |    0.13859 |              |
7.8288s     |          2 |            |            |      0.97107 |      0.96400
11.1821s    |          2 |          0 |    0.05180 |              |
12.0019s    |          2 |        244 |    0.07798 |              |
13.0009s    |          2 |        538 |    0.03853 |              |
13.2144s    |          2 |        599 |    0.08072 |              |
13.2156s    |          3 |            |            |      0.97862 |      0.97010
Saving in "results/train/dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0/cmd.txt"
Saving in "results/train/dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0/args.json"
Saving in "results/train/dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0/metrics.json"
Saving in "results/train/dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0/metrics.png"
Saving in "results/train/dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0/model.pth"
Saving in "results/train/dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0/table.pkl"
Model name = `dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0`
Final metrics = 0.97862 (train), 0.97010 (test)
Time taken for `train` = 15.2683 seconds

Plot confusion matrix

juml plotconfusionmatrix --model_name dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0

Sweep hyperparameters

juml sweep --model LinearModel --dataset LinearDataset --dataset.LinearDataset.input_dim 5 --dataset.LinearDataset.output_dim 10 --print_level 1 --sweep.seeds 1 2 3 --sweep.params '{"trainer.BpSp.epochs":[100,200,300],"trainer.BpSp.optimiser.Adam.lr":[1e-5,1e-4,1e-3,1e-2]}' --sweep.log_x trainer.BpSp.optimiser.Adam.lr --sweep.devices "[[],[],[],[],[],[]]" --sweep.no_cache

[ full_sweep_results ]

Profile

juml profile --model_name dM_lC_mMd3eFen3h100pI_tBb100e3lCle1E-05oAol0.001_s0 --profile.num_warmup 1000 --profile.num_profile 1000 --profile.batch_size 100
Key Value
Model Mlp(num_params=99.7k)
Time (total) 0.45093 s
Time (average) 0.00451 ms/sample
Throughput 221.8k samples/second
FLOPS 199.1KFLOPS/sample
Total number of samples 100.0k
Batch size 100
juml train --dataset RandomImage --model Cnn --model.Cnn.pooler Average2d --trainer BpSp --trainer.BpSp.epochs 1
juml profile --model_name dRfFi3,32,32o10te200tr200_lC_mCb2c64eIk5n3pAs2_tBb100e1lCle1E-05oAol0.001_s0
Key Value
Model Cnn(num_params=620.3k)
Time (total) 0.22285 s
Time (average) 0.22285 ms/sample
Throughput 4.5k samples/second
FLOPS 7.2MFLOPS/sample
Total number of samples 1.0k
Batch size 100
juml train --dataset RandomImage --model RzCnn --model.RzCnn.pooler Average2d --trainer BpSp --trainer.BpSp.epochs 1
juml profile --model_name dRfFi3,32,32o10te200tr200_lC_mRZCb2eIk5m64n3pAs2x2.0_tBb100e1lCle1E-05oAol0.001_s0
Key Value
Model RzCnn(num_params=283.4k)
Time (total) 0.22679 s
Time (average) 0.22679 ms/sample
Throughput 4.4k samples/second
FLOPS 7.2MFLOPS/sample
Total number of samples 1.0k
Batch size 100

Further examples

See scripts/further_examples.sh for some further examples of juml commands with some explanations.

Extending JUML

The script scripts/demo_extend_juml.py (also shown below) is a simple demonstration of how the JUML framework can be extended with a simple new model and synthetic dataset. Despite being only 58 lines (including whitespace), this script inherits the following capabilities from JUML without any additional code:

  1. CLI interface
  2. Training loop
  3. Hyperparameter sweeps
  4. Automated model naming/saving/loading
  5. Visualisation commands
  6. Profiling (from the CLI)
  7. Other models to compare against (by calling appropriate CLI arguments)
import torch
from jutility import cli
import juml

class PolynomialRegression1d(juml.base.Model):
    def __init__(
        self,
        n:              int,
        input_shape:    list[int],
        output_shape:   list[int],
    ):
        self._torch_module_init()
        self.p_i    = torch.arange(n)
        self.w_i1   = torch.nn.Parameter(torch.zeros([n, 1]))

    def forward(self, x_n1: torch.Tensor) -> torch.Tensor:
        x_ni = (x_n1 ** self.p_i)
        x_n1 = x_ni @ self.w_i1
        return x_n1

    @classmethod
    def get_cli_options(cls) -> list[cli.Arg]:
        return [cli.Arg("n", type=int, default=5)]

class Step1d(juml.datasets.Synthetic):
    def __init__(self):
        self._init_synthetic(
            input_shape=[1],
            output_shape=[1],
            n_train=200,
            n_test=200,
            x_std=0.1,
            t_std=0.02,
        )

    def _compute_target(self, x: torch.Tensor) -> torch.Tensor:
        return torch.where(x > 0, 1.0, 0.0)

    @classmethod
    def get_default_loss(cls) -> type[juml.base.Loss] | None:
        return juml.loss.Mse

class DemoExtendFramework(juml.base.Framework):
    @classmethod
    def get_models(cls) -> list[type[juml.base.Model]]:
        return [
            *juml.models.get_all(),
            PolynomialRegression1d,
        ]

    @classmethod
    def get_datasets(cls) -> list[type[juml.base.Dataset]]:
        return [
            *juml.datasets.get_all(),
            Step1d,
        ]

if __name__ == "__main__":
    DemoExtendFramework.run()

Example usage:

python scripts/demo_extend_juml.py -h
python scripts/demo_extend_juml.py train -h
python scripts/demo_extend_juml.py train --model PolynomialRegression1d --dataset Step1d --trainer.BpSp.epochs 1000 --print_level 1
python scripts/demo_extend_juml.py plot1dregression --model_name dST_lM_mPn5_tBb100e1000lCle1E-05oAol0.001_s0
python scripts/demo_extend_juml.py sweep --model PolynomialRegression1d --dataset Step1d --trainer.BpSp.epochs 1000 --print_level 1 --sweep.params '{"model.PolynomialRegression1d.n":[3,4,5,6,7,8,9,10]}' --sweep.devices "[[],[],[],[],[],[]]" --sweep.no_cache
python scripts/demo_extend_juml.py plot1dregression --model_name dST_lM_mPn6_tBb100e1000lCle1E-05oAol0.001_s0

python scripts/demo_extend_juml.py train --model RzMlp --dataset Step1d --trainer.BpSp.epochs 1000 --print_level 1
python scripts/demo_extend_juml.py plot1dregression --model_name dST_lM_mRZMd3eIm100pIx2.0_tBb100e1000lCle1E-05oAol0.001_s0
python scripts/demo_extend_juml.py plotsequential --model_name dST_lM_mRZMd3eIm100pIx2.0_tBb100e1000lCle1E-05oAol0.001_s0

Citation

If you find JUML helpful in your research, please cite:

@misc{levi_juml_2025,
    title = {{JUML}: {A} {Judicious}, {Unified} and extendable framework for multi-paradigm {Machine} {Learning} research},
    shorttitle = {{JUML}},
    url = {https://github.com/jakelevi1996/juml},
    abstract = {A Judicious, Unified and extendable framework for multi-paradigm Machine Learning research, powered by jutility and PyTorch.},
    author = {Levi, Jake},
    year = {2025},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

juml_toolkit-0.0.4.tar.gz (34.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

juml_toolkit-0.0.4-py3-none-any.whl (48.7 kB view details)

Uploaded Python 3

File details

Details for the file juml_toolkit-0.0.4.tar.gz.

File metadata

  • Download URL: juml_toolkit-0.0.4.tar.gz
  • Upload date:
  • Size: 34.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.12

File hashes

Hashes for juml_toolkit-0.0.4.tar.gz
Algorithm Hash digest
SHA256 d15743858b1f2b2d01e4781b449b3bc3dc65442311e6e9e7f3f0e5821752502d
MD5 5e42e3a8d037a9232b6ab640f29a74b2
BLAKE2b-256 7f8eca69b23e48ae69d86cb58838d50f2f00a5d11a3f8d51e269276020c948d8

See more details on using hashes here.

File details

Details for the file juml_toolkit-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: juml_toolkit-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 48.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.12

File hashes

Hashes for juml_toolkit-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 a7e51698877978ef9dc05c17eb7068d4ae284e8e1b9821bed0dbc8391aa4c504
MD5 b8f08c2ab691280ea4e7d44a52089322
BLAKE2b-256 1a8b3304434fc0a30a0e46c3004a55d5939fb24d297ff48de3cf35d2e2675611

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page