A JAX-based L-BFGS optimizer
Project description
L-BFGS optimizer written with JAX
Features
- Implements the Limited-memory BFGS algorithm.
- JIT/vmap/pmap compatible for performance with JAX.
- Note requirements.txt is setup for JAX[CPU]
Usage
Define a function to minimize
def func(x):
jnp.sum((-1*coefficients + x)**2)
Call Lbfgs -f: function to minimize -m: number of previous iterations to store in memory -tol: tolerance of convergence
optimizer = Lbfgs(f=func, m=10, tol=1e-6)
iterate to find minimum
# Initialize optimizer state
opt_state = optimizer.init(x0)
@jax.jit
def opt_step(carry, _):
opt_state, losses = carry
opt_state = optimizer.update(opt_state)
losses = losses.at[opt_state.k].set(loss(opt_state.position))
return (opt_state, losses), _
iterations=10000 #<-- A lot of iterations!!!
losses = jnp.zeros((iterations,))
(final_state, losses), _ = jax.lax.scan(opt_step, (opt_state,losses), None, length=iterations)
#note losses will be the length of iterations
losses = jnp.array(jnp.where(losses == 0, jnp.nan, losses))
output
[-7.577116e-15 1.000000e+00 2.000000e+00 3.000000e+00 4.000000e+00
5.000000e+00 6.000000e+00 7.000000e+00 8.000000e+00 9.000000e+00
1.000000e+01 1.100000e+01 1.200000e+01 1.300000e+01 1.400000e+01
1.500000e+01 1.600000e+01 1.700000e+01 1.800000e+01 1.900000e+01
2.000000e+01 2.100000e+01 2.200000e+01 2.300000e+01 2.400000e+01
2.500000e+01 2.600000e+01 2.700000e+01 2.800000e+01 2.900000e+01
3.000000e+01 3.100000e+01 3.200000e+01 3.300000e+01 3.400000e+01
3.500000e+01 3.600000e+01 3.700000e+01 3.800000e+01 3.900000e+01
4.000000e+01 4.100000e+01 4.200000e+01 4.300000e+01 4.400000e+01
4.500000e+01 4.600000e+01 4.700000e+01 4.800000e+01 4.900000e+01
5.000000e+01 5.100000e+01 5.200000e+01 5.300000e+01 5.400000e+01
5.500000e+01 5.600000e+01 5.700000e+01 5.800000e+01 5.900000e+01
6.000000e+01 6.100000e+01 6.200000e+01 6.300000e+01 6.400000e+01
6.500000e+01 6.600000e+01 6.700000e+01 6.800000e+01 6.900000e+01
7.000000e+01 7.100000e+01 7.200000e+01 7.300000e+01 7.400000e+01
7.500000e+01 7.600000e+01 7.700000e+01 7.800000e+01 7.900000e+01
8.000000e+01 8.100000e+01 8.200000e+01 8.300000e+01 8.400000e+01
8.500000e+01 8.600000e+01 8.700000e+01 8.800000e+01 8.900000e+01
9.000000e+01 9.100000e+01 9.200000e+01 9.300000e+01 9.400000e+01
9.500000e+01 9.600000e+01 9.700000e+01 9.800000e+01 9.900000e+01]
Function value at minimum: 5.7412694e-29
k: 2 #<-- stops early if gradient norm is less than tol!!
NOTE: Examples has the quadratic function and the Rosenbrock function.
1000-dimensional Rosenbrock solved in 4038 steps
Installation
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file gradienttransformation-1.0.0.tar.gz
.
File metadata
- Download URL: gradienttransformation-1.0.0.tar.gz
- Upload date:
- Size: 7.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1d1dde18a26d6f44f6090021d704c39a92bf06a44ef87bb9f32aac2b52ba1eef |
|
MD5 | 18b5d27181eff6a712525de6ea3de8e4 |
|
BLAKE2b-256 | 8da5e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884 |
File details
Details for the file GradientTransformation-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: GradientTransformation-1.0.0-py3-none-any.whl
- Upload date:
- Size: 7.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 876b80629cbf944204b7d51d7e3ec57c4b558b270f373f079c0327de34634f90 |
|
MD5 | c1d3ecc4f8d05a2a6d8a17dd11cfebcd |
|
BLAKE2b-256 | 9864ba80981ac6016fa380c8cd8dc318bad26c7faf513e71f1586dac33facef5 |