Gradient Accumulation Optimizer Wrapper for Keras/TensorFlow
Project description
GA Optimizer
Info
GA Optimizer is a wrapper for Keras/TensorFlow optimizers to enable gradient accumulation, allowing you to simulate larger batch sizes than your hardware can handle.
Getting Started
Installation
Pip
You can install the package using pip:
pip install ga-optimizer
Clone repo
Alternatively, you can clone the repository directly from GitHub and install it manually:
git clone https://github.com/kimjansheden/GAOptimizer.git
cd GAOptimizer
pip install .
Usage
Here's an example of how to use GA Optimizer in your TensorFlow/Keras project:
import tensorflow as tf
from tensorflow.keras import layers
from ga_optimizer import make_ga_optimizer
# Define a simple model
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dense(10, activation='softmax')
])
# Define loss and metrics
loss = tf.keras.losses.SparseCategoricalCrossentropy()
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
# Define base optimizer
base_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# Wrap the optimizer with GA Optimizer
ga_optimizer = make_ga_optimizer(
desired_batch_size=64,
batch_size=8,
base_optimizer=base_optimizer,
log_level=Ga_Optimizer.LOG_PARANOID
)
# Compile your model with GA optimizer
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
# Generate some dummy data
import numpy as np
x_train = np.random.random((1000, 784))
y_train = np.random.randint(10, size=(1000,))
# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=batch_size)
Functions and Classes
make_ga_optimizer
The make_ga_optimizer
function wraps a given Keras/TensorFlow optimizer to enable gradient accumulation.
Arguments
desired_batch_size
(int): The effective batch size you want to simulate.batch_size
(int): The actual batch size that your hardware can handle.base_optimizer
(tf.keras.optimizers.Optimizer): The base optimizer to wrap.base_optimizer_params
(dict, optional): Parameters for the base optimizer. Only needed if you have any params in your base_optimizer and you're on a Mac where optimizer gets converted to legacy. Defaults toNone
.log_level
(int, optional): Logging level. Defaults tooptimizers.Optimizer.LOG_NONE
.
Optimizer
Class
This class is a wrapper for Keras optimizers to support gradient accumulation.
Logging Levels
The GA Optimizer supports different logging levels:
LOG_NONE
: No logs.LOG_INFO
: Informational messages.LOG_DEBUG
: Debug messages, more verbose.LOG_PARANOID
: Paranoid debug messages, extremely verbose.LOG_EXTREMELY_PARANOID
: Extremely paranoid debug messages, extremely, extremely verbose.
Tests
To run the tests:
python -m unittest discover -s tests
License
This project is licensed under the MIT License. See the LICENSE file for details.
Contributions
Contributions are welcome! Please open an issue or submit a pull request on GitHub.
Acknowledgements
I would like to acknowledge and thank the developers of the RunAI GA Optimizer project, which served as the foundation for this work. Their innovative approach to gradient accumulation inspired and enabled the development of this extended version.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file ga-optimizer-0.1.1.tar.gz
.
File metadata
- Download URL: ga-optimizer-0.1.1.tar.gz
- Upload date:
- Size: 15.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 67eeeabab96f88a8ea89e5e625ab5d15e1c602b74b5515f23f8a62ca64504f32 |
|
MD5 | fb7967920269abf37064c9b1c9878ca7 |
|
BLAKE2b-256 | c984081115176f922cf72b4f3a7c3e88d91db1f9c2160ac07b6213c695441d45 |
File details
Details for the file ga_optimizer-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: ga_optimizer-0.1.1-py3-none-any.whl
- Upload date:
- Size: 19.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3f212f41b5e1b48562cdaceddf68b7edff8f693f261974cf708f875d6d1e7848 |
|
MD5 | 59078cd8ce55f2ead9cc0f694b8c6dd0 |
|
BLAKE2b-256 | 522814b0e2adba5e537eb5f6279597c19cb4e37a860238e8a8a082ed653a20fc |