Skip to main content

Interact with the Databricks Foundation Model API from python

Project description

Databricks Generative AI Inference SDK (wfork)

PyPI version

The Databricks Generative AI Inference Python library provides a user-friendly python interface to use the Databricks Foundation Model API. Since the library is marked as Apache license, this is the wfork, my fork that provides minor updates while we wait for the official new version that's been a long time coming.

[!NOTE] This SDK was primarily designed for pay-per-token endpoints (databricks-*). It has a list of known model names (eg. dbrx-instruct) and automatically maps them to the corresponding shared endpoint (databricks-dbrx-instruct). You can use this with provisioned throughput endpoints, as long as they do not match known model names. If there is an overlap, you can use the DATABRICKS_MODEL_URL_ENV URL to directly provide an endpoint URL.

This library includes a pre-defined set of API classes Embedding, Completion, ChatCompletion with convenient functions to make API request, and to parse contents from raw json response.

It also offers a high level ChatSession object for easy management of multi-round chat completions, which is especially useful for your next chatbot development.

You can find more usage details in the databricks SDK onboarding doc.

[!IMPORTANT]
They are allegedly preparing to release version 1.0 of the official Databricks GenerativeAI Inference Python library, which will probably be better than this fork. Watch https://pypi.org/project/databricks-genai-inference/ like a hawk for more.

Installation

pip install wfork-databricks-genai-inference

(note that the step above is different than the original project, but it's the only step different)

Usage

Embedding

from databricks_genai_inference import Embedding

(note that the import statement has not changed from the original package!)

Text embedding

response = Embedding.create(
    model="bge-large-en", 
    input="3D ActionSLAM: wearable person tracking in multi-floor environments")
print(f'embeddings: {response.embeddings[0]}')

[!TIP]
You may want to reuse http connection to improve request latency for large-scale workload, code example:

with requests.Session() as client:
    for i, text in enumerate(texts):
        response = Embedding.create(
            client=client,
            model="bge-large-en",
            input=text
        )

Text embedding (async)

async with httpx.AsyncClient() as client:
    response = await Embedding.acreate(
        client=client,
        model="bge-large-en", 
        input="3D ActionSLAM: wearable person tracking in multi-floor environments")
    print(f'embeddings: {response.embeddings[0]}')

Text embedding with instruction

response = Embedding.create(
    model="bge-large-en", 
    instruction="Represent this sentence for searching relevant passages:", 
    input="3D ActionSLAM: wearable person tracking in multi-floor environments")
print(f'embeddings: {response.embeddings[0]}')

Text embedding (batching)

[!IMPORTANT]
Support max batch size of 150

response = Embedding.create(
    model="bge-large-en", 
    input=[
        "3D ActionSLAM: wearable person tracking in multi-floor environments",
        "3D ActionSLAM: wearable person tracking in multi-floor environments"])
print(f'response.embeddings[0]: {response.embeddings[0]}\n')
print(f'response.embeddings[1]: {response.embeddings[1]}')

Text embedding with instruction (batching)

[!IMPORTANT]
Support one instruction per batch Batch size

response = Embedding.create(
    model="bge-large-en", 
    instruction="Represent this sentence for searching relevant passages:",
    input=[
        "3D ActionSLAM: wearable person tracking in multi-floor environments",
        "3D ActionSLAM: wearable person tracking in multi-floor environments"])
print(f'response.embeddings[0]: {response.embeddings[0]}\n')
print(f'response.embeddings[1]: {response.embeddings[1]}')

Text completion

from databricks_genai_inference import Completion

Text completion

response = Completion.create(
    model="mpt-7b-instruct",
    prompt="Represent the Science title:")
print(f'response.text:{response.text:}')

Text completion (async)

async with httpx.AsyncClient() as client:
    response = await Completion.acreate(
        client=client,
        model="mpt-7b-instruct",
        prompt="Represent the Science title:")
    print(f'response.text:{response.text:}')

Text completion (streaming)

[!IMPORTANT]
Only support batch size = 1 in streaming mode

response = Completion.create(
    model="mpt-7b-instruct", 
    prompt="Count from 1 to 100:",
    stream=True)
print(f'response.text:')
for chunk in response:
    print(f'{chunk.text}', end="")

Text completion (streaming + async)

async with httpx.AsyncClient() as client:
    response = await Completion.acreate(
        client=client,
        model="mpt-7b-instruct", 
        prompt="Count from 1 to 10:",
        stream=True)
    print(f'response.text:')
    async for chunk in response:
        print(f'{chunk.text}', end="")

Text completion (batching)

[!IMPORTANT]
Support max batch size of 16

response = Completion.create(
    model="mpt-7b-instruct", 
    prompt=[
        "Represent the Science title:", 
        "Represent the Science title:"])
print(f'response.text[0]:{response.text[0]}')
print(f'response.text[1]:{response.text[1]}')

Chat completion

from databricks_genai_inference import ChatCompletion

[!IMPORTANT]
Batching is not supported for ChatCompletion

Chat completion

response = ChatCompletion.create(model="llama-2-70b-chat", messages=[{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Knock knock."}])
print(f'response.text:{response.message:}')

Chat completion (async)

async with httpx.AsyncClient() as client:
    response = await ChatCompletion.acreate(
        client=client,
        model="llama-2-70b-chat",
        messages=[{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Knock knock."}],
    )
    print(f'response.text:{response.message:}')

Chat completion (streaming)

response = ChatCompletion.create(model="llama-2-70b-chat", messages=[{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Count from 1 to 30, add one emoji after each number"}], stream=True)
for chunk in response:
    print(f'{chunk.message}', end="")

Chat completion (streaming + async)

async with httpx.AsyncClient() as client:
    response = await ChatCompletion.acreate(
        client=client,
        model="llama-2-70b-chat",
        messages=[{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Count from 1 to 30, add one emoji after each number"}],
        stream=True,
    )
    async for chunk in response:
        print(f'{chunk.message}', end="")

Chat session

from databricks_genai_inference import ChatSession

[!IMPORTANT]
Streaming mode is not supported for ChatSession

chat = ChatSession(model="llama-2-70b-chat")
chat.reply("Kock, kock!")
print(f'chat.last: {chat.last}')
chat.reply("Take a guess!")
print(f'chat.last: {chat.last}')

print(f'chat.history: {chat.history}')
print(f'chat.count: {chat.count}')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wfork_databricks_genai_inference-2.0.1.tar.gz (28.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file wfork_databricks_genai_inference-2.0.1.tar.gz.

File metadata

File hashes

Hashes for wfork_databricks_genai_inference-2.0.1.tar.gz
Algorithm Hash digest
SHA256 7f06f3e455cb6a1d888ca15534449246cd4c5f4b4d1fe5b69adaa1eac997ef6b
MD5 324d89dc3bf3d1ef0f598f58cfd44813
BLAKE2b-256 73b8b058f0a2409685ea147e6de4120ad4eb0f602647a04a37da9f17fd99bbe1

See more details on using hashes here.

File details

Details for the file wfork_databricks_genai_inference-2.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for wfork_databricks_genai_inference-2.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ce1b087dd7a02e6f15fe245d312b4bb725e7d04be74da1d92a2b5fdd322657aa
MD5 34642e8b9267e4ada3cbb200dfe713a3
BLAKE2b-256 4ec0cdb629728ef423bf661483ac863b61430c152053c28f48bcb4802ca82a97

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page