Skip to main content

Analysis functions to quantify inputs importance in neural network models.

Project description

NeuralSens

Jaime Pizarroso Gonzalo, jpizarroso@comillas.edu

Antonio Muñoz San Roque, Antonio.Munoz@iit.comillas.edu

José Portela González, jose.portela@iit.comillas.edu

CRAN status Travis build status AppVeyor build status CRAN_Download_Badge

This is the development repository for the NeuralSens package. Functions within this package can be used for the analysis of neural network models created in R.

The current version of this package can be installed from Github:

install.packages('devtools')
library(devtools)
install_github('JaiPizGon/NeuralSens')

The last version can be installed from CRAN:

install.packages('NeuralSens')

Bug reports

Please submit any bug reports (or suggestions) using the issues tab of the GitHub page.

Functions

One function is available to analyze the sensitivity of a multilayer perceptron, evaluating variable importance and plotting the analysis results. A sample dataset is also provided for use with the examples. The function has S3 methods developed for neural networks from the following packages: nnet, neuralnet, RSNNS, caret, neural, h2o and forecast. Numeric inputs that describe model weights are also acceptable.

Start by loading the package and the sample dataset.

library(NeuralSens)
data(DAILY_DEMAND_TR)

The SensAnalysisMLP function analyze the sensitivity of the output to the input and plots three graphics with information about this analysis. To calculate this sensitivity it calculates the partial derivatives of the output to the inputs using the training data. The first plot shows information between the mean and the standard deviation of the sensitivity among the training data:

  • if the mean is different from zero, it means that the output depends on the input because the output changes when the input change.
  • if the mean is nearly zero, it means that the output could not depend on the input. If the standard deviation is also near zero it almost sure that the output does not depend on the variable because for all the training data the partial derivative is zero.
  • if the standard deviation is different from zero it means the the output has a non-linear relation with the input because the partial derivative derivative of the output depends on the value of the input.
  • if the standard deviation is nearly zero it means that the output has a linear relation with the input because the partial derivative of the output does not depend on the value of the input. The second plot gives an absolute measure to the importance of the inputs, by calculating the sum of the squares of the partial derivatives of the output to the inputs. The third plot is a density plot of the partial derivatives of the output to the inputs among the training data, giving similar information as the first plot.
# Scale the data
DAILY_DEMAND_TR[,4] <- DAILY_DEMAND_TR[,4]/10
DAILY_DEMAND_TR[,2] <- DAILY_DEMAND_TR[,2]/100
# Parameters of the neural network
hidden_neurons <- 5
iters <- 250
decay <- 0.1

# create neural network
library(caret)
## Warning: package 'caret' was built under R version 4.1.3
ctrl_tune <- trainControl(method = "boot",
                          savePredictions = FALSE,
                          summaryFunction = defaultSummary)
set.seed(150) #For replication
mod <- caret::train(form = DEM~TEMP + WD,
                    data = DAILY_DEMAND_TR,
                    method = "nnet",
                    linout = TRUE,
                    tuneGrid = data.frame(size = hidden_neurons,
                                          decay = decay),
                    maxit = iters,
                    preProcess = c("center","scale"),
                    trControl = ctrl_tune,
                    metric = "RMSE")

# Analysis of the neural network
sens <- SensAnalysisMLP(mod)

Apart from the plot created with the SensAnalysisMLP function by an internal call to SensitivityPlot, other plots can be obtained to analyze the neural network model. If it is a forecast problem, the SensTimePlot function returns a plot which shows how the sensitivity of the output changes over the time of the data.

SensTimePlot(sens, fdata = DAILY_DEMAND_TR, facet = TRUE)

Also, a more detailed plot about the distribution of the variables can be obtained with the SensFeaturePlot function. This function returns a scatter plot over a violin plot for each input variable, where each point represent a sensitivity value of a sample of the dataset. The color of each point depends on the value of the input for its corresponding sample.

SensFeaturePlot(sens, fdata = DAILY_DEMAND_TR)

Citation

Please, to cite NeuralSens in publications use:

Pizarroso J, Portela J, Muñoz A (2022). “NeuralSens: Sensitivity Analysis of Neural Networks.” Journal of Statistical Software, 102(7), 1-36. doi: 10.18637/jss.v102.i07 (URL: https://doi.org/10.18637/jss.v102.i07).

License

This package is released in the public domain under the General Public License GPL.

Association

Package created in the Institute for Research in Technology (IIT), link to homepage

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neuralsens-0.0.3.dev18.tar.gz (25.1 kB view details)

Uploaded Source

Built Distribution

neuralsens-0.0.3.dev18-py3-none-any.whl (22.5 kB view details)

Uploaded Python 3

File details

Details for the file neuralsens-0.0.3.dev18.tar.gz.

File metadata

  • Download URL: neuralsens-0.0.3.dev18.tar.gz
  • Upload date:
  • Size: 25.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.7

File hashes

Hashes for neuralsens-0.0.3.dev18.tar.gz
Algorithm Hash digest
SHA256 427b0b82d7ed89d3b6eec12bba27bce97436dd3aa1eee8a18811f25c6c684ba3
MD5 663f4f2a7f87676dadb7d1913a729e15
BLAKE2b-256 a2b5f5670cafbabfe152a188737c7aa7ccbe6fff9383db0c3f6e5ac1bcf3fe08

See more details on using hashes here.

File details

Details for the file neuralsens-0.0.3.dev18-py3-none-any.whl.

File metadata

File hashes

Hashes for neuralsens-0.0.3.dev18-py3-none-any.whl
Algorithm Hash digest
SHA256 ea93e096130decb9e3764424ebf8a7e28a3a129c0785043fbdb7e4eee07d8723
MD5 06222f12a8d6ca21fadec694768d1290
BLAKE2b-256 a270b7a1cec7f71d624a599b3d75ee93434cdb986d193fbfed56f7c77b31dba5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page