Skip to main content

Universal Approximation Theorem in JAX

Project description

uat

Universal Approximation Theorem in JAX

pip install uat

Universal Approximation Theorem

The Universal Approximation Theorem states that a feedforward neural network with a single hidden layer containing a finite number of neurons can approximate any continuous function on compact subsets of $\mathbb{R}^n$, given appropriate activation functions.

General Formula

The general formula for a neural network with one hidden layer can be expressed as:

$$f(x) = \sum_{i=1}^{N} c_i \sigma(a_i \cdot x + b_i)$$

where:

  • $x$ is the input vector.
  • $N$ is the number of neurons in the hidden layer.
  • $a_i$ and $b_i$ are the weights and biases of the neurons in the hidden layer.
  • $c_i$ are the weights of the output layer.
  • $\sigma$ is the activation function (e.g., sigmoid, tanh).

Activation Function

The Universal Approximation Theorem (UAT) hinges on the choice of activation function used in the neural network. Not all activation functions are suitable for ensuring that a neural network can approximate any continuous function on compact subsets of $\mathbb{R}^n$. Here are the key requirements and considerations for activation functions in the context of the UAT:

Requirements

  1. Non-linearity: The activation function must be non-linear. Linear activation functions (like the identity function) do not introduce the necessary complexity for the network to approximate non-linear functions.

  2. Boundedness: The activation function should be bounded. This means that the output of the activation function should lie within a fixed range. For example, the sigmoid function outputs values in the range (0, 1).

  3. Continuity: The activation function should be continuous. Discontinuous activation functions can lead to issues in training and may not satisfy the conditions of the UAT.

  4. Non-constant: The activation function should not be a constant function. A constant activation function would not allow the network to learn any meaningful patterns from the input data.

Common Activation Functions

Here are some commonly used activation functions that satisfy the requirements of the UAT:

  • Sigmoid: $\sigma(x) = \frac{1}{1 + e^{-x}}$

    • Bounded between 0 and 1.
    • Non-linear and continuous.
  • Tanh: $\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$

    • Bounded between -1 and 1.
    • Non-linear and continuous.

Creating a Model

To create a model, use the create function. This function initializes the parameters of the model.

from jax.random import PRNGKey
from uat import create

key = PRNGKey(0)
input_dim = 2
output_dim = 1
neurons = 2
dims = (input_dim, output_dim, neurons)
params = create(key, dims)

Applying the Model

To apply the model to an input, use the apply function. This function computes the output of the model given the input and the model parameters.

from jax.numpy import array
from uat import apply

x = array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=float)
output = apply(params, x)

Explanation of the apply Function

The apply function computes the output of the model using the following formula:

def apply(params, x):
    a, b, c = params
    return sigmoid(x @ a + b) @ c

Here's a step-by-step explanation of the formula:

  1. Parameter Unpacking: The parameters params are unpacked into three components: a, b, and c.

    • a is a matrix of shape (input_dim, neurons).
    • b is a bias vector of shape (1, neurons).
    • c is a weight matrix of shape (neurons, output_dim).
  2. Matrix Multiplication: The input x (of shape (n_samples, input_dim)) is multiplied by the matrix a using the @ operator. This results in a matrix of shape (n_samples, neurons).

  3. Bias Addition: The bias vector b is added to the result of the matrix multiplication. Broadcasting is used to add b to each row of the matrix, resulting in a matrix of shape (n_samples, neurons).

  4. Sigmoid Activation: The sigmoid function sigmoid is applied element-wise to the result of the bias addition. This introduces non-linearity into the model.

  5. Output Calculation: The resulting matrix (of shape (n_samples, neurons)) is then multiplied by the weight matrix c using the @ operator. This results in the final output matrix of shape (n_samples, output_dim).

Training the Model on XOR

To train the model on the XOR problem, you can use the following code. This code uses stochastic gradient descent (SGD) to optimize the model parameters.

from optax import sgd, apply_updates
from jax.numpy import array, allclose, abs, mean
from jax.random import PRNGKey
from jax import grad, jit
from uat import create, apply

# Initialize parameters
key = PRNGKey(0)
input_dim = 2
output_dim = 1
neurons = 2
dims = (input_dim, output_dim, neurons)
params = create(key, dims)

# XOR input and output
x = array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=float)
y = array([[0], [1], [1], [0]], dtype=float)

# Define optimizer
optimizer = sgd(learning_rate=0.1, momentum=0.9)
state = optimizer.init(params)

# Define loss function
def loss(params, x, y):
    y_hat = apply(params, x)
    return mean(abs(y - y_hat))

# Define training step
@jit
def fit(state, params, x, y):
    grads = grad(loss)(params, x, y)
    updates, state = optimizer.update(grads, state)
    params = apply_updates(params, updates)
    return state, params

# Train the model
for _ in range(1000):
    state, params = fit(state, params, x, y)

# Check the output
y_hat = apply(params, x)
assert allclose(y, y_hat, atol=0.1)

This code initializes the model parameters, defines the XOR input and output, sets up the optimizer, and trains the model for 1000 iterations. Finally, it checks if the model's output is close to the expected XOR output.

Note on Optimizer Compatibility

Since the create function outputs a pytree, you can use any JAX-based optimizer library like optax to optimize the model parameters. This allows for flexibility in choosing different optimization algorithms and techniques to train your model effectively.

Creating Complex Approximation Models

By using this simple library, you can create complex approximation models. The flexibility of the Universal Approximation Theorem allows you to model a wide range of continuous functions, making it a powerful tool for various applications in machine learning and data science.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uat-0.2.0.tar.gz (4.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uat-0.2.0-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file uat-0.2.0.tar.gz.

File metadata

  • Download URL: uat-0.2.0.tar.gz
  • Upload date:
  • Size: 4.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.9.13 Windows/10

File hashes

Hashes for uat-0.2.0.tar.gz
Algorithm Hash digest
SHA256 8eb97123ff1e4e99278d7ff62f3df972819e3a6fa263d051fe900cdcc712cf01
MD5 ee0e2ce36b46c73c49fa7193957ffd95
BLAKE2b-256 44027045e9a8337faff0e71adea61ccb0b19c4ae28f4ff609b0b43961ce584f8

See more details on using hashes here.

File details

Details for the file uat-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: uat-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 4.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.9.13 Windows/10

File hashes

Hashes for uat-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bc343638483864055edf456248c1abe34984db143d2c6f75bea6b49812f4c550
MD5 240f2e2d99497e4fb1993d28e4c771d0
BLAKE2b-256 0c0877aec7656f985d4f5c053d49370c9043ff23bec2f51909dcc62e8619e67e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page