Sagemaker framework for Tidymodels
Project description
sagemaker-tidymodels
sagemaker-tidymodels
is an AWS
Sagemaker framework for training and
deploy machine learning models written in R.
This framework lets you do cloud-based training and deployment with tidymodels, using the same code you would write locally.
Installation
This Python package is not yet available on PyPi. In the meantime, you can install it from Github:
git clone https://github.com/tmastny/sagemaker-tidymodels.git
pip install sagemaker-tidymodels/
The docker image is hosted on dockerhub, or you can pull directly with
docker pull tmastny/sagemaker-tidymodels
Usage
The sagemaker-tidymodels
Python package provides simple wrappers
around the
Estimator
and Model sagemaker
classes.
The main difference is the entry_point
parameter, where you can supply
an R script. This R script should process the raw data, train the model,
and save the final fit.
from sagemaker_tidymodels import Tidymodels, get_role
tidymodels = Tidymodels(
entry_point="tests/train.R",
train_instance_type="local",
role=get_role(),
image_name="tmastny/sagemaker-tidymodels:latest",
)
s3_data = "s3://sagemaker-sample-data-us-east-2/processing/census/census-income.csv"
tidymodels.fit({"train": s3_data})
train.R
is a normal R script, with a few necessary additions so it can
run on the command line effectively.
#!/usr/bin/env Rscript
library(tidymodels)
if (sys.nframe() == 0) {
input_path <- file.path(Sys.getenv('SM_CHANNEL_TRAIN'), "census-income.csv")
df <- read.csv(input_path, stringsAsFactors = TRUE)
pipeline <- workflow() %>%
add_formula(income ~ age) %>%
add_model(logistic_reg() %>% set_engine("glm"))
model <- pipeline %>%
fit(data = df)
output_path <- file.path(Sys.getenv('SM_MODEL_DIR'), "model.RDS")
saveRDS(model, output_path)
}
-
The first line should be the shebang
#!/usr/bin/env Rscript
, so it can be ran as./train.R
as required by the framework. Make sure to runchmod +x train.R
so it’s an executable. -
All the training logic should be wrapped by the following
if
statement. This seems a little mysterious, but it makes sure that the training logic doesn’t accidentally run when we’ve deployed our model for predictions.
if (sys.nframe() == 0) {
# training logic goes here!
}
- Sagemaker is very specific about input and output locations. The
input data path is found in an environment variable that can be read
using
Sys.getenv('SM_CHANNEL_TRAIN')
. Likewise, the output model path can be found withSys.getenv('SM_MODEL_DIR')
.
From there, you can deploy the model as normal!
predictor = model.deploy(initial_instance_count=1, instance_type="local")
predictor.predict(r'28\n')
Advanced Usage
The docker container has some additional features that may be useful.
Custom Model Serving
The model serving defaults are defined in docker/server/default_fn.R
.
If you’d like to customize how the model is served, you can overwrite
these defaults by defining these functions in your entry_point
script.
The valid options are model_fn
, input_fn
, predict_fn
, and
output_fn
. In our script basic-train.R
, the default predict_fn
means we get class predictors, either - 50000.
or 50000+.
.
If we wanted to output the probability of belonging to either class, we
could include our own predict_fn
in basic-train.R
:
# add to `train.R`
predict_fn <- function(model, new_data) {
predict(model, new_data, type = "prob")
}
This is also why the training script needs to be wrapped by the if
statement.
Identical Local and Cloud Scripts
In train.R
, the logic you use to train is exactly the same you would
write locally. However, you can’t run the script locally as is, because
sagemaker
defines the environment variables SM_CHANNEL_TRAIN
and
SM_MODEL_DIR
(as well as many
others
you might want to use).
A nice way to set some defaults so the script can run both locally and is sagemaker is by using r-optparse.
For example:
library(optparse)
option_list <- list(
make_option(c("-i", "--input"), default = Sys.getenv("SM_CHANNEL_TRAIN")),
make_option(c("-o", "--output"), default = Sys.getenv("SM_MODEL_DIR"))
)
args <- parse_args(OptionParser(option_list = option_list))
This lets us use args$input
and args$output
for the input data path
and output model path if we are running locally or in sagemaker as the
entry_point
.
Then when running locally, we can define inputs and outputs
Rscript tests/train.R -i data/census-income.csv -o models/
on the command line so it runs properly.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file sagemaker-tidymodels-0.1.0.tar.gz
.
File metadata
- Download URL: sagemaker-tidymodels-0.1.0.tar.gz
- Upload date:
- Size: 5.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 73ebb62d33c5e2ef444337a099363ced73c5161b56e73182c605017495634447 |
|
MD5 | b07bfebeddecbef9d5c813841178bcab |
|
BLAKE2b-256 | 4e478b7c6c7666983188d5671c1f3a6adecdb7f16b5c2bed6f94415ed651e9aa |
File details
Details for the file sagemaker_tidymodels-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: sagemaker_tidymodels-0.1.0-py3-none-any.whl
- Upload date:
- Size: 7.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6b0ef8eb3218d2e982a03839fc06d1d2850b82a48a19f16d0842e71caecd678f |
|
MD5 | 9a2e372d2f7a8933b5733e04c3aef094 |
|
BLAKE2b-256 | 3aa548f36570578a61c61d4c578a8fd92a10c5506f874f3d3c0f687a1da96427 |