Skip to main content

Wrapper around language model APIs

Project description

LmWrapper

Provides a wrapper around OpenAI API and Hugging Face Language models, focusing on being a clean and user-friendly interface. Because every input and output is object-oriented (rather than just JSON dictionaries with string keys and values), your IDE can help you with things like argument and property names and catch certain bugs statically. Additionally, it allows you to switch inbetween OpenAI endpoints and local models with minimal changes.

Installation

For usage with just OpenAI models:

pip install lmwrapper

For usage with HuggingFace models as well:

pip install 'lmwrapper[hf]'

For development dependencies:

pip install 'lmwrapper[dev]'

Please note that this method is for development and not supported.

Example usage

Completion models

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt

lm = get_open_ai_lm(
    model_name=OpenAiModelNames.gpt_3_5_turbo_instruct,
    api_key_secret=None,  # By default, this will read from the OPENAI_API_KEY environment variable.
    # If that isn't set, it will try the file ~/oai_key.txt
    # You need to place the key in one of these places,
    # or pass in a different location. You can get an API
    # key at (https://platform.openai.com/account/api-keys)
)

prediction = lm.predict(
    LmPrompt(  # A LmPrompt object lets your IDE hint on args
        "Once upon a",
        max_tokens=10,
    )
)
print(prediction.completion_text)
# " time, there were three of us." - Example. This will change with each sample.

Chat models

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt, LmChatTurn

lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)

# Single user utterance
pred = lm.predict("What is 2+2?")
print(pred.completion_text)  # "2+2 is equal to 4."

# Conversation alternating between `user` and `assistant`.
pred = lm.predict(LmPrompt(
    [
        "What is 2+2?",  # user turn
        "4",  # assistant turn
        "What is 5+3?"  # user turn
        "8",  # assistant turn
        "What is 4+4?"  # user turn
        # We use few-shot turns to encourage the answer to be our desired format.
        #   If you don't give example turns you might get something like
        #   "The answer is 8." instead of just "8".
    ],
    max_tokens=10,
))
print(pred.completion_text)  # "8"

# If you want things like the system message, you can use LmChatTurn objects
pred = lm.predict(LmPrompt(
    text=[
        LmChatTurn(role="system", content="You always answer like a pirate"),
        LmChatTurn(role="user", content="How does bitcoin work?"),
    ],
    max_tokens=25,
    temperature=0,
))
print(pred.completion_text)
# "Arr, me matey! Bitcoin be a digital currency that be workin' on a technology called blockchain..."

Hugging Face models

Causal LM models on Hugging Face models can be used interchangeably with the OpenAI models.

from lmwrapper.huggingface_wrapper import get_huggingface_lm
from lmwrapper.structs import LmPrompt

lm = get_huggingface_lm("gpt2")  # The smallest 124M parameter model

prediction = lm.predict(LmPrompt(
    "The capital of Germany is Berlin. The capital of France is",
    max_tokens=1,
    temperature=0,
))
print(prediction.completion_text)
assert prediction.completion_text == " Paris"

Features

lmwrapper provides several features missing from the OpenAI API.

Caching

Add caching = True in the prompt to cache the output to disk. Any subsequent calls with this prompt will return the same value. Note that this might be unexpected behavior if your temperature is non-zero. (You will always sample the same output on reruns.)

Retries on rate limit

An OpenAIPredictor can be configured to read rate limit errors and wait the appropriate amount of seconds in the error before retrying.

from lmwrapper.openai_wrapper import *

lm = get_open_ai_lm(
    OpenAiModelNames.gpt_3_5_turbo_instruct,
    retry_on_rate_limit=True
)

Other features

Built-in token counting

from lmwrapper.openai_wrapper import *
from lmwrapper.structs import LmPrompt

lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo_instruct)
assert lm.estimate_tokens_in_prompt(
    LmPrompt("My name is Spingldorph", max_tokens=10)) == 7
assert not lm.could_completion_go_over_token_limit(LmPrompt(
    "My name is Spingldorph", max_tokens=1000))

TODOs

If you are interested in one of these particular features or something else please make a Github Issue.

  • Openai completion
  • Openai chat
  • Huggingface interface
  • Huggingface device checking on PyTorch
  • Move cache to be per project
  • Anthropic interface
  • Redesign cache to make it easier to manage
  • sort through usage of quantized models
  • async / streaming
  • Additional Huggingface runtimes (TensorRT, BetterTransformers, etc)
  • Cost estimating (so can estimate cost of a prompt before running / track total cost)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lmwrapper-0.9.0.2.tar.gz (54.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lmwrapper-0.9.0.2-py3-none-any.whl (42.3 kB view details)

Uploaded Python 3

File details

Details for the file lmwrapper-0.9.0.2.tar.gz.

File metadata

  • Download URL: lmwrapper-0.9.0.2.tar.gz
  • Upload date:
  • Size: 54.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.11

File hashes

Hashes for lmwrapper-0.9.0.2.tar.gz
Algorithm Hash digest
SHA256 942bf387e64cdddb2d5c167e1fd8c594f82b2f66d70e1609153d362fd9b8c2f8
MD5 7a27cda6764133fe08943c55bd5877e6
BLAKE2b-256 f916d7efe1b21e3cff4687541ba18096e314e9a6c74d28b364a0f959bee45a3a

See more details on using hashes here.

File details

Details for the file lmwrapper-0.9.0.2-py3-none-any.whl.

File metadata

  • Download URL: lmwrapper-0.9.0.2-py3-none-any.whl
  • Upload date:
  • Size: 42.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.11

File hashes

Hashes for lmwrapper-0.9.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b81e2264131a50fda9abe5429141cf2bf3d8ff3cf510d2ca54284bc5bf689d7c
MD5 4b520eb8f26afaef774d15cf42134c2f
BLAKE2b-256 a728c50306fcdfb721255df3e39c4700ff5cd6fa9f861bc657cf3d441f6c5f41

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page