Skip to main content

An object-oriented wrapper around language models with caching, batching, and more.

Project description

lmwrapper provides a wrapper around OpenAI API and Hugging Face Language models, focusing on being a clean, object-oriented, and user-friendly interface. It has two main goals:

A) Make it easier to use the OpenAI API

B) Make it easy to reuse your code for other language models with minimal changes.

Some key features currently include local caching of responses, and super simple use of the OpenAI batching API which can save 50% on costs.

lmwrapper is lightweight and can serve as a flexible stand-in for the OpenAI API.

Installation

For usage with just OpenAI models:

pip install lmwrapper

For usage with HuggingFace models as well:

pip install 'lmwrapper[hf]'

For development dependencies:

pip install 'lmwrapper[dev]'

Example usage

Basic Completion and Prompting

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt

lm = get_open_ai_lm(
    model_name=OpenAiModelNames.gpt_3_5_turbo_instruct,
    api_key_secret=None,  # By default, this will read from the OPENAI_API_KEY environment variable.
    # If that isn't set, it will try the file ~/oai_key.txt
    # You need to place the key in one of these places,
    # or pass in a different location. You can get an API
    # key at (https://platform.openai.com/account/api-keys)
)

prediction = lm.predict(
    LmPrompt(  # A LmPrompt object lets your IDE hint on args
        "Once upon a",
        max_tokens=10,
        temperature=1, # Set this to 0 for deterministic completions
    )
)
print(prediction.completion_text)
# " time, there were three of us." - Example. This will change with each sample.

Chat

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt, LmChatTurn

lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)

# Single user utterance
pred = lm.predict("What is 2+2?")
print(pred.completion_text)  # "2+2 is equal to 4."

# Conversation alternating between `user` and `assistant`.
pred = lm.predict(LmPrompt(
    [
        "What is 2+2?",  # user turn
        "4",  # assistant turn
        "What is 5+3?"  # user turn
        "8",  # assistant turn
        "What is 4+4?"  # user turn
        # We use few-shot turns to encourage the answer to be our desired format.
        #   If you don't give example turns you might get something like
        #   "The answer is 8." instead of just "8".
    ],
    max_tokens=10,
))
print(pred.completion_text)  # "8"

# If you want things like the system message, you can use LmChatTurn objects
pred = lm.predict(LmPrompt(
    text=[
        LmChatTurn(role="system", content="You always answer like a pirate"),
        LmChatTurn(role="user", content="How does bitcoin work?"),
    ],
    max_tokens=25,
    temperature=0,
))
print(pred.completion_text)
# "Arr, me matey! Bitcoin be a digital currency that be workin' on a technology called blockchain..."

Hugging Face models

Local Causal LM models on Hugging Face models can be used interchangeably with the OpenAI models.

from lmwrapper.huggingface_wrapper import get_huggingface_lm
from lmwrapper.structs import LmPrompt

lm = get_huggingface_lm("gpt2")  # Download the smallest 124M parameter model

prediction = lm.predict(LmPrompt(
    "The capital of Germany is Berlin. The capital of France is",
    max_tokens=1,
    temperature=0,
))
print(prediction.completion_text)
assert prediction.completion_text == " Paris"

Additionally, with HuggingFace models lmwrapper provides an interface for accessing the model internal states.

Caching

Add caching = True in the prompt to cache the output to disk. Any subsequent calls with this prompt will return the same value. Note that this might be unexpected behavior if your temperature is non-zero. (You will always sample the same output on reruns.)

OpenAI Batching

The OpenAI batching API has a 50% reduced cost when willing to accept a 24-hour turnaround. This makes it good for processing datasets or other non-interactive tasks (which is the main target for lmwrapper currently).

lmwrapper takes care of managing the batch files and other details so that it's as easy as the normal API.

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt
from lmwrapper.batch_config import CompletionWindow

def load_dataset() -> list:
    """Load some toy task"""
    return ["France", "United States", "China"]

def make_prompts(data) -> list[LmPrompt]:
    """Make some toy prompts for our data"""
    return [
        LmPrompt(
            f"What is the capital of {country}? Answer with just the city name.",
            max_tokens=10,
            temperature=0,
            cache=True,
        ) 
        for country in data
    ]

data = load_dataset()
prompts = make_prompts(data)
lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
predictions = lm.predict_many(
    prompts,
    completion_window=CompletionWindow.BATCH_ANY 
    #                 ^ swap out for CompletionWindow.ASAP
    #                   to complete as soon as possible via
    #                   the non-batching API at a higher cost.
) # The batch is submitted here

for ex, pred in zip(data, predictions):  # Will wait for the batch to complete
    print(f"Country: {ex} --- Capital: {pred.completion_text}")
    if ex == "France": assert pred.completion_text == "Paris" 
    # ...

The above code could technically take up to 24hrs to complete. However, OpenAI seems to complete these quicker (for example, these three prompts in ~1 minute or less). In a large batch, you don't have to keep the process running for hours. Thanks to lmwrapper cacheing it will automatically load or pick back up waiting on the existing batch when the script is reran.

The lmwrapper cache lets you also intermix cached and uncached examples.

# ... above code

def load_more_data() -> list:
    """Load some toy task"""
    return ["Mexico", "Canada"]

data = load_data() + load_more_data()
prompts = make_prompts(data)
# If we submit the five prompts, only the two new prompts will be
# submitted to the batch. The already completed prompts will
# be loaded near-instantly from the local cache.
predictions = list(lm.predict_many(
    prompts,
    completion_window=CompletionWindow.BATCH_ANY
))

lmwrapper is designed to automatically manage the batching of thousands or millions of prompts. If needed, it will automatically split up prompts into sub-batches and will manage issues around rate limits.

This feature is mostly designed for the OpenAI cost savings. You could swap out the model for HuggingFace and the same code will still work. However, internally it is like a loop over the prompts. Eventually in lmwrapper we want to do more complex batching if GPU/CPU/accelerator memory is available.

Caveats / Implementation needs

This feature is still somewhat experimental. It likely works in typical usecases, but recovery from failures (like invalid prompts or errors on OpenAI's end) might not be ideally managed. There are few known things to sort out / TODOs:

  • Retry batch API connection errors
  • Automatically splitting up batches when have >50,000 prompts (limit from OpenAI)
  • Recovering / splitting up batches when hitting your token Batch Queue Limit (see docs on limits)
  • Handle canceled batches during current run (use the web interface to cancel)
  • Handle/recover canceled batches outside of current run
  • Handle if openai batch expires unfinished in 24hrs (though not actually tested or observed this)
  • Automatically splitting up batch when exceeding 100MB prompts limit
  • Handling of failed prompts (like when have too many tokens). Use LmPrediction.has_errors to check for an error on a response.
  • Handle when there are duplicate prompts in batch submission
  • Handle when a given prompt has num_completions>1
  • Automatically clean up API files after done (right now end up with a lot of file in storage. There isn't an obvious cost for these batch files, but this might change and it would be better to clean them up.)
  • Test on free-tier accounts. It is not clear what the tiny request limit counts
  • Fancy batching of HF
  • Concurrent batching when in ASAP mode

Please open an issue if you want to discuss one of these or something else.

Note, in the progress bars in PyCharm can be bit cleaner if you enable terminal emulation in your run configuration.

Retries on rate limit

from lmwrapper.openai_wrapper import *

lm = get_open_ai_lm(
    OpenAiModelNames.gpt_3_5_turbo_instruct,
    retry_on_rate_limit=True
)

Other features

Built-in token counting

from lmwrapper.openai_wrapper import *
from lmwrapper.structs import LmPrompt

lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo_instruct)
assert lm.estimate_tokens_in_prompt(
    LmPrompt("My name is Spingldorph", max_tokens=10)) == 7
assert not lm.could_completion_go_over_token_limit(LmPrompt(
    "My name is Spingldorph", max_tokens=1000))

TODOs

If you are interested in one of these particular features or something else please make a Github Issue.

  • Openai completion
  • Openai chat
  • Huggingface interface
  • Huggingface device checking on PyTorch
  • Move cache to be per project
  • Redesign cache away from generic diskcache to make it easier to manage
  • Smart caching when num_completions > 1 (reusing prior completions)
  • OpenAI batching interface (experimental)
  • Anthropic interface
  • Be able to add user metadata to a prompt
  • Automatic cache eviction to limit count or disk size
  • Multimodal/images in super easy format (like automatically process pil, opencv, etc)
  • sort through usage of quantized models
  • Cost estimating (so can estimate cost of a prompt before running / track total cost)
  • Additional Huggingface runtimes (TensorRT, BetterTransformers, etc)
  • async / streaming (not a top priority for non-interactive research use cases)
  • some lightweight utilities to help with tool use

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lmwrapper-0.10.0.0.tar.gz (77.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lmwrapper-0.10.0.0-py3-none-any.whl (61.3 kB view details)

Uploaded Python 3

File details

Details for the file lmwrapper-0.10.0.0.tar.gz.

File metadata

  • Download URL: lmwrapper-0.10.0.0.tar.gz
  • Upload date:
  • Size: 77.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.11

File hashes

Hashes for lmwrapper-0.10.0.0.tar.gz
Algorithm Hash digest
SHA256 e4467cbf428462d9c5ab20f9011611058a2bb8e7cf4ba5b681ce2f20ca3e4a07
MD5 b624d0784b669e34171b62d21dc0a901
BLAKE2b-256 58537d9b85c98fc30f15b1a7f166aca2847443bea6f1d8d824c14edf6ba9728c

See more details on using hashes here.

File details

Details for the file lmwrapper-0.10.0.0-py3-none-any.whl.

File metadata

  • Download URL: lmwrapper-0.10.0.0-py3-none-any.whl
  • Upload date:
  • Size: 61.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.11

File hashes

Hashes for lmwrapper-0.10.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e547378b57dc315389c93151b6daf911b5854a9aa0f7cd0ed9f492217c5fc7b8
MD5 990a9e26061b0fa8f6bda9371a3b9e0d
BLAKE2b-256 1d50ebf371877469b6c428038bcbe4508218227c4e306d6c5a645d52f48837bd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page