Skip to main content

Unified rheological analysis framework with JAX acceleration

Project description

RheoJAX - JAX-Powered Rheological Analysis

CI PyPI version Python 3.12+ License: MIT Documentation

JAX-accelerated package for rheological data analysis. Provides 21 rheological models, 5 data transforms, Bayesian inference via NumPyro, and 27 tutorial notebooks.

Features

Rheological analysis toolkit with Bayesian inference and parameter optimization:

Core Capabilities

  • 21 Rheological Models: Classical (Maxwell, Zener, SpringPot), Fractional (11 variants), Flow (6 models), Multi-Mode (Generalized Maxwell)
  • 5 Data Transforms: FFT, Mastercurve (TTS), Mutation Number, OWChirp (LAOS), Smooth Derivative
  • Model-Data Compatibility Checking: Detects when models are inappropriate for data based on physics (exponential vs power-law decay, material type classification)
  • Bayesian Inference: All 21 models support NumPyro NUTS sampling with NLSQ warm-start
  • Pipeline API: Fluent interface for load → fit → plot → save workflows
  • Automatic Initialization: Parameter initialization for fractional models in oscillation mode
  • JAX-First Architecture: 5-270x performance improvement with automatic differentiation and GPU support

Data & I/O

  • Data Support: Automatic test mode detection (relaxation, creep, oscillation, rotation)
  • File Formats: TRIOS, CSV, Excel, Anton Paar with format auto-detection
  • Parameter System: Type-safe parameter management with bounds and constraints

Visualization & Diagnostics

  • Visualization: Three built-in styles (default, publication, presentation)
  • ArviZ Diagnostic Suite: 6 plot types (pair, forest, energy, autocorr, rank, ESS) for MCMC quality
  • Plugin System: Support for custom models and transforms

Tutorial Notebooks & Examples

  • 27 Tutorial Notebooks: Organized in 4 categories
    • examples/basic/ - 5 notebooks covering fundamental models
    • examples/transforms/ - 7 notebooks for data transforms and analysis
    • examples/bayesian/ - 7 notebooks for Bayesian inference workflows
    • examples/advanced/ - 8 notebooks for production patterns
  • I/O Examples: TRIOS complex modulus handling and plotting

Installation

Requirements

  • Python 3.12 or later (3.8-3.11 are NOT supported due to JAX 0.8.0 requirements)
  • JAX and jaxlib for acceleration
  • NLSQ for GPU-accelerated optimization
  • NumPyro for Bayesian inference
  • ArviZ for Bayesian diagnostics

Basic Installation

pip install rheojax

Development Installation

git clone https://github.com/imewei/rheojax.git
cd rheojax
pip install -e ".[dev]"

GPU Installation (Linux Only)

Performance Impact: 20-100x speedup for large datasets (>10K points)

Option 1: Install via Makefile

From the repository:

git clone https://github.com/imewei/rheojax.git
cd rheojax
make install-jax-gpu  # Handles uninstall + GPU install

This command:

  • Uninstalls CPU-only JAX
  • Installs GPU-enabled JAX with CUDA 12 support
  • Verifies GPU detection

Option 2: Manual Installation

For GPU-accelerated computation on Linux systems with CUDA 12+:

# Step 1: Uninstall CPU-only version
pip uninstall -y jax jaxlib

# Step 2: Install JAX with CUDA support
pip install jax[cuda12-local]==0.8.0 jaxlib==0.8.0

# Step 3: Verify GPU detection
python -c "import jax; print('Devices:', jax.devices())"
# Should show: [cuda(id=0)] instead of [CpuDevice(id=0)]

Why separate installation? JAX with CUDA support is Linux-specific and requires system CUDA 12.1-12.9 pre-installed. Separating the installation avoids dependency conflicts on macOS/Windows.

GPU Troubleshooting

Issue: Warning "An NVIDIA GPU may be present... but a CUDA-enabled jaxlib is not installed"

Solution:

# 1. Check GPU hardware
nvidia-smi  # Should show your GPU

# 2. Check CUDA version
nvcc --version  # Should show CUDA 12.1-12.9

# 3. Reinstall JAX with GPU support
pip uninstall -y jax jaxlib
pip install jax[cuda12-local]==0.8.0 jaxlib==0.8.0

# 4. Verify JAX detects GPU
python -c "import jax; print(jax.devices())"
# Expected: [cuda(id=0)]
# If still showing [CpuDevice(id=0)], check CUDA installation

Issue: ImportError or CUDA library not found

Solution:

# Set CUDA library path (add to ~/.bashrc for permanent fix)
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH

Platform Support

  • Linux + NVIDIA GPU + CUDA 12.1-12.9: Full GPU acceleration (20-100x speedup)
  • macOS: CPU-only (Apple Silicon/Intel, no NVIDIA GPU support)
  • Windows: CPU-only (CUDA support experimental/unstable)

Requirements (Linux GPU):

  • System CUDA 12.1-12.9 pre-installed
  • NVIDIA driver >= 525
  • Linux x86_64 or aarch64

Conda/Mamba Users

The package works in conda environments using pip:

conda create -n rheojax python=3.12
conda activate rheojax
pip install rheojax

# For GPU acceleration (Linux only)
git clone https://github.com/imewei/rheojax.git
cd rheojax
make install-jax-gpu

Note: Conda extras syntax (conda install rheojax[gpu]) is not supported. Use the Makefile or manual pip installation method above.

Quick Start

Loading and Visualizing Data

from rheojax.io.readers import auto_read
from rheojax.visualization import plot_rheo_data
import matplotlib.pyplot as plt

# Load data (auto-detect format)
data = auto_read("stress_relaxation.txt")

# Check detected test mode
print(f"Test mode: {data.test_mode}")  # Output: relaxation

# Visualize
fig, ax = plot_rheo_data(data, style='publication')
plt.show()

Basic Model Fitting

from rheojax.models.maxwell import Maxwell
import numpy as np

# Generate or load data
t = np.logspace(-2, 2, 50)
G_data = 1e5 * np.exp(-t / 0.01) + np.random.normal(0, 1e3, 50)

# Fit with NLSQ (5-270x faster than scipy)
model = Maxwell()
model.fit(t, G_data)

print(f"G0 = {model.parameters.get_value('G0'):.3e} Pa")
print(f"eta = {model.parameters.get_value('eta'):.3e} Pa·s")

Bayesian Inference Workflow

from rheojax.models.maxwell import Maxwell
import numpy as np

# Create model and data
model = Maxwell()
t = np.logspace(-2, 2, 50)
G_data = 1e5 * np.exp(-t / 0.01) + np.random.normal(0, 1e3, 50)

# Step 1: NLSQ optimization (fast point estimate)
model.fit(t, G_data)
print(f"NLSQ: G0={model.parameters.get_value('G0'):.3e}")

# Step 2: Bayesian inference with warm-start
result = model.fit_bayesian(
    t, G_data,
    num_warmup=1000,
    num_samples=2000
)

# Step 3: Analyze results
print(f"Posterior mean: G0={result.summary['G0']['mean']:.3e} ± {result.summary['G0']['std']:.3e}")
print(f"Convergence: R-hat={result.diagnostics['r_hat']['G0']:.4f}, ESS={result.diagnostics['ess']['G0']:.0f}")

# Get credible intervals
intervals = model.get_credible_intervals(result.posterior_samples, credibility=0.95)
print(f"G0 95% CI: [{intervals['G0'][0]:.3e}, {intervals['G0'][1]:.3e}]")

Bayesian Pipeline with ArviZ Diagnostics

from rheojax.pipeline.bayesian import BayesianPipeline

pipeline = BayesianPipeline()

# Fluent API: load → fit_nlsq → fit_bayesian → plot → save
(pipeline
    .load('data.csv', x_col='time', y_col='stress')
    .fit_nlsq('maxwell')
    .fit_bayesian(num_samples=2000, num_warmup=1000)
    .plot_posterior()
    .plot_trace()
    .save('results.hdf5'))

# ArviZ diagnostic plots (MCMC quality assessment)
(pipeline
    .plot_pair(divergences=True)        # Parameter correlations with divergences
    .plot_forest(hdi_prob=0.95)         # Credible intervals comparison
    .plot_energy()                       # NUTS energy diagnostic
    .plot_autocorr()                     # Mixing diagnostic
    .plot_rank()                         # Convergence diagnostic
    .plot_ess(kind='local'))            # Effective sample size

Reference: See Bayesian Quick Start Guide for:

  • When and why to use Bayesian inference
  • NLSQ → NUTS → ArviZ workflow walkthrough
  • Troubleshooting convergence issues
  • Best practices checklist
  • Runnable demo: python examples/bayesian_workflow_demo.py

Model-Data Compatibility Checking

RheoJAX detects when models are inappropriate for data based on physics:

from rheojax.models.fractional_zener_ss import FractionalZenerSolidSolid
from rheojax.utils.compatibility import check_model_compatibility, format_compatibility_message
import numpy as np

# Generate exponential decay data
t = np.logspace(-2, 2, 50)
G_t = 1e5 * np.exp(-t / 1.0)

# Check compatibility before fitting
model = FractionalZenerSolidSolid()
compat = check_model_compatibility(
    model, t=t, G_t=G_t, test_mode='relaxation'
)

# Get report
print(format_compatibility_message(compat))
# Output:
# ⚠ Model may not be appropriate for this data
#   Confidence: 90%
#   Detected decay: exponential
#   Material type: viscoelastic_liquid
#
# Warnings:
#   • FZSS model expects Mittag-Leffler (power-law) relaxation,
#     but data shows exponential decay.
#
# Recommended alternative models:
#   • Maxwell
#   • Zener

# Or enable checking during fit
model.fit(t, G_t, check_compatibility=True)  # Warns if incompatible

Features:

  • Detects decay type (exponential, power-law, stretched, Mittag-Leffler)
  • Classifies material type (solid, liquid, gel, viscoelastic)
  • Provides model recommendations when incompatible
  • Error messages explain physics mismatches

Reference: Model Selection Guide for decision flowcharts and model characteristics.

Working with Parameters

from rheojax.core import ParameterSet

# Create parameter set
params = ParameterSet()
params.add("E", value=1000.0, bounds=(100, 10000), units="Pa")
params.add("tau", value=1.0, bounds=(0.1, 100), units="s")

# Get/set values
E = params.get_value("E")
params.set_value("tau", 2.5)

Data Transforms

from rheojax.transforms import FFTAnalysis, Mastercurve, MutationNumber

# FFT analysis for frequency spectrum
fft = FFTAnalysis(window='hann', detrend=True)
freq_data = fft.transform(data)
tau_char = fft.get_characteristic_time(freq_data)

# Time-temperature superposition (mastercurves)
mc = Mastercurve(reference_temp=298.15, method='wlf')

# Option 1: Create mastercurve (basic)
mastercurve = mc.create_mastercurve(datasets)

# Option 2: Transform with shift factors (for plotting)
mastercurve, shift_factors = mc.transform(datasets)

# Get parameters and arrays for analysis
wlf_params = mc.get_wlf_parameters()
temps, shifts = mc.get_shift_factors_array()

# Mutation number (viscoelastic character)
mn = MutationNumber()
delta = mn.calculate(data)  # 0=elastic, 1=viscous

Tutorial Notebooks

27 tutorial notebooks organized by topic:

examples/
├── basic/                       # 5 notebooks: Fundamental models
│   ├── 01-maxwell-fitting.ipynb
│   ├── 02-zener-fitting.ipynb
│   ├── 03-springpot-fitting.ipynb
│   ├── 04-bingham-fitting.ipynb
│   └── 05-power-law-fitting.ipynb
├── transforms/                  # 7 notebooks: Data analysis workflows
│   ├── 01-fft-analysis.ipynb
│   ├── 02-mastercurve-tts.ipynb
│   ├── 02b-mastercurve-wlf-validation.ipynb
│   ├── 03-mutation-number.ipynb
│   ├── 04-owchirp-laos-analysis.ipynb
│   ├── 05-smooth-derivative.ipynb
│   └── 07-mastercurve_auto_shift.ipynb
├── bayesian/                    # 7 notebooks: Bayesian inference
│   ├── 01-bayesian-basics.ipynb
│   ├── 02-prior-selection.ipynb
│   ├── 03-convergence-diagnostics.ipynb
│   ├── 04-model-comparison.ipynb
│   ├── 05-uncertainty-propagation.ipynb
│   ├── 06-bayesian_workflow_demo.ipynb
│   └── 07-gmm_bayesian_workflow.ipynb
├── advanced/                    # 8 notebooks: Production patterns
│   ├── 01-multi-technique-fitting.ipynb
│   ├── 02-batch-processing.ipynb
│   ├── 03-custom-models.ipynb
│   ├── 04-fractional-models-deep-dive.ipynb
│   ├── 05-performance-optimization.ipynb
│   ├── 06-frequentist-model-selection.ipynb
│   ├── 07-trios_chunked_reading_example.ipynb
│   └── 08-generalized_maxwell_fitting.ipynb
└── io/                          # I/O demonstrations
    └── plot_trios_complex_modulus.ipynb

See examples/README.md for learning path guide.

Documentation

Documentation: https://rheojax.readthedocs.io

Key Topics

Performance

NLSQ Optimization Performance

NLSQ performance compared to scipy:

Dataset Size scipy (Powell) NLSQ (JAX) Speedup
50 points 180 ms 35 ms 5x
500 points 920 ms 48 ms 19x
5000 points 8.2 s 95 ms 86x
50000 points 94 s 350 ms 270x

Bayesian Warm-Start Performance

NLSQ → NUTS warm-start improves MCMC convergence:

Method Convergence Time Divergences ESS/sec
Cold start (random init) 45s 15% 44
NLSQ warm-start 18s 0.2% 111
Improvement 2.5x faster 75x fewer 2.5x higher

Benchmarks on M1 MacBook Pro. GPU acceleration provides additional 5-20x speedups for large datasets.

Contributing

Contributions are accepted. See Contributing Guide for details.

Development Setup

# Clone repository
git clone https://github.com/imewei/rheojax.git
cd rheojax

# Install development dependencies
pip install -e ".[dev]"

# Install pre-commit hooks
pre-commit install

# Run tests
pytest

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use rheo in your research, please cite:

@software{rheo2024,
  title = {Rheo: JAX-Powered Unified Rheology Package with Bayesian Inference},
  year = {2024},
  author = {Wei Chen},
  url = {https://github.com/imewei/rheojax},
  version = {0.2.2}
}

Acknowledgments

Built on open-source software:

  • JAX for automatic differentiation and acceleration
  • NLSQ for GPU-accelerated nonlinear least squares
  • NumPyro for probabilistic programming
  • ArviZ for Bayesian visualization
  • NumPy and SciPy for numerical computing
  • matplotlib for visualization

Support

Roadmap

See CHANGELOG.md for development history and examples/ for tutorial notebooks.


Wei Chen

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rheojax-0.2.2.tar.gz (202.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rheojax-0.2.2-py3-none-any.whl (250.5 kB view details)

Uploaded Python 3

File details

Details for the file rheojax-0.2.2.tar.gz.

File metadata

  • Download URL: rheojax-0.2.2.tar.gz
  • Upload date:
  • Size: 202.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rheojax-0.2.2.tar.gz
Algorithm Hash digest
SHA256 e8b498e98e3a868e90f67ae103af40cd0c9bd3134fee5130daf099e1e16a93fe
MD5 36d1acab16b0310b1dfd0de07d77d41e
BLAKE2b-256 7c1986df5c8384d47142395fdb7d09001a7a2661cfbeb652465a700d7a29d945

See more details on using hashes here.

Provenance

The following attestation bundles were made for rheojax-0.2.2.tar.gz:

Publisher: release.yml on imewei/rheojax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rheojax-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: rheojax-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 250.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rheojax-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 75628dd19d126bd61608c00e5ff94d92a1ba945dafd0f153883b2e59a5b94b3c
MD5 e5881369d189ff9beec56ae4553e060c
BLAKE2b-256 d0f4d90ea0ba36ea94f0acd2c97609441ef69e910753277458d2f4617f03770c

See more details on using hashes here.

Provenance

The following attestation bundles were made for rheojax-0.2.2-py3-none-any.whl:

Publisher: release.yml on imewei/rheojax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page