Continuation Methods for Deep Neural Networks.
Project description
continuation-jax : Continuaion Framework for lambda
Continuation methods of Deep Neural Networks Tags: optimization, deep-learning, homotopy, bifurcation-analysis, continuation
Install using pip:
pip install continuation-jax
Import and version
import cjax
print(cjax.__version__)
Examples:
- Examples: https://github.com/harsh306/continuation-jax/tree/main/examples
- Sample Runner: https://github.com/harsh306/continuation-jax/blob/main/run.py
from cjax.continuation.creator.continuation_creator import ContinuationCreator
from examples.toy.vectror_pitchfork import SigmoidFold
from cjax.utils.abstract_problem import ProblemWraper
import json
from jax.config import config
from datetime import datetime
from cjax.utils.visualizer import bif_plot, pick_array
config.update("jax_debug_nans", True)
# TODO: use **kwargs to reduce params
if __name__ == "__main__":
problem = SigmoidFold()
problem = ProblemWraper(problem)
with open(problem.HPARAMS_PATH, "r") as hfile:
hparams = json.load(hfile)
start_time = datetime.now()
if hparams["n_perturbs"] > 1:
for perturb in range(hparams["n_perturbs"]):
print(f"Running perturb {perturb}")
continuation = ContinuationCreator(
problem=problem, hparams=hparams, key=perturb
).get_continuation_method()
continuation.run()
else:
continuation = ContinuationCreator(
problem=problem, hparams=hparams
).get_continuation_method()
continuation.run()
end_time = datetime.now()
print(f"Duration: {end_time-start_time}")
bif_plot(hparams['output_dir'], pick_array, hparams['n_perturbs'])
Contact:
harshnpathak@gmail.com
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
continuation_jax-0.0.2.tar.gz
(26.6 kB
view hashes)