Package: tabnet 0.9.0

Christophe Regouby

tabnet: Fit 'TabNet' Models for Classification and Regression

Implements the 'TabNet' model by Sercan O. Arik et al. (2019) <doi:10.48550/arXiv.1908.07442> with 'Coherent Hierarchical Multi-label Classification Networks' by Giunchiglia et al. <doi:10.48550/arXiv.2010.10151> and provides a consistent interface for fitting and creating predictions. It's also fully compatible with the 'tidymodels' ecosystem.

Authors:Daniel Falbel [aut], RStudio [cph], Christophe Regouby [cre, ctb], Egill Fridgeirsson [ctb], Philipp Haarmeyer [ctb], Sven Verweij [ctb]

tabnet_0.9.0.tar.gz
tabnet_0.9.0.tar.gz(r-4.7-any)tabnet_0.9.0.tar.gz(r-4.6-any)
tabnet_0.9.0.tgz(r-4.6-emscripten)
manual.pdf |manual.html
card.svg |card.png
tabnet/json (API)
NEWS

# Install 'tabnet' in R:
install.packages('tabnet', repos = c('https://cran.r-universe.dev', 'https://cloud.r-project.org'))

Bug tracker:https://github.com/mlverse/tabnet/issues

Pkgdown/docs site:https://mlverse.github.io

On CRAN:

Conda:

5.42 score 96 scripts 4.6k downloads 40 exports 117 dependencies

Last updated from:6f51ecf80f. Checks:4 OK. Indexed: no.

TargetResultTimeFilesSyslog
linux-devel-x86_64OK233
source / vignettesOK242
linux-release-x86_64OK229
wasm-releaseOK312

Exports:%>%attention_widthaugmentbuild_ancestor_matrix_from_outcomescat_emb_dimcheck_compliant_nodecheckpoint_epochsdecision_widthdrop_lastencoder_activationentmaxentmax15feature_reusagelr_schedulermask_typemlp_activationmlp_hidden_multipliermomentumnn_aum_lossnn_mc_lossnnf_mc_lossnnf_multilabel_one_hotnode_to_dfnum_independentnum_independent_decodernum_sharednum_shared_decodernum_stepsoptimizerpenaltysparsemaxsparsemax15tabnettabnet_configtabnet_explaintabnet_fittabnet_nntabnet_pretrainverbosevirtual_batch_size

Dependencies:base64encbitbit64bslibcachemcallrclasscliclockcodetoolscorocpp11crayondata.tabledata.treedescdiagramdialsDiceDesigndigestdplyrevaluatefarverfastmapfontawesomefsfurrrfuturefuture.applyGauProgenericsggplot2globalsgluegowergtablehardhathighrhmshtmltoolsipredisobandjquerylibjsonliteKernSmoothknitrlabelinglatticelavalbfgslifecyclelistenvlubridatemagrittrMASSMatrixmemoisemimemixoptmodelenvnnetnumDerivotelparallellyparsnippillarpkgconfigprettyunitsprocessxprodlimprogressprogressrpspurrrR6rappdirsRColorBrewerRcppRcppArmadillorecipesrlangrmarkdownrpartrsampleS7safetensorssassscalessfdshapeslidersparsevctrssplitfngrSQUAREMstringistringrsurvivaltailortibbletidyrtidyselecttimechangetimeDatetinytextorchtunetzdbutf8vctrsviridisLitewarpwithrworkflowsxfunyamlyardstickzeallot

Fitting tabnet with tidymodels

Rendered fromtidymodels-interface.Rmdusingknitr::rmarkdownon Jun 12 2026.

Last update: 2026-01-31
Started: 2021-01-14

Hierarchical Classification

Rendered fromHierarchical_classification.Rmdusingknitr::rmarkdownon Jun 12 2026.

Last update: 2026-06-12
Started: 2023-12-06

Interpretation tools

Rendered frominterpretation.Rmdusingknitr::rmarkdownon Jun 12 2026.

Last update: 2026-01-31
Started: 2021-01-14

Self-supervised training and fine-tuning

Rendered fromselfsupervised_training.Rmdusingknitr::rmarkdownon Jun 12 2026.

Last update: 2026-06-12
Started: 2023-12-06

Training a Tabnet model from missing-values dataset

Rendered fromMissing_data_predictors.Rmdusingknitr::rmarkdownon Jun 12 2026.

Last update: 2025-04-17
Started: 2023-05-11

Using ROC AUM loss for imbalanced binary classification

Rendered fromaum_loss.Rmdusingknitr::rmarkdownon Jun 12 2026.

Last update: 2026-01-31
Started: 2026-01-31

Readme and manuals

Help Manual

Help pageTopics
Parameters for the tabnet modelattention_width decision_width feature_reusage mask_type momentum num_independent num_shared num_steps
Plot tabnet_explain mask importance heatmapautoplot.tabnet_explain
Plot tabnet_fit model loss along epochsautoplot.tabnet_fit autoplot.tabnet_pretrain
Build ancestor matrix aligned with observed outcome classesbuild_ancestor_matrix_from_outcomes
Non-tunable parameters for the tabnet modelcat_emb_dim checkpoint_epochs drop_last encoder_activation lr_scheduler mlp_activation mlp_hidden_multiplier num_independent_decoder num_shared_decoder optimizer penalty verbose virtual_batch_size
Check that Node object names are compliantcheck_compliant_node
Alpha-entmaxentmax entmax15
Apply hierarchy constraints via max-pooling over descendants (MCM)get_constr_output
Optimal threshold (tau) computation for 1.5-entmaxget_tau
AUM lossnn_aum_loss
Max-Constraint Margin Loss (module)nn_mc_loss
Prune top layer(s) of a tabnet networknn_prune_head.tabnet_fit nn_prune_head.tabnet_pretrain
Max-Constraint Margin Loss (functional)nnf_mc_loss
Convert class_id tensor to binary one-hot tensornnf_multilabel_one_hot
Turn a Node object into predictor and outcome.node_to_df
Predict using 'tabnet'augment.tabnet_fit predict.tabnet_fit
Sparsemaxsparsemax sparsemax15
Parsnip compatible tabnet modeltabnet
Configuration for TabNet modelstabnet_config
Interpretation metrics from a TabNet modeltabnet_explain tabnet_explain.default tabnet_explain.model_fit tabnet_explain.tabnet_fit tabnet_explain.tabnet_pretrain
Tabnet modeltabnet_fit tabnet_fit.data.frame tabnet_fit.default tabnet_fit.formula tabnet_fit.Node tabnet_fit.recipe
TabNet Model Architecturetabnet_nn
Tabnet modeltabnet_pretrain tabnet_pretrain.data.frame tabnet_pretrain.default tabnet_pretrain.formula tabnet_pretrain.Node tabnet_pretrain.recipe