Skip to main content

Gradient Accumulation Optimizer Wrapper for Keras/TensorFlow

Project description

GA Optimizer

Info

GA Optimizer is a wrapper for Keras/TensorFlow optimizers to enable gradient accumulation, allowing you to simulate larger batch sizes than your hardware can handle.

Getting Started

Installation

Pip

You can install the package using pip:

pip install ga-optimizer

Clone repo

Alternatively, you can clone the repository directly from GitHub and install it manually:

git clone https://github.com/kimjansheden/GAOptimizer.git
cd GAOptimizer
pip install .

Usage

Here's an example of how to use GA Optimizer in your TensorFlow/Keras project:

import tensorflow as tf
from tensorflow.keras import layers
from ga_optimizer import make_ga_optimizer

# Define a simple model
model = tf.keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(784,)),
    layers.Dense(10, activation='softmax')
])

# Define loss and metrics
loss = tf.keras.losses.SparseCategoricalCrossentropy()
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]

# Define base optimizer
base_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Wrap the optimizer with GA Optimizer
ga_optimizer = make_ga_optimizer(
    desired_batch_size=64,
    batch_size=8,
    base_optimizer=base_optimizer,
    log_level=Ga_Optimizer.LOG_PARANOID
)

# Compile your model with GA optimizer
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

# Generate some dummy data
import numpy as np
x_train = np.random.random((1000, 784))
y_train = np.random.randint(10, size=(1000,))

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=batch_size)

Functions and Classes

make_ga_optimizer

The make_ga_optimizer function wraps a given Keras/TensorFlow optimizer to enable gradient accumulation.

Arguments

  • desired_batch_size (int): The effective batch size you want to simulate.
  • batch_size (int): The actual batch size that your hardware can handle.
  • base_optimizer (tf.keras.optimizers.Optimizer): The base optimizer to wrap.
  • base_optimizer_params (dict, optional): Parameters for the base optimizer. Only needed if you have any params in your base_optimizer and you're on a Mac where optimizer gets converted to legacy. Defaults to None.
  • log_level (int, optional): Logging level. Defaults to optimizers.Optimizer.LOG_NONE.

Optimizer Class

This class is a wrapper for Keras optimizers to support gradient accumulation.

Logging Levels

The GA Optimizer supports different logging levels:

  • LOG_NONE: No logs.
  • LOG_INFO: Informational messages.
  • LOG_DEBUG: Debug messages, more verbose.
  • LOG_PARANOID: Paranoid debug messages, extremely verbose.
  • LOG_EXTREMELY_PARANOID: Extremely paranoid debug messages, extremely, extremely verbose.

Tests

To run the tests:

python -m unittest discover -s tests

License

This project is licensed under the MIT License. See the LICENSE file for details.

Contributions

Contributions are welcome! Please open an issue or submit a pull request on GitHub.

Acknowledgements

I would like to acknowledge and thank the developers of the RunAI GA Optimizer project, which served as the foundation for this work. Their innovative approach to gradient accumulation inspired and enabled the development of this extended version.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ga-optimizer-0.1.1.tar.gz (15.0 kB view details)

Uploaded Source

Built Distribution

ga_optimizer-0.1.1-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file ga-optimizer-0.1.1.tar.gz.

File metadata

  • Download URL: ga-optimizer-0.1.1.tar.gz
  • Upload date:
  • Size: 15.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.7

File hashes

Hashes for ga-optimizer-0.1.1.tar.gz
Algorithm Hash digest
SHA256 67eeeabab96f88a8ea89e5e625ab5d15e1c602b74b5515f23f8a62ca64504f32
MD5 fb7967920269abf37064c9b1c9878ca7
BLAKE2b-256 c984081115176f922cf72b4f3a7c3e88d91db1f9c2160ac07b6213c695441d45

See more details on using hashes here.

File details

Details for the file ga_optimizer-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for ga_optimizer-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3f212f41b5e1b48562cdaceddf68b7edff8f693f261974cf708f875d6d1e7848
MD5 59078cd8ce55f2ead9cc0f694b8c6dd0
BLAKE2b-256 522814b0e2adba5e537eb5f6279597c19cb4e37a860238e8a8a082ed653a20fc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page