Natural language-driven JAX/Flax model building powered by Gemini
Project description
jax-gemini is a Python library that allows AI/ML practitioners, researchers, and domain experts to build, train, evaluate, and snapshot JAX/Flax neural network models using conversational plain English prompts without writing JAX boilerplate code.
Unlike traditional descriptive LLM wrappers, jax-gemini returns real, executable Python/JAX flax.nnx.Module objects. Your prompts dictate instructions; Gemini generates the respective code; jax-gemini runs it securely inside a sandboxed namespace and hands the resulting object back to you.
Features
- Real Objects, Not Text: Every operation yields a genuine Python object (e.g.
flax.nnx.Module), trained weights, or dictionaries natively integrated with JAX environments. - Conversational Memory: Retains the context of modifications multi-turn iteratively. Add dropout layers, reshape variables or tweak hyper-parameters step-by-step.
- Safe Execution by Default: All parsed LLM code generation is restricted through an AST Abstract Syntax Tree Validator restricting arbitrary script execution and imports.
- Auto-Healing: Encountering exceptions? jax-gemini features an adaptive auto-correction mechanism which relays failed code and Python tracebacks directly back to Gemini for automatic pipeline repair.
Quickstart
Installation
Install via PyPI:
pip install jax-gemini
Setup your Gemini API Key:
export GEMINI_API_KEY="your-api-key-here"
Conversational Model Building
import jax_gemini as jg
import numpy as np
# Note: jg automatically picks up GEMINI_API_KEY from env
jg.config.set({"model_name": "gemini-3.1-pro"})
# Build a model exclusively purely from text
model = jg.build("Build a 4-layer MLP for handwritten digit classification")
# Refine the architecture interactively
model = jg.modify("Ah, wait, add dropout with rate 0.2 between each layer for regularization")
# Train your model
X_train = np.random.randn(100, 28, 28, 1).astype(np.float32)
y_train = np.random.randint(0, 10, size=(100,))
model, metrics = jg.train(
"Train for 10 epochs with Adam optimizer and Cross Entropy Loss",
dataset=(X_train, y_train)
)
print(f"Accuracy: {metrics['accuracy']:.2%}")
# Checkpoint Persistence
checkpoint_path = jg.save("digit_classifier_v1")
See examples/ for more comprehensive workflows such as Jupyter Notebook deployments end-to-end setups.
Documentation
- Quickstart Guide: Extensive introduction and core philosophies.
- API Reference: Detailed API coverage and methods parameters.
- Architecture Insights: Design elements under the hood.
- Security Protocol: Code validation and restricted runtime policies.
License and Contributing
This repository thrives on community input. Check out our Contribution Guidelines to log issues or prepare Pull Requests.
Distributed under the MIT License. See LICENSE for more information.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_gemini-0.1.1.tar.gz.
File metadata
- Download URL: jax_gemini-0.1.1.tar.gz
- Upload date:
- Size: 23.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
de7d0f966188ff64a30304aa6b36bd9554c5f068fe102b797b2d9c5e2bc81118
|
|
| MD5 |
0c01c765c479c43df2ac86d6f5d74b48
|
|
| BLAKE2b-256 |
65d2dfc7a7ecb466ff155dbb019743105dc3922fc8ac847efe864d21a6281bbc
|
File details
Details for the file jax_gemini-0.1.1-py3-none-any.whl.
File metadata
- Download URL: jax_gemini-0.1.1-py3-none-any.whl
- Upload date:
- Size: 21.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bb46ccb541201e2ba3ac0d8e148cc72bd8606bdce6a4972ff8a5dc41144e92db
|
|
| MD5 |
e1484575d6ae1da0aaea6bd8f57b8a04
|
|
| BLAKE2b-256 |
edc093901fdd18ad478e69e2f68b02012fe4a6499454e3b027cf3c89208e6c51
|