Package 'mlr3viz'

Title: Visualizations for 'mlr3'
Description: Visualization package of the 'mlr3' ecosystem. It features plots for mlr3 objects such as tasks, learners, predictions, benchmark results, tuning instances and filters via the 'autoplot()' generic of 'ggplot2'. The package draws plots with the 'viridis' color palette and applies the minimal theme. Visualizations include barplots, boxplots, histograms, ROC curves, and Precision-Recall curves.
Authors: Michel Lang [aut] , Patrick Schratz [aut] , Raphael Sonabend [aut] , Marc Becker [cre, aut] , Jakob Richter [aut] , Damir Pulatov [ctb], John Zobolas [aut]
Maintainer: Marc Becker <[email protected]>
License: LGPL-3
Version: 0.10.0
Built: 2024-12-08 07:21:06 UTC
Source: CRAN

Help Index


mlr3viz: Visualizations for 'mlr3'

Description

logo

Visualization package of the 'mlr3' ecosystem. It features plots for mlr3 objects such as tasks, learners, predictions, benchmark results, tuning instances and filters via the 'autoplot()' generic of 'ggplot2'. The package draws plots with the 'viridis' color palette and applies the minimal theme. Visualizations include barplots, boxplots, histograms, ROC curves, and Precision-Recall curves.

Author(s)

Maintainer: Marc Becker [email protected] (ORCID)

Authors:

Other contributors:

See Also

Useful links:


Convert to 'precrec' Format

Description

Converts to a format which is understood by precrec::evalmod() of package precrec.

Usage

as_precrec(object)

## S3 method for class 'PredictionClassif'
as_precrec(object)

## S3 method for class 'ResampleResult'
as_precrec(object)

## S3 method for class 'BenchmarkResult'
as_precrec(object)

Arguments

object

(any)
Object to convert.

Value

Object as created by precrec::mmdata().

References

Saito T, Rehmsmeier M (2017). “Precrec: fast and accurate precision-recall and ROC curve calculations in R.” Bioinformatics, 33(1), 145-147. doi:10.1093/bioinformatics/btw570.


Plots for Benchmark Results

Description

Visualizations for mlr3::BenchmarkResult. The argument type controls what kind of plot is drawn. Possible choices are:

Usage

## S3 method for class 'BenchmarkResult'
autoplot(
  object,
  type = "boxplot",
  measure = NULL,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3::BenchmarkResult).

type

(character(1)):
Type of the plot. See description.

measure

(mlr3::Measure)
Performance measure to use.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

References

Saito T, Rehmsmeier M (2017). “Precrec: fast and accurate precision-recall and ROC curve calculations in R.” Bioinformatics, 33(1), 145-147. doi:10.1093/bioinformatics/btw570.

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  tasks = tsks(c("pima", "sonar"))
  learner = lrns(c("classif.featureless", "classif.rpart"),
    predict_type = "prob")
  resampling = rsmps("cv")
  object = benchmark(benchmark_grid(tasks, learner, resampling))

  head(fortify(object))
  autoplot(object)
  autoplot(object$clone(deep = TRUE)$filter(task_ids = "pima"), type = "roc")
}

Plots for Ensemble Feature Selection Results

Description

Visualizations for EnsembleFSResult. The argument type determines the type of plot generated. The available options are:

  • "pareto" (default): Scatterplot of performance versus the number of features, possibly including the Pareto front, which allows users to decide how much performance they are willing to trade off for a more sparse model.

  • "performance": Boxplot of performance across the different learners used in the ensemble feature selection process. Each box represents the distribution of scores across different resampling iterations for a particular learner.

  • ⁠"n_features⁠: Boxplot of the number of features selected by each learner in the different resampling iterations.

  • "stability": Barplot of stability score for each learner used in the ensemble feature selection. This plot shows how similar are the output feature sets from each learner across the different resamplings.

Usage

## S3 method for class 'EnsembleFSResult'
autoplot(
  object,
  type = "pareto",
  pareto_front = "stepwise",
  stability_measure = "jaccard",
  stability_args = NULL,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3fselect::EnsembleFSResult).

type

(character(1)):
Type of the plot. See description.

pareto_front

(character(1))
Type of pareto front to plot. Can be "stepwise" (default), "estimated" or "none".

stability_measure

(character(1))
The stability measure to be used in case type = "stability". One of the measures returned by stabm::listStabilityMeasures() in lower case. Default is "jaccard".

stability_args

(list)
Additional arguments passed to the stability measure function.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3fselect)

  set.seed (42)
  efsr = ensemble_fselect(
    fselector = fs("random_search"),
    task = tsk("sonar"),
    learners = lrns(c("classif.rpart", "classif.featureless")),
    init_resampling = rsmp("subsampling", repeats = 5),
    inner_resampling = rsmp("cv", folds = 3),
    measure = msr("classif.ce"),
    terminator = trm("evals", n_evals = 5)
  )

  # Pareto front (default, stepwise)
  autoplot(efsr)

  # Pareto front (estimated)
  autoplot(efsr, pareto_front = "estimated")

  # Performance
  autoplot(efsr, type = "performance")

  # Number of features
  autoplot(efsr, type = "n_features")

  # stability
  autoplot(efsr, type = "stability")
}

Plots for Filter Scores

Description

Visualizations for mlr3filters::Filter. The argument type controls what kind of plot is drawn. Possible choices are:

  • "barplot" (default): Bar plot of filter scores.

Usage

## S3 method for class 'Filter'
autoplot(object, type = "boxplot", n = Inf, theme = theme_minimal(), ...)

Arguments

object

(mlr3filters::Filter).

type

(character(1)):
Type of the plot. See description.

n

(integer(1))
Only include the first n features with the highest importance. Defaults to all features.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)
  library(mlr3filters)

  task = tsk("mtcars")
  f = flt("correlation")
  f$calculate(task)

  head(fortify(f))
  autoplot(f, n = 5)
}

Plot for Classification Learners

Description

Visualizations for mlr3::LearnerClassif. The argument type controls what kind of plot is drawn. Possible choices are:

  • "prediction" (default): Decision boundary of the learner and the true class labels.

Usage

## S3 method for class 'LearnerClassif'
autoplot(
  object,
  type = "prediction",
  task,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3::LearnerClassif).

type

(character(1)):
Type of the plot. See description.

task

(mlr3::Task)
Train task.

grid_points

(integer(1))
Number of grid points per feature dimension.

expand_range

(numeric(1))
Expand the range of the grid.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
library(mlr3)
library(mlr3viz)

task = tsk("pima")$select(c("age", "pedigree"))
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)

autoplot(learner, type = "prediction", task)
}

Plots for GLMNet Learners

Description

Visualizations for mlr3learners::LearnerClassifGlmnet. The argument type controls what kind of plot is drawn. Possible choices are:

  • "prediction" (default): Decision boundary of the learner and the true class labels.

  • "ggfortify": Visualizes the model using the package ggfortify.

Usage

## S3 method for class 'LearnerClassifCVGlmnet'
autoplot(
  object,
  type = "prediction",
  task = NULL,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

## S3 method for class 'LearnerClassifGlmnet'
autoplot(
  object,
  type = "prediction",
  task = NULL,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

## S3 method for class 'LearnerRegrCVGlmnet'
autoplot(
  object,
  type = "prediction",
  task = NULL,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

## S3 method for class 'LearnerRegrGlmnet'
autoplot(
  object,
  type = "prediction",
  task = NULL,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3learners::LearnerClassifGlmnet | mlr3learners::LearnerRegrGlmnet | mlr3learners::LearnerRegrCVGlmnet | mlr3learners::LearnerRegrCVGlmnet).

type

(character(1)):
Type of the plot. See description.

task

(mlr3::Task)
Train task.

grid_points

(integer(1))
Number of grid points per feature dimension.

expand_range

(numeric(1))
Expand the range of the grid.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

References

Tang Y, Horikoshi M, Li W (2016). “ggfortify: Unified Interface to Visualize Statistical Result of Popular R Packages.” The R Journal, 8(2), 474–485. doi:10.32614/RJ-2016-060.

Examples

## Not run: 
library(mlr3)
library(mlr3viz)
library(mlr3learners)

# classification
task = tsk("sonar")
learner = lrn("classif.glmnet")
learner$train(task)
autoplot(learner, type = "ggfortify")

# regression
task = tsk("mtcars")
learner = lrn("regr.glmnet")
learner$train(task)
autoplot(learner, type = "ggfortify")

## End(Not run)

Plots for Rpart Learners

Description

Visualizations for mlr3::LearnerClassifRpart. The argument type controls what kind of plot is drawn. Possible choices are:

  • "prediction" (default): Decision boundary of the learner and the true class labels.

  • "ggparty": Visualizes the tree using the package ggparty.

Usage

## S3 method for class 'LearnerClassifRpart'
autoplot(
  object,
  type = "prediction",
  task = NULL,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

## S3 method for class 'LearnerRegrRpart'
autoplot(
  object,
  type = "prediction",
  task = NULL,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3::LearnerClassifRpart | mlr3::LearnerRegrRpart).

type

(character(1)):
Type of the plot. See description.

task

(mlr3::Task)
Train task.

grid_points

(integer(1))
Number of grid points per feature dimension.

expand_range

(numeric(1))
Expand the range of the grid.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  # classification
  task = tsk("iris")
  learner = lrn("classif.rpart", keep_model = TRUE)
  learner$train(task)
  autoplot(learner, type = "ggparty")

  # regression
  task = tsk("mtcars")
  learner = lrn("regr.rpart", keep_model = TRUE)
  learner$train(task)
  autoplot(learner, type = "ggparty")
}

Plots for Hierarchical Clustering Learners

Description

Visualizations for hierarchical clusters. The argument type controls what kind of plot is drawn. Possible choices are:

  • "dend" (default): Dendrograms using ggdendro package.

  • "scree": Scree plot that shows the number of possible clusters on the x-axis and the height on the y-axis.

Usage

## S3 method for class 'LearnerClustHierarchical'
autoplot(
  object,
  type = "dend",
  task = NULL,
  theme = theme_minimal(),
  theme_dendro = TRUE,
  ...
)

Arguments

object

(mlr3cluster::LearnerClustAgnes | mlr3cluster::LearnerClustDiana | mlr3cluster::LearnerClustHclust).

type

(character(1)):
Type of the plot. See description.

task

(mlr3::Task)
Optionally, pass the task to add labels of observations to a hclust dendrogram. Labels are set via the row names of the task.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

theme_dendro

(logical(1))
If TRUE (default), the special dendrogram theme from ggdendro package is used in plot "dend". Set to FALSE to use the theme passed in theme.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3cluster)
  library(mlr3viz)

  task = tsk("usarrests")

  # agnes clustering
  learner = lrn("clust.agnes")
  learner$train(task)
  autoplot(learner)

  # diana clustering
  learner = lrn("clust.diana")
  learner$train(task)
  autoplot(learner)

  # hclust clustering
  learner = lrn("clust.hclust")
  learner$train(task)
  autoplot(learner, type = "scree")
}

Plot for Regression Learners

Description

Visualizations for mlr3::LearnerRegr. The argument type controls what kind of plot is drawn. Possible choices are:

  • "prediction" (default): Decision boundary of the learner and the true class labels.

Usage

## S3 method for class 'LearnerRegr'
autoplot(
  object,
  type = "prediction",
  task,
  grid_points = 100L,
  expand_range = 0,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3::LearnerRegr).

type

(character(1)):
Type of the plot. See description.

task

(mlr3::Task)
Train task.

grid_points

(integer(1))
Number of grid points per feature dimension.

expand_range

(numeric(1))
Expand the range of the grid.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
library(mlr3)
library(mlr3viz)

task = tsk("mtcars")$select(c("am", "carb"))
learner = lrn("regr.rpart")
learner$train(task)

autoplot(learner, type = "prediction", task)
}

Plots for Cox Proportional Hazards Learner

Description

Visualizations for mlr3proba::LearnerSurvCoxPH.

The argument type controls what kind of plot is drawn. The only possible choice right now is "ggforest" (default) which is a Forest Plot, using ggforest. This plot displays the estimated hazard ratios (HRs) and their confidence intervals (CIs) for different variables included in the (trained) model.

Usage

## S3 method for class 'LearnerSurvCoxPH'
autoplot(object, type = "ggforest", ...)

Arguments

object

(mlr3proba::LearnerSurvCoxPH).

type

(character(1)):
Type of the plot. See description.

...

Additional parameters passed down to ggforest.

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3proba")) {
  library(mlr3proba)
  library(mlr3viz)

  task = tsk("lung")
  learner = lrn("surv.coxph")
  learner$train(task)
  autoplot(learner)
}

Plots for Optimization Instances

Description

Visualizations for bbotk::OptimInstanceBatchSingleCrit. The argument type controls what kind of plot is drawn. Possible choices are:

  • "marginal" (default): Scatter plots of x versus y. The color of the points shows the batch number.

  • "performance": Scatter plots of batch number versus y

  • "parameter": Scatter plots of batch number versus input. The color of the points shows the y values.

  • "parallel": Parallel coordinates plot. x values are rescaled by (x - mean(x)) / sd(x).

  • "points": Scatter plot of two x dimensions versus. The color of the points shows the y values.

  • "surface": Surface plot of two x dimensions versus y values. The y values are interpolated with the supplied mlr3::Learner.

  • "pairs": Plots all x and y values against each other.

  • "incumbent": Plots the incumbent versus the number of configurations.

Usage

## S3 method for class 'OptimInstanceBatchSingleCrit'
autoplot(
  object,
  type = "marginal",
  cols_x = NULL,
  trafo = FALSE,
  learner = mlr3::lrn("regr.ranger"),
  grid_resolution = 100,
  batch = NULL,
  theme = theme_minimal(),
  ...
)

Arguments

object

(bbotk::OptimInstanceBatchSingleCrit).

type

(character(1)):
Type of the plot. See description.

cols_x

(character())
Column names of x values. By default, all untransformed x values from the search space are plotted. Transformed hyperparameters are prefixed with x_domain_.

trafo

(logical(1))
If FALSE (default), the untransformed x values are plotted. If TRUE, the transformed x values are plotted.

learner

(mlr3::Learner)
Regression learner used to interpolate the data of the surface plot.

grid_resolution

(numeric())
Resolution of the surface plot.

batch

(integer())
The batch number(s) to limit the plot to. The default is all batches.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3") && requireNamespace("bbotk") && requireNamespace("patchwork")) {
  library(bbotk)
  library(paradox)

  fun = function(xs) {
    c(y = -(xs[[1]] - 2)^2 - (xs[[2]] + 3)^2 + 10)
  }
  domain = ps(
    x1 = p_dbl(-10, 10),
    x2 = p_dbl(-5, 5)
  )
  codomain = ps(
    y = p_dbl(tags = "maximize")
  )
  obfun = ObjectiveRFun$new(
    fun = fun,
    domain = domain,
    codomain = codomain
  )

  instance = oi(objective = obfun, terminator = trm("evals", n_evals = 20))

  optimizer = opt("random_search", batch_size = 2)
  optimizer$optimize(instance)

  # plot y versus batch number
  print(autoplot(instance, type = "performance"))

  # plot x1 values versus performance
  print(autoplot(instance, type = "marginal", cols_x = "x1"))

  # plot parallel coordinates plot
  print(autoplot(instance, type = "parallel"))

  # plot pairs
  print(autoplot(instance, type = "pairs"))

  # plot incumbent
  print(autoplot(instance, type = "incumbent"))
}

Plots for Classification Predictions

Description

Visualizations for mlr3::PredictionClassif. The argument type controls what kind of plot is drawn. Possible choices are:

  • "stacked" (default): Stacked barplot of true and estimated class labels.

  • "roc": ROC curve (1 - specificity on x, sensitivity on y). Requires package precrec.

  • "prc": Precision recall curve. Requires package precrec.

  • "threshold": Systematically varies the threshold of the mlr3::PredictionClassif object and plots the resulting performance as returned by measure.

Usage

## S3 method for class 'PredictionClassif'
autoplot(
  object,
  type = "stacked",
  measure = NULL,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3::PredictionClassif).

type

(character(1)):
Type of the plot. See description.

measure

(mlr3::Measure)
Performance measure to use.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

References

Saito T, Rehmsmeier M (2017). “Precrec: fast and accurate precision-recall and ROC curve calculations in R.” Bioinformatics, 33(1), 145-147. doi:10.1093/bioinformatics/btw570.

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  task = tsk("spam")
  learner = lrn("classif.rpart", predict_type = "prob")
  object = learner$train(task)$predict(task)

  head(fortify(object))
  autoplot(object)
  autoplot(object, type = "roc")
  autoplot(object, type = "prc")
}

Plots for Cluster Predictions

Description

Visualizations for mlr3cluster::PredictionClust. The argument type controls what kind of plot is drawn. Possible choices are:

  • "scatter" (default): scatterplot with correlation values and colored cluster assignments.

  • "sil": Silhouette plot with mean silhouette value as the reference line. Requires package ggfortify.

  • "pca": Perform PCA on data and color code cluster assignments. Inspired by and uses ggfortify::autoplot.kmeans.

Usage

## S3 method for class 'PredictionClust'
autoplot(
  object,
  task,
  row_ids = NULL,
  type = "scatter",
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3cluster::PredictionClust).

task

(mlr3cluster::TaskClust).

row_ids

(integer()) Row ids to subset task data to ensure that only the data used to make predictions are shown in plots.

type

(character(1)):
Type of the plot. See description.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

References

Tang Y, Horikoshi M, Li W (2016). “ggfortify: Unified Interface to Visualize Statistical Result of Popular R Packages.” The R Journal, 8(2), 474–485. doi:10.32614/RJ-2016-060.

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3cluster)
  library(mlr3viz)

  task = tsk("usarrests")
  learner = lrn("clust.kmeans", centers = 3)
  object = learner$train(task)$predict(task)

  head(fortify(object))
  autoplot(object, task)
}

Plots for Regression Predictions

Description

Visualizations for mlr3::PredictionRegr. The argument type controls what kind of plot is drawn. Possible choices are:

  • "xy" (default): Scatterplot of "true" response vs. "predicted" response. By default a linear model is fitted via geom_smooth(method = "lm") to visualize the trend between x and y (by default colored blue). In addition geom_abline() with slope = 1 is added to the plot. Note that geom_smooth() and geom_abline() may overlap, depending on the given data.

  • "histogram": Histogram of residuals: r=yy^r = y - \hat{y}.

  • "residual": Plot of the residuals, with the response y^\hat{y} on the "x" and the residuals on the "y" axis. By default a linear model is fitted via geom_smooth(method = "lm") to visualize the trend between x and y (by default colored blue).

  • ⁠"confidence⁠: Scatterplot of "true" response vs. "predicted" response with confidence intervals. Error bars calculated as object$reponse +- quantile * object$se and so only possible with predict_type = "se". geom_abline() with slope = 1 is added to the plot.

Usage

## S3 method for class 'PredictionRegr'
autoplot(
  object,
  type = "xy",
  binwidth = NULL,
  theme = theme_minimal(),
  quantile = 1.96,
  ...
)

Arguments

object

(mlr3::PredictionRegr).

type

(character(1)):
Type of the plot. See description.

binwidth

(integer(1))
Width of the bins for the histogram.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

quantile

(numeric(1))
Quantile multiplier for standard errors for type="confidence". Default 1.96.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  task = tsk("mtcars")
  learner = lrn("regr.rpart")
  object = learner$train(task)$predict(task)

  head(fortify(object))
  autoplot(object)
  autoplot(object, type = "histogram", binwidth = 1)
  autoplot(object, type = "residual")

 if (requireNamespace("mlr3learners")) {
  library(mlr3learners)
  learner = lrn("regr.ranger", predict_type = "se")
  object = learner$train(task)$predict(task)
  autoplot(object, type = "confidence")
 }
}

Plots for Resample Results

Description

Visualizations for mlr3::ResampleResult. The argument type controls what kind of plot is drawn. Possible choices are:

  • "boxplot" (default): Boxplot of performance measures.

  • "histogram": Histogram of performance measures.

  • "roc": ROC curve (1 - specificity on x, sensitivity on y). The predictions of the individual mlr3::Resamplings are merged prior to calculating the ROC curve (micro averaged). Requires package precrec.

  • "prc": Precision recall curve. See "roc".

  • "prediction": Plots the learner prediction for a grid of points. Needs models to be stored. Set store_models = TRUE for ⁠[mlr3::resample]⁠. For classification, we support tasks with exactly two features and learners with ⁠predict_type=⁠ set to "response" or "prob". For regression, we support tasks with one or two features. For tasks with one feature we can print confidence bounds if the predict type of the learner was set to "se". For tasks with two features the predict type will be ignored.

Usage

## S3 method for class 'ResampleResult'
autoplot(
  object,
  type = "boxplot",
  measure = NULL,
  predict_sets = "test",
  binwidth = NULL,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3::ResampleResult).

type

(character(1)):
Type of the plot. See description.

measure

(mlr3::Measure)
Performance measure to use.

predict_sets

(character())
Only for type set to "prediction". Which points should be shown in the plot? Can be a subset of ("train", "test") or empty.

binwidth

(integer(1))
Width of the bins for the histogram.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

References

Saito T, Rehmsmeier M (2017). “Precrec: fast and accurate precision-recall and ROC curve calculations in R.” Bioinformatics, 33(1), 145-147. doi:10.1093/bioinformatics/btw570.

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  task = tsk("sonar")
  learner = lrn("classif.rpart", predict_type = "prob")
  resampling = rsmp("cv", folds = 3)
  object = resample(task, learner, resampling)

  head(fortify(object))

  # Default: boxplot
  autoplot(object)

  # Histogram
  autoplot(object, type = "histogram", bins = 30)

  # ROC curve, averaged over resampling folds:
  autoplot(object, type = "roc")

  # ROC curve of joint prediction object:
  autoplot(object$prediction(), type = "roc")

  # Precision Recall Curve
  autoplot(object, type = "prc")

  # Prediction Plot
  task = tsk("iris")$select(c("Sepal.Length", "Sepal.Width"))
  resampling = rsmp("cv", folds = 3)
  object = resample(task, learner, resampling, store_models = TRUE)
  autoplot(object, type = "prediction")
}

Plots for Classification Tasks

Description

Visualizations for mlr3::TaskClassif. The argument type controls what kind of plot is drawn. Possible choices are:

  • "target" (default): Bar plot of the target variable (default).

  • "duo": Passes data to GGally::ggduo(). columnsX is the target and columnsY are the features.

  • "pairs": Passes data to GGally::ggpairs(). Color is set to target column.

Usage

## S3 method for class 'TaskClassif'
autoplot(object, type = "target", theme = theme_minimal(), ...)

Arguments

object

(mlr3::TaskClassif).

type

(character(1)):
Type of the plot. See description.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  task = tsk("iris")

  head(fortify(task))
  autoplot(task)
  autoplot(task$clone()$select(c("Sepal.Length", "Sepal.Width")),
    type = "pairs")
  autoplot(task, type = "duo")
}

Plots for Clustering Tasks

Description

Visualizations for mlr3cluster::TaskClust. The argument type controls what kind of plot is drawn. Possible choices are:

Usage

## S3 method for class 'TaskClust'
autoplot(object, type = "pairs", theme = theme_minimal(), ...)

Arguments

object

(mlr3cluster::TaskClust).

type

(character(1)):
Type of the plot. See description.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3cluster)
  library(mlr3viz)

  task = mlr_tasks$get("usarrests")

  head(fortify(task))
  autoplot(task)
}

Plots for Regression Tasks

Description

Visualizations for mlr3::TaskRegr. The argument type controls what kind of plot is drawn. Possible choices are:

  • "target" (default): Box plot of the target variable.

  • "pairs": Passes data to GGally::ggpairs(). Color is set to target column.

Usage

## S3 method for class 'TaskRegr'
autoplot(object, type = "target", theme = theme_minimal(), ...)

Arguments

object

(mlr3::TaskRegr).

type

(character(1)):
Type of the plot. See description.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  task = tsk("mtcars")
  task$select(c("am", "carb"))

  head(fortify(task))
  autoplot(task)
  autoplot(task, type = "pairs")
}

Plots for Tuning Instances

Description

Visualizations for mlr3tuning::TuningInstanceBatchSingleCrit. The argument type controls what kind of plot is drawn. Possible choices are:

  • "marginal" (default): Scatter plots of x versus y. The color of the points shows the batch number.

  • "performance": Scatter plots of batch number versus y

  • "parameter": Scatter plots of batch number versus input. The color of the points shows the y values.

  • "parallel": Parallel coordinates plot. hyperparameters are rescaled by (x - mean(x)) / sd(x).

  • "points": Scatter plot of two x dimensions versus. The color of the points shows the y values.

  • "surface": Surface plot of two x dimensions versus y values. The y values are interpolated with the supplied mlr3::Learner.

  • "pairs": Plots all x and y values against each other.

  • "incumbent": Plots the incumbent versus the number of configurations.

Usage

## S3 method for class 'TuningInstanceBatchSingleCrit'
autoplot(
  object,
  type = "marginal",
  cols_x = NULL,
  trafo = FALSE,
  learner = mlr3::lrn("regr.ranger"),
  grid_resolution = 100,
  theme = theme_minimal(),
  ...
)

Arguments

object

(mlr3tuning::TuningInstanceBatchSingleCrit.

type

(character(1)):
Type of the plot. See description.

cols_x

(character())
Column names of hyperparameters. By default, all untransformed hyperparameters are plotted. Transformed hyperparameters are prefixed with x_domain_.

trafo

(logical(1))
If FALSE (default), the untransformed hyperparameters are plotted. If TRUE, the transformed hyperparameters are plotted.

learner

(mlr3::Learner)
Regression learner used to interpolate the data of the surface plot.

grid_resolution

(numeric())
Resolution of the surface plot.

theme

(ggplot2::theme())
The ggplot2::theme_minimal() is applied by default to all plots.

...

(ignored).

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3tuning") && requireNamespace("patchwork")) {
  library(mlr3tuning)

  learner = lrn("classif.rpart")
  learner$param_set$values$cp = to_tune(0.001, 0.1)
  learner$param_set$values$minsplit = to_tune(1, 10)

  instance = ti(
    task = tsk("iris"),
    learner = learner,
    resampling = rsmp("holdout"),
    measure = msr("classif.ce"),
    terminator = trm("evals", n_evals = 10))

  tuner = tnr("random_search")

  tuner$optimize(instance)

  # plot performance versus batch number
  autoplot(instance, type = "performance")

  # plot cp values versus performance
  autoplot(instance, type = "marginal", cols_x = "cp")

  # plot transformed parameter values versus batch number
  autoplot(instance, type = "parameter", trafo = TRUE)

  # plot parallel coordinates plot
  autoplot(instance, type = "parallel")

  # plot pairs
  autoplot(instance, type = "pairs")
}

Plots for Learner Predictions

Description

Visualizations for the mlr3::Prediction of a single mlr3::Learner on a single mlr3::Task.

  • For classification we support tasks with exactly two features and learners with predict_type set to "response" or "prob".

  • For regression we support tasks with one or two features. For tasks with one feature we print confidence bounds if the predict type of the learner was set to "se". For tasks with two features the predict type will be ignored.

Note that this function is a wrapper around autoplot.ResampleResult() for a temporary mlr3::ResampleResult using mlr3::mlr_resamplings_holdout with ratio 1 (all observations in the training set).

Usage

plot_learner_prediction(learner, task, grid_points = 100L, expand_range = 0)

Arguments

learner

(mlr3::Learner).

task

(mlr3::Task).

grid_points

(integer(1))
Resolution of the grid. For factors, ordered and logicals this value is ignored.

expand_range

(numeric(1))
Expand the prediction range for numerical features.

Value

ggplot2::ggplot().

Examples

if (requireNamespace("mlr3")) {
  library(mlr3)
  library(mlr3viz)

  task = mlr3::tsk("pima")$select(c("age", "glucose"))
  learner = lrn("classif.rpart", predict_type = "prob")
  p = plot_learner_prediction(learner, task)
  print(p)
}

Generates a data.table of evenly distributed points.

Description

For each point we have the predicted class / regression value in column response. If the learner predicts probabilities, a column ".prob.response" is added that contains the probability of the predicted class

Usage

predict_grid(learners, task, grid_points, expand_range)

Arguments

learners

list of trained learners, each learner belongs to one resampling iteration

task

the task all learners are trained on

grid_points

(int): see sequenize

expand_range

see sequenize