ConfusionTableR

This package allows for the rapid transformation of confusion matrix objects from the caret package and allows for these to be easily converted into data frame objects, as the objects are natively list object types.

Why is this useful

This is useful, as it allows for the list items to be turned into a transposed row and column data frame. I had the idea when working with a number of machine learning models and wanted to store the results in database tables, thus I wanted a way to have one row per model run. This is something that is not implemented in the excellent caret package created by Max Kuhn [https://CRAN.R-project.org/package=caret].

Preparing the ML model to then evaluate

The following approach shows how the single confusion matrix function can be used to flatten all the results of the caret confusion matrices down from the multiple classification model, using the multi_class_cm function. This example is implemented below:

Example:

library(caret)
library(dplyr)
library(mlbench)
library(tidyr)
library(e1071)
library(randomForest)

# Load in the iris data set for this problem 
data(iris)
df <- iris
# View the class distribution, as this is a multiclass problem, we can use the multi-uclassification data table builder
table(iris$Species)
#> 
#>     setosa versicolor  virginica 
#>         50         50         50

Splitting the data into train and test splits:

train_split_idx <- caret::createDataPartition(df$Species, p = 0.75, list = FALSE)
# Here we define a split index and we are now going to use a multiclass ML model to fit the data
train <- df[train_split_idx, ]
test <- df[-train_split_idx, ]
str(train)
#> 'data.frame':    114 obs. of  5 variables:
#>  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 4.6 5 4.4 5.4 4.8 ...
#>  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.4 3.4 2.9 3.7 3 ...
#>  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.4 1.5 1.4 1.5 1.4 ...
#>  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.3 0.2 0.2 0.2 0.1 ...
#>  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

This now creates a 75% training set for training the ML model and we are going to use the remaining 25% as validation data to test the model.

rf_model <- caret::train(Species ~ .,
                         data = df,
                         method = "rf",
                         metric = "Accuracy")

rf_model
#> Random Forest 
#> 
#> 150 samples
#>   4 predictor
#>   3 classes: 'setosa', 'versicolor', 'virginica' 
#> 
#> No pre-processing
#> Resampling: Bootstrapped (25 reps) 
#> Summary of sample sizes: 150, 150, 150, 150, 150, 150, ... 
#> Resampling results across tuning parameters:
#> 
#>   mtry  Accuracy   Kappa    
#>   2     0.9482995  0.9214239
#>   3     0.9473322  0.9199555
#>   4     0.9444634  0.9156175
#> 
#> Accuracy was used to select the optimal model using the largest value.
#> The final value used for the model was mtry = 2.

The model is relatively accurate. This is not a lesson on machine learning, however we now know how well the model performs on the training set, we need to validate this with a confusion matrix. The Random Forest shows that it has been trained on greater than >2 classes so this moves from a binary model to a multi-classification model. The functions contained in the package work with binary and multiclassification methods.

Using the native Confusion Matrix in CARET

The native confusion matrix is excellent in CARET, however it is stored as a series of list items that need to be utilised together to compare model fit performance over time to make sure there is no underlying feature slippage and regression in performance. This is where my solution comes in.

# Make a prediction on the fitted model with the test data
rf_class <- predict(rf_model, newdata = test, type = "raw") 
predictions <- cbind(data.frame(train_preds=rf_class, 
                                test$Species))
# Create a confusion matrix object
cm <- caret::confusionMatrix(predictions$train_preds, predictions$test.Species)
print(cm) 
#> Confusion Matrix and Statistics
#> 
#>             Reference
#> Prediction   setosa versicolor virginica
#>   setosa         12          0         0
#>   versicolor      0         12         0
#>   virginica       0          0        12
#> 
#> Overall Statistics
#>                                      
#>                Accuracy : 1          
#>                  95% CI : (0.9026, 1)
#>     No Information Rate : 0.3333     
#>     P-Value [Acc > NIR] : < 2.2e-16  
#>                                      
#>                   Kappa : 1          
#>                                      
#>  Mcnemar's Test P-Value : NA         
#> 
#> Statistics by Class:
#> 
#>                      Class: setosa Class: versicolor Class: virginica
#> Sensitivity                 1.0000            1.0000           1.0000
#> Specificity                 1.0000            1.0000           1.0000
#> Pos Pred Value              1.0000            1.0000           1.0000
#> Neg Pred Value              1.0000            1.0000           1.0000
#> Prevalence                  0.3333            0.3333           0.3333
#> Detection Rate              0.3333            0.3333           0.3333
#> Detection Prevalence        0.3333            0.3333           0.3333
#> Balanced Accuracy           1.0000            1.0000           1.0000

The outputs of the matrix are really useful, however I want to combine all this information into one row of a data frame for storing information in a data table and import into a database universe.

Using ConfusionTableR to collapse this data into a data frame

The package has two functions for dealing with these types of problems, firstly I will show the multi-classification version and show how this can be implemented:

Example

# Implementing function to collapse data
library(ConfusionTableR)
mc_df <- ConfusionTableR::multi_class_cm(predictions$train_preds, predictions$test.Species,
                                         mode="everything")
# Access the reduced data for storage in databases
print(mc_df$record_level_cm)
#>   setosa : setosa setosa : versicolor setosa : virginica versicolor : setosa
#> 1              12                   0                  0                   0
#>   versicolor : versicolor versicolor : virginica virginica : setosa
#> 1                      12                      0                  0
#>   virginica : versicolor virginica : virginica Accuracy Kappa AccuracyLower
#> 1                      0                    12        1     1     0.9026062
#>   AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue setosa : Sensitivity
#> 1             1    0.3333333   6.662463e-18           NaN                    1
#>   setosa : Specificity setosa : Pos Pred Value setosa : Neg Pred Value
#> 1                    1                       1                       1
#>   setosa : Precision setosa : Recall setosa : F1 setosa : Prevalence
#> 1                  1               1           1           0.3333333
#>   setosa : Detection Rate setosa : Detection Prevalence
#> 1               0.3333333                     0.3333333
#>   setosa : Balanced Accuracy versicolor : Sensitivity versicolor : Specificity
#> 1                          1                        1                        1
#>   versicolor : Pos Pred Value versicolor : Neg Pred Value
#> 1                           1                           1
#>   versicolor : Precision versicolor : Recall versicolor : F1
#> 1                      1                   1               1
#>   versicolor : Prevalence versicolor : Detection Rate
#> 1               0.3333333                   0.3333333
#>   versicolor : Detection Prevalence versicolor : Balanced Accuracy
#> 1                         0.3333333                              1
#>   virginica : Sensitivity virginica : Specificity virginica : Pos Pred Value
#> 1                       1                       1                          1
#>   virginica : Neg Pred Value virginica : Precision virginica : Recall
#> 1                          1                     1                  1
#>   virginica : F1 virginica : Prevalence virginica : Detection Rate
#> 1              1              0.3333333                  0.3333333
#>   virginica : Detection Prevalence virginica : Balanced Accuracy
#> 1                        0.3333333                             1
#>                 cm_ts
#> 1 2024-12-16 06:55:14
glimpse(mc_df$record_level_cm)
#> Rows: 1
#> Columns: 50
#> $ `setosa : setosa`                   <int> 12
#> $ `setosa : versicolor`               <int> 0
#> $ `setosa : virginica`                <int> 0
#> $ `versicolor : setosa`               <int> 0
#> $ `versicolor : versicolor`           <int> 12
#> $ `versicolor : virginica`            <int> 0
#> $ `virginica : setosa`                <int> 0
#> $ `virginica : versicolor`            <int> 0
#> $ `virginica : virginica`             <int> 12
#> $ Accuracy                            <dbl> 1
#> $ Kappa                               <dbl> 1
#> $ AccuracyLower                       <dbl> 0.9026062
#> $ AccuracyUpper                       <dbl> 1
#> $ AccuracyNull                        <dbl> 0.3333333
#> $ AccuracyPValue                      <dbl> 6.662463e-18
#> $ McnemarPValue                       <dbl> NaN
#> $ `setosa : Sensitivity`              <dbl> 1
#> $ `setosa : Specificity`              <dbl> 1
#> $ `setosa : Pos Pred Value`           <dbl> 1
#> $ `setosa : Neg Pred Value`           <dbl> 1
#> $ `setosa : Precision`                <dbl> 1
#> $ `setosa : Recall`                   <dbl> 1
#> $ `setosa : F1`                       <dbl> 1
#> $ `setosa : Prevalence`               <dbl> 0.3333333
#> $ `setosa : Detection Rate`           <dbl> 0.3333333
#> $ `setosa : Detection Prevalence`     <dbl> 0.3333333
#> $ `setosa : Balanced Accuracy`        <dbl> 1
#> $ `versicolor : Sensitivity`          <dbl> 1
#> $ `versicolor : Specificity`          <dbl> 1
#> $ `versicolor : Pos Pred Value`       <dbl> 1
#> $ `versicolor : Neg Pred Value`       <dbl> 1
#> $ `versicolor : Precision`            <dbl> 1
#> $ `versicolor : Recall`               <dbl> 1
#> $ `versicolor : F1`                   <dbl> 1
#> $ `versicolor : Prevalence`           <dbl> 0.3333333
#> $ `versicolor : Detection Rate`       <dbl> 0.3333333
#> $ `versicolor : Detection Prevalence` <dbl> 0.3333333
#> $ `versicolor : Balanced Accuracy`    <dbl> 1
#> $ `virginica : Sensitivity`           <dbl> 1
#> $ `virginica : Specificity`           <dbl> 1
#> $ `virginica : Pos Pred Value`        <dbl> 1
#> $ `virginica : Neg Pred Value`        <dbl> 1
#> $ `virginica : Precision`             <dbl> 1
#> $ `virginica : Recall`                <dbl> 1
#> $ `virginica : F1`                    <dbl> 1
#> $ `virginica : Prevalence`            <dbl> 0.3333333
#> $ `virginica : Detection Rate`        <dbl> 0.3333333
#> $ `virginica : Detection Prevalence`  <dbl> 0.3333333
#> $ `virginica : Balanced Accuracy`     <dbl> 1
#> $ cm_ts                               <dttm> 2024-12-16 06:55:14

This stores a list item. Here you can retrieve:

  • the confusion matrix, as this is generated automatically and does not require one to be fit beforehand, as in the previous example
  • the record_level_cm that can then be used to output data into a database
  • the confusion matrix numerical table
  • the datetime the list was created

To get the original confusion matrix:

mc_df$confusion_matrix
#> Confusion Matrix and Statistics
#> 
#>             Reference
#> Prediction   setosa versicolor virginica
#>   setosa         12          0         0
#>   versicolor      0         12         0
#>   virginica       0          0        12
#> 
#> Overall Statistics
#>                                      
#>                Accuracy : 1          
#>                  95% CI : (0.9026, 1)
#>     No Information Rate : 0.3333     
#>     P-Value [Acc > NIR] : < 2.2e-16  
#>                                      
#>                   Kappa : 1          
#>                                      
#>  Mcnemar's Test P-Value : NA         
#> 
#> Statistics by Class:
#> 
#>                      Class: setosa Class: versicolor Class: virginica
#> Sensitivity                 1.0000            1.0000           1.0000
#> Specificity                 1.0000            1.0000           1.0000
#> Pos Pred Value              1.0000            1.0000           1.0000
#> Neg Pred Value              1.0000            1.0000           1.0000
#> Precision                   1.0000            1.0000           1.0000
#> Recall                      1.0000            1.0000           1.0000
#> F1                          1.0000            1.0000           1.0000
#> Prevalence                  0.3333            0.3333           0.3333
#> Detection Rate              0.3333            0.3333           0.3333
#> Detection Prevalence        0.3333            0.3333           0.3333
#> Balanced Accuracy           1.0000            1.0000           1.0000

To get the confusion matrix table:

mc_df$cm_tbl
#>             Reference
#> Prediction   setosa versicolor virginica
#>   setosa         12          0         0
#>   versicolor      0         12         0
#>   virginica       0          0        12

This data frame can now be used to store analyse these records over time i.e. looking at the machine learning statistics and if they depreciate or reduce upon different training runs.

Using ConfusionTableR to collapse binary confusion matrix outputs

In this example we will use the breast cancer datasets, from mlbench to allow us to predict whether a new patient has cancer, dependent on the retrospective patterns in the data and the underlying data features.

Preparing data and fitting the model

# Load in the data
library(dplyr)
library(ConfusionTableR)
library(caret)
library(tidyr)
library(mlbench)

# Load in the data
data("BreastCancer", package = "mlbench")
breast <- BreastCancer[complete.cases(BreastCancer), ] #Create a copy
breast <- breast[, -1]
breast$Class <- factor(breast$Class) # Create as factor
for(i in 1:9) {
 breast[, i] <- as.numeric(as.character(breast[, i]))
}

We now have our stranded patient model ready. Now we will fit a confusion matrix to this and use the tools in ConfusionTableR to output to a record level list, as observed in the previous section and to build a visualisation of the confusion matrix.

Predicting the class labels using the training dataset

This snippet shows how to achieve this:

#Perform train / test split on the data
train_split_idx <- caret::createDataPartition(breast$Class, p = 0.75, list = FALSE)
train <- breast[train_split_idx, ]
test <- breast[-train_split_idx, ]
rf_fit <- caret::train(Class ~ ., data=train, method="rf")
#Make predictions to expose class labels
preds <- predict(rf_fit, newdata=test, type="raw")
predicted <- cbind(data.frame(class_preds=preds), test)

Now this is where we will use the package to visualise and reduce to a data frame.

Binary Confusion Matrix Data Frame

The following example shows how this is implemented:

bin_cm <- ConfusionTableR::binary_class_cm(predicted$class_preds, predicted$Class)
# Get the record level data
bin_cm$record_level_cm
#>   Pred_benign_Ref_benign Pred_malignant_Ref_benign Pred_benign_Ref_malignant
#> 1                    107                         4                         1
#>   Pred_malignant_Ref_malignant  Accuracy     Kappa AccuracyLower AccuracyUpper
#> 1                           58 0.9705882 0.9358684     0.9327003     0.9903825
#>   AccuracyNull AccuracyPValue McnemarPValue Sensitivity Specificity
#> 1    0.6529412   1.692566e-24     0.3710934    0.963964   0.9830508
#>   Pos.Pred.Value Neg.Pred.Value Precision   Recall        F1 Prevalence
#> 1      0.9907407      0.9354839 0.9907407 0.963964 0.9771689  0.6529412
#>   Detection.Rate Detection.Prevalence Balanced.Accuracy               cm_ts
#> 1      0.6294118            0.6352941         0.9735074 2024-12-16 06:55:23
glimpse(bin_cm$record_level_cm)
#> Rows: 1
#> Columns: 23
#> $ Pred_benign_Ref_benign       <int> 107
#> $ Pred_malignant_Ref_benign    <int> 4
#> $ Pred_benign_Ref_malignant    <int> 1
#> $ Pred_malignant_Ref_malignant <int> 58
#> $ Accuracy                     <dbl> 0.9705882
#> $ Kappa                        <dbl> 0.9358684
#> $ AccuracyLower                <dbl> 0.9327003
#> $ AccuracyUpper                <dbl> 0.9903825
#> $ AccuracyNull                 <dbl> 0.6529412
#> $ AccuracyPValue               <dbl> 1.692566e-24
#> $ McnemarPValue                <dbl> 0.3710934
#> $ Sensitivity                  <dbl> 0.963964
#> $ Specificity                  <dbl> 0.9830508
#> $ Pos.Pred.Value               <dbl> 0.9907407
#> $ Neg.Pred.Value               <dbl> 0.9354839
#> $ Precision                    <dbl> 0.9907407
#> $ Recall                       <dbl> 0.963964
#> $ F1                           <dbl> 0.9771689
#> $ Prevalence                   <dbl> 0.6529412
#> $ Detection.Rate               <dbl> 0.6294118
#> $ Detection.Prevalence         <dbl> 0.6352941
#> $ Balanced.Accuracy            <dbl> 0.9735074
#> $ cm_ts                        <dttm> 2024-12-16 06:55:23

This is now in a data.frame class and can be used and saved as a single record to a database server to monitor confusion matrix performance over time.

Visualising the confusion matrix

The last tool in the package produces a nice visual of the confusion matrix that can be used in presentations and papers to display the matrix and its associated summary statistics:


ConfusionTableR::binary_visualiseR(train_labels = predicted$class_preds,
                                   truth_labels= predicted$Class,
                                   class_label1 = "Not Stranded", 
                                   class_label2 = "Stranded",
                                   quadrant_col1 = "#28ACB4", 
                                   quadrant_col2 = "#4397D2", 
                                   custom_title = "Breast Cancer Confusion Matrix", 
                                   text_col= "black")

These can be used in combination with the outputs from the CARET package to build up the analysis of how well the model fits and how well it will fit in the future, from the analysis of Cohen’s Kappa value and other associated metrics.

Wrapping up

This has been created to aid in the storage of confusion matrix outputs into a flat row wise structure for storage in data tables, frames and data warehouses, as from experience we tend to monitor the test statistics for working with these matrices over time, when they have been retrained.