Skip to main content

JAX Logical Neural Networks – neuro-symbolic framework with interval Łukasiewicz logic

Project description

JLNN – JAX Logical Neural Networks


JLNN Logo

Neuro-symbolic framework for interval-based Łukasiewicz logic built on JAX + Flax NNX.

License: MIT Python 3.10+ Open In Colab

JLNN enables turning symbolic logical rules into differentiable neural networks for training on data while maintaining interpretability and logical consistency.

Features

  • Interval truth values [L, U]: Full support for uncertainty and contradiction modeling.
  • Weighted logical gates: AND, OR, XOR, Implication, and NOT with Łukasiewicz semantics.
  • Symbolic Compiler: Compile formulas like "0.8::A & B -> C" directly to NNX graphs.
  • Temporal Logic: Experimental support for temporal operators (G, F, X).
  • Logical Constraints: Built-in enforcement of axioms (e.g., weights $w \geq 1.0$).
  • High Performance: JIT-compilation and hardware acceleration via JAX.
  • Interoperability: Export trained models to ONNX, StableHLO, or PyTorch.

Installation

# From PyPI
pip install jax-lnn

# From GitHub
pip install git+[https://github.com/RadimKozl/JLNN.git](https://github.com/RadimKozl/JLNN.git)

# For development
git clone [https://github.com/RadimKozl/JLNN.git](https://github.com/RadimKozl/JLNN.git)
cd JLNN
uv sync  # or pip install -e ".[test]"

Quickstart

import jax
import jax.numpy as jnp
from flax import nnx
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints
from jlnn.training.losses import total_lnn_loss, logical_mse_loss, contradiction_loss
from jlnn.storage.checkpoints import save_checkpoint, load_checkpoint
import optax

# 1. Define and compile the formula
model = LNNFormula("0.8::A & B -> C", nnx.Rngs(42))

# 2. Ground inputs (including initial state for C)
inputs = {
    "A": jnp.array([[0.9]]),
    "B": jnp.array([[0.7]]),
    "C": jnp.array([[0.5]])   # MANDATORY – consequent must have grounding!
}

target = jnp.array([[0.6, 0.85]])

# 3. Loss function
def loss_fn(model, inputs, target):
    pred = model(inputs)
    pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0)  # protection against NaN
    return total_lnn_loss(pred, target)

# 4. Initialize Optimizer
optimizer = nnx.Optimizer(
    model,
    wrt=nnx.Param,
    tx=optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(learning_rate=0.001)
    )
)

# 5. Training Step
@nnx.jit
def train_step(model, optimizer, inputs, target):
    # Gradients to the model – closure is traceable (inputs/target are arrays)
    grads = nnx.grad(lambda m: loss_fn(m, inputs, target))(model)

    # Loss before update (for debug)
    loss = loss_fn(model, inputs, target)

    optimizer.update(model, grads)
    apply_constraints(model)

    final_loss = loss_fn(model, inputs, target)
    final_pred = model(inputs)

    return loss, final_loss, final_pred

print("=== Starting training ===")
steps = 50
for step in range(steps):
    loss, final_loss, pred = train_step(model, optimizer, inputs, target)

    print(f"Step {step:3d} | Loss before/after constraints: {loss:.6f}{final_loss:.6f}")
    print(f"Prediction: {pred}")
    print("─" * 60)

    if jnp.isnan(final_loss).any():
        print("❌ NaN detected! Stopping.")
        break

print("=== Training completed ===")

# 6. Result after training

final_pred = model(inputs)
print("\nFinal prediction after training:")
print(final_pred)

print(f"\nTarget interval: {target}")
print(f"Final loss: {total_lnn_loss(final_pred, target):.6f}")

See the introductory Jupyter notebook: Jax_lnn_base.ipynb

Acknowledgments & Inspiration

JLNN is inspired by and builds upon the foundations laid by several pioneering neuro-symbolic projects:

  • LNN (IBM Research) – The primary inspiration for interval-based logical neural networks.
  • LTNtorch – Logic Tensor Networks implementation in PyTorch.
  • PyReason – Software for open-world temporal logic reasoning.

Documentation

Discord channel

Discord link: https://discord.gg/ADehdYCM


License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_lnn-0.1.1.post1.tar.gz (42.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_lnn-0.1.1.post1-py3-none-any.whl (53.8 kB view details)

Uploaded Python 3

File details

Details for the file jax_lnn-0.1.1.post1.tar.gz.

File metadata

  • Download URL: jax_lnn-0.1.1.post1.tar.gz
  • Upload date:
  • Size: 42.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_lnn-0.1.1.post1.tar.gz
Algorithm Hash digest
SHA256 eaa4e6c84494dc5d3558954087612f98bc97c788a424f03d4b2479efb84fa754
MD5 38af19add18ef526dcc54fb2f1e15356
BLAKE2b-256 6ee72674cb012f104845ebc548f86c4a283e026ac090f2f0702213ede61aa995

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_lnn-0.1.1.post1.tar.gz:

Publisher: publish.yml on RadimKozl/JLNN

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_lnn-0.1.1.post1-py3-none-any.whl.

File metadata

  • Download URL: jax_lnn-0.1.1.post1-py3-none-any.whl
  • Upload date:
  • Size: 53.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_lnn-0.1.1.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 2919d18fa50920cf745854cc28ef1c001eac2c14cbe0e3d95d8a9d7159b8814d
MD5 ed1e6f5c179bb15b36f970946ae07373
BLAKE2b-256 f7661db4455dab8f23ff9f389fe9d250683ee4e1ff1b8a8b9390f3471c0ab247

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_lnn-0.1.1.post1-py3-none-any.whl:

Publisher: publish.yml on RadimKozl/JLNN

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page