Additional code for Stable-baselines3 to load and upload models from the Hub.
Project description
Hugging Face x Stable-baselines3
A library to load and upload Stable-baselines3 models from the Hub.
Installation
With pip
Examples
[Todo: add colab tutorial]
Case 1: I want to download a model from the Hub
import gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(repo_id="ThomasSimonini/ppo-CartPole-v1", filename="CartPole-v1")
PPO.load(checkpoint)
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
env.close()
Case 2: I trained an agent and want to upload it to the Hub
First you need to be logged in to Hugging Face:
- If you're using Colab/Jupyter Notebooks:
from huggingface_hub import notebook_login
notebook_login()
- Else:
huggingface-cli login
Then:
import gym
from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO
# Create the environment
env = gym.make('CartPole-v1')
# Define a PPO MLpPolicy architecture
model = PPO('MlpPolicy', env, verbose=1)
# Train it for 10000 timesteps
model.learn(total_timesteps=10000)
# Save the model
model.save("CartPole-v1")
# Push this saved model to the hf repo
# If this repo does not exists it will be created
## filename: the name of the file == "name" inside model.save("CartPole-v1")
push_to_hub(repo_name = "CartPole-v1",
organization = "ThomasSimonini",
filename = "CartPole-v1",
commit_message = "Added Cartpole-v1 trained model")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
huggingface_sb3-0.5.tar.gz
(3.6 kB
view hashes)
Built Distribution
Close
Hashes for huggingface_sb3-0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5f85fb2cb2823ee79f636a79fd803555afe17ca32d1be60691227ac06784f4ab |
|
MD5 | f98600ccfda8c93adf99111bdf04d1be |
|
BLAKE2b-256 | d6e0e760e7d7d0c5dec5e250e38497bd56c7a2aaf9eec65570adbbab56be46a1 |