Skip to main content

PyTorch-like hooks for TensorFlow Keras layers

Project description

tensorflow-hooks: PyTorch-like hooks for TensorFlow Keras layers

One of PyTorch's many strengths are its torch.nn.Module hooks. Inspired by this issue, this utility aims to provide a similar functionality as PyTorch's forward pre-hooks and hooks to TensorFlow Keras layers.

Note: Backward hooks are not supported and are not planned to be supported at this time.

Prerequisites

TensorFlow should be installed. tensorflow-hooks is tested with versions 2.14.1 and above.

Installation

Install via pip via

pip install tensorflow-hooks

or clone this repo then use

pip install .

Using a Forward Pre-hook

A forward pre-hook callable must be:

fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict) -> Union[None, Tuple[tuple, dict]]

During the layer's forward method, the hook will execute before the layer's forward pass. The hook either returns None as a passthrough (or inputs have been modified in-place) or tuple, dict to provide the arguments and keyword arguments that the layer will receive.

tf-hooks registers the hook and modifies the layer's call method via tf_hooks.register_forward_pre_hook.

For example:

import tensorflow as tf
from tf_hooks import register_forward_pre_hook

model = tf.keras.applications.resnet50.ResNet50()

def prehook_fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict):
    print(f"{layer.name} args received: {args}")
    print(f"{layer.name} kwargs received: {kwargs}")

prehooks = []
for layer in model.layers:
    prehooks.append(register_forward_pre_hook(layer, prehook_fn))

test_input = tf.random.uniform((4, 224, 224, 3))

test_output = model(test_input)

The above would result in printing out all the inputs seen by each layer. If the received arguments / keyword arguments were modified, or new ones provided, this would affect layer computation.

Each item in the prehooks list above is a tf_hook.hooks.TFForwardPreHook. To register a hook, use prehook.remove(). For example, to remove all the hooks applied above:

for prehook in prehooks:
    prehook.remove()

Notes:

  • Multiple pre-hooks can be applied to the same layer via using register_forward_pre_hook again, and each pre-hook will execute in the order it was registered for the layer. To prepend a pre-hook, use prepend=True when using register_forward_pre_hook.

Using a Forward Hook

A forward hook callable must be:

fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict, outputs: Union[tf.Tensor, tuple]) -> Union[None, tf.Tensor, tuple]

After the layer's forward method, the hook will execute, using the layer inputs and outputs. The hook either returns None as a passthrough (or outputs have been modified in-place) or whatever objects the hook function returns. These will be provided to the next layer(s).

tf-hooks registers the hook and modifies the layer's call method via tf_hooks.register_forward_hook.

For example:

import tensorflow as tf
from tf_hooks import register_forward_hook
from typing import Union

model = tf.keras.applications.resnet50.ResNet50()

def hook_fn(layer: tf.keras.layers.Layer, args: tuple, kwargs: dict, outputs: Union[tf.Tensor, tuple]):
    print(f"{layer.name} args received: {args}")
    print(f"{layer.name} kwargs received: {kwargs}")
    print(f"{layer.name} outputs: {outputs}")

hooks = []
for layer in model.layers:
    hooks.append(register_forward_hook(layer, hook_fn))

test_input = tf.random.uniform((4, 224, 224, 3))

test_output = model(test_input)

The above would result in printing out all layers' inputs and outputs. If the received outputs were modified, or new ones provided, this would affect downstream layers.

Each item in the hooks list above is a tf_hook.hooks.TFForwardHook. To register a hook, use hook.remove(). For example, to remove all the hooks applied above:

for hook in hooks:
    hook.remove()

Notes:

  • Multiple hooks can be applied to the same layer via using register_forward_hook again, and each hook will execute in the order it was registered for the layer. To prepend a hook, use prepend=True when using register_forward_hook.
  • Should you want the hook to always be called even if an exception occurs, use always_call=True when using register_forward_hook.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorflow-hooks-1.0.0.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

tensorflow_hooks-1.0.0-py3-none-any.whl (8.7 kB view details)

Uploaded Python 3

File details

Details for the file tensorflow-hooks-1.0.0.tar.gz.

File metadata

  • Download URL: tensorflow-hooks-1.0.0.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.8

File hashes

Hashes for tensorflow-hooks-1.0.0.tar.gz
Algorithm Hash digest
SHA256 0a89b9091482a9ff461d05b2a5194897bd6bb26f2ec04f514f0da9d10ca10d2a
MD5 ac7c92f3858ff0c038d83e2e66946bd0
BLAKE2b-256 3b628f0a6f451641c773fffb65c71210b25c9ae6b9d97615701ae3de19b416fa

See more details on using hashes here.

File details

Details for the file tensorflow_hooks-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for tensorflow_hooks-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6897e1cd303a7cc458f28eeb3d48d90ca7b46d9e9b5da3563da1d709dc2cf68b
MD5 8d718ecd7e110f58b3b6e2a1c9eac1f4
BLAKE2b-256 01b9e05713cc199970663bbe225bf849763633e366d8fb9618c1be922127c016

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page