A JAX-accelerated agent-based modeling framework
Project description
JaxABM: JAX-Accelerated Agent-Based Modeling Framework
JaxABM is a high-performance agent-based modeling (ABM) framework that leverages JAX for GPU acceleration, vectorization, and automatic differentiation, now with an easy-to-use AgentPy-like interface. This enables significantly faster simulation speeds and advanced capabilities compared to traditional Python-based ABM frameworks.
Key Features
- Easy-to-use Interface: AgentPy-like API for intuitive model development
- GPU Acceleration: Run simulations on GPUs with minimal code changes
- Fully Vectorized: Uses JAX's vectorization for highly parallel agent simulations
- Multiple Agent Types: Support for heterogeneous agent populations
- Differentiable Simulations: End-to-end differentiable ABM for gradient-based optimization
- Powerful Analysis Tools: Built-in sensitivity analysis and parameter calibration
- Spatial Structures: Built-in support for grid and network environments
- Backward Compatible: Legacy API support for traditional (non-JAX) modeling
Installation
Basic Installation
pip install jaxabm
Install with JAX capabilities
First install JAX following the official instructions (for GPU support), then:
pip install jaxabm[jax]
Quick Start
Here's a simple example of a model with agents that move randomly:
import jaxabm as jx
import jax.numpy as jnp
class MyAgent(jx.Agent):
def setup(self):
"""Initialize agent state."""
return {
'x': 0.5,
'y': 0.5
}
def step(self, model_state):
"""Update agent state."""
# Get current position
x = self._state['x']
y = self._state['y']
# Move randomly (using a simple deterministic rule for this example)
x += 0.01
y += 0.01
# Wrap around at boundaries
x = x % 1.0
y = y % 1.0
# Return updated state
return {
'x': x,
'y': y
}
class MyModel(jx.Model):
def setup(self):
"""Set up model with agents and environment."""
# Add agents
self.agents = self.add_agents(10, MyAgent)
# Set up environment
self.env.add_state('time', 0)
def step(self):
"""Execute model logic each step."""
# Update environment time
# Note: Agents are updated automatically by the framework
if hasattr(self._jax_model, 'state'):
time = self._jax_model.state['env'].get('time', 0)
self._jax_model.add_env_state('time', time + 1)
# Record data
self.record('time', time)
# Run model
model = MyModel({'steps': 100})
results = model.run()
# Plot results
results.plot()
The AgentPy-like Interface
JaxABM now provides an easy-to-use, AgentPy-like interface built on top of the high-performance JAX core.
Agent
The Agent class is the base class for all agents in the model. To create a custom agent, inherit from this class and override the setup and step methods.
class MyAgent(jx.Agent):
def setup(self):
"""Initialize agent state."""
return {
'x': 0,
'y': 0
}
def step(self, model_state):
"""Update agent state."""
return {
'x': self._state['x'] + 0.1,
'y': self._state['y'] + 0.1
}
AgentList
The AgentList class is a container for managing collections of agents.
# In Model.setup():
self.agents = self.add_agents(10, MyAgent)
# Access agent attributes:
x_positions = self.agents.x # Returns array of x values
# Filter agents:
active_agents = self.agents.select(lambda agents: agents.active)
Environment
The Environment class is a container for environment state and methods for creating and managing spatial structures.
# In Model.setup():
self.env.add_state('temperature', 25.0)
# Access environment state:
temp = self.env.temperature
Grid and Network
For spatial models, the Grid and Network classes provide structures for agent interactions.
# Create a grid:
self.grid = jx.Grid(self, (10, 10))
# Position agents on grid:
self.grid.position_agents(self.agents)
# Create a network:
self.network = jx.Network(self)
# Add edges:
self.network.add_edge(agent1, agent2)
Model
The Model class is the base class for all models. It provides methods for setting up, running, and analyzing models.
class MyModel(jx.Model):
def setup(self):
"""Set up model with agents and environment."""
self.agents = self.add_agents(10, MyAgent)
self.env.add_state('time', 0)
def step(self):
"""Execute model logic each step."""
# Environment updates (agent updates happen automatically)
if hasattr(self._jax_model, 'state'):
time = self._jax_model.state['env'].get('time', 0)
self._jax_model.add_env_state('time', time + 1)
# Record data
self.record('time', time)
def end(self):
"""Execute code at the end of a simulation."""
print("Simulation completed!")
# Create and run model
model = MyModel({'steps': 100})
results = model.run()
Results
The Results class is a container for simulation results. It provides methods for accessing and visualizing results.
# Run model and get results
results = model.run()
# Plot all metrics
results.plot()
# Access specific variables
results.variables.agent.x.plot()
# Save results
results.save('my_results.pkl')
# Load results
results = jx.Results.load('my_results.pkl')
Advanced Features
Sensitivity Analysis
JaxABM provides tools to analyze how model outputs respond to parameter changes:
from jaxabm.analysis import SensitivityAnalysis
# Create model factory function
def create_model(params=None, config=None):
# Create model with parameters from the params dict
model = MyModel(params)
return model
# Perform sensitivity analysis
sensitivity = SensitivityAnalysis(
model_factory=create_model,
param_ranges={
'propensity_to_consume': (0.6, 0.9),
'productivity': (0.5, 1.5),
},
metrics_of_interest=['gdp', 'unemployment', 'inequality'],
num_samples=10
)
# Run analysis
results = sensitivity.run()
# Calculate sensitivity indices
indices = sensitivity.sobol_indices()
Model Calibration
Find optimal parameters to match target metrics using gradient-based or RL-based methods:
from jaxabm.analysis import ModelCalibrator
# Define target metrics
target_metrics = {
'gdp': 10.0,
'unemployment': 0.05,
'inequality': 2.0
}
# Initialize calibrator
calibrator = ModelCalibrator(
model_factory=create_model,
initial_params={
'propensity_to_consume': 0.7,
'productivity': 1.0
},
target_metrics=target_metrics,
metrics_weights={
'gdp': 0.1,
'unemployment': 1.0,
'inequality': 0.5
},
learning_rate=0.01,
max_iterations=20,
method='gradient' # or 'rl'
)
# Run calibration
optimal_params = calibrator.calibrate()
Examples
The package includes several example models demonstrating different features:
examples/random_walk.py: Simple model with random walking agentsexamples/schelling_model.py: Classic Schelling segregation modelexamples/minimal_example_agentpy.py: AgentPy-like version of the minimal exampleexamples/agentpy_interface_example.py: Bouncing agents with AgentPy-like interfaceexamples/minimal_example.py: Original JaxABM API exampleexamples/jax_abm_simple.py: Simplified model with original APIexamples/jax_abm_example.py: Detailed economic model with sensitivity analysis
Run examples with:
python examples/random_walk.py
python examples/schelling_model.py
Core Abstractions (Original API)
The framework is also built around several key core abstractions that power the AgentPy-like interface:
AgentType Protocol
Defines the behavior of agents:
init_state(model_config, key): Initialize agent stateupdate(state, model_state, model_config, key): Update agent state based on current state and environment
AgentCollection
Manages a collection of agents of the same type:
__init__(agent_type, num_agents): Create collection placeholderinit(key, model_config): Initialize all agents in the collectionupdate(model_state, key, model_config): Update all agents in parallelstates: Access the current states of all agentsfilter(condition): Creates a filtered subset of agents
ModelConfig
Provides simulation configuration:
seed: Random seed for reproducibilitysteps: Number of simulation stepstrack_history: Whether to track model historycollect_interval: Interval for collecting metrics
JaxModel
Coordinates the overall simulation:
add_agent_collection(name, collection): Add an agent collectionadd_env_state(name, value): Add an environmental state variableinitialize(): Prepare the model for simulationstep(): Execute a single time steprun(steps): Run the full simulationjit_step(): Get a JIT-compiled step function for maximum performance
Performance
JaxABM provides significant performance improvements:
- 10-100x faster than pure Python implementations
- GPU acceleration with no code changes
- Parallel agent updates through vectorization
- JIT compilation for optimal performance
Citation
If you use JaxABM in your research, please cite:
BibTeX
@software{pham2025jaxabm,
title={JaxABM: JAX-Accelerated Agent-Based Modeling Framework},
author={Pham, Anh-Duy and D'Orazio, Paola},
year={2025},
month={June},
version={0.1.1},
url={https://github.com/a11to1n3/JaxABM},
note={High-performance agent-based modeling framework with GPU acceleration and reinforcement learning calibration}
}
APA Style
Pham, A.-D., & D'Orazio, P. (2025). JaxABM: JAX-Accelerated Agent-Based Modeling Framework (Version 0.1.1) [Computer software]. https://github.com/a11to1n3/JaxABM
IEEE Style
A.-D. Pham and P. D'Orazio, "JaxABM: JAX-Accelerated Agent-Based Modeling Framework," Version 0.1.1, June 2025. [Online]. Available: https://github.com/a11to1n3/JaxABM
Key Features to Cite
When citing JaxABM, you may want to highlight these innovations:
- GPU-accelerated agent-based modeling with JAX backend
- Advanced reinforcement learning calibration methods (Actor-Critic, Policy Gradient, Q-Learning, DQN)
- High-performance vectorized simulations with 10-100x speedup over traditional ABM frameworks
- Differentiable agent-based models enabling gradient-based optimization
- Comprehensive parameter optimization toolkit with multiple calibration algorithms
Related Publications
If you use specific features, consider citing the underlying methodologies:
- For reinforcement learning calibration: Reference the specific RL algorithms used (Actor-Critic, Policy Gradient, etc.)
- For sensitivity analysis: Sobol indices methodology
- For JAX backend: The JAX library for high-performance machine learning research
Requirements
- Python 3.8+
- JAX 0.4.1+ (for acceleration features)
- NumPy
- Matplotlib (for visualization)
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxabm-0.1.5.tar.gz.
File metadata
- Download URL: jaxabm-0.1.5.tar.gz
- Upload date:
- Size: 59.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c4cd154fad4b6172068bbb6567368178743b0aff8becc2b40241a9fa45dd1c8e
|
|
| MD5 |
3dea88910d89a3391995ad25482cfd4c
|
|
| BLAKE2b-256 |
e0b205bac8d65888d46485965a1b65cbc94aa0d868e3e8d2137fd684e4872973
|
File details
Details for the file jaxabm-0.1.5-py3-none-any.whl.
File metadata
- Download URL: jaxabm-0.1.5-py3-none-any.whl
- Upload date:
- Size: 60.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ab33a2d5c0dd8c8b0d5be30634ba9ea43f170a0e4d8135e7836e1e171cc1e330
|
|
| MD5 |
d950ee23a92eeec55a57d222f4f69f33
|
|
| BLAKE2b-256 |
9478b2a1da77f1bffa84cd5c256e261711eeca5c0271284936fc1a1754670540
|