Skip to main content

jaxquantum

Project description

jaxquantum logo

License

code coverage tests ruff docs

S. R. Jha, S. Chowdhury, G. Rolleri, M. Hays, J. A. Grover, W. D. Oliver

Docs: equs.github.io/jaxquantum

Community Discord: discord.gg/frWqbjvZ4s

jaxquantum leverages JAX to enable the auto differentiable and (CPU, GPU, TPU) accelerated simulation of quantum dynamical systems, including tooling such as operator construction, unitary evolution and master equation solving. As such, jaxquantum serves as a QuTiP drop-in replacement written entirely in JAX.

Moreover, jaxquantum has recently absorbed bosonic and qcsys. As such, it is now a unified toolkit for quantum circuit design, simulation and control.

Installation

Installing from source (recommended)

Recommended: As this is a rapidly evolving project, we recommend installing the latest version of jaxquantum from source as follows:

pip install git+https://github.com/EQuS/jaxquantum.git

If you are installing on a GPU (NVIDIA, CUDA12), then run this instead:

pip install 'git+https://github.com/EQuS/jaxquantum.git#egg=jaxquantum[gpu]'

And, on a TPU, run this:

pip install 'git+https://github.com/EQuS/jaxquantum.git#egg=jaxquantum[tpu]'

If you face issues running JAX on your hardware, visit this page: https://docs.jax.dev/en/latest/installation.html

Installing from source in editable mode (recommended for developers)

If you are interested in contributing to the package, please clone this repository and install this package in editable mode after changing into the root directory of this repository:

pip install -e ".[dev,docs]"

This will also install extras from the dev and docs flags, which can be useful when developing the package. Since this is installed in editable mode, the package will automatically be updated after pulling new changes in the repository. Again, add the gpu or tpu extra, if needed.

Installing from PyPI (not recommended)

jaxquantum is also published on PyPI. Simply run the following code to install the package:

pip install jaxquantum

If you are installing on a GPU (NVIDIA, CUDA12), then run this instead:

pip install 'jaxquantum[gpu]'

And, on a TPU, run this:

pip install 'jaxquantum[tpu]'

If you face issues running JAX on your hardware, visit this page: https://docs.jax.dev/en/latest/installation.html

For more details, please visit the getting started > installation section of our docs.

Check Hardware

To check which hardware JAX is running on, run the following python code:

import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
print(x.device)

This will, for example, print out cuda:0 if running on a GPU.

An Example

Here's an example of how to set up a simulation in jaxquantum.

from jax import jit
import jaxquantum as jqt 
import jax.numpy as jnp
import matplotlib.pyplot as plt

N = 100

omega_a = 2.0*jnp.pi*5.0
kappa = 2*jnp.pi*jnp.array([1,2]) # Batching to explore two different kappa values!
initial_state = jqt.displace(N, 0.1) @ jqt.basis(N,0)
initial_state_dm = initial_state.to_dm()
ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101)

a = jqt.destroy(N)
n = a.dag() @ a

c_ops = jqt.Qarray.from_list([jnp.sqrt(kappa)*a])

@jit
def Ht(t):
    H0 = omega_a*n
    return H0

solver_options = jqt.SolverOptions.create(progress_meter=True)
states = jqt.mesolve(Ht, initial_state_dm, ts, c_ops=c_ops, solver_options=solver_options) 
nt = jnp.real(jqt.overlap(n, states))
a_real = jnp.real(jqt.overlap(a, states))
a_imag = jnp.imag(jqt.overlap(a, states))

fig, axs = plt.subplots(2,1, dpi=200, figsize=(6,5))
ax = axs[0]
ax.plot(ts, a_real[:,0], label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 0
ax.plot(ts, a_real[:,1], "--", label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 1
ax.plot(ts, a_imag[:,0], label=r"$Re[\langle a(t)\rangle]$", color="red") # Batch kappa value 0
ax.plot(ts, a_imag[:,1], "--", label=r"$Re[\langle a(t)\rangle]$", color="red") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()

ax = axs[1]
ax.plot(ts, nt[:,0], label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 0
ax.plot(ts, nt[:,1], "--", label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()
fig.tight_layout()

Output of above code.

Acknowledgements & History

Core Devs: Shantanu R. Jha, Shoumik Chowdhury, Gabriele Rolleri

This package was initially a small part of bosonic. In early 2022, jaxquantum was extracted and made into its own package. This package was briefly announced to the world at APS March Meeting 2023 and released to a select few academic groups shortly after. Since then, this package has been open sourced and developed while conducting research in the Engineering Quantum Systems Group at MIT with advice and support from Prof. William D. Oliver.

Citation

Thank you for taking the time to try our package out. If you found it useful in your research, please cite us as follows:

@software{jha2024jaxquantum,
  author = {Shantanu R. Jha and Shoumik Chowdhury and Gabriele Rolleri and Max Hays and Jeff A. Grover and William D. Oliver},
  title  = {JAXQuantum: An auto-differentiable and hardware-accelerated toolkit for quantum hardware design, simulation, and control},
  url    = {https://jaxquantum.org},
  version = {0.2.2},
  year   = {2025},
}

S. R. Jha, S. Chowdhury, G. Rolleri, M. Hays, J. A. Grover, and W. D. Oliver. "JAXQuantum: An auto-differentiable and hardware-accelerated toolkit for quantum hardware design, simulation, and control," jaxquantum.org (2025).

Contributions & Contact

This package is open source and, as such, very open to contributions. Please don't hesitate to open an issue, report a bug, request a feature, or create a pull request. We are also open to deeper collaborations to create a tool that is more useful for everyone. If a discussion would be helpful, please email shanjha@mit.edu to set up a meeting.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxquantum-0.3.0.tar.gz (333.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxquantum-0.3.0-py3-none-any.whl (352.4 kB view details)

Uploaded Python 3

File details

Details for the file jaxquantum-0.3.0.tar.gz.

File metadata

  • Download URL: jaxquantum-0.3.0.tar.gz
  • Upload date:
  • Size: 333.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxquantum-0.3.0.tar.gz
Algorithm Hash digest
SHA256 02d4fc890ad077696ab4dcdce891a813d1b8b3fac979d3193fa4a1420c79751e
MD5 da05331100d2df86427e5fe4f6410f85
BLAKE2b-256 cb14fc0f7d8c8bf73fcf087fdab170cca2a817e7125f5cc42583821056450505

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxquantum-0.3.0.tar.gz:

Publisher: publish.yml on EQuS/jaxquantum

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxquantum-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: jaxquantum-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 352.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxquantum-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 83035432802c325f0bcb7bc51ea3dbe94b96636b314a3a942c4b8a67aa212832
MD5 f466cf145eeb3be69779e3dbf955c708
BLAKE2b-256 2129575030881f8bc9e8140ba0b3ef41f4b1ede982ce44b781f139bdca043b6f

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxquantum-0.3.0-py3-none-any.whl:

Publisher: publish.yml on EQuS/jaxquantum

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page