Additional code for Stable-baselines3 to load and upload models from the Hub.
Project description
Hugging Face 🤗 x Stable-baselines3
A library to load and upload Stable-baselines3 models from the Hub.
Installation
With pip
pip install huggingface-sb3
Examples
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here
Case 1: I want to download a model from the Hub
import gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)
# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
model, eval_env, render=False, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
Case 2: I trained an agent and want to upload it to the Hub
First you need to be logged in to Hugging Face:
- If you're using Colab/Jupyter Notebooks:
from huggingface_hub import notebook_login
notebook_login()
- Else:
huggingface-cli login
Then:
from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO
# Define a PPO model with MLP policy network
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
# Train it for 10000 timesteps
model.learn(total_timesteps=10_000)
# Save the model
model.save("ppo-CartPole-v1")
# Push this saved model to the hf repo
# If this repo does not exists it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
commit_message="Added Cartpole-v1 model trained with PPO",
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
huggingface_sb3-1.0.5.tar.gz
(3.7 kB
view hashes)
Built Distribution
Close
Hashes for huggingface_sb3-1.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d9dda7ca4dfc7cc50b8ce3205567f7a7403454792472bc5c4fad87e18e037d8a |
|
MD5 | 5a35c73258ce375910b247ac72cf3bc9 |
|
BLAKE2b-256 | 5b9459229e1104d773033d00cefcd58165752ffc70cf8beae8eb3f7f6dfef1e0 |