Skip to main content

An simple JAX graphics library

Project description

JaxGL

JaxGL is a simple and flexible graphics library written entirely in JAX. JaxGL was created by Michael Matthews and Michael Beukman for the Kinetix project.

Basic Usage

# 512x512 pixels
screen_size = (512, 512)

# Clear a fresh screen with a black background
clear_colour = jnp.array([0.0, 0.0, 0.0])
pixels = clear_screen(screen_size, clear_colour)

# We render to a 256x256 'patch'
patch_size = (256, 256)
triangle_renderer = make_renderer(screen_size, fragment_shader_triangle, patch_size)

# Patch position (top left corner)
pos = jnp.array([128, 128])

triangle_data = (
    # Vertices (note these must be anti-clockwise)
    jnp.array([[150, 200], [150, 300], [300, 150]]),
    # Colour
    jnp.array([255.0, 0.0, 0.0]),
)

# Render the triangle to the screen
pixels = triangle_renderer(pixels, pos, triangle_data)

This produces the following image:

Custom Shaders

Arbitrary rendering effects can be achieved by writing your own shaders.

screen_size = (512, 512)

clear_colour = jnp.array([0.0, 0.0, 0.0])
pixels = clear_screen(screen_size, clear_colour)

patch_size = (256, 256)

# We make our own variation of the circle shader
# We give both a central and edge colour and interpolate between these

# Each fragment shader has access to
# position: global position in screen space
# current_frag: the current colour of the fragment (useful for transparency)
# unit_position: the position inside the patch (scaled to between 0 and 1)
# uniform: anything you want for your shader.  These are the same for every fragment.

def my_shader(position, current_frag, unit_position, uniform):
    centre, radius, colour_centre, colour_outer = uniform

    dist = jnp.sqrt(jnp.square(position - centre).sum())
    colour_interp = dist / radius

    colour = colour_interp * colour_outer + (1 - colour_interp) * colour_centre

    return jax.lax.select(dist < radius, colour, current_frag)

circle_renderer = make_renderer(screen_size, my_shader, patch_size)

# Patch position (top left corner)
pos = jnp.array([128, 128])

# This is the uniform that is passed to the shader
circle_data = (
    # Centre
    jnp.array([256.0, 256.0]),
    # Radius
    100.0,
    # Colour centre
    jnp.array([255.0, 0.0, 0.0]),
    # Colour outer
    jnp.array([0.0, 255.0, 0.0]),
)

# Render the triangle to the screen
pixels = circle_renderer(pixels, pos, circle_data)

In Kinetix

JaxGL is used for rendering in Kinetix. Shown below is an example robotics grasping task.

Installation

To use JaxGL in your work you can install via PyPi:

pip install jaxgl

If you want to extend JaxGL you can install as follows:

git clone https://github.com/FLAIROx/JaxGL
cd JaxGL
pip install -e ".[dev]"
pre-commit install

See Also

  • JAX Renderer A more complete JAX renderer more suitable for 3D rendering.
  • Jax2D 2D physics engine in JAX.
  • Kinetix physics-based reinforcement learning in JAX.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxgl-1.0.0.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

jaxgl-1.0.0-py3-none-any.whl (6.3 kB view details)

Uploaded Python 3

File details

Details for the file jaxgl-1.0.0.tar.gz.

File metadata

  • Download URL: jaxgl-1.0.0.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jaxgl-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3466034c5df6d50612813653dd85881a6f53e284065f77dd918c965f2b57dd91
MD5 12a3e84f33188817d53c2f255cf60e9a
BLAKE2b-256 4518937e51f09a2153ecc6f57c424d21444085c3120e10ba5e80f806ebbb17d8

See more details on using hashes here.

File details

Details for the file jaxgl-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: jaxgl-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jaxgl-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a6059a4aaa2d679ca84d76ba624a504552eb9d2ecb82685299ec2d10945bd09d
MD5 cfc8fc89a9e05211d63aacdfd3e37cec
BLAKE2b-256 052887d9b5852d8ee5ae39ce39eb307ffec4725147be12ec3cd36f94c65a8778

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page