PyTorch-like hooks for TensorFlow Keras layers
Project description
tensorflow-hooks: PyTorch-like hooks for TensorFlow Keras layers
One of PyTorch's many strengths are its torch.nn.Module hooks.
Inspired by this issue,
this utility aims to provide a similar functionality as PyTorch's forward pre-hooks and hooks to TensorFlow Keras layers.
Note: Backward hooks are not supported and are not planned to be supported at this time.
Prerequisites
TensorFlow should be installed.
tensorflow-hooks is tested with versions 2.14.1 and above.
Installation
Install via pip via
pip install tensorflow-hooks
or clone this repo then use
pip install .
Using a Forward Pre-hook
A forward pre-hook callable must be:
fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict) -> Union[None, Tuple[tuple, dict]]
During the layer's forward method, the hook will execute before the layer's forward pass. The hook either returns None as a passthrough (or inputs have been modified in-place) or tuple, dict to provide the arguments and keyword arguments that the layer will receive.
tf-hooks registers the hook and modifies the layer's call method via tf_hooks.register_forward_pre_hook.
For example:
import tensorflow as tf
from tf_hooks import register_forward_pre_hook
model = tf.keras.applications.resnet50.ResNet50()
def prehook_fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict):
print(f"{layer.name} args received: {args}")
print(f"{layer.name} kwargs received: {kwargs}")
prehooks = []
for layer in model.layers:
prehooks.append(register_forward_pre_hook(layer, prehook_fn))
test_input = tf.random.uniform((4, 224, 224, 3))
test_output = model(test_input)
The above would result in printing out all the inputs seen by each layer. If the received arguments / keyword arguments were modified, or new ones provided, this would affect layer computation.
Each item in the prehooks list above is a tf_hook.hooks.TFForwardPreHook.
To register a hook, use prehook.remove(). For example, to remove all the hooks applied above:
for prehook in prehooks:
prehook.remove()
Notes:
- Multiple pre-hooks can be applied to the same layer via using
register_forward_pre_hookagain, and each pre-hook will execute in the order it was registered for the layer. To prepend a pre-hook, useprepend=Truewhen usingregister_forward_pre_hook.
Using a Forward Hook
A forward hook callable must be:
fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict, outputs: Union[tf.Tensor, tuple]) -> Union[None, tf.Tensor, tuple]
After the layer's forward method, the hook will execute, using the layer inputs and outputs. The hook either returns None as a passthrough (or outputs have been modified in-place) or
whatever objects the hook function returns. These will be provided to the next layer(s).
tf-hooks registers the hook and modifies the layer's call method via tf_hooks.register_forward_hook.
For example:
import tensorflow as tf
from tf_hooks import register_forward_hook
from typing import Union
model = tf.keras.applications.resnet50.ResNet50()
def hook_fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict, outputs: Union[tf.Tensor, tuple]):
print(f"{layer.name} args received: {args}")
print(f"{layer.name} kwargs received: {kwargs}")
print(f"{layer.name} outputs: {outputs}")
hooks = []
for layer in model.layers:
hooks.append(register_forward_hook(layer, hook_fn))
test_input = tf.random.uniform((4, 224, 224, 3))
test_output = model(test_input)
The above would result in printing out all layers' inputs and outputs. If the received outputs were modified, or new ones provided, this would affect downstream layers.
Each item in the hooks list above is a tf_hook.hooks.TFForwardHook.
To register a hook, use hook.remove(). For example, to remove all the hooks applied above:
for hook in hooks:
hook.remove()
Notes:
- Multiple hooks can be applied to the same layer via using
register_forward_hookagain, and each hook will execute in the order it was registered for the layer. To prepend a hook, useprepend=Truewhen usingregister_forward_hook. - Should you want the hook to always be called even if an exception occurs, use
always_call=Truewhen usingregister_forward_hook.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tensorflow-hooks-1.0.0.tar.gz.
File metadata
- Download URL: tensorflow-hooks-1.0.0.tar.gz
- Upload date:
- Size: 11.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a89b9091482a9ff461d05b2a5194897bd6bb26f2ec04f514f0da9d10ca10d2a
|
|
| MD5 |
ac7c92f3858ff0c038d83e2e66946bd0
|
|
| BLAKE2b-256 |
3b628f0a6f451641c773fffb65c71210b25c9ae6b9d97615701ae3de19b416fa
|
File details
Details for the file tensorflow_hooks-1.0.0-py3-none-any.whl.
File metadata
- Download URL: tensorflow_hooks-1.0.0-py3-none-any.whl
- Upload date:
- Size: 8.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6897e1cd303a7cc458f28eeb3d48d90ca7b46d9e9b5da3563da1d709dc2cf68b
|
|
| MD5 |
8d718ecd7e110f58b3b6e2a1c9eac1f4
|
|
| BLAKE2b-256 |
01b9e05713cc199970663bbe225bf849763633e366d8fb9618c1be922127c016
|