Skip to main content

Add your description here

Project description

HAMUX

A new class of Deep Learning model built around ENERGY.

HAMUX Logo

Part proof-of-concept, part functional prototype, HAMUX is designed to bridge modern AI architectures and Hopfield Networks.

HAMUX: A Hierarchical Associative Memory User eXperience.

Documentation. Also described in our tutorial on Associative Memory (ch. 3).

Quickstart

The code in this codebase is meant to be as minimal as possible. We do not provide a complete component library for associative memory. It consists of 1 main logic file (hamux/core.py) (<200 lines of important code), 1 file of example lagrangians (hamux/lagrangians.py) and 1 demo notebook: demo.ipynb. Intended for research code.

A Universal Abstraction for Hopfield Networks

HAMUX fully captures the the energy fundamentals of Hopfield Networks and enables anyone to:

  • 🧠 Build DEEP Hopfield nets

  • 🧱 With modular ENERGY components

  • 🏆 That resemble modern DL operations

Every architecture built using HAMUX is a formal Associative Memory (AM). That is, the architecture defines a tractable energy, whose minimization describes a dynamical system that is guaranteed to converge to a fixed point. Hierarchical Associative Memories (HAMs) have several additional advantages over traditional Hopfield Networks (HNs):

Hopfield Networks (HNs) Hierarchical Associative Memories (HAMs)
HNs are only two layers systems HAMs connect any number of layers
HNs have limited storage capacity HAMs can be used to describe Associative Memories with much denser storage
HNs model only simple relationships between layers HAMs model any complex but differentiable operation (e.g., convolutions, pooling, attention, $\ldots$)
HNs use only pairwise synapses HAMs can also use many-body synapses (which we denote HyperSynapses)

How does HAMUX work?

HAMUX is a hypergraph of 🌀neurons connected via 🤝hypersynapses, an abstraction sufficiently general to model any conceivable Associative Memory.

See our walkthrough in this notebook for a more detailed explanation of how everything works.

In summary, this library handles all the complexity of scaling modular, learnable energy functions that interconnect many layers and hypersynapses. It is a barebones framework to explore Associative Memories that look like Deep Learning architectures.

  1. Implement your favorite Deep Learning operations as a HyperSynapse
  2. Port over your favorite activation functions as Lagrangians
  3. Connect layers and hypersynapses into a hypergraph with a single total energy.
  4. Easily use autograd for descending states.

All of this made possible by JAX and equinox.

Installation

Install latest from the GitHub repo

$ pip install git+https://github.com/bhoov/hamux.git

or from pypi

pip install hamux

How to use

import hamux as hmx
import jax, jax.numpy as jnp, jax.random as jr, jax.tree_util as jtu
import equinox as eqx
from typing import *
import matplotlib.pyplot as plt
class LinearSynapse(eqx.Module):
    """The energy synapse corrolary of the linear layer in standard neural networks"""
    W: jax.Array
    def __call__(self, xhat1:jax.Array, xhat2:jax.Array):
        return xhat1 @ self.W @ xhat2

    @classmethod
    def rand_init(cls, key: jax.Array, D1: int, D2: int):
        Winit = 0.02 * jr.normal(key, (D1, D2))
        return cls(W=Winit)

key = jr.key(0)
nhid = 9
nlabel = 8
ninput = 7

neurons = {
    "input": hmx.NeuronLayer(hmx.lagr_identity, (ninput,)),
    "labels": hmx.NeuronLayer(hmx.lagr_softmax, (nlabel,)),
    "hidden": hmx.NeuronLayer(hmx.lagr_softmax, (nhid,))
}

synapses = {
    "dense1": LinearSynapse.rand_init(key, ninput, nhid),
    "dense2": LinearSynapse.rand_init(key, nlabel, nhid)
}

connections = [
    (("input", "hidden"), "dense1"),
    (("labels", "hidden"), "dense2")
]

ham = hmx.HAM(neurons, synapses, connections)
xs = ham.init_states() # No batch size
xhats = ham.activations(xs)

ham.energy_tree(xhats, xs)
ham.energy(xhats, xs)
ham.dEdact(xhats, xs)
{'hidden': Array([ 0.00705259,  0.00320656, -0.02189678,  0.00424237,  0.00248319,
        -0.00192548,  0.00498188,  0.00388546,  0.00415148], dtype=float32),
 'input': Array([ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904], dtype=float32),
 'labels': Array([ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904,  0.00254421], dtype=float32)}
vham = ham.vectorize()
bs = 3
xs = vham.init_states(bs) # Batch size 3
xhats = vham.activations(xs)

print(vham.energy_tree(xhats, xs))
print(vham.energy(xhats, xs))
print(vham.dEdact(xhats, xs))

ham = vham.unvectorize()
{'connections': [Array([0., 0., 0.], dtype=float32), Array([0.00068681, 0.00068681, 0.00068681], dtype=float32)], 'neurons': {'hidden': Array([-2.1972246, -2.1972246, -2.1972246], dtype=float32), 'input': Array([0., 0., 0.], dtype=float32), 'labels': Array([-2.0794415, -2.0794415, -2.0794415], dtype=float32)}}
[-4.275979 -4.275979 -4.275979]
{'hidden': Array([[ 0.00705259,  0.00320656, -0.02189678,  0.00424237,  0.00248319,
        -0.00192548,  0.00498188,  0.00388546,  0.00415148],
       [ 0.00705259,  0.00320656, -0.02189678,  0.00424237,  0.00248319,
        -0.00192548,  0.00498188,  0.00388546,  0.00415148],
       [ 0.00705259,  0.00320656, -0.02189678,  0.00424237,  0.00248319,
        -0.00192548,  0.00498188,  0.00388546,  0.00415148]],      dtype=float32), 'input': Array([[ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904],
       [ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904],
       [ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904]], dtype=float32), 'labels': Array([[ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904,  0.00254421],
       [ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904,  0.00254421],
       [ 0.00667361,  0.00921866,  0.00110246, -0.00476699, -0.0013505 ,
        -0.00371795, -0.00420904,  0.00254421]], dtype=float32)}

We can check that the energy descent using randomize weights behaves as expected.

# Randomize the initial states
rkey = jr.key(42)
key_tree = dict(zip(xs.keys(), jr.split(rkey, len(xs))))
xs = jtu.tree_map(lambda key, x: jr.normal(key, x.shape), key_tree, xs)
xhats = vham.activations(xs)

nsteps = 40
step_size = 0.5
energies = jnp.empty((nsteps, bs))
for i in range(nsteps):
    energy, dEdxhats = vham.dEdact(xhats, xs, return_energy=True)
    energies = energies.at[i].set(energy)
    xs = jtu.tree_map(lambda x, u: x - step_size * u, xs, dEdxhats)

plt.plot(jnp.arange(nsteps), energies)

See the nbs/_examples directory for more examples.

Developing locally

uv (in pyproject.toml) handles all dependencies, nbdev (and its settings.ini) handles all packaging. We handle syncing between the pyproject.toml and settings.ini files using scripts/sync_dependencies.py.

[!WARNING]

Package is currently based on a fork of nbdev that allows development in plain text .qmd files.

Prerequisite: Download ‘uv’

uv sync
uv run uv pip install -e .

# OPTIONAL: Add GPU enabled JAX e.g., for CUDA 12
uv run uv pip install -U "jax[cuda12]"

source .venv/bin/activate
nbdev_prepare
uv sync
source .venv/bin/activate

# Make changes to source files in `nbs/`.
uv run nbdev_prepare # Before committing changes, export and test library
uv run nbdev_preview # Preview docs

VSCode for developmentautomatic library export

Never let your .qmd source get out of sync with your .py library.

VSCode has an excellent interactive mode for developing quarto files. We install the Run on Save extension to keep the .qmd files in sync with the .py library, removing the need for explicit nbdev_export commands.

To accomplish this, copy and paste the following into your user/workspace settings (Cmd+Shift+P then either “Preferences: Open User settings (JSON)” or “Preferences: Open Workspace settings (JSON)”)

{
    "files.watcherExclude": {
        "**/.git/objects/**": true,
        "**/.git/subtree-cache/**": true,
        "**/node_modules/*/**": true,
        "**/.hg/store/**": true,
    },
    "emeraldwalk.runonsave": {
        "commands": [
        {
            "match": "nbs/.*\\.qmd$", // Replace with your own nbs/ directory
            "cmd": "source ${workspaceFolder}/.venv/bin/activate && nbdev_export", // Replace with a path to your python env where `nbdev` is installed
        }
        ]
    }
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hamux-0.2.2.tar.gz (1.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hamux-0.2.2-py3-none-any.whl (15.6 kB view details)

Uploaded Python 3

File details

Details for the file hamux-0.2.2.tar.gz.

File metadata

  • Download URL: hamux-0.2.2.tar.gz
  • Upload date:
  • Size: 1.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for hamux-0.2.2.tar.gz
Algorithm Hash digest
SHA256 d1678aef4b09666d4b744a216bf25ce75cbaf8ceb34b3f99b79331ad3f215403
MD5 b387c79947b21b61a0a9fe545ccd7e92
BLAKE2b-256 1c9e561042aff966c90d616b80485b120ad2986ab62e944040b4d670772698af

See more details on using hashes here.

File details

Details for the file hamux-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: hamux-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 15.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for hamux-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 00075f69143de44ee2032a9806cc29a72d0f0230251c6d144f426ec80c13b68c
MD5 4754e68c8de18451e7131d22856ca5db
BLAKE2b-256 bf174e7920446f434c2015a4c082663f50da12ece7506309a37cba0eea41f10a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page