Skip to main content

A Homomorphic Encryption implementation (CGGI) written in JAX

Project description

Jaxite

Jaxite is a fully homomorphic encryption backend targeting TPUs and GPUs, written in JAX.

It implements the CGGI cryptosystem with some optimizations, and is a supported backend for Google's FHE compiler.

Quick start

jaxite_bool

A program that shows how to use the jaxite_bool boolean gate API.

from jaxite import jaxite_bool

bool_params = jaxite_bool.bool_params

lwe_rng = bool_params.get_lwe_rng_for_128_bit_security(seed=1)
rlwe_rng = bool_params.get_rlwe_rng_for_128_bit_security(seed=1)
params = bool_params.get_params_for_128_bit_security()
cks = jaxite_bool.ClientKeySet(
    params,
    lwe_rng=lwe_rng,
    rlwe_rng=rlwe_rng,
)
sks = jaxite_bool.ServerKeySet(
    cks,
    params,
    lwe_rng=lwe_rng,
    rlwe_rng=rlwe_rng,
    bootstrap_callback=callback,
)

ct_true = jaxite_bool.encrypt(True, cks, lwe_rng)
ct_false = jaxite_bool.encrypt(False, cks, lwe_rng)

not_false = jaxite_bool.not_(ct_false, params)
or_false = jaxite_bool.or_(not_false, ct_false, sks, params)
and_true = jaxite_bool.and_(or_false, ct_true, sks, params)
xor_true = jaxite_bool.xor_(and_true, ct_true, sks, params)
actual = jaxite_bool.decrypt(xor_true, cks)

expected = (((not False) or False) and True) != True
assert actual == expected

Jaxite also supports higher-bit-width gates, which are called look-up tables (LUTs).

For example, this function is an 8-bit adder consisting entirely of lut3 gates.

def add_i8(
    x: List[types.LweCiphertext],
    y: List[types.LweCiphertext],
    sks: jaxite_bool.ServerKeySet,
    params: jaxite_bool.Parameters) -> List[types.LweCiphertext]:
  temp_nodes: Dict[int, types.LweCiphertext] = {}
  false = jaxite_bool.constant(False, params)
  out = [None] * 8
  temp_nodes[0] = jaxite_bool.lut3(x[0], y[0], false, 8, sks, params)
  temp_nodes[1] = jaxite_bool.lut3(temp_nodes[0], x[1], y[1], 23, sks, params)
  temp_nodes[2] = jaxite_bool.lut3(temp_nodes[1], x[2], y[2], 43, sks, params)
  temp_nodes[3] = jaxite_bool.lut3(temp_nodes[2], x[3], y[3], 43, sks, params)
  temp_nodes[4] = jaxite_bool.lut3(temp_nodes[3], x[4], y[4], 43, sks, params)
  temp_nodes[5] = jaxite_bool.lut3(temp_nodes[4], x[5], y[5], 43, sks, params)
  temp_nodes[6] = jaxite_bool.lut3(temp_nodes[5], x[6], y[6], 43, sks, params)
  out[0] = jaxite_bool.lut3(x[0], y[0], false, 6, sks, params)
  out[1] = jaxite_bool.lut3(temp_nodes[0], x[1], y[1], 150, sks, params)
  out[2] = jaxite_bool.lut3(temp_nodes[1], x[2], y[2], 105, sks, params)
  out[3] = jaxite_bool.lut3(temp_nodes[2], x[3], y[3], 105, sks, params)
  out[4] = jaxite_bool.lut3(temp_nodes[3], x[4], y[4], 105, sks, params)
  out[5] = jaxite_bool.lut3(temp_nodes[4], x[5], y[5], 105, sks, params)
  out[6] = jaxite_bool.lut3(temp_nodes[5], x[6], y[6], 105, sks, params)
  out[7] = jaxite_bool.lut3(temp_nodes[6], x[7], y[7], 105, sks, params)
  return out

On a platform with parallelism, jaxite uses JAX's pmap API to allow parallel evaluation of gates that have no interdependencies. E.g., the last eight gate operations of the i8 adder above could be rewritten as

 inputs = [
    (x[0], y[0], false, 6),            # out[0]
    (temp_nodes[0], x[1], y[1], 150),  # out[1]
    (temp_nodes[1], x[2], y[2], 105),  # out[2]
    (temp_nodes[2], x[3], y[3], 105),  # out[3]
    (temp_nodes[3], x[4], y[4], 105),  # out[4]
    (temp_nodes[4], x[5], y[5], 105),  # out[5]
    (temp_nodes[5], x[6], y[6], 105),  # out[6]
    (temp_nodes[6], x[7], y[7], 105),  # out[7]
  ]
  outputs = jaxite_bool.pmap_lut3(inputs, sks, params)
  return outputs

These circuits were generated with the jaxite support in Google's Fully Homomorphic Encryption Transpiler project, see transpiler/jaxite in that project for instructions.

Contributing

See CONTRIBUTING.md for details.

License

Apache 2.0; see LICENSE for details.

Disclaimer

This project is not an official Google project. It is not supported by Google and Google specifically disclaims all warranties as to its quality, merchantability, or fitness for a particular purpose.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxite-0.0.2.tar.gz (36.2 kB view details)

Uploaded Source

Built Distribution

jaxite-0.0.2-py3-none-any.whl (46.2 kB view details)

Uploaded Python 3

File details

Details for the file jaxite-0.0.2.tar.gz.

File metadata

  • Download URL: jaxite-0.0.2.tar.gz
  • Upload date:
  • Size: 36.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.24.1

File hashes

Hashes for jaxite-0.0.2.tar.gz
Algorithm Hash digest
SHA256 c163f2654f3220e9aa6fab94710bbf97edad67fc4177ec8ba833d9adc31d533a
MD5 be422f5b2a86a5ec16cf193a44d8ad37
BLAKE2b-256 60b5a7768066feb9f254364524c34b1a4f6f1a068130969a27e9cc3309a6689f

See more details on using hashes here.

File details

Details for the file jaxite-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: jaxite-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 46.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.24.1

File hashes

Hashes for jaxite-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0668e076167e3ce32a994ca7c029f7f0402becbb5e32052a7af390081fd08f89
MD5 7f6f610c56a0a565a2cd18af087cfdbc
BLAKE2b-256 436bcd24c27d865d9e0caa7e548733cf56420915d679fbda655e81b73efcab42

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page