Introduction to Multiclass

Multiclass projects in DataRobot are projects that allow for prediction of more than two classes (unlike binary prediction, which is for precisely two classes). Currently, DataRobot supports predicting up to 10 different classes.

Connect to DataRobot

To explore multiclass projects, let’s first connect to DataRobot. First, you must load the DataRobot R package library.

If you have set up a credentials file, library(datarobot) will initialize a connection to DataRobot automatically. Otherwise, you can specify your endpoint and apiToken as in this example to connect to DataRobot directly. For more information on connecting to DataRobot, see the “Introduction to DataRobot” vignette.

library(datarobot)
endpoint <- "https://<YOUR DATAROBOT URL GOES HERE>/api/v2"
apiToken <- "<YOUR API TOKEN GOES HERE>"
ConnectToDataRobot(endpoint = endpoint, token = apiToken)

Creating a Multiclass Project

Let’s predict for the iris dataset:

library(knitr)
data(iris) # Load `iris` from R data memory.
kable(iris)

If your target is categorical and has a cardinality of up to 10, we will automatically select a Multiclass targetType and that argument is not needed when calling StartProject. However, if the target is numerical and you would like to force it to be seen as a Multiclass project in DataRobot, you can specify the targetType as seen below:

project <- StartProject(iris,
                        projectName = "multiclassExample",
                        target = "Species",
                        targetType = TargetType$Multiclass,
                        maxWait = 600)

Now we can build a model:

blueprint <- ListBlueprints(project)[[1]]
RequestNewModel(project, blueprint)

And then we can get predictions:

model <- ListModels(project)[[1]]
predictions <- Predict(model, iris)
print(table(predictions))
## request issued, waiting for predictions
## Multiclass with labels setosa, versicolor, virginica
setosa versicolor  virginica 
    50         47         53 

You can also get a dataframe with the probabilities of each class using type = "probability":

predictions <- Predict(model, iris, type = "probability")
kable(head(predictions))
## request issued, waiting for predictions
## Multiclass with labels setosa, versicolor, virginica
class_setosa class_versicolor class_virginica
0.9987500 0.0000000 0.0012500
0.9344544 0.0491984 0.0163472
0.9854799 0.0080586 0.0064615
0.9931519 0.0054731 0.0013750
0.9954167 0.0022222 0.0023611
0.9883673 0.0017766 0.0098561

Confusion Charts

The confusion chart is a chart that helps understand how the multiclass model performs:

confusionChart <- GetConfusionChart(model, source = DataPartition$VALIDATION)
kable(capture.output(confusionChart))
x
$source | |[1] "validation" | | | |$data
dataclasses
[1] “setosa” “versicolor” “virginica”
dataclassMetrics
dataclassMetricswasActualPercentages||dataclassMetricswasActualPercentages[[1]]
percentage otherClassName
1 1 setosa
2 0 versicolor
3 0 virginica
dataclassMetricswasActualPercentages[[2]]||percentageotherClassName||10.0setosa||20.8versicolor||30.2virginica||||dataclassMetricswasActualPercentages[[3]]
percentage otherClassName
1 0 setosa
2 0 versicolor
3 1 virginica
dataclassMetricsf1||[1]1.00000000.88888890.9523810||||dataclassMetricsconfusionMatrixOneVsAll
dataclassMetricsconfusionMatrixOneVsAll[[1]]||[, 1][, 2]||[1,]150||[2,]09||||dataclassMetricsconfusionMatrixOneVsAll[[2]]
[,1] [,2]
[1,] 19 0
[2,] 1 4
dataclassMetricsconfusionMatrixOneVsAll[[3]]||[, 1][, 2]||[1,]131||[2,]010||||||dataclassMetricsrecall
[1] 1.0 0.8 1.0
dataclassMetricsactualCount||[1]9510||||dataclassMetricsprecision
[1] 1.0000000 1.0000000 0.9090909
dataclassMetricswasPredictedPercentages||dataclassMetricswasPredictedPercentages[[1]]
percentage otherClassName
1 1 setosa
2 0 versicolor
3 0 virginica
dataclassMetricswasPredictedPercentages[[2]]||percentageotherClassName||10setosa||21versicolor||30virginica||||dataclassMetricswasPredictedPercentages[[3]]
percentage otherClassName
1 0.00000000 setosa
2 0.09090909 versicolor
3 0.90909091 virginica
dataclassMetrics$className | |[1] "setosa" "versicolor" "virginica" | | | |$dataclassMetricspredictedCount
[1] 9 4 11
dataconfusionMatrix
[,1] [,2] [,3]
[1,] 9 0 0
[2,] 0 4 1
[3,] 0 0 10

Here, we can see the source comes from the "validation" partition (options are in the DataPartition object), and class metrics show:

  • wasActualPercentages: for each class, what percentage of that class was predicted as. A prediction of any other class would involve mispredicting.
  • wasPredictedPercentages: for each predicted class, what percentage of that prediction was actually the other class.
  • confusionMatrix: A matrix for each predicted class, showing on the x-axis whether the actual class matches the predicted class (1) or not (2), and on the y-axis whether the class being predicted is the class for the matrix (1) or not (2). Thus the top-left quadrant (1-1) is the number of records that actually are the predicted class and were predicted to be that class (true positives), the top-right quadrant (1-2) is the number of records that were mispredicted as not the class but actually were the class (false negatives), the bottom-left quadrant (1-2) is the number of records that actually were not the class but were mispredicted to be the class (false positives), and the bottom-right quadrant (2-2) is the number of records that are not the class and were also predicted to not be the class (true negatives).
  • f1: The F1 score for each class.
  • precision The precision statistic for each class.
  • recall: The recall statistic for each class.
  • actualCount: The number of records for each class that actually are that class.
  • predictedCount: The number of times each class was predicted.

The confusion chart also shows a full confusion matrix with one row and one column for each class, showing how each class was predicted or mispredicted. The columns represent the predicted classes and the rows represent the actual classes.