我正在跟随Tibshirani的ISL文本。我试图在ggplot2中绘制SVM的结果。我可以得到点和支持向量,但我无法弄清楚如何获得2D情况下绘制的边距和超平面。我用Google搜索并检查了e1071自述文件。一般的动态解决方案(适用于各种SVM内核,成本等)会很棒。这是我的MWE:
set.seed(1)
N=20
x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
library(e1071)
library(ggplot2)
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)
df = dat; df
df = cbind(df, sv=rep(0,nrow(df)))
df[svmfit$index,]$sv = 1
ggplot(data=df,aes(x=x.1,y=x.2,group=y,color=y)) +
geom_point(aes(shape=factor(sv)))
所以你不想绘制支持向量吗?基于plot.svm
源代码,这里有一些适用于您的示例的基本内容。
https://github.com/cran/e1071/blob/master/R/svm.R
你可以通过查看源代码来构建更丰富的东西。
library(e1071)
library(ggplot2)
set.seed(1)
N=20
x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)
grid <- expand.grid(seq(min(dat[, 1]), max(dat[, 1]),length.out=100),
seq(min(dat[, 2]), max(dat[, 2]),length.out=100))
names(grid) <- names(dat)[1:2]
preds <- predict(svmfit, grid)
df <- data.frame(grid, preds)
ggplot(df, aes(x = x.2, y = x.1, fill = preds)) + geom_tile()
应该输出这个:
将此与plot.svm
输出进行比较:
plot(svmfit, dat)
编辑:
如果你想重现这些点,我稍微修改了上面的代码:
cols <- c('1' = 'red', '-1' = 'black')
tiles <- c('1' = 'magenta', '-1' = 'cyan')
shapes <- c('support' = 4, 'notsupport' = 1)
dat$support <- 'notsupport'
dat[svmfit$index, 'support'] <- 'support'
ggplot(df, aes(x = x.2, y = x.1)) + geom_tile(aes(fill = preds)) +
scale_fill_manual(values = tiles) +
geom_point(data = dat, aes(color = y, shape = support), size = 2) +
scale_color_manual(values = cols) +
scale_shape_manual(values = shapes) +
ggtitle('SVM classification plot')