Skip to main content

Process SuperVision - Pytorch

Project description

"Let’s Verify Step by Step"

Implementation of "Improving Mathematical Reasoning with Process Supervision" by OPENAI

Install

pip3 install --upgrade process-supervision-torch

Usage:

GPT4 without tokenizer

import torch 
from process_supervision.main import GPT4

# Usage with random inputs
text = torch.randint(0, 20000, (1, 1024))

# Initiliaze the model
model = GPT4()
output = model(text)
print(output)

PRM

import torch
from process_supervision.prm import PRM
from swarms.models import OpenAIChat
from process_supervision.generator import MathDataGenerator
import os
from dotenv import load_dotenv

load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")

# LLM initialization
llm = OpenAIChat(openai_api_key=api_key)

# Math data generator initialization
math_datagenerator = MathDataGenerator(llm, num_iters=10)

# Device initialization
device = 0 if torch.cuda.is_available() else "cpu"

# Model initialization
prm_model = PRM(
    model_name="lvwerra/gpt2-imdb-pos-v2",
    ref_model_name="lvwerra/gpt2-imdb",
    reward_model_name="lvwerra/distilbert-imdb",
    device=device,
)

# Generation arguments
gen_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": prm_model.tokenizer.eos_token_id,
}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

# Sample queries
queries = ["Sample query 1", "Sample query 2"]
queries = [math_datagenerator.generate_samples(query) for query in queries]

# Generate responses
responses = prm_model.generate_responses(
    queries, gen_len=10, gen_kwargs=gen_kwargs
)

# Score responses
scores = prm_model.score_responses(responses, sent_kwargs)

# Display results
for query, response, score in zip(queries, responses, scores):
    print(f"Query: {query}\nResponse: {response}\nScore: {score}\n")

GPT4 + PRM

Method

Citation

@misc{lightman2023lets,
   title={Let's Verify Step by Step}, 
   author={Hunter Lightman and Vineet Kosaraju and Yura Burda and Harri Edwards and Bowen Baker and Teddy Lee and Jan Leike and John Schulman and Ilya Sutskever and Karl Cobbe},
   year={2023},
   eprint={2305.20050},
   archivePrefix={arXiv},
   primaryClass={cs.LG}
}

Todo

  • Creae the PRM reward model

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

process_supervision_torch-0.0.3.tar.gz (13.9 kB view details)

Uploaded Source

Built Distribution

process_supervision_torch-0.0.3-py3-none-any.whl (14.0 kB view details)

Uploaded Python 3

File details

Details for the file process_supervision_torch-0.0.3.tar.gz.

File metadata

  • Download URL: process_supervision_torch-0.0.3.tar.gz
  • Upload date:
  • Size: 13.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for process_supervision_torch-0.0.3.tar.gz
Algorithm Hash digest
SHA256 6c2001938ecc0d08ed5d6ce8200dcaffb0287d1b4191156fd3f9e79e6db31cc5
MD5 b488edb0687152f08f314b4d14997111
BLAKE2b-256 a49af3fc0d0f76058e86fb5e423d18c1653e74bcb5ef5e20006ed8405d66f9e9

See more details on using hashes here.

File details

Details for the file process_supervision_torch-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for process_supervision_torch-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 b9579a02d604375244c6b2448255dc0a46d34f3b85555b3b339909b37d24471e
MD5 b8accf2dfed51a6692dac87b431acc4f
BLAKE2b-256 28dbe5144d5c8c128a6e7e8016dde66bb7a94d77dfdaa8f6356124aae923f0db

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page