Skip to main content

Wrapper around language model APIs

Project description

This provides a wrapper around OpenAI API and Huggingface Language models focusing on being a clean and user-friendly interface. Because every input and output is object-oriented (rather than just JSON dictionaries with string keys and values), your IDE can help you with things like argument and property names and catch certain bugs statically. Additionally, it allows you to switch inbetween openai endpoints and local models with minimal changes.

Installation

For usage with just OpenAI models:

pip install lmwrapper @ git+https://github.com/DNGros/lmwrapper.git

For usage with HuggingFace models as well:

pip install lmwrapper[huggingface] @ git+https://github.com/DNGros/lmwrapper.git

Example usage

Completion models

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt

lm = get_open_ai_lm(
    model_name=OpenAiModelNames.text_ada_001,
    api_key_secret=None, # By default this will read from the OPENAI_API_KEY environment variable.
                         # If that isn't set, it will try the file ~/oai_key.txt
                         # You need to place the key in one of these places,
                         # or pass in a different location. You can get an API
                         # key at (https://platform.openai.com/account/api-keys)
)

prediction = lm.predict(
    LmPrompt(  # A LmPrompt object lets your IDE hint on args
        "Once upon a",
        max_tokens=10,
    )
)
print(prediction.completion_text)
# " time, there were three of us." - Example. This will change with each sample.

Chat models

from lmwrapper.openai_wrapper import get_open_ai_lm, OpenAiModelNames
from lmwrapper.structs import LmPrompt, LmChatTurn
lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)

# Single user utterance
pred = lm.predict("What is 2+2?")
print(pred.completion_text)  # "2+2 is equal to 4."

# Conversation alternating between `user` and `system`.
pred = lm.predict(LmPrompt(
    [
        "What is 2+2?",  # user turn
        "4",             # system turn
        "What is 5+3?"   # user turn
        "8",             # system turn
        "What is 4+4?"   # user turn
        # Because we have the fewshot examples, we might expect one number
    ],
    max_tokens=10,
))
print(pred.completion_text) # "8"

# If you want things like the system message, you can use LmChatTurn objects
pred = lm.predict(LmPrompt(
    text=[
        LmChatTurn(role="system", content="You always answer like a pirate"),
        LmChatTurn(role="user", content="How does bitcoin work?"),
    ],
    max_tokens=25,
    temperature=0,
))
print(pred.completion_text)
# "Arr, me matey! Bitcoin be a digital currency that be workin' on a technology called blockchain..."

Huggingface models

Causal LM models on huggingface models can be used interchangeably with the OpenAI models. Note it is still a todo to make sure devices and GPUs are used appropriately.

from lmwrapper.huggingface_wrapper import get_huggingface_lm
from lmwrapper.structs import LmPrompt
lm = get_huggingface_lm("gpt2")  # The smallest 124M parameter model

prediction = lm.predict(LmPrompt(
    "The capital of Germany is Berlin. The capital of France is",
    max_tokens=1,
    temperature=0,
))
print(prediction.completion_text)
assert prediction.completion_text == " Paris"

Features

lmwrapper provides several features missing from the OpenAI API.

Caching

Add caching = True in the prompt to cache the output to disk. Any subsequent calls with this prompt will return the same value. Note that this might be unexpected behavior if your temperature is non-zero. (You will always sample the same output on reruns.)

Retries on rate limit

from lmwrapper.openai_wrapper import *
lm = get_open_ai_lm(OpenAiModelNames.text_ada_001, retry_on_rate_limit=True)

Other features

Built-in token counting

from lmwrapper.openai_wrapper import *
lm = get_open_ai_lm(OpenAiModelNames.text_ada_001)
assert lm.estimate_tokens_in_prompt(
    LmPrompt("My name is Spingldorph", max_tokens=10)) == 7
assert not lm.could_completion_go_over_token_limit(LmPrompt(
    "My name is Spingldorph", max_tokens=1000))

TODOs

  • Openai completion
  • Openai chat
  • Huggingface interface
  • Proper GPU handling with huggingface
  • sort through usage of quantized models
  • Improved caching (per-project cache and committable)
  • Anthropic interface
  • Cost estimating

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lmwrapper-0.3.3.tar.gz (21.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lmwrapper-0.3.3-py3-none-any.whl (19.7 kB view details)

Uploaded Python 3

File details

Details for the file lmwrapper-0.3.3.tar.gz.

File metadata

  • Download URL: lmwrapper-0.3.3.tar.gz
  • Upload date:
  • Size: 21.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.11

File hashes

Hashes for lmwrapper-0.3.3.tar.gz
Algorithm Hash digest
SHA256 0ddad789467177d078f79fedee405d643fe75912a7b16bdd73f44a326240f398
MD5 f5deaba87f90249a1f2031582200e103
BLAKE2b-256 1bccacc50b43cee690b7c710af9e3dc8abd504875423352f8fd6a0d96e4cfd05

See more details on using hashes here.

File details

Details for the file lmwrapper-0.3.3-py3-none-any.whl.

File metadata

  • Download URL: lmwrapper-0.3.3-py3-none-any.whl
  • Upload date:
  • Size: 19.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.11

File hashes

Hashes for lmwrapper-0.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7bff797a54785bf0e00abae1715ca94733241c4ce412de8414e0e8ae863e22c7
MD5 3ed6d20edc85efbcecd50f160aee381c
BLAKE2b-256 915a052ce7bde9d3aa95f7d9050ce7af176c0774c083754c097a1c12e269ee87

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page