Skip to main content

Easy inference using Google's Vertex AI and Gemini models

Project description

EasyInference

Version License

A robust system for large-scale text inference using Vertex AI (Gemini).

This repository hosts a modular framework to orchestrate large-scale batch and live inference requests to Gemini models.

๐Ÿ“‘ Contents


๐Ÿ“ฆ Installation

You can install the package directly from GitHub:

pip install git+https://github.com/ericzhao28/easyinference.git

๐Ÿš€ Quick Start

  1. Set up your credentials for GCP and Vertex AI:
gcloud auth application-default login
  1. Configure the necessary environment variables:
# Google Cloud Platform Configuration
export GCP_PROJECT_ID="your-project-id"
export GCP_PROJECT_NUM="123456789012"
export GCP_REGION="us-central1"
export VERTEX_BUCKET="your-gcs-bucket"

# SQL Configuration
export TABLE_NAME="your-table"
export SQL_DATABASE_NAME="your-database"
export SQL_USER="db-user"
export SQL_PASSWORD="your-password"
export SQL_INSTANCE_CONNECTION_NAME="project-id:region:instance-name"
export POOL_SIZE="50"

# Additional Configuration
export COOLDOWN_SECONDS="1.0"
export MAX_RETRIES="8"
export BATCH_TIMEOUT_HOURS="3"
export ROUND_ROBIN_ENABLED="false"

Alternatively, you can use the provided example.env file:

  • Copy example.env to .env
  • Update the values in .env with your configuration
  • Use python-dotenv to load these variables in your code
  • Make sure you set your environment variables before importing easyinference. Otherwise, you should run easyinference.reload_config() after setting your environment variables.
  1. Initialize the database connection:
from easyinference import initialize_query_connection

# Initialize the database connection before using any inference functions
initialize_query_connection()
  1. Import and use the package:
from dotenv import load_dotenv  # pip install python-dotenv

# Load environment variables from .env file (if using this approach)
load_dotenv()

from easyinference import inference, individual_inference, run_clearing_inference, reload_config, initialize_query_connection

# Initialize the database connection
initialize_query_connection()

โš™๏ธ Core Functions

1. inference()

Main async function for batch processing multiple datapoints
async def inference(
    prompt_functions: List[Callable[[Any], str]],  # Functions that convert datapoints to prompt text
    datapoints: List[Any],                         # List of data items to process
    tags: Optional[List[str]] = None,              # Identifier tags for tracking
    duplication_indices: Optional[List[int]] = None, # Indices for running datapoints multiple times
    run_fast: bool = True,                         # If True, makes direct API calls; if False, queues for batch
    allow_failure: bool = False,                   # If True, continues after max retries with error messages
    attempts_cap: int = 8,                         # Maximum number of retry attempts
    temperature: float = 0,                        # Temperature parameter for generation
    max_output_tokens: int = 8192,                 # Maximum tokens to generate in response
    system_prompt: str = "",                       # System prompt to guide model behavior
    model: str = "publishers/google/models/gemini-1.5-flash-002", # Generative model to use
    batch_size: int = 1000,                        # Max concurrent requests or batch job size
    run_fast_timeout: float = 200,                 # Timeout in seconds for fast mode calls
    cooldown_seconds: float = 1.0,                 # Base wait time between retries
    batch_timeout_hours: int = 3,                  # Max runtime before restarting
    round_robin_enabled: bool = False,             # Whether to cycle through regions
    round_robin_options: List[str] = ["us-central1", "us-west1", "us-east1", "us-west4", "us-east4", "us-east5", "us-south1"], # Region options for cycling
    initial_histories: Optional[List[dict]] = None,   # Starting conversation histories for the inference sessions
) -> tuple[List[tuple], str]                       # Returns ([[[response 1, response 2, ...], [query 1, query 2, ...]], ... for each datapoint], launch_timestamp_tag)

2. individual_inference()

For processing a single datapoint through multiple prompt functions
async def individual_inference(
    prompt_functions: List[Callable[[Any], str]],  # Functions that convert datapoint to prompt text
    datapoint: Any,                                # Data to process
    tags: Optional[List[str]] = None,              # Identifier tags for tracking
    optional_tags: Optional[List[str]] = None,     # Additional tags not used for lookup
    duplication_index: int = 0,                    # Index to distinguish duplicate runs
    run_fast: bool = True,                         # If True, makes direct API calls; if False, queues for batch
    allow_failure: bool = False,                   # If True, continues after max retries with error messages
    attempts_cap: int = 8,                         # Maximum number of retry attempts
    temperature: float = 0,                        # Temperature parameter for generation
    max_output_tokens: int = 8192,                 # Maximum tokens to generate in response
    system_prompt: str = "",                       # System prompt to guide model behavior
    model: str = "publishers/google/models/gemini-1.5-flash-002", # Generative model to use
    run_fast_timeout: float = 200,                 # Timeout in seconds for fast mode calls
    cooldown_seconds: float = 1.0,                 # Base wait time between retries
    round_robin_enabled: bool = False,             # Whether to cycle through regions
    round_robin_options: List[str] = ["us-central1", "us-west1", "us-east1", "us-west4", "us-east4", "us-east5", "us-south1"], # Region options for cycling
    initial_history_json: Optional[dict] = None,   # Starting conversation history for the inference session
) -> tuple[List[str], List[str]]                   # Returns [[response 1, response 2, ...], [query 1, query 2, ...]]

3. run_clearing_inference()

For managing batch inference jobs
async def run_clearing_inference(
    tag: str,                                      # Unique identifier tag for the batch
    batch_size: int,                               # Maximum number of requests per batch job
    run_batch_jobs: bool,                          # Whether to launch new batch jobs
    batch_timeout_hours: int = 3                   # Maximum runtime hours before restarting
) -> None

4. reload_config()

For reloading the config after setting environment variables
def reload_config() -> None

๐Ÿ’ก Example Usage

import asyncio
from dotenv import load_dotenv
from easyinference import inference, reload_config, initialize_query_connection

load_dotenv()
reload_config()

# Initialize the database connection before using any inference functions
initialize_query_connection()

async def process_data():
    # Define data and prompt function
    datapoints = [
        {"text": "What is machine learning?"},
        {"text": "Explain neural networks"}
    ]
    
    def create_prompt(dp):
        return f"Please explain: {dp['text']}"
    
    # Run inference
    results, timestamp = await inference(
        prompt_functions=[create_prompt],
        datapoints=datapoints,
        tags=["explanation", "v1"],
        run_fast=True
    )
    
    # Process results
    first_datapoint_result, second_datapoint_result = results
    for i, (response, query) in enumerate(first_datapoint_result):
        print(f"Query: {query}")
        print(f"Response: {response}")
    
    return results

# Run the async function
results = asyncio.run(process_data())

๐Ÿ“š Package Overview

Goal: We provide a scalable and robust pipeline to handle:

  • โœจ Conversation-based inference requests to Gemini models
  • โœจ Failure tracking and retry logic to ensure stable operation
  • โœจ Asynchronous or synchronous methods for generating text from the model

We accomplish this by:

  1. Storing every inference "step" in a PostgreSQL table, which captures the query text, model parameters, conversation history, and final responses (or errors).
  2. Separating "fast" live calls vs. "slow" batch-based calls.
  3. Monitoring the status of batch inference jobs, so you can schedule or restart them if they take too long.
  4. Allowing different usage patterns: single datapoint or bulk processing, with multi-prompt sequences, concurrency caps, and re-tries.

๐Ÿ—๏ธ System Architecture

 โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
 โ”‚  Your Application โ”‚
 โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
           โ”‚
           โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
           โ”‚                 โ”‚
           โ–ผ                 โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚Individual Inference โ”‚    โ”‚      Inference      โ”‚
โ”‚      (Fast)         โ”‚โ—€---โ”‚                     โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
           โ”‚                          โ”‚
           โ”‚                          โ–ผ
           โ”‚               โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
           โ”‚               โ”‚     Batch Clearing      โ”‚
           โ”‚               โ”‚      (monitoring)       โ”‚
           โ”‚               โ”‚                         โ”‚
           โ”‚               โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
           โ”‚                          โ”‚
           โ–ผ                          โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Vertex AI (Gemini API) โ”‚  โ”‚ Vertex AI (Gemini API) โ”‚
โ”‚     (Live Calls)       โ”‚  โ”‚     (Batch Job)        โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
           โ”‚                          โ”‚
           โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”   โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                      โ–ผ   โ–ผ
               โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
               โ”‚     PostgreSQL     โ”‚
               โ”‚    Master Table    โ”‚
               โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
  1. Individual Inference manages a single datapoint and a sequence of prompts.
  2. Inference is a bulk orchestrator that calls individual inference on multiple datapoints.
  3. Clearing Inference takes unprocessed/failed rows and triggers additional attempts (live or batch). It also monitors batch jobs and handles timeouts.

๐Ÿ”‘ Key Concepts and Features

1. Conversation History ๐Ÿ’ฌ

Stored in PostgreSQL under history_json as a JSON object:

{
  "history": [
    {"role": "user", "parts": {"text": "Hello, how are you?"}},
    {"role": "model", "parts": {"text": "I am fine. How can I help?"}}
  ]
}

This helps Vertex continue the same conversation context across multiple queries without duplication.

2. Generation Parameters โš™๏ธ

Stored under generation_params_json (JSON):

{
  "temperature": 0.7,
  "max_output_tokens": 8192,
  "system_prompt": "You are a helpful assistant..."
}

3. Duplication Index ๐Ÿ”„

An integer marking whether a row is an exact duplicate of an earlier row (e.g., a re-run). Defaults to 0.

4. Tags ๐Ÿท๏ธ

A list of strings (alphabetically sorted) representing categories or labels applied to a request (e.g. ["admin", "api-v1"]).
This can help in filtering or grouping by usage scenario.

5. Request Cause ๐Ÿ“ก

Either "intentional" (explicit user request) or "backup" (an automatic fallback).

6. Status and Failure Counts ๐Ÿ“Š

  • Last Status can be "PENDING", "RUNNING", "FAILED", "SUCCEEDED", "WAITING".
  • Failure Count tracks how many attempts have failed so far, and Attempts Cap sets the max allowed.

7. Content Hash ๐Ÿ”’

A hash of (Model, History, Query, GenerationParams, DuplicationIndex) for deduplicating or resuming.

8. Modes โšก

  • Run Fast: calls the Vertex API directly, returning the result in real-time.
  • Run Slow: queues up the request for a batch job. The run_clearing_inference function handles job submission and monitoring.

9. Initialization โšก

Before using any inference functions, you must initialize the database connection by calling:

from easyinference import initialize_query_connection

initialize_query_connection()

This sets up the necessary connections to the PostgreSQL database for tracking inference requests.


๐Ÿงฉ Core Components

1. cloudsql/schema.py ๐Ÿ“

  • Defines a ConvoRow data class that mirrors each column in the table.
  • Enumerations for RequestStatus and RequestCause.

2. cloudsql/table_utils.py ๐Ÿ”ง

  • Helper functions to insert, update, or read rows from PostgreSQL.
  • Includes concurrency checks so you don't overwrite a "SUCCEEDED" row with "FAILED."
  • Functions for connecting to PostgreSQL, creating tables, and querying data.

3. inference.py ๐Ÿง 

  • Implements both individual_inference and inference functions
  • Contains run_chat_inference_async for "fast" calls with built-in retry/backoff
  • Implements run_clearing_inference that handles both batch submission and monitoring
  • Manages the logic for deduplicating (by content hash), incrementing failure counts, and handling partial successes

4. config.py โš™๏ธ

  • Configuration settings for database connections, retry logic, and batch operations.
  • Contains defaults for constants like MAX_RETRIES, BATCH_TIMEOUT_HOURS, and various connection parameters.

๐Ÿ“Š Table Schema

Your master PostgreSQL table has the following columns:

Column Name Type Description
row_id INTEGER Auto-incrementing primary key
content_hash STRING SHA-256 hash of key fields for deduplication
history_json JSON JSON storing prior conversation messages in a format with the key "history"
query STRING User's latest query that needs a response
model STRING Full path of the model (e.g. "publishers/google/models/gemini-1.5-flash-002")
generation_params_json JSON JSON storing generation settings, e.g. {"temperature":0.7,"max_output_tokens":8192,"system_prompt":"..."}
duplication_index INTEGER Used to mark re-runs or explicit duplicates. Defaults to 0
tags ARRAY(STRING) A sorted list of tags (e.g. ["api-v1","testing"])
request_cause STRING "intentional" or "backup". Uses the RequestCause enum
request_timestamp STRING ISO 8601 timestamp ("2025-02-25T12:34:56Z")
access_timestamps ARRAY(STRING) List of ISO 8601 timestamps of each read/update
attempts_metadata_json ARRAY(JSON) JSON array of prior attempts, storing batch info and error messages
response_json JSON JSON containing the final successful response if available. Example: {"text":"...response..."}
current_batch STRING The ID of any currently running batch job. Can be NULL
last_status STRING "PENDING", "RUNNING", "FAILED", "SUCCEEDED", or "WAITING"
failure_count INTEGER How many times this row has failed so far
attempts_cap INTEGER The maximum number of times we will re-try
notes STRING Optional free-text notes
insertion_timestamp TIMESTAMP When the row was inserted into the database

Content Hash ๐Ÿ”

  • SHA-256 over the combination of (Model, History, Query, GenerationParams, DuplicationIndex).
  • Ensures we don't re-run the same content multiple times unless we want to.

Tagging ๐Ÿท๏ธ

  • A query can have tags like ["api-v1","admin-request"]. The system enforces that the tag list is alphabetically sorted.
  • For batch mode, a timestamp tag is automatically added for tracking.

๐Ÿ“„ License

MIT License

This project is provided under the MIT License.

Feel free to modify or extend the code to suit your deployment and usage requirements.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easyvertexinference-1.0.2.tar.gz (31.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

easyvertexinference-1.0.2-py3-none-any.whl (27.4 kB view details)

Uploaded Python 3

File details

Details for the file easyvertexinference-1.0.2.tar.gz.

File metadata

  • Download URL: easyvertexinference-1.0.2.tar.gz
  • Upload date:
  • Size: 31.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for easyvertexinference-1.0.2.tar.gz
Algorithm Hash digest
SHA256 c620247fcd64624b0e34b326a64441449ae09d418502371c33419e1fc7fec714
MD5 de468e8d24743d5dbaaacac74ba426df
BLAKE2b-256 a8ac88d3dcbe10f837b47bc4ec2699e8688e7fe93963d0891d10e4c3dc69c0cd

See more details on using hashes here.

Provenance

The following attestation bundles were made for easyvertexinference-1.0.2.tar.gz:

Publisher: python-publish.yml on ericzhao28/easyinference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file easyvertexinference-1.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for easyvertexinference-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0928ec97deb9d0f84c4d765a2bbc67b9ee1ccd4c16fe549f8b8ac7e9b5f1fa75
MD5 d2ea58a29efec61647c9230ca876eac2
BLAKE2b-256 c36e7f5a83e8d4f3f46cd5ef9aaca659c612291f53700ce50ce513f2e079b50b

See more details on using hashes here.

Provenance

The following attestation bundles were made for easyvertexinference-1.0.2-py3-none-any.whl:

Publisher: python-publish.yml on ericzhao28/easyinference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page