Skip to main content

Deep unfolding of iterative methods to solve linear equations

Project description

Tests codecov docs PyPI PyPI - Downloads GPLv3

deep-unfolding: Deep unfolding of iterative methods

The deep-unfolding package includes iterative methods for solving linear equations. However, due to the various parameters and performance characteristics of the iterative approach, it is necessary to optimize these parameters to improve the convergence rate. deep-unfolding takes an iterative algorithm with a fixed number of iterations $T$, unravels its structure, and adds trainable parameters. These parameters are then trained using deep learning techniques such as loss functions, stochastic gradient descent, and backpropagation.

The package contains two different modules containing iterative methods. The first, methods, includes conventional iterative methods. The second, train_methods, includes deep unfolding versions of the conventional methods.

Installation

pip install --upgrade pip
pip install deep-unfolding

Quick start

from deep_unfolding import device, evaluate_model, generate_A_H_sol, SORNet, train_model
from torch import nn, optim

total_itr = 25  # Total number of iterations
n = 300  # Number of rows
m = 600  # Number of columns
bs = 10000  # Mini-batch size (samples)
num_batch = 500  # Number of mini-batches
lr_adam = 0.002  # Learning rate of optimizer
init_val_SORNet = 1.1  # Initial value of omega for SORNet

seed = 12

A, H, W, solution, y = generate_A_H_sol(n=n, m=m, seed=seed, bs=bs)
loss_func = nn.MSELoss()

# Model
model_SorNet = SORNet(A, H, bs, y, init_val_SORNet, device=device)

# Optimizer
opt_SORNet = optim.Adam(model_SorNet.parameters(), lr=lr_adam)

trained_model_SorNet, loss_gen_SORNet = train_model(model_SorNet, opt_SORNet, loss_func, solution, total_itr, num_batch)

norm_list_SORNet = evaluate_model(trained_model_SorNet, solution, n, bs, total_itr, device=device)

Package contents

This package implements various iterative techniques for approximating the solutions of linear problems of the type $Ax = b$. The conventional methods implemented in the methods module are:

  • GS: Gauss-Seidel (GS) algorithm
  • RI: Richardson iteration algorithm
  • Jacobi: Jacobi iteration (RI) algorithm
  • SOR: Successive Over-Relaxation (SOR) algorithm
  • SORCheby: Successive Over-Relaxation (SOR) with Chebyshev acceleration algorithm
  • AOR: Accelerated Over-Relaxation (AOR) algorithm
  • AORCheby: Accelerated Over-Relaxation (AOR) with Chebyshev acceleration algorithm

This package also implements several models based on Deep Unfolding Learning, enabling optimization of the parameters of some of the preceding algorithms to obtain an optimal approximation. The models implemented in the module train_methods are:

  • SORNet: Optimization via Deep Unfolding Learning of the Successive Over-Relaxation (SOR) algorithm
  • SORChebyNet: Optimization via Deep Unfolding Learning of the Successive Over-Relaxation (SOR) with Chebyshev acceleration algorithm
  • AORNet: Optimization via Deep Unfolding Learning of the Accelerated Over-Relaxation (AOR) algorithm
  • RINet: Optimization via Deep Unfolding Learning of the Richardson iteration (RI) algorithm

Reference

If you use this software, please cite the following reference: available soon

License

GPLv3 License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep_unfolding-0.2.0.tar.gz (26.0 kB view details)

Uploaded Source

Built Distribution

deep_unfolding-0.2.0-py3-none-any.whl (22.7 kB view details)

Uploaded Python 3

File details

Details for the file deep_unfolding-0.2.0.tar.gz.

File metadata

  • Download URL: deep_unfolding-0.2.0.tar.gz
  • Upload date:
  • Size: 26.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for deep_unfolding-0.2.0.tar.gz
Algorithm Hash digest
SHA256 d78f267afd5e5f41dbe81ac6b00d33258dd8390fee7f1670dc6d88d80e85e394
MD5 4d085c50600ef9ba1d874879bbebcfb1
BLAKE2b-256 bb850db53bb41ab7dba7acc565d23c893dcb577796fbc1477d47b78e92fef827

See more details on using hashes here.

File details

Details for the file deep_unfolding-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for deep_unfolding-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0f7b474d1fce0ed80af4d5365d63352a3e9c40976cdd7b9d3eaf76e40b81f789
MD5 38ac9159c0d5d71605c304143f49f7c1
BLAKE2b-256 eb370350b4df739491a1bfdd183fa28d7a5d41bda34473daaa9c259206a88647

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page