À l'aide du curseur R package, comment puis-je générer une courbe ROC basée sur les résultats de validation croisée de la fonction train ()?
Dis, je fais ce qui suit:
data(Sonar)
ctrl <- trainControl(method="cv",
summaryFunction=twoClassSummary,
classProbs=T)
rfFit <- train(Class ~ ., data=Sonar,
method="rf", preProc=c("center", "scale"),
trControl=ctrl)
La fonction d'apprentissage passe en revue une plage de paramètres et calcule l'ASC ROC. Je voudrais voir la courbe ROC associée - comment faire?
Remarque: si la méthode utilisée pour l'échantillonnage est LOOCV, alors rfFit
contiendra une trame de données non nulle dans le rfFit$pred
slot, qui semble être exactement ce dont j'ai besoin. Cependant, j'ai besoin de cela pour la méthode "cv" (validation k-fold) plutôt que LOO.
Aussi: non, la fonction roc
qui était incluse dans les anciennes versions de caret n'est pas une réponse - c'est une fonction de bas niveau, vous ne pouvez pas l'utiliser si vous n'avez pas les probabilités de prédiction pour chaque échantillon à validation croisée.
Il n'y a que l'argument savePredictions = TRUE
Manquant dans ctrl
(cela fonctionne également pour d'autres méthodes de rééchantillonnage):
library(caret)
library(mlbench)
data(Sonar)
ctrl <- trainControl(method="cv",
summaryFunction=twoClassSummary,
classProbs=T,
savePredictions = T)
rfFit <- train(Class ~ ., data=Sonar,
method="rf", preProc=c("center", "scale"),
trControl=ctrl)
library(pROC)
# Select a parameter setting
selectedIndices <- rfFit$pred$mtry == 2
# Plot:
plot.roc(rfFit$pred$obs[selectedIndices],
rfFit$pred$M[selectedIndices])
Il me manque peut-être quelque chose, mais une petite préoccupation est que train
estime toujours des valeurs AUC légèrement différentes de plot.roc
Et pROC::auc
(Différence absolue <0,005), bien que twoClassSummary
utilise pROC::auc
pour estimer l'ASC. Edit: Je suppose que cela se produit parce que le ROC de train
est la moyenne de l'AUC utilisant les CV-Sets séparés et nous voici calculer l'ASC sur tous les rééchantillons simultanément pour obtenir l'ASC globale.
Mise à jour Puisque cela retient un peu l'attention, voici une solution utilisant plotROC::geom_roc()
pour ggplot2
:
library(ggplot2)
library(plotROC)
ggplot(rfFit$pred[selectedIndices, ],
aes(m = M, d = factor(obs, levels = c("R", "M")))) +
geom_roc(hjust = -0.4, vjust = 1.5) + coord_equal()
Ici, je modifie l'intrigue de @ thei1e que d'autres peuvent trouver utile.
Former le modèle et faire des prédictions
library(caret)
library(ggplot2)
library(mlbench)
library(plotROC)
data(Sonar)
ctrl <- trainControl(method="cv", summaryFunction=twoClassSummary, classProbs=T,
savePredictions = T)
rfFit <- train(Class ~ ., data=Sonar, method="rf", preProc=c("center", "scale"),
trControl=ctrl)
# Select a parameter setting
selectedIndices <- rfFit$pred$mtry == 2
Mise à jour du tracé de la courbe ROC
g <- ggplot(rfFit$pred[selectedIndices, ], aes(m=M, d=factor(obs, levels = c("R", "M")))) +
geom_roc(n.cuts=0) +
coord_equal() +
style_roc()
g + annotate("text", x=0.75, y=0.25, label=paste("AUC =", round((calc_auc(g))$AUC, 4)))