Skip to main content

HighJax: A JAX implementation of the HighwayEnv driving environment

Project description

HighJax: Highway Driving environment for Reinforcement Learning research

HighJax PPO training demo
PPO agent learning to drive on a 4-lane highway

HighJax is an autonomous driving environment for Reinforcement Learning research. It's a JAX implementation of the HighwayEnv. HighJax provides a fully JIT-compilable and vectorizable highway driving simulation.

Besides being much faster than the original, it provides Octane, a Rust-based TUI for examining your experiment runs. Octane provides an interface for defining behaviors and then measuring how much each policy exhibits them.

HighJax was produced as part of our research project about BXRL: Behavior-Explainable Reinforcement Learning.

Installation

pip install highjax-rl # Minimal installation
pip install "highjax-rl[cuda12]" # Including GPU support
pip install "highjax-rl[trainer]" # Including PPO implementation
pip install "highjax-rl[cuda12,trainer]" # Including both

Quick Start

import jax
import highjax

env, params = highjax.make('highjax-v0')
key = jax.random.PRNGKey(0)
obs, state = env.reset(key, params)
obs, state, reward, done, info = env.step(key, state, 1, params)  # IDLE

Using with JAX RL Libraries

HighJax follows the gymnax API, so it works with JAX RL frameworks that expect gymnax-style environments:

Training

Train a PPO agent via the CLI:

highjax-trainer train

Key options:

Flag Default Description
--n-epochs / -e 300 Training epochs
--n-es 400 Parallel episodes per epoch
--n-ts 40 Timesteps per episode
--seed / -s 0 Random seed
--actor-lr 3e-4 Actor learning rate
--critic-lr 3e-3 Critic learning rate
--n-npcs 50 NPC vehicles
--no-trek Disable trek recording
--n-sample-es 1 Episodes to sample per epoch for trek
--trek-path auto Custom trek directory path
--discount 0.95 Discount factor (gamma)
--n-lanes 4 Number of highway lanes

Training automatically records episode data to ~/.highjax/t/ for browsing with Octane (the TUI). Use --no-trek to disable.

Here's a snazzy one-liner that will let you explore the results of the current experiment run using VisiData:

pip install visidata
vd "$(ls -d ~/.highjax/t/2*/ | tail -1)"/epochia.pq

Use the following command line to produce similar results as seen in Figure 2 of the paper:

highjax-trainer train --n-es 128 --n-ts 400 --n-epochs 300 --target-kld 0.0005

Octane (Episode Browser)

This repo also includes Octane, which is a Rust-based TUI for browsing HighJax experiments.

Installation

sudo apt-get install build-essential # C toolchain (needed by Rust)
sudo apt-get install ffmpeg # Needed for `octane animate`
git clone https://github.com/HumanCompatibleAI/HighJax # Clone this repo
cd HighJax
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh # Install Rust
source "$HOME/.cargo/env"
cd octane && cargo build --release # Build Octane
alias octane="$(readlink -f octane/target/release/octane)"

The binary will be at octane/target/release/octane.

Usage

After training, launch Octane to see all the experiments you ran with highjax-trainer:

octane

Figures

Use Octane to make figures for your paper:

octane draw -t ~/.highjax/t/2026-03-15-20-02-25-101327 --epoch 300 -e 0 --timestep 19 --theme light \
  --zoom 1.8 --png ~/figure.png

Octane figure output

Behavior crafting

Octane includes a behavior explorer for defining measurable policy properties. While watching an episode, press b to capture a scenario — mark which actions you want (positive weight) or don't want (negative weight) at that traffic state. Name it, and Octane saves the behavior to ~/.highjax/behaviors/. The next time you run highjax-trainer train, all discovered behaviors are evaluated every epoch and their scores are recorded as behavior.{name} columns in epochia.parquet.

Behavior crafting dialog in Octane
Defining a behavior scenario in Octane

Press B (Shift-B) to open the full Behavior Explorer tab.

See the Octane docs for full details.

Documentation

Full documentation is in the docs/ folder:

Examples

  • examples/basic_usage.py — Create env, reset, step, print observations
  • examples/train_ppo.py — Train a PPO agent and evaluate it
  • examples/use_purejaxrl.py — PureJaxRL integration (vectorized scan loop)
  • examples/use_stoix.py — Stoix integration (via stoa gymnax adapter)
  • examples/use_rejax.py — Rejax integration (JIT-compiled training, vmapped seeds)

Citation

If you use HighJax in your research, please cite:

@article{rachum2026bxrl,
  title={BXRL: Behavior-Explainable Reinforcement Learning},
  author={Ram Rachum and Yotam Amitai and Yonatan Nakar and Reuth Mirsky and Cameron Allen},
  year={2026},
  eprint={2603.23738},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2603.23738},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

highjax_rl-0.1.6.tar.gz (67.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

highjax_rl-0.1.6-py3-none-any.whl (76.6 kB view details)

Uploaded Python 3

File details

Details for the file highjax_rl-0.1.6.tar.gz.

File metadata

  • Download URL: highjax_rl-0.1.6.tar.gz
  • Upload date:
  • Size: 67.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for highjax_rl-0.1.6.tar.gz
Algorithm Hash digest
SHA256 f4da6b22a3065fc393ae3429af9600d1b4c9afe0b54d3c2b6f76fb6f6cdcda8b
MD5 91ca04e5323901a458b92d1a3666be4b
BLAKE2b-256 ab35f33c59328f8c35322794e14da3bf4de7ea97be0d9703a9ea43a391d1eab2

See more details on using hashes here.

File details

Details for the file highjax_rl-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: highjax_rl-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 76.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for highjax_rl-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 a6fd9d5cffe149da1691baaa111e26513eb2aa2b58ae809be011f3cd890d99b3
MD5 cf54f784d73c793ef10f99d4093eab5e
BLAKE2b-256 06bfefca934bded4bd47b254c026a1ba029f18037cd1afd39468f3819f08f595

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page