-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy path06-KNN.Rmd
51 lines (50 loc) · 1.96 KB
/
06-KNN.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
The next method is KNN that classifies a new observation on the basis of the surrounding observations. This method does not need to build a statistical model and does not have distributional requirement. To identify the best number of neighbors, we follow the following do.chunk function on the basis of 10-fold cross-validation.
```{r}
nfold = 10
set.seed(1)
folds = seq.int(nrow(banking.train)) %>%
cut(breaks = nfold, labels=FALSE) %>% sample
do.chunk <- function(chunkid, folddef, Xdat, Ydat, k){
train = (folddef!=chunkid)
Xtr = Xdat[train,]
Ytr = Ydat[train]
Xvl = Xdat[!train,]
Yvl = Ydat[!train]
predYtr = knn(train = Xtr, test = Xtr, cl = Ytr, k = k)
predYvl = knn(train = Xtr, test = Xvl, cl = Ytr, k = k)
data.frame(fold =chunkid,
train.error = calc_error_rate(predYtr, Ytr),
val.error = calc_error_rate(predYvl, Yvl))
}
###########
error.folds=NULL
kvec = c(1, seq(10, 50, length.out=5))
set.seed(1)###Take a while to run
for (j in kvec){
tmp = ldply(1:nfold, do.chunk,
folddef=folds, Xdat=XTrain, Ydat=YTrain, k=j)
tmp$neighbors = j
error.folds = rbind(error.folds, tmp)
}
errors = melt(error.folds, id.vars=c("fold","neighbors"), value.name= "error" )
val.error.means = errors %>%
filter(variable== "val.error" ) %>%
group_by(neighbors, variable) %>%
summarise_each(funs(mean), error) %>%
ungroup() %>%
filter(error==min(error))
numneighbor = max(val.error.means$neighbors)
numneighbor#the best number of neighbors =20
```
As it turns out, choosing 20 neighbors can minimize the cross-validation error rate.
```{r}
#training error
set.seed(20)
pred.YTtrain = knn(train=XTrain, test=XTrain, cl=YTrain, k=20)
knn_traing_error <- calc_error_rate(predicted.value=pred.YTtrain, true.value=YTrain)
#test error =0.095
set.seed(20)
pred.YTest = knn(train=XTrain, test=XTest, cl=YTrain, k=20)
knn_test_error <- calc_error_rate(predicted.value=pred.YTest, true.value=YTest)
records[3,] <- c(knn_traing_error,knn_test_error)
```