Getting Started with ensembleML

Overview

ensembleML provides a single, consistent API for ensemble machine learning in R. Regardless of which algorithm you choose, the core workflow is always:

em_fit()  ->  em_predict()  ->  em_evaluate()

Advanced usage adds:

em_cv()        # k-fold cross-validation (stability estimates)
em_tune()      # grid-search hyperparameter optimisation
em_compare()   # side-by-side algorithm comparison
em_importance() # feature importance
em_partial()   # partial dependence plots
em_confusion() # confusion matrix heatmap
em_calibration() # calibration / reliability diagram
em_residuals() # regression diagnostics

1. Train a model

data(iris)
set.seed(42)
idx   <- sample(nrow(iris), 120)
train <- iris[idx,  ]
test  <- iris[-idx, ]

rf <- em_fit(Species ~ ., data = train, method = "random_forest",
             verbose = TRUE)
#> [ensembleML] task auto-detected as 'classification'
#> 
#> ╭────────────────────────────────────────────────────╮
#> │  Algorithm:           random_forest               │
#> │  Task:                classification              │
#> │  Response:            Species                     │
#> │  Classes:             setosa, versicolor, virginica│
#> │  Predictors:          4  (Sepal.Length, Sepal.Width, Petal.Length, …)│
#> │  Training n:          120                         │
#> │  Fit time:            0.025 sec                   │
#> │  Train metrics:       accuracy=1.0000  kappa=1.0000  precision=1.0000  recall=1.0000  f1=1.0000  auc=NA│
#> │  ⚠  Use em_evaluate() on held-out data         │
#> ╰────────────────────────────────────────────────────╯

Switching algorithms requires changing a single argument:

xgb <- em_fit(Species ~ ., data = train, method = "xgboost")
ada <- em_fit(Species ~ ., data = train, method = "adaboost")
bag <- em_fit(Species ~ ., data = train, method = "bagging")

2. Predict

preds <- em_predict(rf, newdata = test)
head(preds)
#>      7     11     12     19     23     28 
#> setosa setosa setosa setosa setosa setosa 
#> Levels: setosa versicolor virginica

Class probabilities:

probs <- em_predict(rf, newdata = test, type = "prob")
head(probs, 3)
#>    setosa versicolor virginica
#> 7   1.000      0.000         0
#> 11  0.998      0.002         0
#> 12  1.000      0.000         0

3. Evaluate

em_evaluate(rf, newdata = test)
#>  accuracy     kappa precision    recall        f1       auc 
#>    0.9333    0.8997    0.9364    0.9364    0.9364        NA

Select specific metrics:

em_evaluate(rf, newdata = test, metrics = c("accuracy", "f1", "kappa"))
#> accuracy       f1    kappa 
#>   0.9333   0.9364   0.8997

4. Cross-validation

Use em_cv() to get mean +/- SD across folds before committing to a method:

cv_res <- em_cv(Species ~ ., data = iris, method = "random_forest",
                cv_folds = 5, repeats = 3)
cv_res$summary
em_plot_cv(cv_res, metric = "accuracy")

5. Tune hyperparameters

grid <- list(ntree = c(100, 300, 500), mtry = c(1, 2, 3))

tuned <- em_tune(
  Species ~ ., data = train, method = "random_forest",
  param_grid = grid, cv_folds = 5
)

tuned$best_params
tuned$best_score
tuned$results

6. Compare algorithms

cmp <- em_compare(Species ~ ., train = train, test = test)
cmp$table

7. Feature importance

em_importance(rf, top_n = 4)


8. Partial dependence

em_partial(rf, data = train, feature = "Petal.Length")

9. Confusion matrix

em_confusion(rf, newdata = test)
em_confusion(rf, newdata = test, normalise = TRUE)

10. Regression example

Everything works identically for numeric responses:

set.seed(7)
reg_data  <- data.frame(
  x1 = rnorm(200), x2 = rnorm(200),
  y  = 3 + 2 * rnorm(200) + rnorm(200))
reg_train <- reg_data[1:160, ]
reg_test  <- reg_data[161:200, ]

reg_model <- em_fit(y ~ ., data = reg_train, method = "random_forest")
#> [ensembleML] task auto-detected as 'regression'
em_evaluate(reg_model, reg_test)
#>    rmse     mae    mape     rsq adj_rsq 
#>  2.4320  1.8556 88.1007 -0.2193 -0.2852
em_residuals(reg_model, reg_test)
#> `geom_smooth()` using formula = 'y ~ x'


Citation

If you use ensembleML in published work, please cite it:

citation("ensembleML")

The individual algorithms should also be cited — see citation("ensembleML") for the full list of references.


Session info

sessionInfo()
#> R version 4.6.0 (2026-04-24)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.4 LTS
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
#> LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so;  LAPACK version 3.12.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
#>  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
#>  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
#> 
#> time zone: Etc/UTC
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ensembleML_0.2.5 rmarkdown_2.31  
#> 
#> loaded via a namespace (and not attached):
#>  [1] Matrix_1.7-5         randomForest_4.7-1.2 gtable_0.3.6        
#>  [4] jsonlite_2.0.0       dplyr_1.2.1          compiler_4.6.0      
#>  [7] tidyselect_1.2.1     gridExtra_2.3        jquerylib_0.1.4     
#> [10] splines_4.6.0        scales_1.4.0         yaml_2.3.12         
#> [13] fastmap_1.2.0        lattice_0.22-9       ggplot2_4.0.3       
#> [16] R6_2.6.1             labeling_0.4.3       generics_0.1.4      
#> [19] knitr_1.51           tibble_3.3.1         maketools_1.3.2     
#> [22] bslib_0.11.0         pillar_1.11.1        RColorBrewer_1.1-3  
#> [25] rlang_1.2.0          cachem_1.1.0         xfun_0.58           
#> [28] sass_0.4.10          sys_3.4.3            S7_0.2.2            
#> [31] otel_0.2.0           cli_3.6.6            mgcv_1.9-4          
#> [34] withr_3.0.2          magrittr_2.0.5       digest_0.6.39       
#> [37] grid_4.6.0           nlme_3.1-169         lifecycle_1.0.5     
#> [40] vctrs_0.7.3          evaluate_1.0.5       glue_1.8.1          
#> [43] farver_2.1.2         buildtools_1.0.0     tools_4.6.0         
#> [46] pkgconfig_2.0.3      htmltools_0.5.9