Skip to main content

Unified rheological analysis framework with JAX acceleration

Project description

RheoJAX - JAX-Powered Rheological Analysis

CI PyPI version Python 3.12+ License: MIT Documentation

JAX-accelerated package for rheological data analysis. Provides 20 rheological models, 5 data transforms, Bayesian inference via NumPyro, and 24 tutorial notebooks.

Features

Rheological analysis toolkit with Bayesian inference and parameter optimization:

Core Capabilities

  • 20 Rheological Models: Classical (Maxwell, Zener, SpringPot), Fractional (11 variants), Flow (6 models)
  • 5 Data Transforms: FFT, Mastercurve (TTS), Mutation Number, OWChirp (LAOS), Smooth Derivative
  • Model-Data Compatibility Checking: Detects when models are inappropriate for data based on physics (exponential vs power-law decay, material type classification)
  • Bayesian Inference: All 20 models support NumPyro NUTS sampling with NLSQ warm-start
  • Pipeline API: Fluent interface for load → fit → plot → save workflows
  • Automatic Initialization: Parameter initialization for fractional models in oscillation mode
  • JAX-First Architecture: 5-270x performance improvement with automatic differentiation and GPU support

Data & I/O

  • Data Support: Automatic test mode detection (relaxation, creep, oscillation, rotation)
  • File Formats: TRIOS, CSV, Excel, Anton Paar with format auto-detection
  • Parameter System: Type-safe parameter management with bounds and constraints

Visualization & Diagnostics

  • Visualization: Three built-in styles (default, publication, presentation)
  • ArviZ Diagnostic Suite: 6 plot types (pair, forest, energy, autocorr, rank, ESS) for MCMC quality
  • Plugin System: Support for custom models and transforms

Tutorial Notebooks & Examples

  • 24 Tutorial Notebooks: Organized in 4 categories
    • examples/basic/ - 5 notebooks covering fundamental models
    • examples/transforms/ - 6 notebooks for data transforms and analysis
    • examples/bayesian/ - 6 notebooks for Bayesian inference workflows
    • examples/advanced/ - 7 notebooks for production patterns
  • I/O Examples: TRIOS complex modulus handling and plotting

Installation

Requirements

  • Python 3.12 or later (3.8-3.11 are NOT supported due to JAX 0.8.0 requirements)
  • JAX and jaxlib for acceleration
  • NLSQ for GPU-accelerated optimization
  • NumPyro for Bayesian inference
  • ArviZ for Bayesian diagnostics

Basic Installation

pip install rheojax

Development Installation

git clone https://github.com/imewei/rheojax.git
cd rheojax
pip install -e ".[dev]"

GPU Installation (Linux Only)

Performance Impact: 20-100x speedup for large datasets (>10K points)

Option 1: Install via Makefile

From the repository:

git clone https://github.com/imewei/rheojax.git
cd rheojax
make install-jax-gpu  # Handles uninstall + GPU install

This command:

  • Uninstalls CPU-only JAX
  • Installs GPU-enabled JAX with CUDA 12 support
  • Verifies GPU detection

Option 2: Manual Installation

For GPU-accelerated computation on Linux systems with CUDA 12+:

# Step 1: Uninstall CPU-only version
pip uninstall -y jax jaxlib

# Step 2: Install JAX with CUDA support
pip install jax[cuda12-local]==0.8.0 jaxlib==0.8.0

# Step 3: Verify GPU detection
python -c "import jax; print('Devices:', jax.devices())"
# Should show: [cuda(id=0)] instead of [CpuDevice(id=0)]

Why separate installation? JAX with CUDA support is Linux-specific and requires system CUDA 12.1-12.9 pre-installed. Separating the installation avoids dependency conflicts on macOS/Windows.

GPU Troubleshooting

Issue: Warning "An NVIDIA GPU may be present... but a CUDA-enabled jaxlib is not installed"

Solution:

# 1. Check GPU hardware
nvidia-smi  # Should show your GPU

# 2. Check CUDA version
nvcc --version  # Should show CUDA 12.1-12.9

# 3. Reinstall JAX with GPU support
pip uninstall -y jax jaxlib
pip install jax[cuda12-local]==0.8.0 jaxlib==0.8.0

# 4. Verify JAX detects GPU
python -c "import jax; print(jax.devices())"
# Expected: [cuda(id=0)]
# If still showing [CpuDevice(id=0)], check CUDA installation

Issue: ImportError or CUDA library not found

Solution:

# Set CUDA library path (add to ~/.bashrc for permanent fix)
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH

Platform Support

  • Linux + NVIDIA GPU + CUDA 12.1-12.9: Full GPU acceleration (20-100x speedup)
  • macOS: CPU-only (Apple Silicon/Intel, no NVIDIA GPU support)
  • Windows: CPU-only (CUDA support experimental/unstable)

Requirements (Linux GPU):

  • System CUDA 12.1-12.9 pre-installed
  • NVIDIA driver >= 525
  • Linux x86_64 or aarch64

Conda/Mamba Users

The package works in conda environments using pip:

conda create -n rheojax python=3.12
conda activate rheojax
pip install rheojax

# For GPU acceleration (Linux only)
git clone https://github.com/imewei/rheojax.git
cd rheojax
make install-jax-gpu

Note: Conda extras syntax (conda install rheojax[gpu]) is not supported. Use the Makefile or manual pip installation method above.

Quick Start

Loading and Visualizing Data

from rheojax.io.readers import auto_read
from rheojax.visualization import plot_rheo_data
import matplotlib.pyplot as plt

# Load data (auto-detect format)
data = auto_read("stress_relaxation.txt")

# Check detected test mode
print(f"Test mode: {data.test_mode}")  # Output: relaxation

# Visualize
fig, ax = plot_rheo_data(data, style='publication')
plt.show()

Basic Model Fitting

from rheojax.models.maxwell import Maxwell
import numpy as np

# Generate or load data
t = np.logspace(-2, 2, 50)
G_data = 1e5 * np.exp(-t / 0.01) + np.random.normal(0, 1e3, 50)

# Fit with NLSQ (5-270x faster than scipy)
model = Maxwell()
model.fit(t, G_data)

print(f"G0 = {model.parameters.get_value('G0'):.3e} Pa")
print(f"eta = {model.parameters.get_value('eta'):.3e} Pa·s")

Bayesian Inference Workflow

from rheojax.models.maxwell import Maxwell
import numpy as np

# Create model and data
model = Maxwell()
t = np.logspace(-2, 2, 50)
G_data = 1e5 * np.exp(-t / 0.01) + np.random.normal(0, 1e3, 50)

# Step 1: NLSQ optimization (fast point estimate)
model.fit(t, G_data)
print(f"NLSQ: G0={model.parameters.get_value('G0'):.3e}")

# Step 2: Bayesian inference with warm-start
result = model.fit_bayesian(
    t, G_data,
    num_warmup=1000,
    num_samples=2000
)

# Step 3: Analyze results
print(f"Posterior mean: G0={result.summary['G0']['mean']:.3e} ± {result.summary['G0']['std']:.3e}")
print(f"Convergence: R-hat={result.diagnostics['r_hat']['G0']:.4f}, ESS={result.diagnostics['ess']['G0']:.0f}")

# Get credible intervals
intervals = model.get_credible_intervals(result.posterior_samples, credibility=0.95)
print(f"G0 95% CI: [{intervals['G0'][0]:.3e}, {intervals['G0'][1]:.3e}]")

Bayesian Pipeline with ArviZ Diagnostics

from rheojax.pipeline.bayesian import BayesianPipeline

pipeline = BayesianPipeline()

# Fluent API: load → fit_nlsq → fit_bayesian → plot → save
(pipeline
    .load('data.csv', x_col='time', y_col='stress')
    .fit_nlsq('maxwell')
    .fit_bayesian(num_samples=2000, num_warmup=1000)
    .plot_posterior()
    .plot_trace()
    .save('results.hdf5'))

# ArviZ diagnostic plots (MCMC quality assessment)
(pipeline
    .plot_pair(divergences=True)        # Parameter correlations with divergences
    .plot_forest(hdi_prob=0.95)         # Credible intervals comparison
    .plot_energy()                       # NUTS energy diagnostic
    .plot_autocorr()                     # Mixing diagnostic
    .plot_rank()                         # Convergence diagnostic
    .plot_ess(kind='local'))            # Effective sample size

Reference: See Bayesian Quick Start Guide for:

  • When and why to use Bayesian inference
  • NLSQ → NUTS → ArviZ workflow walkthrough
  • Troubleshooting convergence issues
  • Best practices checklist
  • Runnable demo: python examples/bayesian_workflow_demo.py

Model-Data Compatibility Checking

RheoJAX detects when models are inappropriate for data based on physics:

from rheojax.models.fractional_zener_ss import FractionalZenerSolidSolid
from rheojax.utils.compatibility import check_model_compatibility, format_compatibility_message
import numpy as np

# Generate exponential decay data
t = np.logspace(-2, 2, 50)
G_t = 1e5 * np.exp(-t / 1.0)

# Check compatibility before fitting
model = FractionalZenerSolidSolid()
compat = check_model_compatibility(
    model, t=t, G_t=G_t, test_mode='relaxation'
)

# Get report
print(format_compatibility_message(compat))
# Output:
# ⚠ Model may not be appropriate for this data
#   Confidence: 90%
#   Detected decay: exponential
#   Material type: viscoelastic_liquid
#
# Warnings:
#   • FZSS model expects Mittag-Leffler (power-law) relaxation,
#     but data shows exponential decay.
#
# Recommended alternative models:
#   • Maxwell
#   • Zener

# Or enable checking during fit
model.fit(t, G_t, check_compatibility=True)  # Warns if incompatible

Features:

  • Detects decay type (exponential, power-law, stretched, Mittag-Leffler)
  • Classifies material type (solid, liquid, gel, viscoelastic)
  • Provides model recommendations when incompatible
  • Error messages explain physics mismatches

Reference: Model Selection Guide for decision flowcharts and model characteristics.

Working with Parameters

from rheojax.core import ParameterSet

# Create parameter set
params = ParameterSet()
params.add("E", value=1000.0, bounds=(100, 10000), units="Pa")
params.add("tau", value=1.0, bounds=(0.1, 100), units="s")

# Get/set values
E = params.get_value("E")
params.set_value("tau", 2.5)

Data Transforms

from rheojax.transforms import FFTAnalysis, Mastercurve, MutationNumber

# FFT analysis for frequency spectrum
fft = FFTAnalysis(window='hann', detrend=True)
freq_data = fft.transform(data)
tau_char = fft.get_characteristic_time(freq_data)

# Time-temperature superposition (mastercurves)
mc = Mastercurve(reference_temp=298.15, method='wlf')

# Option 1: Create mastercurve (basic)
mastercurve = mc.create_mastercurve(datasets)

# Option 2: Transform with shift factors (for plotting)
mastercurve, shift_factors = mc.transform(datasets)

# Get parameters and arrays for analysis
wlf_params = mc.get_wlf_parameters()
temps, shifts = mc.get_shift_factors_array()

# Mutation number (viscoelastic character)
mn = MutationNumber()
delta = mn.calculate(data)  # 0=elastic, 1=viscous

Tutorial Notebooks

24 tutorial notebooks organized by topic:

examples/
├── basic/                       # 5 notebooks: Fundamental models
│   ├── 01-maxwell-fitting.ipynb
│   ├── 02-zener-fitting.ipynb
│   ├── 03-springpot-fitting.ipynb
│   ├── 04-bingham-fitting.ipynb
│   └── 05-power-law-fitting.ipynb
├── transforms/                  # 6 notebooks: Data analysis workflows
│   ├── 01-fft-analysis.ipynb
│   ├── 02-mastercurve-tts.ipynb
│   ├── 02b-mastercurve-wlf-validation.ipynb
│   ├── 03-mutation-number.ipynb
│   ├── 04-owchirp-laos-analysis.ipynb
│   └── 05-smooth-derivative.ipynb
├── bayesian/                    # 6 notebooks: Bayesian inference
│   ├── 01-bayesian-basics.ipynb
│   ├── 02-prior-selection.ipynb
│   ├── 03-convergence-diagnostics.ipynb
│   ├── 04-model-comparison.ipynb
│   ├── 05-uncertainty-propagation.ipynb
│   └── 06-bayesian_workflow_demo.ipynb
├── advanced/                    # 7 notebooks: Production patterns
│   ├── 01-multi-technique-fitting.ipynb
│   ├── 02-batch-processing.ipynb
│   ├── 03-custom-models.ipynb
│   ├── 04-fractional-models-deep-dive.ipynb
│   ├── 05-performance-optimization.ipynb
│   ├── 06-frequentist-model-selection.ipynb
│   └── 07-trios_chunked_reading_example.ipynb
└── io/                          # I/O demonstrations
    └── plot_trios_complex_modulus.ipynb

See examples/README.md for learning path guide.

Documentation

Documentation: https://rheojax.readthedocs.io

Key Topics

Performance

NLSQ Optimization Performance

NLSQ performance compared to scipy:

Dataset Size scipy (Powell) NLSQ (JAX) Speedup
50 points 180 ms 35 ms 5x
500 points 920 ms 48 ms 19x
5000 points 8.2 s 95 ms 86x
50000 points 94 s 350 ms 270x

Bayesian Warm-Start Performance

NLSQ → NUTS warm-start improves MCMC convergence:

Method Convergence Time Divergences ESS/sec
Cold start (random init) 45s 15% 44
NLSQ warm-start 18s 0.2% 111
Improvement 2.5x faster 75x fewer 2.5x higher

Benchmarks on M1 MacBook Pro. GPU acceleration provides additional 5-20x speedups for large datasets.

Contributing

Contributions are accepted. See Contributing Guide for details.

Development Setup

# Clone repository
git clone https://github.com/imewei/rheojax.git
cd rheojax

# Install development dependencies
pip install -e ".[dev]"

# Install pre-commit hooks
pre-commit install

# Run tests
pytest

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use rheo in your research, please cite:

@software{rheo2024,
  title = {Rheo: JAX-Powered Unified Rheology Package with Bayesian Inference},
  year = {2024},
  author = {Wei Chen},
  url = {https://github.com/imewei/rheojax},
  version = {0.2.0}
}

Acknowledgments

Built on open-source software:

  • JAX for automatic differentiation and acceleration
  • NLSQ for GPU-accelerated nonlinear least squares
  • NumPyro for probabilistic programming
  • ArviZ for Bayesian visualization
  • NumPy and SciPy for numerical computing
  • matplotlib for visualization

Support

Roadmap

See CHANGELOG.md for development history and examples/ for tutorial notebooks.


Wei Chen

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rheojax-0.2.1.tar.gz (173.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rheojax-0.2.1-py3-none-any.whl (233.4 kB view details)

Uploaded Python 3

File details

Details for the file rheojax-0.2.1.tar.gz.

File metadata

  • Download URL: rheojax-0.2.1.tar.gz
  • Upload date:
  • Size: 173.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rheojax-0.2.1.tar.gz
Algorithm Hash digest
SHA256 c887d9867065c0c2361919e28375338f72cfb8386ba6b690df2ed87b7f4a8886
MD5 8030c4165c012b106b63751d4b553135
BLAKE2b-256 e83c5ba82f8dfc65e8a777805283d1c2d33f3b1f430ab2fddc8b37d7f93cd036

See more details on using hashes here.

Provenance

The following attestation bundles were made for rheojax-0.2.1.tar.gz:

Publisher: release.yml on imewei/rheojax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rheojax-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: rheojax-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 233.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rheojax-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7330f7da1018406eb74f4463ea49aa67b98f3b04edee3f3de9f83dc643ce80c5
MD5 a80f385dd66e632eaa5ed174dd3f80f4
BLAKE2b-256 898eb5d1c109be33c24a76b67566d2719bb37412a0fcd590eaf01f089ea3adf4

See more details on using hashes here.

Provenance

The following attestation bundles were made for rheojax-0.2.1-py3-none-any.whl:

Publisher: release.yml on imewei/rheojax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page