Skip to main content

PyLevel Optimisation: A Python library for bilevel optimisation

Project description

PyLevel Optimisation (PyLOpt)

PyLOpt is a PyTorch-based library for learning hyperparameters $\theta$ within the context of image reconstruction by means of solving the bilevel problem

$$(P_\text{bilevel}) ~~~~~\inf_{\theta} F(u_{\theta}, u^{(0)}) ~~~ \text{s.t.}
~~~ u_{\theta}\in\mathop{\text{arginf}}_{u}E(u, u^{(\delta)}, \theta)$$

The function $F$ refers to the upper loss function quantifying the goodness of the learned $\theta$ w.r.t. groundtruth data $u^{(0)}$. $E$ denotes the lower cost or energy function, which is used to reconstruct clean data $u^{(0)}$ from noisy observations $u^{(\delta)}$. We assume that $E$ is of the form

$$E(u, u^{(\delta)}, \theta) = \frac{1}{2\sigma^{2}}|u - u^{(\delta)}|_{2}^{2} + \lambda R(u, \theta)$$

where $\sigma$ indicates the noise level, $R$ refers to a regulariser and $\lambda>0$ denotes a regularisation parameter. We consider trainable regulariser modelled by means of a Fields of Experts (FoE) model: Let $k_{1}, \ldots, k_{m}$ denote quadratic image filters and let $\rho(\cdot| \gamma)$ be a parameter-dependent potential function. For the potential parameters $\gamma_{1}, \ldots, \gamma_{m}$, and the merged parameter $\theta$ comprising filters and potential parameters, we define

$$R(u, \theta) = \sum_{j}\sum_{i}\rho([k_{j} * u]{i}|\gamma{j}),$$

where in the inner sum we sum up all elements of the $j$-th filter response.

Objective

PyLOpt aims to serve as a toolbox for scientists and engineers to address bilevel problems in imaging supporting different gradient-based solution methods. The package is modular and extendable by design and follows familiar interfaces from pupular Python packages such as SciPy ([2]), scikit-learn ([8]) and PyTorch ([9]).

Table of contents

Features

Current features

  • Image reconstruction using pretrained filter and potential models by solving the lower level problem $\mathop{\text{arginf}}_{u}E(u, u^{(\delta)}, \theta)$ using one of the following gradient-based methods

    • NAG: Nesterov accelerated gradient method - see f.e. [5]
    • NAPG: Proximal gradient method with Nesterov acceleration - see f.e. [4]
    • Adam: See [6]
    • Unrolling: Both NAG and NAPG are implemented as well using an unrolled approach. This allows to solve the upper-level problem using automatic differentiation.
  • Training of filters and/or potentials of an FoE regualriser model by solving the bilevel problem $P_\text{bilevel}$. The training relies on the application of one the following gradient-based methods onto the upper problem:

    • NAG
    • Adam
    • LBFGS: Quasi-Newton method - see [7]

    Gradients of solutions of the lower level problem w.r.t. the parameter $\theta$ are computed or by implicit differentiation, or automatic differentiation provided the lower problem is solved by means of an unrolling scheme.

  • Modularity and extensibility: The package is modular by design. Its architecture allows easy customization and extension. Each of the core components

    • ImageFilter
    • Potential
    • FieldsOfExperts
    • Energy

    is encapsulated in its own module. Thus, all of these components can be exchanged easily without any need to modify the core logic. In addition, methods for the solution of the lower problem (solve_lower.py) and the upper problem (solve_bilevel.py) can easily be added.

  • The repository contains pretrained models and sample scripts and notebooks showing the application of the package for image denoising.

Upcoming features

  • Sampling based approach for solving inner problem

Installation

  • Through pip

    pip install pylopt
    
  • From source

    git clone https://github.com/VLOGroup/pylopt.git
    

Core components

The FoE regulariser is implemented via the FieldsOfExperts class. It combines an ImageFilter, which defines the convolutional filters applied to the image, and a subclass of Potential, which models the corresponding potential functions. The lower problem is modelled by the PyTorch module Energy, which represents the energy function to be minimised. An object of this class contains an M̀easurementModel instance, a PyTorch module modeling the measurement process, and a FieldsOfExperts instance as its components.

Image reconstruction or the solution of the lower level problem is carried out by the function solve_lower(). The training of filters and potentials is managed by the class BilevelOptimisation.

For the usage of the package and its methods see section Usage.

Usage

Conceptual

The interface of the function solve_lower() which is used to solve the lower level problem is designed to closely follow the conventions of SciPy optimisation routines. Given an Energy instance, the corresponding lower level problem can be solved for example using Nesterov's accelerated gradient method (NAG) via

lower_prob_result = solve_lower(energy=energy, method='nag', 
                                options={'max_num_iterations': 1000, 'rel_tol': 1e-5, 'batch_optimisation': False})

The upper-level optimisation, i.e. training of filters and potential parameters, follows conventions of scikit-learn for interface design and usability. Training these parameters using Adam for the upper level optimisation and NAPG for the lower level optimisation is obtained via

prox = DenoisingProx(noise_level=noise_level)
bilevel_optimisation = BilevelOptimisation(method_lower='napg',
                                           options_lower={'max_num_iterations': 1000, 'rel_tol': 1e-5, 'prox': prox, 
                                                          'batch_optimisation': False}, 
                                           operator=torch.nn.Identity(),
                                           noise_level=0.1, 
                                           solver='cg', options_solver={'max_num_iterations': 500},
                                           path_to_experiments_dir=path_to_eval_dir)

bilevel_optimisation.learn(regulariser, lam, l2_loss_func, train_image_dataset,
                           optimisation_method_upper='adam', 
                           optimisation_options_upper={'max_num_iterations': 10000, 'lr': [1e-3, 1e-1], 
                                                       'alternating': True},
                           dtype=dtype, device=device, callbacks=callbacks, schedulers=schedulers)

Concrete

Concrete and executable code for training and prediction is contained in pylopt/examples. Please note that reproducibility of training results can be obtained only when using the datatype torch.float64. However, this comes at the cost of increased computation time.

Denoising using pretrained models

  • Example I

    • Filters: Pretrained filters from [1]
    • Potential:
      • Type: Student-t
      • Weights: Optimised using pylopt

    To run the script, execute

    python examples/scripts/denoising_predict.py
    

    Alternatively, run the Jupyter notebook denoising_predict.ipynb. Denoising the images watercastle and koala of the well known BSDS300 dataset (see [3]), we obtain

    Method Options mean PSNR [dB] Iter Time [s] on GPU
    'nag' {'max_num_iterations': 1000, 'rel_tol': 1e-5, 'batch_optimisation': False, 'lip_const': 1e1} 29.199 312 0.988
    'napg' {'max_num_iterations': 1000, 'rel_tol': 1e-5, 'prox': ..., 'batch_optimisation': False, 'lip_const': 1} 29.207 361 1.577
    'adam' {'max_num_iterations': 1000, 'rel_tol': 1e-5, 'lr': [1e-3, 1e-3], 'batch_optimisation': False} 28.833 1000 1.667

    and, when using the NAG optimiser:

Training of FoE models

The script denoising_train.py contains several setups for training filters and/or potential functions. To run the script with the corresponding setup, ececute

python examples/scripts/denoising_train.py --example <example_id>

with example_id in {training_I, training_II, training_III}. In the following an overview of these examples is presented.

  • Example I (example_id = training_I)

    • Filters:
      • Pretrained filters from [1]
      • Frozen, e.g. non-trainable
    • Potential:
      • Type: Student-t
      • Weights:
        • Uniform initialisation
        • Trainable
    • Lower level: NAPG
    • Upper level Adam
    Training stats Potential weight stats Test triplet
  • Example II (example_id = training_II)

    • Filters:
      • Random initialisation
      • Trainable
    • Potential:
      • Type: Student-t
      • Weights:
        • Uniform initialisation
        • Trainable
    • Optimiser:
      • Inner: NAPG
      • Outer: Adam
  • Example III ((example_id = training_III))

Contributing

  1. Fork the repository
  2. Create a feature branch
  3. Submit a pull request

References

[1] Chen, Y., Ranftl, R. and Pock, T., 2014. Insights into analysis operator learning: From patch-based sparse models to higher order MRFs. IEEE Transactions on Image Processing, 23(3), pp.1060-1072.

[2] Pauli Virtanen, Ralf Gommers, Travis E. Oliphant, Matt Haberland, Tyler Reddy, David Cournapeau, Evgeni Burovski, Pearu Peterson, Warren Weckesser, Jonathan Bright, Stéfan J. van der Walt, Matthew Brett, Joshua Wilson, K. Jarrod Millman, Nikolay Mayorov, Andrew R. J. Nelson, Eric Jones, Robert Kern, Eric Larson, CJ Carey, İlhan Polat, Yu Feng, Eric W. Moore, Jake VanderPlas, Denis Laxalde, Josef Perktold, Robert Cimrman, Ian Henriksen, E.A. Quintero, Charles R Harris, Anne M. Archibald, Antônio H. Ribeiro, Fabian Pedregosa, Paul van Mulbregt, and SciPy 1.0 Contributors, 2020. SciPy 1.0: Fundamental Algorithms for Scientific Computing in Python. Nature Methods, 17(3), 261-272.

[3] Martin, D., Fowlkes, C., Tal, D. and Malik, J., 2001, July. A database of human segmented natural images and its application to evaluating segmentation algorithms and measuring ecological statistics. In Proceedings eighth IEEE international conference on computer vision. ICCV 2001 (Vol. 2, pp. 416-423). IEEE.

[4] Beck, A., 2017. First-order methods in optimization. Society for Industrial and Applied Mathematics.

[5] d’Aspremont A, Scieur D, Taylor A. Acceleration methods. Foundations and Trends® in Optimization. 2021 Dec 14;5(1-2):1-245.

[6] Kingma DP. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980. 2014.

[7] Nocedal, J. and Wright, S.J., 2006. Numerical optimization. New York, NY: Springer New York.

[8] Pedregosa, F., Varoquaux, G., Gramfort, A., Michel, V., Thirion, B., Grisel, O., Blondel, M., Prettenhofer, P., Weiss, R., Dubourg, V. and Vanderplas, J., 2011. Scikit-learn: Machine learning in Python. the Journal of machine Learning research, 12, pp.2825-2830.

[9] Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L. and Desmaison, A., 2019. Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems, 32.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pylopt-0.1.0.tar.gz (59.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pylopt-0.1.0-py3-none-any.whl (80.0 kB view details)

Uploaded Python 3

File details

Details for the file pylopt-0.1.0.tar.gz.

File metadata

  • Download URL: pylopt-0.1.0.tar.gz
  • Upload date:
  • Size: 59.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for pylopt-0.1.0.tar.gz
Algorithm Hash digest
SHA256 86dc7d2090988cf4e59543884e70a6a61b4e534a0e11023e1ec33740aeb89e8a
MD5 d6aa3630f3dce924f592e605aed0e388
BLAKE2b-256 7a98d8180467ed9dc649b1dd6b3c8b24a26acc4c9df153233d4578b1d577ba5b

See more details on using hashes here.

File details

Details for the file pylopt-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pylopt-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 80.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for pylopt-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 911eeb0204f36e9ba2296ba3cbf7b8a36ab83d577da819b538c000f7eae728ac
MD5 5b75fcc656551dcd67b493fd28b4e869
BLAKE2b-256 11f67d80167b3c9384722b6f3162dbc50f7c998aa13071c1cf0928ad15d73ad5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page