Skip to main content

Natural language-driven JAX/Flax model building powered by Gemini

Project description

Jax-Gemini

Natural language-driven JAX/Flax model building powered by Google Gemini.

PyPI version Python Versions License


jax-gemini is a Python library that allows AI/ML practitioners, researchers, and domain experts to build, train, evaluate, and snapshot JAX/Flax neural network models using conversational plain English prompts without writing JAX boilerplate code.

Unlike traditional descriptive LLM wrappers, jax-gemini returns real, executable Python/JAX flax.nnx.Module objects. Your prompts dictate instructions; Gemini generates the respective code; jax-gemini runs it securely inside a sandboxed namespace and hands the resulting object back to you.

Features

  • Real Objects, Not Text: Every operation yields a genuine Python object (e.g. flax.nnx.Module), trained weights, or dictionaries natively integrated with JAX environments.
  • Conversational Memory: Retains the context of modifications multi-turn iteratively. Add dropout layers, reshape variables or tweak hyper-parameters step-by-step.
  • Safe Execution by Default: All parsed LLM code generation is restricted through an AST Abstract Syntax Tree Validator restricting arbitrary script execution and imports.
  • Auto-Healing: Encountering exceptions? jax-gemini features an adaptive auto-correction mechanism which relays failed code and Python tracebacks directly back to Gemini for automatic pipeline repair.

Quickstart

Installation

Install via PyPI:

pip install jax-gemini

Setup your Gemini API Key:

export GEMINI_API_KEY="your-api-key-here"

Conversational Model Building

import jax_gemini as jg
import numpy as np

# Note: jg automatically picks up GEMINI_API_KEY from env
jg.config.set({"model_name": "gemini-3.1-pro"})

# Build a model exclusively purely from text
model = jg.build("Build a 4-layer MLP for handwritten digit classification")

# Refine the architecture interactively
model = jg.modify("Ah, wait, add dropout with rate 0.2 between each layer for regularization")

# Train your model
X_train = np.random.randn(100, 28, 28, 1).astype(np.float32)
y_train = np.random.randint(0, 10, size=(100,))

model, metrics = jg.train(
    "Train for 10 epochs with Adam optimizer and Cross Entropy Loss",
    dataset=(X_train, y_train)
)
print(f"Accuracy: {metrics['accuracy']:.2%}")

# Checkpoint Persistence
checkpoint_path = jg.save("digit_classifier_v1")

See examples/ for more comprehensive workflows such as Jupyter Notebook deployments end-to-end setups.

Documentation

License and Contributing

This repository thrives on community input. Check out our Contribution Guidelines to log issues or prepare Pull Requests.

Distributed under the MIT License. See LICENSE for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_gemini-0.1.1.tar.gz (23.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_gemini-0.1.1-py3-none-any.whl (21.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_gemini-0.1.1.tar.gz.

File metadata

  • Download URL: jax_gemini-0.1.1.tar.gz
  • Upload date:
  • Size: 23.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for jax_gemini-0.1.1.tar.gz
Algorithm Hash digest
SHA256 de7d0f966188ff64a30304aa6b36bd9554c5f068fe102b797b2d9c5e2bc81118
MD5 0c01c765c479c43df2ac86d6f5d74b48
BLAKE2b-256 65d2dfc7a7ecb466ff155dbb019743105dc3922fc8ac847efe864d21a6281bbc

See more details on using hashes here.

File details

Details for the file jax_gemini-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: jax_gemini-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 21.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for jax_gemini-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bb46ccb541201e2ba3ac0d8e148cc72bd8606bdce6a4972ff8a5dc41144e92db
MD5 e1484575d6ae1da0aaea6bd8f57b8a04
BLAKE2b-256 edc093901fdd18ad478e69e2f68b02012fe4a6499454e3b027cf3c89208e6c51

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page