Skip to main content

Sagemaker framework for Tidymodels

Project description

sagemaker-tidymodels

sagemaker-tidymodels is an AWS Sagemaker framework for training and deploy machine learning models written in R.

This framework lets you do cloud-based training and deployment with tidymodels, using the same code you would write locally.

Installation

You can install the framework with from PyPi:

pip install sagemaker-tidymodels

The docker image is available on dockerhub:

docker pull tmastny/sagemaker-tidymodels

Usage

The sagemaker-tidymodels Python package provides simple wrappers around the Estimator and Model sagemaker classes.

The main difference is the entry_point parameter, where you can supply an R script. This R script should process the raw data, train the model, and save the final fit.

from sagemaker_tidymodels import Tidymodels, get_role

tidymodels = Tidymodels(
    entry_point="tests/train.R",
    train_instance_type="local",
    role=get_role(),
    image_name="tmastny/sagemaker-tidymodels:latest",
)

s3_data = "s3://sagemaker-sample-data-us-east-2/processing/census/census-income.csv"
tidymodels.fit({"train": s3_data})

train.R is a normal R script, with a few necessary additions so it can run on the command line effectively.

#!/usr/bin/env Rscript

library(tidymodels)

if (sys.nframe() == 0) {

  input_path <- file.path(Sys.getenv("SM_CHANNEL_TRAIN"), "census-income.csv")
  df <- read.csv(input_path, stringsAsFactors = TRUE)


  pipeline <- workflow() %>%
    add_formula(income ~ age) %>%
    add_model(logistic_reg() %>% set_engine("glm"))

  model <- pipeline %>%
    fit(data = df)

  output_path <- file.path(Sys.getenv("SM_MODEL_DIR"), "model.RDS")
  saveRDS(model, output_path)
}
  1. The first line should be the shebang #!/usr/bin/env Rscript, so it can be ran as ./train.R as required by the framework. Make sure to run chmod +x train.R so it’s an executable.

  2. All the training logic should be wrapped by the following if statement. This seems a little mysterious, but it makes sure that the training logic doesn’t accidentally run when we’ve deployed our model for predictions.

if (sys.nframe() == 0) {
  # training logic goes here!
}
  1. Sagemaker is very specific about input and output locations. The input data path is found in an environment variable that can be read using Sys.getenv('SM_CHANNEL_TRAIN'). Likewise, the output model path can be found with Sys.getenv('SM_MODEL_DIR').

From there, you can deploy the model as normal!

predictor = model.deploy(initial_instance_count=1, instance_type="local")
predictor.predict("28\n")

Advanced Usage

The docker container has some additional features that may be useful.

Custom Model Serving

The model serving defaults are defined in docker/server/default_fn.R. If you’d like to customize how the model is served, you can overwrite these defaults by defining these functions in your entry_point script.

The valid options are model_fn, input_fn, predict_fn, and output_fn. In our script train.R, the default predict_fn means we get class predictors, either - 50000. or 50000+..

If we wanted to output the probability of belonging to either class, we could include our own predict_fn in train.R:

# add to `train.R`
predict_fn <- function(model, new_data) {
  predict(model, new_data, type = "prob")
}

This is also why the training script needs to be wrapped by the if statement.

Identical Local and Cloud Scripts

In train.R, the logic you use to train is exactly the same you would write locally. However, you can’t run the script locally as is, because sagemaker defines the environment variables SM_CHANNEL_TRAIN and SM_MODEL_DIR (as well as many others you might want to use).

A nice way to set some defaults so the script can run locally and in sagemaker is by using r-optparse.

For example:

library(optparse)

option_list <- list(
  make_option(c("-i", "--input"), default = Sys.getenv("SM_CHANNEL_TRAIN")),
  make_option(c("-o", "--output"), default = Sys.getenv("SM_MODEL_DIR"))
)

args <- parse_args(OptionParser(option_list = option_list))

This lets us use args$input and args$output to refer to the input data path and output model path. Then when running locally, we can define inputs and outputs

Rscript tests/train.R -i data/census-income.csv -o models/

and it works just as it would in sagemaker.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sagemaker-tidymodels-0.1.1.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

sagemaker_tidymodels-0.1.1-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file sagemaker-tidymodels-0.1.1.tar.gz.

File metadata

  • Download URL: sagemaker-tidymodels-0.1.1.tar.gz
  • Upload date:
  • Size: 6.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7

File hashes

Hashes for sagemaker-tidymodels-0.1.1.tar.gz
Algorithm Hash digest
SHA256 aa50b7eff1e77462d9899ed0f6f130daa6aa8d02bdcbffab473d6aeea582a763
MD5 b5a6a4390a6e5c28ab3d6d4385786b3b
BLAKE2b-256 1f02bb6f38e39732da3127a0028138ce3af9e1c646e9d35a7604b2afee91c216

See more details on using hashes here.

File details

Details for the file sagemaker_tidymodels-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: sagemaker_tidymodels-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7

File hashes

Hashes for sagemaker_tidymodels-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e3efd8460d02aa7f0bb1a80f3b1c4ac0aceca29ce73b3a2e89a0251fab516dec
MD5 805ec7e66983bd6b7682d22c777c830d
BLAKE2b-256 ac993f5fb92d0afff6f70ff8ba2a9b68ada2218899a616435bfba92db959e758

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page