Random Forest Model with Cross-validation and Exclusion
Your Name
Source:vignettes/get_zone_exclusioned_rf_model_with_cv.Rmd
get_zone_exclusioned_rf_model_with_cv.Rmd
Introduction
The get_zone_exclusioned_rf_model_with_cv
function
implements a Random Forest classification model with cross-validation.
It provides tools for evaluating the model’s performance, including
sensitivity, specificity, accuracy, and other metrics. The function
allows users to handle indeterminate predictions and includes an option
for undersampling the data, which can be particularly useful when
dealing with imbalanced datasets.
This document explains how to use the function, describes its inputs, outputs, and the key steps involved in the model training and evaluation process.
Function Purpose
The main goal of this function is to train a Random Forest model and evaluate it using cross-validation. The function:
- Performs cross-validation across a specified number of repetitions
(
testReps
). - Allows for undersampling of the dataset to address class imbalance if required.
- Handles indeterminate predictions by setting them to
NA
. - Tracks performance metrics like sensitivity, specificity, positive predictive value (PPV), and accuracy for each repetition.
- Provides an aggregated summary of performance metrics across all repetitions.
Additionally, the function provides an option to adjust the feature importance calculation, either using the Gini index or the Mean Decrease Accuracy.
Parameters
The function accepts the following parameters:
Data (
Data
): A data frame containing the features and the target variable (Target_Organ
) to train the model on.Undersample (
Undersample
): A boolean parameter that indicates whether to perform undersampling on the data to balance the class distribution. If set toTRUE
, the function will undersample the negative class to match the number of positive class instances.Best Model Parameter (
best.m
): A numeric value indicating the best number of variables (mytry
) to use at each split in the Random Forest model. This value can be provided manually or determined through optimization.Test Repetitions (
testReps
): The number of times to repeat the cross-validation process. This value must be at least 2, as the function relies on multiple test sets to assess model performance.Indeterminate Prediction Thresholds (
indeterminateUpper
,indeterminateLower
): These parameters define the upper and lower bounds for predicting “indeterminate” values. If a model’s predicted probability falls between these thresholds, the prediction is considered indeterminate and set toNA
.Feature Importance Type (
Type
): An integer indicating the type of feature importance to use in the Random Forest model. Typically, this will be either1
for “Mean Decrease Accuracy” or2
for “Mean Decrease Gini”.
Model Workflow
-
Data Preparation:
- The input data frame (
Data
) is processed to ensure that it is formatted correctly for model training. The column names are simplified to numeric identifiers for easier manipulation.
- The input data frame (
-
Cross-validation:
- The function performs cross-validation, repeating the training and
testing process for the specified number of repetitions
(
testReps
). In each repetition:- The dataset is split into a training set and a test set, with each iteration using different random samples.
- The Random Forest model is trained on the training set, and predictions are made on the test set.
- The function performs cross-validation, repeating the training and
testing process for the specified number of repetitions
(
-
Undersampling (Optional):
- If
Undersample
is set toTRUE
, the function balances the dataset by undersampling the negative class. The positive class is left unchanged, and the negative class is reduced to match the size of the positive class.
- If
-
Prediction and Evaluation:
- After training the model, predictions are made on the test data. The predicted probabilities are stored and later used to calculate performance metrics.
- Indeterminate predictions are identified based on the upper and
lower thresholds (
indeterminateUpper
andindeterminateLower
). These predictions are marked asNA
and not included in performance calculations.
-
Performance Metrics:
- For each repetition, the function calculates various performance
metrics, including:
- Sensitivity: The proportion of true positives correctly identified by the model.
- Specificity: The proportion of true negatives correctly identified by the model.
- Accuracy: The overall accuracy of the model in predicting both classes.
- PPV (Positive Predictive Value): The proportion of positive predictions that are correct.
- NPV (Negative Predictive Value): The proportion of negative predictions that are correct.
- Prevalence: The proportion of positive cases in the dataset.
- These metrics are computed using the
caret
package’s confusion matrix function.
- For each repetition, the function calculates various performance
metrics, including:
-
Aggregated Results:
- After completing all test repetitions, the function calculates the mean of each performance metric across all repetitions to provide an aggregated performance summary.
- The results include both individual metrics for each repetition and the overall performance summary.
Outputs
The function returns a list with two components:
performance_metrics: A vector containing the aggregated performance metrics (mean sensitivity, specificity, accuracy, etc.) calculated across all test repetitions.
-
raw_results: A list containing the raw performance metrics for each repetition, including:
-
sensitivity
: A vector of sensitivity values for each test repetition. -
specificity
: A vector of specificity values for each test repetition. -
accuracy
: A vector of accuracy values for each test repetition.
-
These outputs can be used to evaluate the model’s performance and further analyze the results.
Example Usage
Below is an example of how to use the function:
# Example dataset (replace with actual data)
Data <- your_data_frame
# Run the model with cross-validation and undersampling
results <- get_zone_exclusioned_rf_model_with_cv(Data = Data,
Undersample = TRUE,
best.m = 5,
testReps = 10,
indeterminateUpper = 0.8,
indeterminateLower = 0.2,
Type = 1)
# View the aggregated performance metrics
print(results$performance_metrics)
# Access raw results for further analysis
print(results$raw_results)