Skip to main content

Jax implementation of rasterizer renderer.

Project description

JAX Renderer: Differentiable Rendering in Batch on Accelerators

PyPI Version Python Versions License Build & Publish Lint & Test Checked with pyright Code style: black Poetry Open in Colab

JaxRenderer is a differentiable renderer implemented in JAX, which supports differentiable rendering and batch rendering on accelerators (e.g. GPU, TPU) using simple function transformations provided by JAX. It is designed to replace by erwincoumans/tinyrenderer in BRAX to support visualising simulation results through fast rendering on accelerators with no external dependencies (other than JAX).

You may find the slides of my final year project presentation useful, where I gave a brief introduction to the renderer and the implementation details, including the design of the pipeline and comparing it with the OpenGL's.

Installation

This project is distributed in PyPI, and can be installed simply using pip:

pip install jaxrenderer

The minimum Python version is 3.8, and the minimum JAX version is 0.4.0. You may need to install jaxlib separately if you are using GPU or TPU; by default, the CPU version of jaxlib is installed. Please refer to JAX's installation guide for more details.

Usage

Please note that the package is imported with name renderer rather than the PyPI package name jaxrenderer. This may change in the future though.

Some example scripts are provided in examples folder. You may find the demo notebook useful as well. In the demo, there is batch rendering and differentiable rendering examples.

The following is a simple example of rendering a cube with a texture map:

import jax.numpy as jnp
import renderer


ImageWidth: int = 640
ImageHeight: int = 480

# Create a cube with texture map of pure blue
cube = renderer.create_cube(
    half_extents=jnp.ones(3, dtype=jnp.single),
    texture_scaling=jnp.ones(2, dtype=jnp.single),
    # pure blue texture map
    diffuse_map=jnp.zeros((2, 2, 3), dtype=jnp.single).at[..., 2].set(1),
    specular_map=jnp.ones((2, 2), dtype=jnp.single) * 2.0,
)

# Render the cube
image = renderer.Renderer.get_camera_image(
    objects=[renderer.ModelObject(model=cube)],
    # Simply use defaults
    camera=renderer.CameraParameters(
        viewWidth=ImageWidth,
        viewHeight=ImageHeight,
        position=jnp.array([2.0, 4.0, 1.0], dtype=jnp.single),
    ),
    # Simply use defaults
    light=renderer.LightParameters(),
    width=ImageWidth,
    height=ImageHeight,
)

You may refer to demo for more complex examples, including differentiable rendering and batch rendering.

Supported Shaders

Built-in Shaders

See renderer/shaders for more details.

Shader Name Description Light Direction
depth Depth Shader, outputs only screen-space depth value N.A.
gouraud Gouraud Shading, interpolates vertex colour and outputs it as fragment colour In model space
gouraud_texture Gouraud Shading with Texture, interpolates vertex colour and samples texture map in fragment shader In model space
phong Phong Shading, interpolates vertex normal and computes light direction in fragment shader In eye space, like "head light"
phong_darboux Phong Shading with Normal Map in Tangent Space, interpolates vertex normal and computes light direction in fragment shader, and samples normal map in tangent space In eye space, like "head light"
phong_reflection Phong Shading with Phong Reflection Approximation, interpolates vertex normal and computes light direction in fragment shader, and samples texture map and specular map in fragment shader In eye space
phong_reflection_shadow Phong Shading with Phong Reflection Approximation and Shadow, interpolates vertex normal and computes light direction in fragment shader, samples texture map and specular map in fragment shader, and tests shadow in fragment shader In eye space

Custom Shaders

You may implement your own shaders by inheriting from Shader and implement the following methods:

  • vertex: this is like vertex shader in OpenGL; it must be overridden.
  • primitive_chooser: at this stage the visibility at each pixel level is tested, it works like pre-z test in OpenGL, makes the pipeline works like a deferred shading pipeline. Noted that you may override and return more than one primitive for each pixel to support transparency. The default implementation simply chooses the primitive with minimum z value (depth).
  • interpolate: this controls how attributes associated with a fragment is interpolated from the vertices; the default implementation interpolates everything using perspective interpolation.
  • fragment: this is like fragment shader in OpenGL; a default implementation is provided if you do not need to process any data in the fragment shader.
  • mix: this is like blending stage in OpenGL; the default implementation simple uses the data from the fragment with minimum screen-space z value (depth).

Gallery

Batch Rendering Example, 30 Ants inference on A100 GPU with 90 iterations, rendered onto 84x84 canvas in 5.26s

Batch Rendering Example, 30 Ants inference on A100 GPU with 90 iterations, rendered onto 84x84 canvas in 5.26s.

Phong Reflection Model + Hard Shadow, 30 frames 1920x1080, 2492 triangles in 9.25s

Phong Reflection Model + Hard Shadow, 30 frames 1920x1080, 2492 triangles in 9.25s.

Differentiable Rendering Toy Example, deduce light colour parameters

Differentiable Rendering Toy Example, deduce light colour parameters.

Key Difference from erwincoumans/tinyrenderer

  • Native JAX implementation, supports jit, vmap, grad, etc.
  • Lighting is computed in main camera's eye space; while in PyTinyrenderer it is computed in world space.
  • Texture specification is different: in PyTinyrenderer, the texture is specified in a flattened array, while in JAX Renderer, the texture is specified in a shape of (width, height, colour channels). A simple way to transform old specification to new specification is to use the convenient method build_texture_from_PyTinyrenderer.
  • Rendering pipeline is different. PyTinyrenderer renders one object at a time, and share zbuffer and framebuffer across one pass. This renderer first merges all objects into one big mesh in world space, then process all vertices together, then interpolates and rasterise and render. For fragment shading, this is done by sweeping each row in a for loop, and batch compute all pixels together. For computing a pixel, all fragments for that pixels are batch compute together, then mixed. This is more memory efficient and allows vmap batching as far as possible.
  • Shadowing within the same object / mesh is allowed. This is not possible in PyTinyrenderer, as it deliberately checks if the shadow comes from the same object; if so, it will not consider to draw a shadow there.
  • Quaternion (for specifying rotation/orientation) is in the form of (w, x, y, z) instead of (x, y, z, w) in PyTinyrenderer. This is for consistency with BRAX.
  • No clipping is performed. To ensure correct rendering of objects with vertices at or behind camera plane, homogeneous interpolation (Olano and Greer, 1997)[^1] is used to avoid the need of homogeneous division.
  • Fix bugs
    • Specular lighting was wrong, where it forgets to reverse the light direction vector.

[^1]: Marc Olano and Trey Greer. 1997. Triangle Scan Conversion Using 2D Homogeneous Coordinates. In Proceedings of the ACM SIGGRAPH/EUROGRAPHICS Workshop on Graphics Hardware (HWWS ’97). ACM, New York, NY, USA, 89–95.

Roadmap

  • Support double-sided objects
  • Profile and accelerate implementation
  • Build a ray tracer as well
  • Differentiable rendering with respect to mesh
  • Differentiable rendering with respect to light parameters
  • Differentiable rendering with respect to camera parameters (not tested)
  • Correctly implement a proper clipping algorithm

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxrenderer-0.3.2.tar.gz (62.5 kB view details)

Uploaded Source

Built Distribution

jaxrenderer-0.3.2-py3-none-any.whl (85.4 kB view details)

Uploaded Python 3

File details

Details for the file jaxrenderer-0.3.2.tar.gz.

File metadata

  • Download URL: jaxrenderer-0.3.2.tar.gz
  • Upload date:
  • Size: 62.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.11.4 Linux/5.15.0-1041-azure

File hashes

Hashes for jaxrenderer-0.3.2.tar.gz
Algorithm Hash digest
SHA256 b7e6eaf66492d6ef1e12d96506017b517228219c3f5bcc6bb3d0e7589d2d0990
MD5 049aced1596eb17da22c3fe678b578c5
BLAKE2b-256 29297666c25b9e7233fc8381dad274e603c0888516a4ceaa5dbb61edc3ca0fef

See more details on using hashes here.

File details

Details for the file jaxrenderer-0.3.2-py3-none-any.whl.

File metadata

  • Download URL: jaxrenderer-0.3.2-py3-none-any.whl
  • Upload date:
  • Size: 85.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.11.4 Linux/5.15.0-1041-azure

File hashes

Hashes for jaxrenderer-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ab3fb19a5f1707fa75a36b63ea864883963387ec81a211730b81350aa352e191
MD5 cd97be413b49fad3b8d649f418f2298f
BLAKE2b-256 6daf7023654173446b07da77b216c9fea39e10de05acd7a5e235971c49e4633c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page