Covariance Matrix Adaptation Evolution Strategy (CMA-ES) implemented with TensorFlow
Project description
Covariance Matrix Adaptation Evolution Strategy (CMA-ES)
A Tensorflow 2 implementation.
What is CMA-ES?
The CMA-ES (Covariance Matrix Adaptation Evolution Strategy) is an evolutionary algorithm for difficult non-linear non-convex black-box optimisation problems in continuous domain. It is considered as state-of-the-art in evolutionary computation and has been adopted as one of the standard tools for continuous optimisation in many (probably hundreds of) research labs and industrial environments around the world.
Installation
The package is available on PyPI and can be installed with pip:
pip install cma-es
Example Usage
1. Define the fitness function
CMA attempts to minimize a user-defined fitness function.
Function signature:
Args:
x: tf.Tensor of shape (M, N)
Returns:
Fitness evaluations: tf.Tensor of shape (M,)
Where M
is the number of solutions to evaluate and N
is the dimension of a single solution.
def fitness_fn(x):
"""
Six-Hump Camel Function
https://www.sfu.ca/~ssurjano/camel6.html
"""
return (
(4 - 2.1 * x[:,0]**2 + x[:,0]**4 / 3) * x[:,0]**2 +
x[:,0] * x[:,1] +
(-4 + 4 * x[:,1]**2) * x[:,1]**2
)
2. Configure CMA-ES
from cma import CMA
cma = CMA(
initial_solution=[1.5, -0.4],
initial_step_size=1.0,
fitness_function=fitness_fn,
)
The initial solution and initial step size (i.e. initial standard deviation of the search distribution) are problem specific.
The population size is automatically set by default, but it can be overidden by specifying the parameter population_size
.
For bounded constraint optimization problems, the parameter enforce_bounds
can be set, e.g. enforce_bounds=[[-2, 2], [-1, 1]]
for a 2D function.
3. Run the optimizer
The search method runs until the maximum number of generation is reached or until one of the early termination criteria is met. By default, the maximum number of generations is 500.
best_solution, best_fitness = cma.search()
The notebook Example 1 - Six Hump Camel Function
goes into more details, including ways to plot the optimization path such as in the figure below.
Logging
A user-defined callback function can be specified to inspect variables during the search.
It is mainly intended for logging purpose, e.g:
max_epochs = 500
def logging_function(cma, logger):
if cma.generation % 10 == 0:
fitness = cma.best_fitness()
logger.info(f'Generation {cma.generation} - fitness {fitness}')
if cma.termination_criterion_met or cma.generation == max_epochs:
sol = cma.best_solution()
fitness = cma.best_fitness()
logger.info(f'Final solution at gen {cma.generation}: {sol} (fitness: {fitness})')
cma = CMA(
initial_solution=[1.5, -0.4],
initial_step_size=1.0,
fitness_function=fitness_fn,
callback_function=logging_function,
)
cma.search(max_epochs)
Check out an example logging progress to TensorBoard: tensorboard_example.py
Run on a GPU
By virtue of being written using TensorFlow, it is trivial to run CMA on a GPU:
with tf.device('/GPU:0'):
cma.search()
More examples
- Jupyter notebooks with examples are available:
- Unit tests provide a few more examples:
cma/core_test.py
Resources
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file cma-es-1.5.0.tar.gz
.
File metadata
- Download URL: cma-es-1.5.0.tar.gz
- Upload date:
- Size: 11.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9bfa5b9f8a47f8ddf8cb3baa6d417de72536546720e56ac0ba393dfdcb9d19f0 |
|
MD5 | 6ae7edb502b39a2f862432f2ec100564 |
|
BLAKE2b-256 | 844a65a5d171a85c3ce6c3797d8b418fa3a8f5bfb7d1a832d90590631cc4c3dd |
File details
Details for the file cma_es-1.5.0-py3-none-any.whl
.
File metadata
- Download URL: cma_es-1.5.0-py3-none-any.whl
- Upload date:
- Size: 10.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a93eac1b3ca8d2e8e7975ac4e748cad4b809f6d68c8e813a9e13495c5f5abc89 |
|
MD5 | 0577d17bed374550a13c53e58a2bfcba |
|
BLAKE2b-256 | 2b4a5d568d63eb277f5565890cda7bcf0c4adbc78afd59f1eef3e296c5b6123d |