Doing k-fold Cross-Validation for Imbalanced Data (Stratification) in R (Example Code)
In this tutorial, you’ll learn how to draw observations to the folds for cross-validation via stratification in R. With stratification, the relative frequencies of the class probabilities are close to those in the complete dataset. Stratification is especially useful when you have unbalanced classes.
Creation of Example Data
We take the iris dataset for an illustration.
data(iris) # Load iris data set head(iris) # Print head of data # Sepal.Length Sepal.Width Petal.Length Petal.Width Species # 1 5.1 3.5 1.4 0.2 setosa # 2 4.9 3.0 1.4 0.2 setosa # 3 4.7 3.2 1.3 0.2 setosa # 4 4.6 3.1 1.5 0.2 setosa # 5 5.0 3.6 1.4 0.2 setosa # 6 5.4 3.9 1.7 0.4 setosa |
data(iris) # Load iris data set head(iris) # Print head of data # Sepal.Length Sepal.Width Petal.Length Petal.Width Species # 1 5.1 3.5 1.4 0.2 setosa # 2 4.9 3.0 1.4 0.2 setosa # 3 4.7 3.2 1.3 0.2 setosa # 4 4.6 3.1 1.5 0.2 setosa # 5 5.0 3.6 1.4 0.2 setosa # 6 5.4 3.9 1.7 0.4 setosa
Take a look at the absolute frequencies of the three species classes.
table(iris$Species) # Species classes in the data # setosa versicolor virginica # 50 50 50 |
table(iris$Species) # Species classes in the data # setosa versicolor virginica # 50 50 50
From the data, we create a new dataset iris_2, with unbalanced class frequencies.
nr_rows_species <- c(50, 30, 10) # Preferred number of observations per species names(nr_rows_species) <- levels(iris$Species) set.seed(543) # Set seed for reproducible results sample_IDs <- unlist(lapply(levels(iris$Species), # Sample desired species observations from iris function (x) { sample(which(iris$Species == x), nr_rows_species[x], replace = FALSE) })) iris_2 <- iris[sample_IDs, ] # New dataset with unbalanced species table(iris_2$Species) # Absolute frequencies of species classes in new data # setosa versicolor virginica # 50 30 10 |
nr_rows_species <- c(50, 30, 10) # Preferred number of observations per species names(nr_rows_species) <- levels(iris$Species) set.seed(543) # Set seed for reproducible results sample_IDs <- unlist(lapply(levels(iris$Species), # Sample desired species observations from iris function (x) { sample(which(iris$Species == x), nr_rows_species[x], replace = FALSE) })) iris_2 <- iris[sample_IDs, ] # New dataset with unbalanced species table(iris_2$Species) # Absolute frequencies of species classes in new data # setosa versicolor virginica # 50 30 10
You see that in iris_2, there are five times as many observations with species setosa than with species virginica.
Example: Create k Stratified Folds for Cross-Validation
For k-fold cross-validation, we take k=10 folds.
nr_folds <- 10 # Number of folds |
nr_folds <- 10 # Number of folds
Now, we create an identifier for 10 stratified folds.
fold_id <- vector(length = nrow(iris_2)) # New vector for fold ID for (species_i in levels(iris_2$Species) ) { # Create stratified folds species_i_ID <- iris_2$Species == species_i n_i <- sum(species_i_ID) n_i_fold <- ceiling(n_i / nr_folds) fold_id[species_i_ID] <- sample(rep(1:nr_folds, n_i_fold), n_i, replace = FALSE) } |
fold_id <- vector(length = nrow(iris_2)) # New vector for fold ID for (species_i in levels(iris_2$Species) ) { # Create stratified folds species_i_ID <- iris_2$Species == species_i n_i <- sum(species_i_ID) n_i_fold <- ceiling(n_i / nr_folds) fold_id[species_i_ID] <- sample(rep(1:nr_folds, n_i_fold), n_i, replace = FALSE) }
Let us see the relative frequencies of the three classes in the 10 folds.
n_species_per_fold <- sapply(1:nr_folds, # Class proportions per fold_id function (fold_i) { round(prop.table(table(iris_2$Species[fold_id == fold_i])) * 100) }) colnames(n_species_per_fold) <- paste0("fold ", 1:nr_folds) n_species_per_fold # fold 1 fold 2 fold 3 fold 4 fold 5 fold 6 fold 7 fold 8 fold 9 # setosa 56 56 56 56 56 56 56 56 56 # versicolor 33 33 33 33 33 33 33 33 33 # virginica 11 11 11 11 11 11 11 11 11 # fold 10 # setosa 56 # versicolor 33 # virginica 11 |
n_species_per_fold <- sapply(1:nr_folds, # Class proportions per fold_id function (fold_i) { round(prop.table(table(iris_2$Species[fold_id == fold_i])) * 100) }) colnames(n_species_per_fold) <- paste0("fold ", 1:nr_folds) n_species_per_fold # fold 1 fold 2 fold 3 fold 4 fold 5 fold 6 fold 7 fold 8 fold 9 # setosa 56 56 56 56 56 56 56 56 56 # versicolor 33 33 33 33 33 33 33 33 33 # virginica 11 11 11 11 11 11 11 11 11 # fold 10 # setosa 56 # versicolor 33 # virginica 11
You see that in each of the ten folds, the class frequencies are almost identical. Now, you can use these stratified folds for a cross-validation task.
Note: This article was created in collaboration with Anna-Lena Wölwer. Anna-Lena is a researcher and programmer who creates tutorials on statistical methodology as well as on the R programming language. You may find more info about Anna-Lena and her other articles on her profile page.