Title: | Reinforcement Learning Trees |
---|---|
Description: | Random forest with a variety of additional features for regression, classification and survival analysis. The features include: parallel computing with OpenMP, embedded model for selecting the splitting variable, based on Zhu, Zeng & Kosorok (2015) <doi:10.1080/01621459.2015.1036994>, subject weight, variable weight, tracking subjects used in each tree, etc. |
Authors: | Ruoqing Zhu [aut, cre, cph] |
Maintainer: | Ruoqing Zhu <[email protected]> |
License: | GPL (>= 2) |
Version: | 3.2.6 |
Built: | 2024-12-31 07:34:38 UTC |
Source: | CRAN |
Get the muting rate based on sample size N
and dimension P
. This is an experimental feature. When P is too small, this is not recommended.
MuteRate(N, P, speed = NULL, info = FALSE)
MuteRate(N, P, speed = NULL, info = FALSE)
N |
sample size |
P |
dimension |
speed |
Muting speed: moderate or aggressive |
info |
Whether to output detailed information |
A suggested muting rate
MuteRate(500, 100, speed = "aggressive")
MuteRate(500, 100, speed = "aggressive")
Predict future subjects with a fitted RLT model
## S3 method for class 'RLT' predict(object, testx, ...)
## S3 method for class 'RLT' predict(object, testx, ...)
object |
A fitted RLT object |
testx |
Testing data |
... |
... |
The predicted values. For survival model, it returns the fitted survival functions
x = matrix(rnorm(100), ncol = 10) y = rowMeans(x) fit = RLT(x, y, ntrees = 5) predict(fit, x)
x = matrix(rnorm(100), ncol = 10) y = rowMeans(x) fit = RLT(x, y, ntrees = 5) predict(fit, x)
Print a RLT object
## S3 method for class 'RLT' print(x, ...)
## S3 method for class 'RLT' print(x, ...)
x |
A fitted RLT object |
... |
... |
No return value
x = matrix(rnorm(100), ncol = 10) y = rowMeans(x) fit = RLT(x, y, ntrees = 5) fit
x = matrix(rnorm(100), ncol = 10) y = rowMeans(x) fit = RLT(x, y, ntrees = 5) fit
Fit models for regression, classification and survival analysis using reinforced splitting rules
RLT( x, y, censor = NULL, model = "regression", print.summary = 0, use.cores = 1, ntrees = if (reinforcement) 100 else 500, mtry = max(1, as.integer(ncol(x)/3)), nmin = max(1, as.integer(log(nrow(x)))), alpha = 0.4, split.gen = "random", nsplit = 1, resample.prob = 0.9, replacement = TRUE, npermute = 1, select.method = "var", subject.weight = NULL, variable.weight = NULL, track.obs = FALSE, importance = TRUE, reinforcement = FALSE, muting = -1, muting.percent = if (reinforcement) MuteRate(nrow(x), ncol(x), speed = "aggressive", info = FALSE) else 0, protect = as.integer(log(ncol(x))), combsplit = 1, combsplit.th = 0.25, random.select = 0, embed.n.th = 4 * nmin, embed.ntrees = max(1, -atan(0.01 * (ncol(x) - 500))/pi * 100 + 50), embed.resample.prob = 0.8, embed.mtry = 1/2, embed.nmin = as.integer(nrow(x)^(1/3)), embed.split.gen = "random", embed.nsplit = 1 )
RLT( x, y, censor = NULL, model = "regression", print.summary = 0, use.cores = 1, ntrees = if (reinforcement) 100 else 500, mtry = max(1, as.integer(ncol(x)/3)), nmin = max(1, as.integer(log(nrow(x)))), alpha = 0.4, split.gen = "random", nsplit = 1, resample.prob = 0.9, replacement = TRUE, npermute = 1, select.method = "var", subject.weight = NULL, variable.weight = NULL, track.obs = FALSE, importance = TRUE, reinforcement = FALSE, muting = -1, muting.percent = if (reinforcement) MuteRate(nrow(x), ncol(x), speed = "aggressive", info = FALSE) else 0, protect = as.integer(log(ncol(x))), combsplit = 1, combsplit.th = 0.25, random.select = 0, embed.n.th = 4 * nmin, embed.ntrees = max(1, -atan(0.01 * (ncol(x) - 500))/pi * 100 + 50), embed.resample.prob = 0.8, embed.mtry = 1/2, embed.nmin = as.integer(nrow(x)^(1/3)), embed.split.gen = "random", embed.nsplit = 1 )
x |
A matrix or data.frame for features |
y |
Response variable, a numeric/factor vector or a Surv object |
censor |
The censoring indicator if survival model is used |
model |
The model type: |
print.summary |
Whether summary should be printed |
use.cores |
Number of cores |
ntrees |
Number of trees, |
mtry |
Number of variables used at each internal node, only for |
nmin |
Minimum number of observations required in an internal node to perform a split. Set this to twice of the desired terminal node size. |
alpha |
Minimum number of observations required for each child node as a portion of the parent node. Must be within |
split.gen |
How the cutting points are generated |
nsplit |
Number of random cutting points to compare for each variable at an internal node |
resample.prob |
Proportion of in-bag samples |
replacement |
Whether the in-bag samples are sampled with replacement |
npermute |
Number of imputations (currently not implemented, saved for future use) |
select.method |
Method to compare different splits |
subject.weight |
Subject weights |
variable.weight |
Variable weights when randomly sample |
track.obs |
Track which terminal node the observation belongs to |
importance |
Should importance measures be calculated |
reinforcement |
If reinforcement splitting rules should be used. There are default values for all tuning parameters under this feature. |
muting |
Muting method, |
muting.percent |
Only for |
protect |
Number of protected variables that will not be muted. These variables are adaptively selected for each tree. |
combsplit |
Number of variables used in a combination split. |
combsplit.th |
The minimum threshold (as a relative measurement compared to the best variable) for a variable to be used in the combination split. |
random.select |
Randomly select a variable from the top variable in the linear combination as the splitting rule. |
embed.n.th |
Number of observations to stop the embedded model and choose randomly from the current protected variables. |
embed.ntrees |
Number of embedded trees |
embed.resample.prob |
Proportion of in-bag samples for embedded trees |
embed.mtry |
Number of variables used for embedded trees, as proportion |
embed.nmin |
Terminal node size for embedded trees |
embed.split.gen |
How the cutting points are generated in the embedded trees |
embed.nsplit |
Number of random cutting points for embedded trees |
A RLT
object; a list consisting of
FittedTrees |
Fitted tree structure |
FittedSurv , timepoints
|
Terminal node survival estimation and all time points, if survival model is used |
AllError |
All out-of-bag errors, if |
VarImp |
Variable importance measures, if |
ObsTrack |
Registration of each observation in each fitted tree |
... |
All the tuning parameters are saved in the fitted |
Zhu, R., Zeng, D., & Kosorok, M. R. (2015) "Reinforcement Learning Trees." Journal of the American Statistical Association. 110(512), 1770-1784.
Zhu, R., & Kosorok, M. R. (2012). Recursively imputed survival trees. Journal of the American Statistical Association, 107(497), 331-340.
N = 600 P = 100 X = matrix(runif(N*P), N, P) Y = rowSums(X[,1:5]) + rnorm(N) trainx = X[1:200,] trainy = Y[1:200] testx = X[-c(1:200),] testy = Y[-c(1:200)] # Regular ensemble trees (Extremely Randomized Trees, Geurts, et. al., 2006) RLT.fit = RLT(trainx, trainy, model = "regression", use.cores = 6) barplot(RLT.fit$VarImp) RLT.pred = predict(RLT.fit, testx) mean((RLT.pred$Prediction - testy)^2) # Reinforcement Learning Trees, using an embedded model to find the splitting rule Mark0 = proc.time() RLT.fit = RLT(trainx, trainy, model = "regression", use.cores = 6, ntrees = 100, importance = TRUE, reinforcement = TRUE, combsplit = 3, embed.ntrees = 25) proc.time() - Mark0 barplot(RLT.fit$VarImp) RLT.pred = predict(RLT.fit, testx) mean((RLT.pred$Prediction - testy)^2)
N = 600 P = 100 X = matrix(runif(N*P), N, P) Y = rowSums(X[,1:5]) + rnorm(N) trainx = X[1:200,] trainy = Y[1:200] testx = X[-c(1:200),] testy = Y[-c(1:200)] # Regular ensemble trees (Extremely Randomized Trees, Geurts, et. al., 2006) RLT.fit = RLT(trainx, trainy, model = "regression", use.cores = 6) barplot(RLT.fit$VarImp) RLT.pred = predict(RLT.fit, testx) mean((RLT.pred$Prediction - testy)^2) # Reinforcement Learning Trees, using an embedded model to find the splitting rule Mark0 = proc.time() RLT.fit = RLT(trainx, trainy, model = "regression", use.cores = 6, ntrees = 100, importance = TRUE, reinforcement = TRUE, combsplit = 3, embed.ntrees = 25) proc.time() - Mark0 barplot(RLT.fit$VarImp) RLT.pred = predict(RLT.fit, testx) mean((RLT.pred$Prediction - testy)^2)