Sample-Splitting and Cross-Validation

Statistical Computing, 36-350

Wednesday November 30, 2016

Reminder: estimating test error

Often, we want an accurate estimate of the test error of our method (e.g., linear regression). Why? Two main purposes:

As’ve we’ve seen, training error is inappropriate for both purposes: generally too optimistic, and also more optimistic the more complex/adaptive the method

Sample-splitting

Given a data set, how can we estimate test error? (Can’t simply simulate more data for testing.) We know training error won’t work

A tried-and-true technique with an old history in statistics: sample-splitting

Examples

dat = read.table("http://www.stat.cmu.edu/~ryantibs/statcomp-F16/data/xy.dat")
head(dat)
##           x         y
## 1 -2.908021 -7.298187
## 2 -2.713143 -3.105055
## 3 -2.439708 -2.855283
## 4 -2.379042 -4.902240
## 5 -2.331305 -6.936175
## 6 -2.252199 -2.703149
n = nrow(dat)

# Split data in half, randomly
set.seed(0)
inds = sample(rep(1:2, length=n))
head(inds, 10)
##  [1] 1 2 2 1 2 2 2 1 2 2
table(inds)
## inds
##  1  2 
## 25 25
dat.tr = dat[inds==1,] # Training data
dat.te = dat[inds==2,] # Test data

plot(dat$x, dat$y, pch=c(21,19)[inds], main="Sample-splitting")
legend("topleft", legend=c("Training","Test"), pch=c(21,19))

(Continued)

# Train on the first half
lm.1 = lm(y ~ x, data=dat.tr)
lm.10 = lm(y ~ poly(x,10), data=dat.tr)

# Predict on the second half, evaluate test error
pred.1 = predict(lm.1, data.frame(x=dat.te$x))
pred.10 = predict(lm.10, data.frame(x=dat.te$x))

test.err.1 = mean((dat.te$y - pred.1)^2)
test.err.10 = mean((dat.te$y - pred.10)^2)

# Plot the results
par(mfrow=c(1,2))
xx = seq(min(dat$x), max(dat$x), length=100)

plot(dat$x, dat$y, pch=c(21,19)[inds], main="Sample-splitting")
lines(xx, predict(lm.1, data.frame(x=xx)), col=2, lwd=2)
legend("topleft", legend=c("Training","Test"), pch=c(21,19))
text(0, -6, label=paste("Test error:", round(test.err.1,3)))

plot(dat$x, dat$y, pch=c(21,19)[inds], main="Sample-splitting")
lines(xx, predict(lm.10, data.frame(x=xx)), col=3, lwd=2)
legend("topleft", legend=c("Training","Test"), pch=c(21,19))
text(0, -6, label=paste("Test error:", round(test.err.10,3)))

Cross-validation

Sample-splitting is simple, effective. But its it estimates the test error when the model/method is trained on less data (say, roughly half as much)

An improvement over sample splitting: \(k\)-fold cross-validation

A common choice is \(k=5\) or \(k=10\) (sometimes \(k=n\), called leave-one-out!)

(Continued)

For demonstration purposes, suppose \(n=6\) and we choose \(k=3\) parts

Data point Part Trained on Prediction
\(Y_1\) 1 2,3 \(\hat{Y}^{-(1)}_1\)
\(Y_2\) 1 2,3 \(\hat{Y}^{-(1)}_2\)
\(Y_3\) 2 1,3 \(\hat{Y}^{-(2)}_3\)
\(Y_4\) 2 1,3 \(\hat{Y}^{-(2)}_4\)
\(Y_5\) 3 1,2 \(\hat{Y}^{-(3)}_5\)
\(Y_6\) 3 1,2 \(\hat{Y}^{-(3)}_6\)

Notation: model trained on parts 2 and 3 in order to make predictions for part 1. So prediction \(\hat{Y}^{-(1)}_1\) for \(Y_1\) comes from model trained on all data except that in part 1. And so on

The cross-validation estimate of test error (also called the cross-validation error) is \[ \frac{1}{6}\Big( (Y_1-\hat{Y}^{-(1)}_1)^2 + (Y_1-\hat{Y}^{-(1)}_2)^2 + (Y_1-\hat{Y}^{-(2)}_3)^2 + \\ (Y_1-\hat{Y}^{-(2)}_4)^2 + (Y_1-\hat{Y}^{-(3)}_5)^2 + (Y_1-\hat{Y}^{-(3)}_6)^2 \Big) \]

Examples

# Split data in 5 parts, randomly
k = 5
set.seed(0)
inds = sample(rep(1:k, length=n))
head(inds, 10)
##  [1] 5 4 3 2 2 5 5 1 3 1
table(inds)
## inds
##  1  2  3  4  5 
## 10 10 10 10 10
# Now run cross-validation: easiest with for loop, running over
# which part to leave out
pred.mat = matrix(0, n, 2) # Empty matrix to store predictions
for (i in 1:k) {
  cat(paste("Fold",i,"... "))
  
  dat.tr = dat[inds!=i,] # Training data
  dat.te = dat[inds==i,] # Test data
  
  # Train our models
  lm.1.minus.i = lm(y ~ x, data=dat.tr)
  lm.10.minus.i = lm(y ~ poly(x,10), data=dat.tr)
  
  # Record predictions
  pred.mat[inds==i,1] = predict(lm.1.minus.i, data.frame(x=dat.te$x))
  pred.mat[inds==i,2] = predict(lm.10.minus.i, data.frame(x=dat.te$x))
}
## Fold 1 ... Fold 2 ... Fold 3 ... Fold 4 ... Fold 5 ...

(Continued)

# Compute cross-validation error, one for each model
cv.errs = colMeans((pred.mat - dat$y)^2)

# Plot the results
par(mfrow=c(1,2))
xx = seq(min(dat$x), max(dat$x), length=100)

plot(dat$x, dat$y, pch=20, col=inds+1, main="Cross-validation")
lines(xx, predict(lm.1, data.frame(x=xx)), # Note: model trained on FULL data!
      lwd=2, lty=2) 
legend("topleft", legend=paste("Fold",1:k), pch=20, col=2:(k+1))
text(0, -6, label=paste("CV error:", round(cv.errs[1],3)))

plot(dat$x, dat$y, pch=20, col=inds+1, main="Cross-validation")
lines(xx, predict(lm.10, data.frame(x=xx)), # Note: model trained on FULL data!
      lwd=2, lty=2) 
legend("topleft", legend=paste("Fold",1:k), pch=20, col=2:(k+1))
text(0, -6, label=paste("CV error:", round(cv.errs[2],3)))

(Continued)

# Now we visualize the different models trained, one for each CV fold
for (i in 1:k) {
  dat.tr = dat[inds!=i,] # Training data
  dat.te = dat[inds==i,] # Test data
  
  # Train our models
  lm.1.minus.i = lm(y ~ x, data=dat.tr)
  lm.10.minus.i = lm(y ~ poly(x,10), data=dat.tr)
  
  # Plot fitted models
  par(mfrow=c(1,2)); cols = c("red","gray")
  plot(dat$x, dat$y, pch=20, col=cols[(inds!=i)+1], main=paste("Fold",i))
  lines(xx, predict(lm.1.minus.i, data.frame(x=xx)), lwd=2, lty=2) 
  legend("topleft", legend=c(paste("Fold",i),"Other folds"), pch=20, col=cols)
  text(0, -6, label=paste("Fold",i,"error:",
       round(mean((dat.te$y - pred.mat[inds==i,1])^2),3)))
                          
  plot(dat$x, dat$y, pch=20, col=cols[(inds!=i)+1], main=paste("Fold",i))
  lines(xx, predict(lm.10.minus.i, data.frame(x=xx)), lwd=2, lty=2) 
  legend("topleft", legend=c(paste("Fold",i),"Other folds"), pch=20, col=cols)
  text(0, -6, label=paste("Fold",i,"error:",
       round(mean((dat.te$y - pred.mat[inds==i,2])^2),3)))
}