博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
scikit-learn一般实例之八:多标签分类
阅读量:5809 次
发布时间:2019-06-18

本文共 3187 字,大约阅读时间需要 10 分钟。

本例模拟一个多标签文档分类问题.数据集基于下面的处理随机生成:

  • 选取标签的数目:泊松(n~Poisson,n_labels)
  • n次,选取类别C:多项式(c~Multinomial,theta)
  • 选取文档长度:泊松(k~Poisson,length)
  • k次,选取一个单词:多项式(w~Multinomial,theta_c)

在上面的处理中,拒绝抽样用来确保n大于2,文档长度不为0.同样,我们拒绝已经被选取的类别.被同事分配给两个分类的文档会被两个圆环包围.

通过投影到由PCA和CCA选取进行可视化的前两个主成分进行分类.接着通过元分类器使用两个线性核的SVC来为每个分类学习一个判别模型.注意,PCA用于无监督降维,CCA用于有监督.

注:在下面的绘制中,"无标签样例"不是说我们不知道标签(就像半监督学习中的那样),而是这些样例根本没有标签~~~

673170-20161005193051160-1300178212.png

# coding:utf-8import numpy as npfrom pylab import *from sklearn.datasets import make_multilabel_classificationfrom sklearn.multiclass import OneVsRestClassifierfrom sklearn.svm import SVCfrom sklearn.preprocessing import LabelBinarizerfrom sklearn.decomposition import PCAfrom sklearn.cross_decomposition import CCAmyfont = matplotlib.font_manager.FontProperties(fname="Microsoft-Yahei-UI-Light.ttc")mpl.rcParams['axes.unicode_minus'] = Falsedef plot_hyperplane(clf, min_x, max_x, linestyle, label):    # 获得分割超平面    w = clf.coef_[0]    a = -w[0] / w[1]    xx = np.linspace(min_x - 5, max_x + 5)  # 确保线足够长    yy = a * xx - (clf.intercept_[0]) / w[1]    plt.plot(xx, yy, linestyle, label=label)def plot_subfigure(X, Y, subplot, title, transform):    if transform == "pca":        X = PCA(n_components=2).fit_transform(X)    elif transform == "cca":        X = CCA(n_components=2).fit(X, Y).transform(X)    else:        raise ValueError    min_x = np.min(X[:, 0])    max_x = np.max(X[:, 0])    min_y = np.min(X[:, 1])    max_y = np.max(X[:, 1])    classif = OneVsRestClassifier(SVC(kernel='linear'))    classif.fit(X, Y)    plt.subplot(2, 2, subplot)    plt.title(title,fontproperties=myfont)    zero_class = np.where(Y[:, 0])    one_class = np.where(Y[:, 1])    plt.scatter(X[:, 0], X[:, 1], s=40, c='gray')    plt.scatter(X[zero_class, 0], X[zero_class, 1], s=160, edgecolors='b',               facecolors='none', linewidths=2, label=u'类别-1')    plt.scatter(X[one_class, 0], X[one_class, 1], s=80, edgecolors='orange',               facecolors='none', linewidths=2, label=u'类别-2')    plot_hyperplane(classif.estimators_[0], min_x, max_x, 'k--',                    u'类别-1的\n边界')    plot_hyperplane(classif.estimators_[1], min_x, max_x, 'k-.',                    u'类别-2的\n边界')    plt.xticks(())    plt.yticks(())    plt.xlim(min_x - .5 * max_x, max_x + .5 * max_x)    plt.ylim(min_y - .5 * max_y, max_y + .5 * max_y)    if subplot == 2:        plt.xlabel(u'第一主成分',fontproperties=myfont)        plt.ylabel(u'第二主成分',fontproperties=myfont)        plt.legend(loc="upper left",prop=myfont)plt.figure(figsize=(8, 6))X, Y = make_multilabel_classification(n_classes=2, n_labels=1,                                      allow_unlabeled=True,                                      random_state=1)plot_subfigure(X, Y, 1, u"有无标签样例 + CCA", "cca")plot_subfigure(X, Y, 2, u"有无标签样例 + PCA", "pca")X, Y = make_multilabel_classification(n_classes=2, n_labels=1,                                      allow_unlabeled=False,                                      random_state=1)plot_subfigure(X, Y, 3, u"没有无标签样例 + CCA", "cca")plot_subfigure(X, Y, 4, u"没有无标签样例 + PCA", "pca")plt.subplots_adjust(.04, .02, .97, .94, .09, .2)plt.suptitle(u"多标签分类", size=20,fontproperties=myfont)plt.show()

转载于:https://www.cnblogs.com/taceywong/p/5932682.html

你可能感兴趣的文章
DEV-C++ 调试方法简明图文教程(转)
查看>>
参加婚礼
查看>>
Java重写equals方法和hashCode方法
查看>>
Spark API编程动手实战-07-join操作深入实战
查看>>
Spring ’14 Wave Update: Installing Dynamics CRM on Tablets for Windows 8.1
查看>>
MySQL 备份与恢复
查看>>
TEST
查看>>
PAT A1037
查看>>
(六)Oracle学习笔记—— 约束
查看>>
[Oracle]如何在Oracle中设置Event
查看>>
top.location.href和localtion.href有什么不同
查看>>
02-创建hibernate工程
查看>>
Scrum之 Sprint计划会议
查看>>
svn命令在linux下的使用
查看>>
Gradle之module间依赖版本同步
查看>>
java springcloud版b2b2c社交电商spring cloud分布式微服务(十五)Springboot整合RabbitMQ...
查看>>
10g手动创建数据库
查看>>
Windwos Server 2008 R2 DHCP服务
查看>>
UVa 11292 勇者斗恶龙(The Dragon of Loowater)
查看>>
白话算法(7) 生成全排列的几种思路(二) 康托展开
查看>>