JAX Logical Neural Networks – neuro-symbolic framework with interval Łukasiewicz logic
Project description
JLNN – JAX Logical Neural Networks
Neuro-symbolic framework for interval-based Łukasiewicz logic built on JAX + Flax NNX.
JLNN enables turning symbolic logical rules into differentiable neural networks for training on data while maintaining interpretability and logical consistency.
Features
- Interval truth values [L, U]: Full support for uncertainty and contradiction modeling.
- Weighted logical gates: AND, OR, XOR, Implication, and NOT with Łukasiewicz semantics.
- Symbolic Compiler: Compile formulas like
"0.8::A & B -> C"directly to NNX graphs. - Temporal Logic: Experimental support for temporal operators (G, F, X).
- Logical Constraints: Built-in enforcement of axioms (e.g., weights $w \geq 1.0$).
- High Performance: JIT-compilation and hardware acceleration via JAX.
- Interoperability: Export trained models to ONNX, StableHLO, or PyTorch.
Installation
# From PyPI
pip install jax-lnn
# From GitHub
pip install git+[https://github.com/RadimKozl/JLNN.git](https://github.com/RadimKozl/JLNN.git)
# For development
git clone [https://github.com/RadimKozl/JLNN.git](https://github.com/RadimKozl/JLNN.git)
cd JLNN
uv sync # or pip install -e ".[test]"
Quickstart
import jax
import jax.numpy as jnp
from flax import nnx
from jlnn.symbolic.compiler import LNNFormula
from jlnn.nn.constraints import apply_constraints
from jlnn.training.losses import total_lnn_loss, logical_mse_loss, contradiction_loss
from jlnn.storage.checkpoints import save_checkpoint, load_checkpoint
import optax
# 1. Define and compile the formula
model = LNNFormula("0.8::A & B -> C", nnx.Rngs(42))
# 2. Ground inputs (including initial state for C)
inputs = {
"A": jnp.array([[0.9]]),
"B": jnp.array([[0.7]]),
"C": jnp.array([[0.5]]) # MANDATORY – consequent must have grounding!
}
target = jnp.array([[0.6, 0.85]])
# 3. Loss function
def loss_fn(model, inputs, target):
pred = model(inputs)
pred = jnp.nan_to_num(pred, nan=0.5, posinf=1.0, neginf=0.0) # protection against NaN
return total_lnn_loss(pred, target)
# 4. Initialize Optimizer
optimizer = nnx.Optimizer(
model,
wrt=nnx.Param,
tx=optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=0.001)
)
)
# 5. Training Step
@nnx.jit
def train_step(model, optimizer, inputs, target):
# Gradients to the model – closure is traceable (inputs/target are arrays)
grads = nnx.grad(lambda m: loss_fn(m, inputs, target))(model)
# Loss before update (for debug)
loss = loss_fn(model, inputs, target)
optimizer.update(model, grads)
apply_constraints(model)
final_loss = loss_fn(model, inputs, target)
final_pred = model(inputs)
return loss, final_loss, final_pred
print("=== Starting training ===")
steps = 50
for step in range(steps):
loss, final_loss, pred = train_step(model, optimizer, inputs, target)
print(f"Step {step:3d} | Loss before/after constraints: {loss:.6f} → {final_loss:.6f}")
print(f"Prediction: {pred}")
print("─" * 60)
if jnp.isnan(final_loss).any():
print("❌ NaN detected! Stopping.")
break
print("=== Training completed ===")
# 6. Result after training
final_pred = model(inputs)
print("\nFinal prediction after training:")
print(final_pred)
print(f"\nTarget interval: {target}")
print(f"Final loss: {total_lnn_loss(final_pred, target):.6f}")
See the introductory Jupyter notebook: Jax_lnn_base.ipynb
Acknowledgments & Inspiration
JLNN is inspired by and builds upon the foundations laid by several pioneering neuro-symbolic projects:
- LNN (IBM Research) – The primary inspiration for interval-based logical neural networks.
- LTNtorch – Logic Tensor Networks implementation in PyTorch.
- PyReason – Software for open-world temporal logic reasoning.
Documentation
Discord channel
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_lnn-0.1.1.post1.tar.gz.
File metadata
- Download URL: jax_lnn-0.1.1.post1.tar.gz
- Upload date:
- Size: 42.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eaa4e6c84494dc5d3558954087612f98bc97c788a424f03d4b2479efb84fa754
|
|
| MD5 |
38af19add18ef526dcc54fb2f1e15356
|
|
| BLAKE2b-256 |
6ee72674cb012f104845ebc548f86c4a283e026ac090f2f0702213ede61aa995
|
Provenance
The following attestation bundles were made for jax_lnn-0.1.1.post1.tar.gz:
Publisher:
publish.yml on RadimKozl/JLNN
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jax_lnn-0.1.1.post1.tar.gz -
Subject digest:
eaa4e6c84494dc5d3558954087612f98bc97c788a424f03d4b2479efb84fa754 - Sigstore transparency entry: 1066037482
- Sigstore integration time:
-
Permalink:
RadimKozl/JLNN@d7b41efc077f684585db29b2d33be8386f21f59f -
Branch / Tag:
refs/tags/v0.1.1.post1 - Owner: https://github.com/RadimKozl
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@d7b41efc077f684585db29b2d33be8386f21f59f -
Trigger Event:
release
-
Statement type:
File details
Details for the file jax_lnn-0.1.1.post1-py3-none-any.whl.
File metadata
- Download URL: jax_lnn-0.1.1.post1-py3-none-any.whl
- Upload date:
- Size: 53.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2919d18fa50920cf745854cc28ef1c001eac2c14cbe0e3d95d8a9d7159b8814d
|
|
| MD5 |
ed1e6f5c179bb15b36f970946ae07373
|
|
| BLAKE2b-256 |
f7661db4455dab8f23ff9f389fe9d250683ee4e1ff1b8a8b9390f3471c0ab247
|
Provenance
The following attestation bundles were made for jax_lnn-0.1.1.post1-py3-none-any.whl:
Publisher:
publish.yml on RadimKozl/JLNN
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jax_lnn-0.1.1.post1-py3-none-any.whl -
Subject digest:
2919d18fa50920cf745854cc28ef1c001eac2c14cbe0e3d95d8a9d7159b8814d - Sigstore transparency entry: 1066037484
- Sigstore integration time:
-
Permalink:
RadimKozl/JLNN@d7b41efc077f684585db29b2d33be8386f21f59f -
Branch / Tag:
refs/tags/v0.1.1.post1 - Owner: https://github.com/RadimKozl
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@d7b41efc077f684585db29b2d33be8386f21f59f -
Trigger Event:
release
-
Statement type: