Example 1: Iris dataset with torch

šŸ“ Note
Since the innsight package relies on the torch package for each method and requires a successful installation of LibTorch and other dependencies (torch::install_torch()), no examples can be run in the R vignette for memory reasons. For the executed code chunks we refer to our website.

In this very elementary example, we want to show you the use of the innsight package and its possibilities in the simplest case. For this, we use the Iris dataset, which contains the length and width of the sepals and petals of the three flower species Iris Setosa, Versicolor and Virginica. In total, there are measurements of 50 flowers of each type in centimeters.

Step 0: Train a model

The first step does not necessarily need the innsight package yet, but at this point, it is listed anyway to have a fully working example. In this example, the model is created and trained in torch.

šŸ“ Note
It is also possible to train the torch model using the luz package, which is described in the vignette with the penguin dataset (see ?vignette("Example_2_penguin", package='innsight')).

library(innsight)
library(torch)

# Set seeds for reproducibility
set.seed(42)
torch_manual_seed(42)

# Prepare data
data(iris)
x <- torch_tensor(as.matrix(iris[, -5]))
# normalize input to [-1, 1]
x <- 2 * (x - x$min(1)[[1]]) / (x$max(1)[[1]] - x$min(1)[[1]]) - 1
y <- torch_tensor(as.integer(iris[, 5]))

# Define model (`torch::nn_sequential`)
model <- nn_sequential(
  nn_linear(4, 30),
  nn_relu(),
  nn_dropout(0.3),
  nn_linear(30, 10),
  nn_relu(),
  nn_linear(10, 3),
  nn_softmax(dim = 2)
)

# Train model
optimizer <- optim_adam(model$parameters, lr = 0.001)
for (t in 1:2500) {
  y_pred <- torch_log(model(x))
  loss <- nnf_nll_loss(y_pred, y)
  if (t %% 250 == 0) {
    cat("Loss: ", as.numeric(loss), "\n")
  }
  optimizer$zero_grad()
  loss$backward()
  optimizer$step()
}

Step 1: Convert the model

The easiest way to convert a model for the innsight package is just using the model for the Converter class without any further arguments (except for torch models, the input dimension is still needed). Thus, the converting part for our trained model is done by:

# Create the converter object
converter <- convert(model, input_dim = c(4))

Since there are no labels for the inputs and outputs stored in the torch model, they were generated automatically by innsight. But you can pass them by yourself using the corresponding arguments. These labels are then also used in the visualizations in step 3. Internally, the dimensions of the input and output labels are also added as checks and errors are thrown if any discrepancies arise.

# Create `Converter` object (with custom labels)
converter <- convert(model,
  input_dim = c(4),
  input_names = c("Sepal (length)", "Sepal (width)", "Petal (length)", "Petal (width)"),
  output_names = c("Setosa", "Versicolor", "Virginica")
)

In addition, the default print() function for R6 classes has been overwritten so that all important properties and contents of the converter object can be displayed in a summarized form:

converter

Step 2: Apply methods

Now we can apply the implemented methods to our trained model. Although these methods are different, they are called the same way in innsight. Essential arguments are the converter object (converter) and the data (data) to which the method is to be applied. In addition to the necessary and method-specific arguments, the following other arguments can be set for each (local) method:

  • channels_first: There are, basically, two data formats for image and signal data, having the channels at the last position or directly after the batch axis. To be able to handle both formats, the format of the passed data can be specified with this argument. In this example, it is unused because we have tabular data.

  • output_idx: These indices specify the output nodes (or classes) for which the method is to be applied. By default, the result for the first ten classes is calculated for efficiency reasons. However, if you want to examine several specific classes simultaneously, you must specify the corresponding indices with this argument.

  • ignore_last_act: Often it makes more sense to ignore the last activation function to get better results. This is especially true for classification problems with softmax or sigmoid as the last activation.

For all possible and more detailed information on the arguments, see the detailed explanation. In the following, we apply our trained model to different methods provided in the package:

  • Gradient without last activation

    grad_no_softmax <- run_grad(converter, x, ignore_last_act = TRUE)
  • Gradient with last activation

    grad_softmax <- run_grad(converter, x, ignore_last_act = FALSE)
  • LRP with Īµ-rule (Īµā€„=ā€„0.01)

    We can also set the method-specific arguments, such as the LRP rule and its associated parameter:

    lrp_eps <- run_lrp(converter, x, rule_name = "epsilon", rule_param = 0.01)
  • DeepLift with mean as reference value

    x_ref <- x$mean(1, keepdim = TRUE) # ref value needs the shape (1,4)
    deeplift_mean <- run_deeplift(converter, x, x_ref = x_ref)

Similar to the instances of the Converter class, the default print() function for R6 classes was also overridden for each method so that all important contents of the corresponding method are displayed:

deeplift_mean

Afterward, the results can be retrieved either with the class method method$get_result() or the equivalent S3 function get_result(). With the argument type, it is also possible to select whether the results are returned as array, data.frame or torch_tensor:

# Get result as a `data.frame` using the class method
head(grad_no_softmax$get_result(type = "data.frame"), 5)

# Get result as `array` (default) using the generic S3 function
str(get_result(grad_no_softmax))

Step 3: Visualization

The package innsight provides two ways to visualize the results of a method, namely as innsight_ggplot2 or innsight_plotly object. Both are S4 classes to combine multiple plots nicely and to be able to make visual modifications or adjustments to the selection of plots even after the object has been created. The first class is based on ggplot2 and behaves partly like an ordinary ggplot2 object. Whereas the other one is based on the plotly package and creates an interactive graph with more detailed information about each variable. For more information on the S4 classes innsight_ggplot and innsight_plotly see the in-depth vignette or the respective R documentation (?innsight_ggplot2 or ?innsight_plotly).

For each of these classes and thus, of course, also for each method, there are two plot functions:

  • plot() shows only individual data points and
  • plot_global() visualizes summaries of multiple data points using summary statistics (you can also use the alias boxplot() for tabular and signal data).

Plot individual results

The function plot() is implemented for each of the available methods. You can select your desired data points and output nodes/classes with the data_idx and output_idx arguments, respectively. To switch between a ggplot2 and plotly based plot, you can use the logical as_plotly parameter, but this requires a successful installation of the plotly package.

šŸ“ Note
Of course, only results of output nodes/classes to which the method has been applied can be plotted, i.e., the indices in the output_idx argument must be a subset of the argument with the same name when calling the method in Step 2.

ggplot2-based plot

Despite being an object of the class innsight_ggplot2, it can be treated like a normal ggplot2 object, i.e., geoms, scales and themes can be added and modified.

# Show data point 1 and 111 for output node 1 (Setosa) and 2 (Versicolor)
plot(grad_no_softmax, data_idx = c(1, 111), output_idx = c(1, 2)) +
  ggplot2::theme_bw()
plotly-based plot
# Show data point 1 for output node 1 (Setosa) and 2 (Versicolor)
plot(deeplift_mean, data_idx = 1, output_idx = c(1, 2), as_plotly = TRUE)

Plot summarized results

The S3 function plot_global() (or boxplot() for tabular and signal data) is implemented for each of the available local methods. You can select your desired data points (default is 'all') and output nodes/classes with the data_idx and output_idx arguments, respectively. To switch between a ggplot2 and plotly based plot, you can use the logical as_plotly parameter, but this requires a successful installation of the plotly package. In addition, you can use ref_data_idx to select a single data point that will be visualized in red as a reference value, and preprocess_FUN to select a function that will be applied to all relevances in advance (e.g., the absolute value).

šŸ“ Note
Of course, only results of output nodes/classes to which the method has been applied can be plotted, i.e., the indices in the output_idx argument must be subset of the argument with the same name when calling the method in Step 2.

ggplot2-based plot

Despite being an object of the class innsight_ggplot2, it can be treated like a normal ggplot2 object, i.e.Ā geoms, scales and themes can be added and modified.

# Summarized results for output node 1 (Setosa) and 2 (Versicolor) and
# reference value of index 3
boxplot(grad_no_softmax, output_idx = c(1, 2), ref_data_idx = 3, preprocess_FUN = abs) +
  ggplot2::theme_bw()
plotly-based plot
# Show boxplot only for instances of class setosa for output node 1 (Setosa)
# and 2 (Versicolor)
boxplot(lrp_eps, data_idx = 1:50, output_idx = c(1, 2), as_plotly = TRUE)