Skip to main content

A generic Python and TensorFlow function that implements a simple version of the "Model-Agnostic Meta-Learning (MAML) Algorithm for Fast Adaptation of Deep Networks" as designed by Chelsea Finn et al. 2017

Project description

SIMPLE MAML

A generic Python and TensorFlow function that implements a simple version of the "Model-Agnostic Meta-Learning (MAML) Algorithm for Fast Adaptation of Deep Networks" as designed by Chelsea Finn et al. 2017 [1]. Especially, this implementation focuses on regression and prediction problems.

Original algorithm adapted for regression

original-algorithm

Usage

  1. Install with pip install simplemaml
  2. In your python code:
    • from simplemaml import MAML
    • MAML(model=your_model, tasks=your_array_of_tasks, etc.)
  3. Your task should be in one of the two follwing formats:
    • tasks=[{"inputs": [], "target": []}, etc.]
    • tasks=[{"train": {"inputs": [], "target": []}, "test": {"inputs": [], "target": []}}, etc.]

You can also download the lib as a .whl file using pip download simplemaml --only-binary=:all: --no-deps

More about the algorithm

Tools needed

Refer to this repository in scientific documents

Neumann, Anas. (2023). Simple Python and TensorFlow implementation of the optimization-based Model-Agnostic Meta-Learning (MAML) algorithm for supervised regression problems. GitHub repository: https://github.com/AnasNeumann/simplemaml.

    @misc{simplemaml,
      author = {Anas Neumann},
      title = {Simple Python and TensorFlow implementation of the optimization-based Model-Agnostic Meta-Learning (MAML) algorithm for supervised regression problems},
      year = {2023},
      publisher = {GitHub},
      journal = {GitHub repository},
      howpublished = {\url{https://github.com/AnasNeumann/simplemaml}},
      commit = {main}
    }

Complete code

def MAML(model, alpha=0.005, beta=0.005, optimizer=keras.optimizers.SGD, c_loss=keras.losses.mse, f_loss=keras.losses.MeanSquaredError(), meta_epochs=100, meta_tasks_per_epoch=[10, 30], inputs_dimension=1, validation_split=0.2, k_folds=0, tasks=[], cumul=False):
    """
    Simple MAML algorithm implementation for supervised regression.
        :param model: A Keras model to be trained using MAML.
        :param alpha: Learning rate for task-specific updates.
        :param beta: Learning rate for meta-updates.
        :param optimizer: Optimizer to be used for training.
        :param c_loss: Loss function for calculating training loss.
        :param meta_epochs: Number of meta-training epochs.
        :param meta_tasks_per_epoch: Range of tasks to sample per epoch.
        :param inputs_dimension: the input dimension (for sequence-to-sequence models).
        :param validation_split: Ratio of data to use for validation in each task (could be fixed or random between two values).
        :param k_folds: cross-validation with k_folds each time a task is called for meta-learning.
        :param tasks: List of tasks for meta-training.
        :param cumul: choose between sum and mean gradients during the outer loop.
        :return: Tuple of trained model and evolution of losses over epochs.
    """
    if "train" in tasks[0] and "test" in tasks[0]:
        build_task_f = _get_task
        build_task_param = {"dimension": inputs_dimension}
    elif k_folds>0:
        build_task_f = _k_fold_task
        build_task_param = {"dimension": inputs_dimension, "k": k_folds}
    else:
        build_task_f = _split_task
        build_task_param = {"dimension": inputs_dimension, "split": validation_split}
    if tf.config.list_physical_devices('GPU'):
        with tf.device('/GPU:0'):
            return _MAML_compute(model, alpha, beta, optimizer, c_loss, f_loss, meta_epochs, meta_tasks_per_epoch, build_task_f, build_task_param, tasks, cumul)
    else:
       return _MAML_compute(model, alpha, beta, optimizer, c_loss, f_loss, meta_epochs, meta_tasks_per_epoch, build_task_f, build_task_param, tasks, cumul)

def _split_task(t, param):
    d = param["dimension"]
    split = param["split"]
    v = random.uniform(split[0], split[1]) if isinstance(split,list) else split
    split_idx = int(len(t["inputs"]) * v)
    train_input = t["inputs"][:split_idx] if d<=1 else [t["inputs"][:split_idx] for _ in range(d)]
    test_input = t["inputs"][split_idx:] if d<=1 else [t["inputs"][split_idx:] for _ in range(d)]
    train_target, test_target = t["target"][:split_idx], t["target"][split_idx:]
    return train_input, test_input, train_target, test_target

def _k_fold_task(t, param):
    d = param["dimension"]
    k = param["k"]
    fold = random.randint(0, k-1)
    fold_size = (len(t["inputs"]) // k)
    v_start = fold * fold_size
    v_end = (fold + 1) * fold_size if fold < k - 1 else len(t["inputs"])
    t_i = np.concatenate((t["inputs"][:v_start], t["inputs"][v_end:]), axis=0)
    train_input = t_i if d<=1 else [t_i for _ in range(d)]
    test_input = t["inputs"][v_start:v_end] if d<=1 else [t["inputs"][v_start:v_end] for _ in range(d)]
    train_target = np.concatenate((t["target"][:v_start], t["target"][v_end:]), axis=0)
    test_target = t["target"][v_start:v_end]
    return train_input, test_input, train_target, test_target

def _get_task(t, param):
    d = param["dimension"]
    train_input = t["train"]["inputs"] if d<=1 else [t["train"]["inputs"] for _ in range(d)]
    test_input = t["test"]["inputs"] if d<=1 else [t["test"]["inputs"] for _ in range(d)]
    return train_input, test_input, t["train"]["target"], t["test"]["target"] 

def _MAML_compute(model, alpha, beta, optimizer, c_loss, f_loss, meta_epochs, meta_tasks_per_epoch, build_task_f, build_task_param, tasks, cumul):
    log_step = meta_epochs // 10 if meta_epochs > 10 else 1
    optim_test=optimizer(learning_rate=alpha)
    optim_train=optimizer(learning_rate=beta)
    model_copy = tf.keras.models.clone_model(model)
    model_copy.build(model.input_shape)
    model_copy.set_weights(model.get_weights())
    optim_test.build(model.trainable_variables)
    optim_train.build(model_copy.trainable_variables)
    model.compile(loss=f_loss, optimizer=optim_test)
    model_copy.compile(loss=f_loss, optimizer=optim_train)
    losses=[]
    total_loss=0.
    for step in range (meta_epochs):
        sum_gradients = [tf.zeros_like(variable) for variable in model.trainable_variables]
        num_tasks_sampled = random.randint(meta_tasks_per_epoch[0], meta_tasks_per_epoch[1])
        model_copy.set_weights(model.get_weights())
        for _ in range(num_tasks_sampled):
            train_input, test_input, train_target, test_target = build_task_f(random.choice(tasks), build_task_param)

            # 1. Inner loop: Update the model copy on the current task
            with tf.GradientTape(watch_accessed_variables=False) as train_tape:
                train_tape.watch(model_copy.trainable_variables)
                train_pred = model_copy(train_input)
                train_loss = tf.reduce_mean(c_loss(train_target, train_pred))
            g = train_tape.gradient(train_loss, model_copy.trainable_variables)
            optim_train.apply_gradients(zip(g, model_copy.trainable_variables))

            # 2. Compute gradients with respect to the test data
            with tf.GradientTape(watch_accessed_variables=False) as test_tape:
                test_tape.watch(model_copy.trainable_variables)
                test_pred = model_copy(test_input)
                test_loss = tf.reduce_mean(c_loss(test_target, test_pred))
            g = test_tape.gradient(test_loss, model_copy.trainable_variables)
            for i, gradient in enumerate(g):
                sum_gradients[i] += gradient

        # 3. Meta-update: apply the accumulated gradients to the original model
        cumul_gradients = [grad / (1.0 if cumul else num_tasks_sampled) for grad in sum_gradients]
        optim_test.apply_gradients(zip(cumul_gradients, model.trainable_variables))
        total_loss += test_loss.numpy()
        loss_evol = total_loss/(step+1)
        losses.append(loss_evol)
        if step % log_step == 0:
            print(f'Meta epoch: {step+1}/{meta_epochs},  Loss: {loss_evol}')
    return model, losses

Build a new version of the lib (after updating the version number in setup.py)

  1. rm -rf dist/ build/ simplemaml.egg-info/
  2. python3 setup.py sdist bdist_wheel
  3. twine upload dist/*

REFERENCES

[1] Finn, C., Abbeel, P. & Levine, S.. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. Proceedings of the 34th International Conference on Machine Learning, in Proceedings of Machine Learning Research 70:1126-1135 Available from https://proceedings.mlr.press/v70/finn17a.html and https://proceedings.mlr.press/v70/finn17a/finn17a.pdf.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simplemaml-1.2.13.tar.gz (5.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

simplemaml-1.2.13-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file simplemaml-1.2.13.tar.gz.

File metadata

  • Download URL: simplemaml-1.2.13.tar.gz
  • Upload date:
  • Size: 5.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for simplemaml-1.2.13.tar.gz
Algorithm Hash digest
SHA256 8a24a6ea57e74f3edaa4e3aeef05a0fd6c4a0c38db616b2e3833d882702fe2d6
MD5 b55dea75ffb926c22662ca3a02377cd5
BLAKE2b-256 2a4d21c3d7ce21cf1a04c254c27dbee1fee05e0e2a98f3276cbf3f2a70637d99

See more details on using hashes here.

File details

Details for the file simplemaml-1.2.13-py3-none-any.whl.

File metadata

  • Download URL: simplemaml-1.2.13-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for simplemaml-1.2.13-py3-none-any.whl
Algorithm Hash digest
SHA256 7ef9690adf6fd69dd791902733af88dd0ba92a4d4843df1fc6a7bcd18f276b75
MD5 6cdd30e5d0e5d1b93a97be8c99d91a53
BLAKE2b-256 cb83969af4726d5188dd8c80ed56e6f84c27d6f926865d4dc289cfb694121355

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page