Skip to main content

A JAX-based L-BFGS optimizer

Project description

L-BFGS optimizer written with JAX

Features

  • Implements the Limited-memory BFGS algorithm.
  • JIT/vmap/pmap compatible for performance with JAX.
  • Note requirements.txt is setup for JAX[CPU]

Usage

Define a function to minimize

def func(x): 
    jnp.sum((-1*coefficients + x)**2)

Call Lbfgs -f: function to minimize -m: number of previous iterations to store in memory -tol: tolerance of convergence

optimizer = Lbfgs(f=func, m=10, tol=1e-6)

iterate to find minimum

# Initialize optimizer state
opt_state = optimizer.init(x0)

@jax.jit
def opt_step(carry, _):
    opt_state, losses = carry
    opt_state = optimizer.update(opt_state)
    losses = losses.at[opt_state.k].set(loss(opt_state.position))
    return (opt_state, losses), _

iterations=10000   #<-- A lot of iterations!!!
losses = jnp.zeros((iterations,))
(final_state, losses), _ = jax.lax.scan(opt_step, (opt_state,losses), None, length=iterations)
#note losses will be the length of iterations
losses = jnp.array(jnp.where(losses == 0, jnp.nan, losses))

output

[-7.577116e-15  1.000000e+00  2.000000e+00  3.000000e+00  4.000000e+00
  5.000000e+00  6.000000e+00  7.000000e+00  8.000000e+00  9.000000e+00
  1.000000e+01  1.100000e+01  1.200000e+01  1.300000e+01  1.400000e+01
  1.500000e+01  1.600000e+01  1.700000e+01  1.800000e+01  1.900000e+01
  2.000000e+01  2.100000e+01  2.200000e+01  2.300000e+01  2.400000e+01
  2.500000e+01  2.600000e+01  2.700000e+01  2.800000e+01  2.900000e+01
  3.000000e+01  3.100000e+01  3.200000e+01  3.300000e+01  3.400000e+01
  3.500000e+01  3.600000e+01  3.700000e+01  3.800000e+01  3.900000e+01
  4.000000e+01  4.100000e+01  4.200000e+01  4.300000e+01  4.400000e+01
  4.500000e+01  4.600000e+01  4.700000e+01  4.800000e+01  4.900000e+01
  5.000000e+01  5.100000e+01  5.200000e+01  5.300000e+01  5.400000e+01
  5.500000e+01  5.600000e+01  5.700000e+01  5.800000e+01  5.900000e+01
  6.000000e+01  6.100000e+01  6.200000e+01  6.300000e+01  6.400000e+01
  6.500000e+01  6.600000e+01  6.700000e+01  6.800000e+01  6.900000e+01
  7.000000e+01  7.100000e+01  7.200000e+01  7.300000e+01  7.400000e+01
  7.500000e+01  7.600000e+01  7.700000e+01  7.800000e+01  7.900000e+01
  8.000000e+01  8.100000e+01  8.200000e+01  8.300000e+01  8.400000e+01
  8.500000e+01  8.600000e+01  8.700000e+01  8.800000e+01  8.900000e+01
  9.000000e+01  9.100000e+01  9.200000e+01  9.300000e+01  9.400000e+01
  9.500000e+01  9.600000e+01  9.700000e+01  9.800000e+01  9.900000e+01]

Function value at minimum: 5.7412694e-29
k:  2   #<-- stops early if gradient norm is less than tol!!

NOTE: Examples has the quadratic function and the Rosenbrock function.
1000-dimensional Rosenbrock solved in 4038 steps

Installation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gradienttransformation-1.0.0.tar.gz (7.0 kB view details)

Uploaded Source

Built Distribution

GradientTransformation-1.0.0-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file gradienttransformation-1.0.0.tar.gz.

File metadata

File hashes

Hashes for gradienttransformation-1.0.0.tar.gz
Algorithm Hash digest
SHA256 1d1dde18a26d6f44f6090021d704c39a92bf06a44ef87bb9f32aac2b52ba1eef
MD5 18b5d27181eff6a712525de6ea3de8e4
BLAKE2b-256 8da5e1ef05f136fbc1f13bb77258981f63d499b2d2416ca0066376fa7c4fc884

See more details on using hashes here.

File details

Details for the file GradientTransformation-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for GradientTransformation-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 876b80629cbf944204b7d51d7e3ec57c4b558b270f373f079c0327de34634f90
MD5 c1d3ecc4f8d05a2a6d8a17dd11cfebcd
BLAKE2b-256 9864ba80981ac6016fa380c8cd8dc318bad26c7faf513e71f1586dac33facef5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page