| Title: | Random Forest-Based Multistate Survival Analysis |
|---|---|
| Description: | Fits cause-specific random survival forests for flexible multistate survival analysis with covariate-adjusted transition probabilities computed via product-integral. State transitions are modeled by random forests. Subject-specific transition probability matrices are assembled from predicted cumulative hazards using the product-integral formula. Also provides a standalone Aalen-Johansen nonparametric estimator as a covariate-free baseline. Supports arbitrary state spaces with any number of states (three or more) and any set of allowed transitions, applicable to clinical trials, disease progression, reliability engineering, and other domains where subjects move among discrete states over time. Provides per-transition feature importance, bias-variance diagnostics, and comprehensive visualizations. Handles right censoring and competing transitions. Methods are described in Ishwaran et al. (2008) <doi:10.1214/08-AOAS169> for random survival forests, Putter et al. (2007) <doi:10.1002/sim.2712> for multistate competing risks decomposition, and Aalen and Johansen (1978) <https://www.jstor.org/stable/4615704> for the nonparametric estimator. |
| Authors: | Yiqing Chen [aut, cre] |
| Maintainer: | Yiqing Chen <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 0.1.2 |
| Built: | 2026-05-10 08:58:01 UTC |
| Source: | https://github.com/cran/RFmstate |
Fits cause-specific random survival forests for flexible multistate survival analysis with covariate-adjusted transition probabilities computed via product-integral. For each transient state, competing transitions are modeled by separate random forests, and patient-specific transition probability matrices are assembled from the predicted cumulative hazards using the product-integral formula. Also provides a standalone Aalen-Johansen nonparametric estimator as a covariate-free baseline. Supports arbitrary state spaces with any number of states (three or more) and any set of allowed transitions, applicable to clinical trials, disease progression, reliability engineering, and other domains where subjects move among discrete states over time. The package provides:
State space and transition structure definition
Wide-to-long data conversion for multistate counting processes
Cause-specific random forest fitting per origin state
Transition probability matrices via product-integral of predicted cumulative hazards
Aalen-Johansen nonparametric estimation (covariate-free baseline)
Per-transition feature importance
Bias-variance diagnostics with Brier score and C-index
Comprehensive visualizations
Maintainer: Yiqing Chen [email protected]
Computes nonparametric estimates of transition probabilities using the Aalen-Johansen estimator via Nelson-Aalen cumulative hazard increments and product-integral construction.
aalen_johansen(msdata, s = 0)aalen_johansen(msdata, s = 0)
msdata |
An |
s |
Numeric, the starting time for transition probabilities (default 0). |
The Aalen-Johansen estimator generalizes the Kaplan-Meier estimator to multistate models under the Markov assumption and independent right censoring. It provides population-level transition probability matrices without covariate adjustment.
An object of class "aj_estimate" containing:
Numeric vector of unique event times.
List of transition probability matrices P(s,t) at each event time.
Matrix of state occupation probabilities over time. Rows are time points, columns are states.
List of Nelson-Aalen cumulative hazard matrices.
List of hazard increment matrices at each event time.
List of Greenwood-type variance estimates for state occupation probabilities.
Matrix of at-risk counts over time.
Data frame of event counts per transition.
The multistate structure used.
The starting time.
ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) aj <- aalen_johansen(msdata) print(aj)ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) aj <- aalen_johansen(msdata) print(aj)
A convenience function that creates the standard clinical trial multistate structure with states: Baseline, Responded, Unresponded, Stabilized, Progressed, Death.
clinical_states()clinical_states()
An mstate_structure object.
ms <- clinical_states() print(ms)ms <- clinical_states() print(ms)
Given cause-specific cumulative hazard functions for all transitions, computes the full transition probability matrix P(s,t) using the product-integral formula.
compute_trans_prob(cum_hazards, structure, s = 0, times = NULL)compute_trans_prob(cum_hazards, structure, s = 0, times = NULL)
cum_hazards |
A list of cumulative hazard data frames, one per
transition. Each should have columns |
structure |
An |
s |
Numeric, starting time (default 0). |
times |
Numeric vector of times at which to evaluate P(s,t). If
|
An object of class "trans_prob" containing:
Evaluation times.
List of transition probability matrices at each time.
Matrix of state occupation probabilities.
The multistate structure.
Starting time.
Defines the state space, absorbing states, and allowed transitions for a multistate model.
define_multistate(state_names, absorbing, transitions)define_multistate(state_names, absorbing, transitions)
state_names |
Character vector of state names. |
absorbing |
Character vector of absorbing state names (must be a subset
of |
transitions |
A named list where each element name is an origin state and the value is a character vector of destination states reachable from that origin. Absorbing states should not appear as list names. |
An object of class "mstate_structure" containing:
Character vector of all state names.
Integer, number of states.
Character vector of absorbing states.
Character vector of transient (non-absorbing) states.
Named list of allowed transitions.
Integer matrix where entry [i,j] is the
transition number for allowed transition i->j, or NA if not
allowed.
Total number of allowed transitions.
Data frame listing all transitions with columns
trans_id, from, to.
ms <- define_multistate( state_names = c("Baseline", "Responded", "Progressed", "Death"), absorbing = "Death", transitions = list( Baseline = c("Responded", "Progressed", "Death"), Responded = c("Progressed", "Death"), Progressed = c("Death") ) ) print(ms)ms <- define_multistate( state_names = c("Baseline", "Responded", "Progressed", "Death"), absorbing = "Death", transitions = list( Baseline = c("Responded", "Progressed", "Death"), Responded = c("Progressed", "Death"), Progressed = c("Death") ) ) print(ms)
Computes diagnostic measures including OOB-based prediction error, Brier score, concordance index, and bias-variance decomposition for each transition-specific model.
diagnose(object, ...) ## S3 method for class 'rfmstate' diagnose(object, eval_times = NULL, ...)diagnose(object, ...) ## S3 method for class 'rfmstate' diagnose(object, eval_times = NULL, ...)
object |
A fitted |
... |
Ignored. |
eval_times |
Numeric vector of times at which to evaluate
diagnostics. If |
The bias-variance decomposition uses OOB predictions from the random forest ensemble. For each transition:
Bias: systematic difference between predicted and observed survival.
Variance: variability of predictions across trees (estimated from tree-level OOB predictions when available).
Brier score: integrated prediction error combining bias and variance.
C-index: concordance between predicted risk and observed event ordering.
An object of class "rfmstate_diag" containing:
Data frame of OOB prediction errors per transition.
List of time-dependent Brier score components per transition.
Data frame of concordance indices per transition.
Data frame of bias-variance decomposition per transition.
Evaluation times used.
ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) diag <- diagnose(fit) print(diag)ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) diag <- diagnose(fit) print(diag)
Extracts and organizes variable importance scores from the fitted random forest models for each transition.
importance(object, ...)importance(object, ...)
object |
A fitted |
... |
Ignored. |
An object of class "rfmstate_importance" containing:
Data frame with columns variable,
from, to, importance.
Matrix with variables as rows and transitions as columns.
Covariate names.
Character vector of transition labels.
ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) imp <- importance(fit) print(imp)ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) imp <- importance(fit) print(imp)
Draws a state transition diagram with event counts annotated on edges. Uses a layered layout that adapts to any number of states and automatically routes arrows around intermediate state boxes using Bezier curves when needed.
plot_transition_diagram( structure, msdata = NULL, col = NULL, main = "Transition Diagram", ... )plot_transition_diagram( structure, msdata = NULL, col = NULL, main = "Transition Diagram", ... )
structure |
An |
msdata |
Optional |
col |
Node colors. Default uses the standard palette. |
main |
Title. |
... |
Ignored. |
No return value, called for its side effect of producing a plot.
ms <- clinical_states() plot_transition_diagram(ms)ms <- clinical_states() plot_transition_diagram(ms)
Visualizes state occupation probabilities and transition probabilities from the Aalen-Johansen estimator.
## S3 method for class 'aj_estimate' plot( x, type = c("state_occupation", "stacked_transition_prob", "cumulative_hazard", "transition_intensity"), states = NULL, ci = TRUE, col = NULL, main = NULL, ... )## S3 method for class 'aj_estimate' plot( x, type = c("state_occupation", "stacked_transition_prob", "cumulative_hazard", "transition_intensity"), states = NULL, ci = TRUE, col = NULL, main = NULL, ... )
x |
An |
type |
Character, one of |
states |
Character vector of states to plot (default: all).
For |
ci |
Logical, whether to show confidence intervals (default
|
col |
Colors for each state/transition. If |
main |
Title (default: auto-generated). |
... |
Additional arguments passed to |
The input x object, returned invisibly. Called for its
side effect of producing a plot.
Visualizes diagnostic measures including Brier score curves, concordance indices, and bias-variance decomposition.
## S3 method for class 'rfmstate_diag' plot( x, type = c("brier", "concordance", "bias_variance"), col = NULL, main = NULL, ... )## S3 method for class 'rfmstate_diag' plot( x, type = c("brier", "concordance", "bias_variance"), col = NULL, main = NULL, ... )
x |
An |
type |
Character, one of |
col |
Colors. |
main |
Title. |
... |
Additional arguments. |
The input x object, returned invisibly. Called for its
side effect of producing a plot.
Visualizes per-transition feature importance as a grouped barplot or heatmap.
## S3 method for class 'rfmstate_importance' plot(x, type = c("barplot", "heatmap"), col = NULL, main = NULL, ...)## S3 method for class 'rfmstate_importance' plot(x, type = c("barplot", "heatmap"), col = NULL, main = NULL, ...)
x |
An |
type |
Character, one of |
col |
Colors. |
main |
Title. |
... |
Additional arguments. |
The input x object, returned invisibly. Called for its
side effect of producing a plot.
Visualizes predicted state occupation probabilities and transition probabilities for individual patients.
## S3 method for class 'rfmstate_pred' plot( x, type = c("state_occupation", "transition_prob"), subject = 1L, col = NULL, main = NULL, ... )## S3 method for class 'rfmstate_pred' plot( x, type = c("state_occupation", "transition_prob"), subject = 1L, col = NULL, main = NULL, ... )
x |
An |
type |
Character, one of |
subject |
Integer, which subject to plot (default 1). Use 0 for mean across all subjects. |
col |
Colors. If |
main |
Title. |
... |
Additional arguments passed to |
The input x object, returned invisibly. Called for its
side effect of producing a plot.
Predicts patient-specific transition probability matrices and state occupation probabilities using fitted random forest multistate models.
## S3 method for class 'rfmstate' predict(object, newdata = NULL, times = NULL, s = 0, ...)## S3 method for class 'rfmstate' predict(object, newdata = NULL, times = NULL, s = 0, ...)
object |
A fitted |
newdata |
A data frame with the same covariates used in fitting.
If |
times |
Numeric vector of times at which to compute transition
probabilities. If |
s |
Numeric, starting time (default 0). |
... |
Ignored. |
An object of class "rfmstate_pred" containing:
Evaluation times.
Array of transition probability matrices (n_subjects x n_states x n_states x n_times).
Array of state occupation probabilities (n_subjects x n_states x n_times).
List of per-subject cumulative hazard matrices.
The multistate structure.
The prediction data.
ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) newpat <- data.frame(age = c(50, 70), sex = c(0, 1), BMI = c(25, 30), treatment = c(1, 0)) pred <- predict(fit, newdata = newpat, times = c(30, 90, 180, 365))ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) newpat <- data.frame(age = c(50, 70), sex = c(0, 1), BMI = c(25, 30), treatment = c(1, 0)) pred <- predict(fit, newdata = newpat, times = c(30, 90, 180, 365))
Converts wide-format clinical data into long counting-process format suitable for multistate survival analysis.
prepare_data( data, id, structure, time_map, censor_col, covariates, initial_state = NULL )prepare_data( data, id, structure, time_map, censor_col, covariates, initial_state = NULL )
data |
A data frame in wide format with one row per patient. |
id |
Character string, name of the patient ID column. |
structure |
An |
time_map |
A named list mapping state names to column names in
|
censor_col |
Character string, name of the column containing the right censoring time (last follow-up time). |
covariates |
Character vector of covariate column names to carry into the long-format data. |
initial_state |
Character string, the starting state for all patients (default: first state in the structure). |
Each patient's trajectory is reconstructed from event times, validated against the allowed transitions, and expanded into start-stop intervals with covariates.
An object of class "msdata" (a data frame) with columns:
Patient identifier.
Origin state for this interval.
Destination state (or NA if censored).
Start time of the interval.
End time of the interval.
1 if a transition occurred, 0 if censored.
Integer transition ID (from structure) or NA.
Duration of the interval.
Covariate columns.
The object also carries an attribute "structure" (the
mstate_structure).
ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 50, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) head(msdata)ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 50, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) head(msdata)
Fits transition-specific cause-specific random survival forests for multistate survival analysis. For each transient origin state, a competing risks model is fit using random forests, where the competing events are the possible transitions to destination states.
rfmstate( msdata, covariates = NULL, num.trees = 1000L, mtry = NULL, min.node.size = 15L, importance = "permutation", seed = NULL, ... )rfmstate( msdata, covariates = NULL, num.trees = 1000L, mtry = NULL, min.node.size = 15L, importance = "permutation", seed = NULL, ... )
msdata |
An |
covariates |
Character vector of covariate column names to use as
predictors. If |
num.trees |
Integer, number of trees per forest (default 1000). |
mtry |
Integer, number of variables to try at each split. Default
|
min.node.size |
Integer, minimum node size (default 15). |
importance |
Character, variable importance mode. One of
|
seed |
Integer, random seed for reproducibility (default |
... |
Additional arguments passed to |
For each transient state , the method:
Subsets all intervals where the patient is in state .
Defines time as the duration in state (Tstop - Tstart).
Codes competing events: 0 = censored, 1, 2, ... for each possible destination state.
Fits a cause-specific random survival forest using
ranger with survival tree type.
Transition probabilities are then computed by combining per-origin-state predicted cumulative hazards via the product-integral formula.
An object of class "rfmstate" containing:
Named list of fitted ranger objects, one per
origin state.
The multistate structure.
Character vector of covariate names used.
Named list of per-origin-state data subsets.
Named list of unique event times per origin state.
The matched call.
List of tuning parameters used.
ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment")) print(fit)ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment")) print(fit)
Generates realistic clinical trial data with covariates and multistate event times for testing and demonstration. Works with any multistate structure.
sim_clinical_data(n = 500, structure = NULL, max_followup = 365, seed = NULL)sim_clinical_data(n = 500, structure = NULL, max_followup = 365, seed = NULL)
n |
Integer, number of patients to simulate. |
structure |
An |
max_followup |
Numeric, maximum follow-up time (for generating censoring). Default 365. |
seed |
Optional integer for reproducibility. |
Transition intensities follow Weibull distributions with covariate effects
on the scale parameter. For the default clinical_states()
structure, transition-specific parameters are calibrated to produce
realistic clinical trial trajectories. For custom structures, sensible
default parameters are used for all transitions.
A data frame in wide format with columns:
Patient identifier (1 to n).
Continuous, simulated from Normal(60, 12).
Binary 0/1.
Continuous, simulated from Normal(26, 5).
Binary 0/1 (balanced arms).
For each non-initial state in the
structure, the time (days) of entry into that state, or NA
if the state was not visited. Column names follow the pattern
time_<StateName> (e.g., time_Death).
Days until last follow-up (right censoring
time), or NA if an absorbing state was reached.
set.seed(123) dat <- sim_clinical_data(n = 100) head(dat) summary(dat)set.seed(123) dat <- sim_clinical_data(n = 100) head(dat) summary(dat)
Provides a comprehensive summary of the fitted model including per-origin state information, OOB prediction error, and transition event counts.
## S3 method for class 'rfmstate' summary(object, ...)## S3 method for class 'rfmstate' summary(object, ...)
object |
A fitted |
... |
Ignored. |
An object of class "summary.rfmstate", printed invisibly.
ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) summary(fit)ms <- clinical_states() set.seed(42) dat <- sim_clinical_data(n = 200, structure = ms) msdata <- prepare_data( data = dat, id = "ID", structure = ms, time_map = list( Responded = "time_Responded", Unresponded = "time_Unresponded", Stabilized = "time_Stabilized", Progressed = "time_Progressed", Death = "time_Death" ), censor_col = "time_censored", covariates = c("age", "sex", "BMI", "treatment") ) fit <- rfmstate(msdata, covariates = c("age", "sex", "BMI", "treatment"), num.trees = 100) summary(fit)