Skip to main content

MiniML - a minimalistic ML framework

Project description

MiniML

Run tests PyPI

MiniML (pronounced "minimal") is a tiny machine-learning framework which uses Jax as its core engine, but mixes a PyTorch inspired approach to building model with Scikit-learn's interface (using the .fit and .predict methods), and is powered by SciPy's optimization algorithms. It's meant for simple prototyping of small ML architectures that allows more flexibility than Scikit's built-in models without sacrificing too much on performance.

Training a linear model in MiniML for example looks as simple as this:

class LinearModel(MiniMLModel):
    A: MiniMLParam
    b: MiniMLParam

    def __init__(self, n_in: int, n_out: int):
        self.A = MiniMLParam((n_in,n_out))
        self.b = MiniMLParam((n_out,))
        super().__init__()

    def _predict_kernel(self, X, buffer):
        return X@self.A(buffer)+self.b(buffer)

lin_model = LinearModel(X.shape[1], y.shape[1])
lin_model.randomize()
lin_model.fit(X, y)
y_hat = lin_model.predict(X)

Note that calling a parameter with the buffer as an argument returns the value of that parameter.

Installation

Simply install this package from PyPi:

pip install miniml-jax

or to use the CUDA-enabled version of Jax:

pip install miniml-jax[cuda]

Usage

The two core types are MiniMLParam and MiniMLModel. There are also MiniMLParamList and MiniMLModelList containers to store multiple of either inside.

To define a model in MiniML, subclass MiniMLModel and define your parameters as MiniMLParam attributes in the __init__ method. Remember to make sure that:

  • every parameter or child model is stored either directly as a class member, or inside a corresponding List class;
  • the super().__init__() constructor is called at the end.

Then, implement the internal _predict_kernel method, which takes an input array as well as a memory buffer containing the parameters and returns the model's prediction. After instantiating your model, call bind() to initialize parameter buffers, or use directly randomize() to initialize parameter values. You can then use methods like fit, save, and load.

Example: Linear Model

import jax.numpy as jnp
from miniml.param import MiniMLParam
from miniml.model import MiniMLModel

class LinearModel(MiniMLModel):
    def __init__(self):
        self.a = MiniMLParam((1,))
        self.b = MiniMLParam((1,))
        super().__init__()

    def _predict_kernel(self, X, buffer):
        return X@self.A(buffer)+self.b(buffer)

# Create and bind the model
model = LinearModel()
model.bind()
model.randomize()

# Fit to data (e.g., y = 2x + 1)
X = jnp.linspace(0, 10, 20)
y = 2 * X + 1
model.fit(X, y)

# Save and load
model.save('model.npz')
model.load('model.npz')

Nested Models

You can compose models by including other MiniMLModel instances as attributes. In this case, remember to always call child models via their own _predict_kernel methods too. For example:

class ConstantModel(MiniMLModel):

    def __init__(self):
        self._c = MiniMLParam((1,))
        super().__init__()

    def _predict_kernel(self, X, buffer):
        return self._c(buffer)

class LinearWithConstant(MiniMLModel):
    def __init__(self):
        self._b = MiniMLParam((5,))
        self._M = MiniMLParam((5, 5))
        self._c = ConstantModel()
        super().__init__()

    def _predict_kernel(self, X, buffer):
        return self._M(buffer) @ X + self._b(buffer)[:, None] + self._c._predict_kernel(X, buffer)

See the full documentation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

miniml_jax-0.7.0.tar.gz (27.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

miniml_jax-0.7.0-py3-none-any.whl (37.1 kB view details)

Uploaded Python 3

File details

Details for the file miniml_jax-0.7.0.tar.gz.

File metadata

  • Download URL: miniml_jax-0.7.0.tar.gz
  • Upload date:
  • Size: 27.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for miniml_jax-0.7.0.tar.gz
Algorithm Hash digest
SHA256 63dd8399ca5ad5ce3716bf36a454dc0c8f291cec9764b8124846f3d3240bf687
MD5 fd5eeb1840bdcb886307e9a69fe5fcc0
BLAKE2b-256 dcfeda7fb0c21f8d93b6a1f3da6752ab28867ce27dc19ea5c0247daa01f116d4

See more details on using hashes here.

Provenance

The following attestation bundles were made for miniml_jax-0.7.0.tar.gz:

Publisher: python-publish.yml on stur86/miniml

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file miniml_jax-0.7.0-py3-none-any.whl.

File metadata

  • Download URL: miniml_jax-0.7.0-py3-none-any.whl
  • Upload date:
  • Size: 37.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for miniml_jax-0.7.0-py3-none-any.whl
Algorithm Hash digest
SHA256 76383fee41772099a0f58a3eb4432fcac35d2637fc4c2de1cdb6c8bdf7838917
MD5 3e5724e7b98706e95fa3b4d7fbb9fd53
BLAKE2b-256 e37aeecf8073c5bd1cddf5953422fa04d38c7cda405547dd94d8ad0b74142ee0

See more details on using hashes here.

Provenance

The following attestation bundles were made for miniml_jax-0.7.0-py3-none-any.whl:

Publisher: python-publish.yml on stur86/miniml

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page