Introduction to the MTLR Workflow

Here we give a brief introduction to using Multi-task Logistic Regression (MTLR) for survival prediction. Note that MTLR was specifically designed to give survival probabilities across a range of times for individual observations. This differs from models which produce risk scores (such as those given by Cox proportional hazards), single time probability models (such as the Gail model), and population wide models (e.g. Kaplan-Meier curves). Producing survival probabilities over a range of times gives a more holistic view of survival to patients and physicians which may be critical in making healthcare decisions.

MTLR was introduced first in 2011 at NIPS under the name, “Learning Patient-Specific Cancer Survival Distributions as a Sequence of Dependent Regressors”. Since then much work has been done including a website which can be used to build MTLR models on uploaded data. While this is an extremely beneficial resource we have extended MTLR to be included in the R environment to make comparisons to other survival methods and use tools included in other R packages, such as survival and randomForestSRC.

MTLR can be used for survival data containing right, left, interval, or no censoring. In addition, these types of censoring can be mixed in the same dataset. Documentation on utilizing these different types of censoring can be found using help(mtlr). In this vignette we will consider an example which includes right censoring only. Namely, we will be using the lung dataset from the survival package.

##Data: lung

One can access the lung dataset by loading the survival package.

library(survival)
#Looking at the top 6 rows...
head(lung)
#>   inst time status age sex ph.ecog ph.karno pat.karno meal.cal wt.loss
#> 1    3  306      2  74   1       1       90       100     1175      NA
#> 2    3  455      2  68   1       0       90        90     1225      15
#> 3    3 1010      1  56   1       0       90        90       NA      15
#> 4    5  210      2  57   1       1       90        60     1150      11
#> 5    1  883      2  60   1       0      100        90       NA       0
#> 6   12 1022      1  74   1       1       50        80      513       0
#help(lung) #See the basic information of lung.

If you look at the help file for lung you will see the following feature definitions:

Most importantly you will notice the two features needed for every survival dataset for use of MTLR – an event time (here time), and the indicator identifying if an observation is uncensored/censored (here status). For this example we have status == 1 indicating a right censored individual and status == 2 indicating an uncensored individual. Later on we will be using the Surv function to structure our survival data for MTLR – there are other acceptable formats for the indicator feature (status) – see help(Surv) for more information.

##Pre-processing

We will remove inst for this example since this is a categorical feature with 19 unique values and we would like to keep the number of features relatively small.

lung <- lung[,-1]

Before progressing any further we will split our data into a training and testing set. Note that we could stratify our training/testing set by the censor status but for simplicity we skip that for now.

numberTrain <- floor(nrow(lung)*0.8)
set.seed(42)
trInd <- sample(1:nrow(lung), numberTrain)
training <- lung[trInd,]
testing <- lung[-trInd,]

You may also notice that there are some missing values in the data, namely in meal.cal and wt.loss (although ph.ecog, ph.karno, and pat.karno also have missing values). The MTLR package does not handle missing values for users so this must be pre-processed ahead of time. If one passes in data which contains missing values anyway, all rows with missing values will be removed before model training/predictions. To remedy this problem we perform a very basic mean imputation on the dataset. Note that we use the means from the training set to impute the test set.

#Perform imputation
trMeans <- colMeans(training,na.rm=T)
for(i in 1:ncol(training)){
  training[is.na(training[,i]), i] <- trMeans[i]
  testing[is.na(testing[,i]), i] <- trMeans[i]
}

##Model Training

Once the dataset has been prepared we can begin to play around with some of the functions found in the MTLR package. Most importantly we will be utilizing the mtlr function to train our model. There are a number of arguments that can be used by mtlr, though only a select few are discussed here. There are only two arguments required to train an mtlr model, formula and data. For formula we must structure our event time feature and censor indicator feature using the Surv function. Since we have time and status as these two features we can create our formula object:

formula <- Surv(time,status)~.

The above says we will be training a model on the survival object created from time and status and using all the other features in our dataset as predictors. If we wanted to select a few features we could do this as well, for example, with age and sex.

formulaSmall <- Surv(time, status)~age+sex

Next, we just need the data argument which in our case is training. We can finally make our first model!

library(MTLR)
fullMod <- mtlr(formula = formula, data = training)
smallMod <- mtlr(formula = formulaSmall, data = training)
#We will print the small model so the output is more compact.
smallMod
#> 
#> Call:  mtlr(formula = formulaSmall, data = training) 
#> 
#> Time points:
#>  [1]  60 105 142 173 183 201 223 249 284 307 353 408 458 546 716
#> 
#> 
#> Weights:
#>            Bias      age     sex
#> 60      0.09990  0.04734 -0.0212
#> 105    -0.08775  0.03330 -0.0180
#> 141.56 -0.00532  0.01611 -0.0302
#> 173.25  0.28043  0.01780 -0.0210
#> 182.56  0.19039  0.01233 -0.0363
#> 201    -0.37891  0.00890 -0.0235
#> 223.38  0.45917  0.00148 -0.0338
#> 249    -0.01072 -0.00227 -0.0299
#> 284    -0.61857  0.01121 -0.0377
#> 307.12  0.06531  0.01987 -0.0408
#> 353    -0.18200  0.01450 -0.0176
#> 408.25 -0.04434 -0.00832 -0.0300
#> 458.12  0.10350  0.00884 -0.0240
#> 546    -0.31218  0.00804 -0.0154
#> 715.81 -0.30572  0.01675 -0.0197

There is a lot to take in at first from the output of the mtlr model. The first item is simply the call that was used to build the model. Next is the time points that mtlr used to train the model. If these time points are not specified when constructing the model then mtlr will choose time points based on the quantiles of the event time feature. Additionally, the number of time points is chosen to be the sqrt(N) where N is the number of observations. Since we had 205 training instances and the sqrt(205 = 14.317) mtlr rounded up to 15 time points.

Last, mtlr outputs the weight matrix for the model – these are the weights corresponding to each feature at each time point (additionally notice that we include the bias weights). The row names correspond to the time point for which the feature weight belongs. If you would like to access these weights, they are saved in the model object as weight_matrix so you can access them using smallMod$weight_matrix.

We can also plot the weights for a mtlr model. Before we printed the small model but here we will look at the weights for the complete model.

plot(fullMod)
#> Warning: Use of `plot_data$time` is discouraged.
#> ℹ Use `time` instead.
#> Warning: Use of `plot_data$value` is discouraged.
#> ℹ Use `value` instead.
#> Warning: Use of `plot_data$variable` is discouraged.
#> ℹ Use `variable` instead.
#> Use of `plot_data$variable` is discouraged.
#> ℹ Use `variable` instead.
#> Warning: Use of `plot_data$time` is discouraged.
#> ℹ Use `time` instead.
#> Warning: Use of `plot_data$value` is discouraged.
#> ℹ Use `value` instead.
#> Warning: Use of `plot_data$variable` is discouraged.
#> ℹ Use `variable` instead.
#> Use of `plot_data$variable` is discouraged.
#> ℹ Use `variable` instead.

By default, plot will only look at the 5 features which had the largest sum of absolute values across time (the most influence). You can alter these specifications by playing with the arguments in plot.

##Model Predictions

Now that we have trained a MTLR model we should make some predictions! This is where our testing set and the predict function will come into play. Note that there are a number of predictions we may be interested in acquiring. First, we may want to view the survival curves of our test observations.

survCurves <- predict(fullMod, testing, type = "survivalcurve")
#survCurves is pretty large so we will look at the first 5 rows/columns.
survCurves[1:5,1:5]
#>       time         1         2         3         4
#> 1   0.0000 1.0000000 1.0000000 1.0000000 1.0000000
#> 2  60.0000 0.9100849 0.9155880 0.9378810 0.8942019
#> 3 105.0000 0.8277074 0.8387033 0.8770047 0.7921875
#> 4 141.5625 0.7333045 0.7517395 0.8092565 0.6744565
#> 5 173.2500 0.6422654 0.6683001 0.7419264 0.5660321

When we use the predict function for survival curves we will be returned a matrix where the first column (time) is the list of time points that the model evaluated the survival probability for each observation (these will be the time points used by mtlr and an additional 0 point). Every following column will correspond to the row number of the data passed in, e.g. column 2 (named 1) corresponds to row 1 of testing. Each row of this matrix gives the probabilities of survival at the corresponding time point (given by the time column). For example, testing observation 1 has a survival probability of 0.919 at time 60.625.

Since these curves may be hard to digest by observing a matrix of survival probabilities we can also choose to plot them.

plotcurves(survCurves, 1:10)
#> Warning: Use of `plot_data$value` is discouraged.
#> ℹ Use `value` instead.
#> Warning: Use of `plot_data$Index` is discouraged.
#> ℹ Use `Index` instead.

Here we have specified that we want to observe the survival curves for the first 10 observations (corresponding to the first 10 rows of testing). You will notice that these curves have been smoothed whereas before we only had probabilities for certain time points. We have performed a monotonic spline fit to those survival probabilities to produce the curves you see here.

Additionally, you may have specific plot specifications you want to make. plotcurves is simply returning a ggplot2 object so specifications can be made like you would make to any other ggplot2 graphic. For example, plotcurves(survCurves, 1:10) + ggplot2::xlab("Days") would change the x-axis label to “Days” instead of “Time”.

###Mean/Median Survival Time

In addition to the entire survival curve one may also be interested in the average survival time. This is again available from the predict function.

#Mean
meanSurv <- predict(fullMod, testing, type = "mean_time")
head(meanSurv)
#> [1] 292.2261 309.4038 337.5738 248.5873 278.6793 331.9974
#Median
medianSurv <- predict(fullMod, testing, type = "median_time")
head(medianSurv)
#> [1] 213.3254 229.9705 288.7202 181.3738 196.7198 285.9620

Here the mean survival time corresponds to the area under the survival curve of each observation. One subtlety is that many survival curves never touch zero probability making this area not well-defined. When this occurs, a linear fit is drawn from the time = 0, survival probability = 1 point to the last time point and extended to the 0 probability time. For example, below we have drawn a linear extension on the curves below to calculate the mean survival time.

This is also performed when calculating the median survival time if the last survival probability is above 0.5.

##Survival Probability at Event Time

The last prediction type supported is acquiring the observations survival probability at the respective event time. However, in order to use this prediction, the event time (whether censored or uncensored) must be included in the features passed into the predict function.

survivalProbs <- predict(fullMod, testing, type = "prob_event")
head(survivalProbs)
#> [1] 0.3444770 0.3358154 0.7573785 0.0676823 0.9070285 0.4560290
#To see what times these probabilities correspond to:
head(testing$time)
#> [1] 310 361 170 707  61 301

You will notice that some of these survival probabilities correspond to 0 (usually those with very large event times). We again have drawn the linear extension for the survival time if the event time could not be mapped onto the survival curve.

##Miscellaneous Commands ###mtlr_cv Previously we just used the default settings of mtlr. However, a number of things can be adjusted included the number of time points, the exact time points used, the initialization of the feature weights, and the regularization parameter (C1) which corresponds to the C1 given in the NIPS paper. The mtlr_cv function helps to select a value of C1. Given a vector of values to test for C1, mtlr_cv will do internal cross validation to select the optimal C1 for some criteria. Currently the only optimization is referred to as the log-likelihood loss (see the “Details” section of help(mtlr_cv)). For example, we use this command with 5 values of C1 (although there is a default of (0.001,0.01,0.1,1,10,100,1000)).

mtlr_cv(formula,training, C1_vec = c(0.01,0.1,1,10,100))
#> $best_C1
#> [1] 1
#> 
#> $avg_loss
#>     0.01      0.1        1       10      100 
#> 2.397453 2.232798 2.186749 2.209120 2.234422

The output gives us the best value of C1 and the losses for the values tested. Once we have the best value of C1 we can then use the mtlr function with the chosen value of C1.

###create_folds

As we mentioned, mtlr_cv uses an internal k-fold cross validation to evaluate the loss. We also export the function (create_folds) used to create these cross-validation folds as it is creating folds in a unique way.

These folds can be deterministic, semi-deterministic, or totally random. The deterministic folds arise by stratifying folds by censor status and attempting to create equal ranges in the event times within each fold. This is done by first stratifying the survival dataset into a censored and uncensored portion and then sorting each portion by the event time. These portions are then numbered off into k different folds (see figure below). This option corresponds to “fullstrat”.

<img src=“crossvalidation.png” alt=“The”fullstrat” method of cross-fold validation. Here “delta” refers to the feature used to identify censored/uncensored observations.” width=“90%” />

The “fullstrat” method of cross-fold validation. Here “delta” refers to the feature used to identify censored/uncensored observations.

The semi-deterministic method stratifies the folds by censor status but ignores event time. In the figure above this would be the same process but skipping the sorting of time. This option is called “censorstrat”. The completely random method ignores the censor status and the event time and randomly assigns observations to folds – the “random” option.

We include the create_folds function in the event one would want to perform k-fold cross-validation on their dataset instead of using a training/testing split. For example, one could set up for 5-fold cross-validation like so:

#Recall we are using the lung dataset.
testInd <- create_folds(lung$time, lung$status, nfolds = 5, foldtype = "fullstrat")

Now testInd is a list of length 5 where each of the 5 items are the indices of the test set to be used for each fold. One could now use a for loop and iterate over 5-folds and index lung by the test indices (lung[testInd,]) for the test set and all but the test indices (lung[-testInd,]) for the training set and train these 5 models.

##Evaluation

The MTLR package does not directly supply functions on how to evaluate the survival curves, but instead the methods to produce these curves. For a detailed discussion on a variety of evaluation metrics and access to some R-scripts containing evaluation code please see the paper Effective Ways to Build and Evaluate Individual Survival Distributions.