No Dicas de R de hoje, vamos falar como funciona a classificação KNN (K nearest neighbours), aplicar o método em um exemplo, e visualizar o resultado a partir da amostra. O escopo de métodos como o KNN é de problemas de classificação. Ao invés de utilizar dados para prever resultados contínuos como em regressões lineares, são criadas regras de decisão para a classe esperada de uma observação. Ademais, tais classes podem ser utilizadas para a previsão de variáveis supostas contínuas, sendo um dos métodos mais simples de estimação não-paramétrica.
O desenvolvimento de métodos de classificação é baseado em distribuições condicionais. Considerando como melhor classificador aquele que minimiza o valor esperado da taxa de erros de teste (o número de classificações feitas erradas sobre o total, a partir da estimação de treino), podemos provar que a melhor regra de decisão é aquela que assinala a classe de maior probabilidade, dado os valores das variáveis de previsão. Essa regra é chama de classificador de Bayes, e depende do conhecimento completo das distribuições condicionais da variável a ser classificada, logo muitas vezes não é possível utilizá-lo.
Com isso, vamos explicar como funciona o classificador KNN. Após separarmos os dados entre treino e teste, definimos a seguinte regra: para uma observação das variáveis regressoras de teste, são identificados os K pontos de treino mais próximos à posição da amostra, e a probabilidade condicional de cada classe é a frequência relativa de sua ocorrência entre os pontos escolhidos. Desse modo, a classificação estimada para o ponto de teste é igual à classe que ocorre mais vezes dentre os K pontos de treino. A escolha do número K é importante: quanto menor for K, menor será a taxa de erro de treino (para K=1, a amostra será perfeitamente classificada), porém a classificação gerará padrões complicados e pode ter alta taxa de erro de teste, dependendo da amostra. Isso condiz com o tradeoff de viés-variância, de modo que, quanto menor o K, menor é o viés da classificação (pois está mais próxima dos dados reais), porém maior a variância da classificação entre amostras distintas, enquanto que, quanto maior o K, menor é o impacto de amostras diferentes de treino sobre o classificador final, abrindo mão de parte da acurácia do modelo.
Vamos então partir para um exemplo. Os dados utilizados serão do clássico dataset iris, contendo 4 dados de medição de flores, e suas respectivas espécies. Primeiramente, vamos utilizar a classificação KNN para tentar prever a espécie de cada flor a partir das medições de suas pétalas. Gerando os dados:
dados <- iris class <- dados[,3:5] train <- c(1:30, 51:80, 101:130) class_train <- class[train,] class_test <- class[-train,]
Então, podemos realizar a classificação utilizando a função knn(), do pacote class. Seus inputs são os dados de treino, teste, o número K e a classificação correta do treino. Abaixo, o modelo e sua previsão, comparada com os valores reais:
library(class) pred = knn(train = class_train[,1:2], test = class_test[1:2], cl = class_train[,3], k = 4) table(pred, class_test[,3]) pred setosa versicolor virginica setosa 20 0 0 versicolor 0 20 1 virginica 0 0 19
Como podemos ver, o modelo acertou quase todas as classificações, com apenas um erro. Ele funcionou tão bem pois as variáveis se separam em grupos bem distintos, e, como o comprimento das pétalas se diferencia bastante entre as espécies - e possui escala maior, logo tem muito mais impacto na regra de decisão -, ele facilita a regra de decisão. Abaixo, vamos plotar o classificador para K=4, em conjunto com os pontos de teste. O código gera uma malha de pontos que cobrem o gráfico, e então faz a classificação KNN sobre cada um dos pontos da malha. Após isso, geramos um dataframe que indica os pontos onde o contorno deve ser de uma espécie, e onde não deve ser, para cada uma das 3 espécies, de modo a gerar 3 contornos que, por hipótese, não se sobrepõem. Com esse dataframe, podemos utilizar a geom_contour() do ggplot2, criando assim o gráfico abaixo:
library(ggplot2) library(tidyverse) library(MASS) cover <- expand.grid(x=seq(min(class_train[,1]-1), max(class_train[,1]+1), by=0.1), y=seq(min(class_train[,2]-1), max(class_train[,2]+1), by=0.1)) cover_class <- knn(train = class_train[,1:2], cover, cl = class_train[,3], k = 4) dataf <- bind_rows(mutate(cover, cls="setosa", prob_cls=ifelse(cover_class==cls, T, F)), mutate(cover, cls="versicolor", prob_cls=ifelse(cover_class==cls, T, F)), mutate(cover, cls="virginica", prob_cls=ifelse(cover_class==cls, T, F))) ggplot(dataf) + geom_point(aes(x=x, y=y, color=cls), data = mutate(cover, cls=cover_class), size=1.2) + geom_contour(aes(x=x, y=y, z=prob_cls, group=cls, color=cls), bins=2, data=dataf) + geom_point(aes(x=Petal.Length, y=Petal.Width, col=Species), size=3, data=class_test)+ labs(x='Petal Length',y = 'Petal Width', colour = 'Species')+ theme_minimal()
Conteúdos como esse podem ser encontrados no nosso Curso de Machine Learning usando o R.