Skip to main content

Sagemaker framework for Tidymodels

Project description

sagemaker-tidymodels

sagemaker-tidymodels is an AWS Sagemaker framework for training and deploy machine learning models written in R.

This framework lets you do cloud-based training and deployment with tidymodels, using the same code you would write locally.

Installation

This Python package is not yet available on PyPi. In the meantime, you can install it from Github:

git clone https://github.com/tmastny/sagemaker-tidymodels.git
pip install sagemaker-tidymodels/

The docker image is hosted on dockerhub, or you can pull directly with

docker pull tmastny/sagemaker-tidymodels

Usage

The sagemaker-tidymodels Python package provides simple wrappers around the Estimator and Model sagemaker classes.

The main difference is the entry_point parameter, where you can supply an R script. This R script should process the raw data, train the model, and save the final fit.

from sagemaker_tidymodels import Tidymodels, get_role

tidymodels = Tidymodels(
    entry_point="tests/train.R",
    train_instance_type="local",
    role=get_role(),
    image_name="tmastny/sagemaker-tidymodels:latest",
)

s3_data = "s3://sagemaker-sample-data-us-east-2/processing/census/census-income.csv"
tidymodels.fit({"train": s3_data})

train.R is a normal R script, with a few necessary additions so it can run on the command line effectively.

#!/usr/bin/env Rscript

library(tidymodels)

if (sys.nframe() == 0) {

  input_path <- file.path(Sys.getenv('SM_CHANNEL_TRAIN'), "census-income.csv")
  df <- read.csv(input_path, stringsAsFactors = TRUE)


  pipeline <- workflow() %>%
    add_formula(income ~ age) %>%
    add_model(logistic_reg() %>% set_engine("glm"))

  model <- pipeline %>%
    fit(data = df)

  output_path <- file.path(Sys.getenv('SM_MODEL_DIR'), "model.RDS")
  saveRDS(model, output_path)
}
  1. The first line should be the shebang #!/usr/bin/env Rscript, so it can be ran as ./train.R as required by the framework. Make sure to run chmod +x train.R so it’s an executable.

  2. All the training logic should be wrapped by the following if statement. This seems a little mysterious, but it makes sure that the training logic doesn’t accidentally run when we’ve deployed our model for predictions.

if (sys.nframe() == 0) {
  # training logic goes here!
}
  1. Sagemaker is very specific about input and output locations. The input data path is found in an environment variable that can be read using Sys.getenv('SM_CHANNEL_TRAIN'). Likewise, the output model path can be found with Sys.getenv('SM_MODEL_DIR').

From there, you can deploy the model as normal!

predictor = model.deploy(initial_instance_count=1, instance_type="local")
predictor.predict(r'28\n')

Advanced Usage

The docker container has some additional features that may be useful.

Custom Model Serving

The model serving defaults are defined in docker/server/default_fn.R. If you’d like to customize how the model is served, you can overwrite these defaults by defining these functions in your entry_point script.

The valid options are model_fn, input_fn, predict_fn, and output_fn. In our script basic-train.R, the default predict_fn means we get class predictors, either - 50000. or 50000+..

If we wanted to output the probability of belonging to either class, we could include our own predict_fn in basic-train.R:

# add to `train.R`
predict_fn <- function(model, new_data) {
  predict(model, new_data, type = "prob")
}

This is also why the training script needs to be wrapped by the if statement.

Identical Local and Cloud Scripts

In train.R, the logic you use to train is exactly the same you would write locally. However, you can’t run the script locally as is, because sagemaker defines the environment variables SM_CHANNEL_TRAIN and SM_MODEL_DIR (as well as many others you might want to use).

A nice way to set some defaults so the script can run both locally and is sagemaker is by using r-optparse.

For example:

library(optparse)

option_list <- list(
  make_option(c("-i", "--input"), default = Sys.getenv("SM_CHANNEL_TRAIN")),
  make_option(c("-o", "--output"), default = Sys.getenv("SM_MODEL_DIR"))
)

args <- parse_args(OptionParser(option_list = option_list))

This lets us use args$input and args$output for the input data path and output model path if we are running locally or in sagemaker as the entry_point.

Then when running locally, we can define inputs and outputs

Rscript tests/train.R -i data/census-income.csv -o models/

on the command line so it runs properly.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sagemaker-tidymodels-0.1.0.tar.gz (5.5 kB view details)

Uploaded Source

Built Distribution

sagemaker_tidymodels-0.1.0-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file sagemaker-tidymodels-0.1.0.tar.gz.

File metadata

  • Download URL: sagemaker-tidymodels-0.1.0.tar.gz
  • Upload date:
  • Size: 5.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7

File hashes

Hashes for sagemaker-tidymodels-0.1.0.tar.gz
Algorithm Hash digest
SHA256 73ebb62d33c5e2ef444337a099363ced73c5161b56e73182c605017495634447
MD5 b07bfebeddecbef9d5c813841178bcab
BLAKE2b-256 4e478b7c6c7666983188d5671c1f3a6adecdb7f16b5c2bed6f94415ed651e9aa

See more details on using hashes here.

File details

Details for the file sagemaker_tidymodels-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: sagemaker_tidymodels-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.3.1.post20200622 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7

File hashes

Hashes for sagemaker_tidymodels-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6b0ef8eb3218d2e982a03839fc06d1d2850b82a48a19f16d0842e71caecd678f
MD5 9a2e372d2f7a8933b5733e04c3aef094
BLAKE2b-256 3aa548f36570578a61c61d4c578a8fd92a10c5506f874f3d3c0f687a1da96427

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page