Skip to main content

An simple JAX graphics library

Project description

JaxGL

JaxGL is a simple and flexible graphics library written entirely in JAX. JaxGL was created by Michael Matthews and Michael Beukman for the Kinetix project.

💻 Basic Usage

# 512x512 pixels
screen_size = (512, 512)

# Clear a fresh screen with a black background
clear_colour = jnp.array([0.0, 0.0, 0.0])
pixels = clear_screen(screen_size, clear_colour)

# We render to a 256x256 'patch'
patch_size = (256, 256)
triangle_renderer = make_renderer(screen_size, fragment_shader_triangle, patch_size)

# Patch position (top left corner)
pos = jnp.array([128, 128])

triangle_data = (
    # Vertices (note these must be anti-clockwise)
    jnp.array([[150, 200], [150, 300], [300, 150]]),
    # Colour
    jnp.array([255.0, 0.0, 0.0]),
)

# Render the triangle to the screen
pixels = triangle_renderer(pixels, pos, triangle_data)

This produces the following image:

👨‍💻 Custom Shaders

Arbitrary rendering effects can be achieved by writing your own shaders.

screen_size = (512, 512)

clear_colour = jnp.array([0.0, 0.0, 0.0])
pixels = clear_screen(screen_size, clear_colour)

patch_size = (256, 256)

# We make our own variation of the circle shader
# We give both a central and edge colour and interpolate between these

# Each fragment shader has access to
# position: global position in screen space
# current_frag: the current colour of the fragment (useful for transparency)
# unit_position: the position inside the patch (scaled to between 0 and 1)
# uniform: anything you want for your shader.  These are the same for every fragment.

def my_shader(position, current_frag, unit_position, uniform):
    centre, radius, colour_centre, colour_outer = uniform

    dist = jnp.sqrt(jnp.square(position - centre).sum())
    colour_interp = dist / radius

    colour = colour_interp * colour_outer + (1 - colour_interp) * colour_centre

    return jax.lax.select(dist < radius, colour, current_frag)

circle_renderer = make_renderer(screen_size, my_shader, patch_size)

# Patch position (top left corner)
pos = jnp.array([128, 128])

# This is the uniform that is passed to the shader
circle_data = (
    # Centre
    jnp.array([256.0, 256.0]),
    # Radius
    100.0,
    # Colour centre
    jnp.array([255.0, 0.0, 0.0]),
    # Colour outer
    jnp.array([0.0, 255.0, 0.0]),
)

# Render the triangle to the screen
pixels = circle_renderer(pixels, pos, circle_data)

🔄 In Kinetix

JaxGL is used for rendering in Kinetix. Shown below is an example robotics grasping task.

⬇️ Installation

To use JaxGL in your work you can install via PyPi:

pip install jaxgl

If you want to extend JaxGL you can install as follows:

git clone https://github.com/FLAIROx/JaxGL
cd JaxGL
pip install -e ".[dev]"
pre-commit install

🔍 See Also

  • JAX Renderer A more complete JAX renderer more suitable for 3D rendering.
  • Jax2D 2D physics engine in JAX.
  • Kinetix physics-based reinforcement learning in JAX.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxgl-1.0.1.tar.gz (6.5 kB view details)

Uploaded Source

Built Distribution

jaxgl-1.0.1-py3-none-any.whl (6.8 kB view details)

Uploaded Python 3

File details

Details for the file jaxgl-1.0.1.tar.gz.

File metadata

  • Download URL: jaxgl-1.0.1.tar.gz
  • Upload date:
  • Size: 6.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jaxgl-1.0.1.tar.gz
Algorithm Hash digest
SHA256 051a1975f2e1932c64b98e9770473ed5c102387f69fa2edded214439de8688e0
MD5 2cfe88cc8463ec2d5c25f38338cea6b6
BLAKE2b-256 504dbfa0f8bf4fd3e85312a77cacba2a507d22feb211e9151bc4ea5da14e6050

See more details on using hashes here.

File details

Details for the file jaxgl-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: jaxgl-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 6.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jaxgl-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b9940ef8a578e4571a4a8ffcdb1e1a2b686df7f1cd46cb6e680437edbffe947e
MD5 7565e739540785cd6ac841933df6d579
BLAKE2b-256 f3e13ac61628a0cc048a83a088a67603486d7fbfd68ee0e8f8a22b41a0ddecf6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page