Sagemaker framework for Tidymodels
Project description
sagemaker-tidymodels
sagemaker-tidymodels is an AWS
Sagemaker framework for training and
deploy machine learning models written in R.
This framework lets you do cloud-based training and deployment with tidymodels, using the same code you would write locally.
Installation
You can install the framework with from PyPi:
pip install sagemaker-tidymodels
The docker image is available on dockerhub:
docker pull tmastny/sagemaker-tidymodels
Usage
The sagemaker-tidymodels Python package provides simple wrappers
around the
Estimator
and Model sagemaker classes.
The main difference is the entry_point parameter, where you can supply
an R script. This R script should process the raw data, train the model,
and save the final fit.
from sagemaker_tidymodels import Tidymodels, get_role
tidymodels = Tidymodels(
entry_point="tests/train.R",
train_instance_type="local",
role=get_role(),
image_name="tmastny/sagemaker-tidymodels:latest",
)
s3_data = "s3://sagemaker-sample-data-us-east-2/processing/census/census-income.csv"
tidymodels.fit({"train": s3_data})
train.R is a normal R script, with a few necessary additions so it can
run on the command line effectively.
#!/usr/bin/env Rscript
library(tidymodels)
if (sys.nframe() == 0) {
input_path <- file.path(Sys.getenv("SM_CHANNEL_TRAIN"), "census-income.csv")
df <- read.csv(input_path, stringsAsFactors = TRUE)
pipeline <- workflow() %>%
add_formula(income ~ age) %>%
add_model(logistic_reg() %>% set_engine("glm"))
model <- pipeline %>%
fit(data = df)
output_path <- file.path(Sys.getenv("SM_MODEL_DIR"), "model.RDS")
saveRDS(model, output_path)
}
-
The first line should be the shebang
#!/usr/bin/env Rscript, so it can be ran as./train.Ras required by the framework. Make sure to runchmod +x train.Rso it’s an executable. -
All the training logic should be wrapped by the following
ifstatement. This seems a little mysterious, but it makes sure that the training logic doesn’t accidentally run when we’ve deployed our model for predictions.
if (sys.nframe() == 0) {
# training logic goes here!
}
- Sagemaker is very specific about input and output locations. The
input data path is found in an environment variable that can be read
using
Sys.getenv('SM_CHANNEL_TRAIN'). Likewise, the output model path can be found withSys.getenv('SM_MODEL_DIR').
From there, you can deploy the model as normal!
predictor = model.deploy(initial_instance_count=1, instance_type="local")
predictor.predict("28\n")
Advanced Usage
The docker container has some additional features that may be useful.
Custom Model Serving
The model serving defaults are defined in docker/server/default_fn.R.
If you’d like to customize how the model is served, you can overwrite
these defaults by defining these functions in your entry_point script.
The valid options are model_fn, input_fn, predict_fn, and
output_fn. In our script train.R, the default predict_fn means we
get class predictors, either - 50000. or 50000+..
If we wanted to output the probability of belonging to either class, we
could include our own predict_fn in train.R:
# add to `train.R`
predict_fn <- function(model, new_data) {
predict(model, new_data, type = "prob")
}
This is also why the training script needs to be wrapped by the if
statement.
Identical Local and Cloud Scripts
In train.R, the logic you use to train is exactly the same you would
write locally. However, you can’t run the script locally as is, because
sagemaker defines the environment variables SM_CHANNEL_TRAIN and
SM_MODEL_DIR (as well as many
others
you might want to use).
A nice way to set some defaults so the script can run locally and in sagemaker is by using r-optparse.
For example:
library(optparse)
option_list <- list(
make_option(c("-i", "--input"), default = Sys.getenv("SM_CHANNEL_TRAIN")),
make_option(c("-o", "--output"), default = Sys.getenv("SM_MODEL_DIR"))
)
args <- parse_args(OptionParser(option_list = option_list))
This lets us use args$input and args$output to refer to the input
data path and output model path. Then when running locally, we can
define inputs and outputs
Rscript tests/train.R -i data/census-income.csv -o models/
and it works just as it would in sagemaker.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sagemaker-tidymodels-0.1.1.tar.gz.
File metadata
- Download URL: sagemaker-tidymodels-0.1.1.tar.gz
- Upload date:
- Size: 6.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
aa50b7eff1e77462d9899ed0f6f130daa6aa8d02bdcbffab473d6aeea582a763
|
|
| MD5 |
b5a6a4390a6e5c28ab3d6d4385786b3b
|
|
| BLAKE2b-256 |
1f02bb6f38e39732da3127a0028138ce3af9e1c646e9d35a7604b2afee91c216
|
File details
Details for the file sagemaker_tidymodels-0.1.1-py3-none-any.whl.
File metadata
- Download URL: sagemaker_tidymodels-0.1.1-py3-none-any.whl
- Upload date:
- Size: 9.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e3efd8460d02aa7f0bb1a80f3b1c4ac0aceca29ce73b3a2e89a0251fab516dec
|
|
| MD5 |
805ec7e66983bd6b7682d22c777c830d
|
|
| BLAKE2b-256 |
ac993f5fb92d0afff6f70ff8ba2a9b68ada2218899a616435bfba92db959e758
|