Skip to main content

A gradient descent optimizer that helps fitting multivariate nonlinear curves to data

Project description

Multivariate Nonlinear Gradient Descent Curve Fitting

This package provides a multivariate nonlinear curve fitter that utilizes gradient descent for optimization based on mean square error. It is designed for solving complex curve-fitting problems where the relationship between input data and target variables is nonlinear. The package supports numerical gradient calculation by default but allows users to specify an analytical gradient function for more complex models to improve performance and precision.

Features

  • Flexible Model Definition: Users can define their own model functions and gradients (optional).
  • Gradient Descent Optimization: Automatically uses numerical gradients but can switch to analytical gradients when specified.
  • Data Scaling: Optional data scaling for more stable optimization.
  • Customizable Learning Rate and Parameters: Users can adjust the learning rate, decay factor, and other parameters to fine-tune the optimization process.

Installation

To install the package, use the following command:

pip install curvefit_gd

You can then import the FunctionFitter using the following command

from curvefit_gd import FunctionFitter

How it works

The optimization process follows these steps:

1. Specify the Model Function:

The model function defines the relationship between your input data and the target output. An example model could be a combination of exponential and quadratic terms.

Example Model Structure:

$$f(x_1, x_2) = c_0 \cdot e^{x_1} + (1 + c_1 \cdot x_2^2)$$

def model_function(x, coefficients):

    x1, x2 = x[0], x[1]
    func = coefficients[0] * np.exp(x1) + (1 + coefficients[1] * x2**2)
    return func

Where x1, x2 are input features, and c0, c1 are the coefficients to be learned.

2. Optionally Specify the Gradient Function:

For more complex models, you may want to provide your own gradient function to improve optimization accuracy.

Example Gradient Terms:

To minimize the loss function using gradient descent, the gradients of the model with respect to the coefficients are calculated as follows:

The gradient with respect to (c0) is:

$$\frac{\partial f(x_1, x_2)}{\partial c_0} = e^{x_1}$$

The gradient with respect to (c1) is:

$$\frac{\partial f(x_1, x_2)}{\partial c_1} = x_2^2$$

def gradient_terms(x_data, coefficients):

    term1 = np.exp(x_data[0])
    term2 = x_data[1]**2
    return np.array([term1, term2])

3. Fit the Model:

Provide your data and the model will fit() the curve to the data using gradient descent, adjusting coefficients to minimize the error.

4. Predict New Values:

After training, use the predict() method to generate predictions based on new input data.

Scaling Data

It is highly recommended to scale your data for better stability in the optimization process. However, if you choose to scale your data, you must:

  • Use the same scaling parameters for any future input data, ideally through the predict() method.
  • If you choose not to scale your data, the coefficients will be easier to interpret but may result in less stable optimization.

Class and Methods

FunctionFitter

The primary class for performing curve fitting using gradient descent.

Constructor

FunctionFitter(model_func, learning_rate=1e-3, decay_factor=0, max_iterations=100000,user_gradients=None, error_tolerance=1e-5, gradient_tolerance=1e-5)

Parameters

  • model_func (callable): The function defining the relationship between input data and target values. (required)

  • learning_rate (float): The initial learning rate for the optimizer. Default is 1e-3. (optional)

  • decay_factor (float): Decay factor for the learning rate over iterations. Default is 0. (optional)

  • max_iterations (int): Maximum number of iterations to perform during optimization. Default is 100,000. (optional)

  • user_gradients (callable): A user-defined function for calculating gradients. Default is None (numerical gradients will be used). (optional)

  • error_tolerance (float): The threshold for convergence based on error reduction. Default is 1e-5. (optional)

  • gradient_tolerance (float): The threshold for convergence based on the gradient's norm. Default is 1e-5. (optional)


fit(x_data, y_data)

Fit the model to the input data using gradient descent.

  • x_data (numpy.ndarray): Input data (features).

  • y_data (numpy.ndarray): Target values.


predict(x_data)

Generate predictions using the optimized model.

  • x_data (numpy.ndarray): Input data (features).

get_coefficients()

Return the optimized coefficients after fitting the model.


get_error()

Return the final error (mean squared error) after the optimization process.

Example Usage

  1. Define your model function.
  2. (Optional) Define your gradient function for complex models.
  3. Fit the model to your data.
  4. Use the trained model to make predictions.

For a detailed example, refer to the gradient_optimizer_example.py file in this repository.

Recommendations

  • Data Scaling: We highly recommend scaling your data for better optimization stability. However, if you scale, the coefficients must be applied to scaled inputs. The predict() method handles this if scaling is used.

  • Model Tuning: Start with a basic set of parameters (learning rate, decay, etc.) and adjust based on model performance.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

curvefit_gd-1.3.0.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

curvefit_gd-1.3.0-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file curvefit_gd-1.3.0.tar.gz.

File metadata

  • Download URL: curvefit_gd-1.3.0.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for curvefit_gd-1.3.0.tar.gz
Algorithm Hash digest
SHA256 d9251edb272fa0121c9b6667262f2daf9c908552579c3bfb9a69dd4afd24fefe
MD5 09f4495c8ca8d110261c6563c3a774a7
BLAKE2b-256 9ef58c7c6fdad2500d884a27650f7d6150abb908192e2dd40e9b8ea47ad1422f

See more details on using hashes here.

Provenance

File details

Details for the file curvefit_gd-1.3.0-py3-none-any.whl.

File metadata

  • Download URL: curvefit_gd-1.3.0-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for curvefit_gd-1.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4d14ec0c78f53823835377f2cfcc1d5d4e5a2bb88420004c9a5b492fc44fe87e
MD5 bd40f30ad104407a1f53e1adaf2ac8b5
BLAKE2b-256 ead4769feab84cc8f1c92ba1d9cce9c10abdffae06ed629dbf557dfcca74d2e1

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page