52matlab技术网站,matlab教程,matlab安装教程,matlab下载

标题: 用t-SNE降维查看分类效果 [打印本页]

作者: matlab的旋律    时间: 2020-2-26 14:57
标题: 用t-SNE降维查看分类效果
本帖最后由 matlab的旋律 于 2020-3-9 15:42 编辑

#####################################加载库#######################################
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt

##################################################################################

digits = load_digits()#加载数据
embeddings = TSNE().fit_transform(digits.data)#t-SNE降维,默认降为二维
vis_x = embeddings[:, 0]#0维
vis_y = embeddings[:, 1]#1维

############################提取不同label对应的坐标#################################
index0 = [i for i in range(len(digits.target)) if digits.target == 0]
index1 = [i for i in range(len(digits.target)) if digits.target
== 1]
index2 = [i for i in range(len(digits.target)) if digits.target
== 2]
index3 = [i for i in range(len(digits.target)) if digits.target
== 3]
index4 = [i for i in range(len(digits.target)) if digits.target
== 4]
index5 = [i for i in range(len(digits.target)) if digits.target
== 5]
index6 = [i for i in range(len(digits.target)) if digits.target
== 6]
index7 = [i for i in range(len(digits.target)) if digits.target
== 7]
index8 = [i for i in range(len(digits.target)) if digits.target
== 8]
index9 = [i for i in range(len(digits.target)) if digits.target

###################################################################################


#######################################绘图########################################
colors=['b', 'c', 'y', 'm', 'r', 'g', 'k','yellow','yellowgreen','wheat']
plt.scatter(vis_x[index0], vis_y[index0], c=colors[0], cmap='brg', marker='h',label='0')
plt.scatter(vis_x[index1], vis_y[index1], c=colors[1], cmap='brg',marker='<',label='1')
plt.scatter(vis_x[index2], vis_y[index2], c=colors[2], cmap='brg',marker='x',label='2')
plt.scatter(vis_x[index3], vis_y[index3], c=colors[3], cmap='brg',marker='.',label='3')
plt.scatter(vis_x[index4], vis_y[index4], c=colors[4], cmap='brg',marker='p',label='4')
plt.scatter(vis_x[index5], vis_y[index5], c=colors[5], cmap='brg',marker='>',label='5')
plt.scatter(vis_x[index6], vis_y[index6], c=colors[6], cmap='brg',marker='^',label='6')
plt.scatter(vis_x[index7], vis_y[index7], c=colors[7], cmap='brg',marker='d',label='7')
plt.scatter(vis_x[index8], vis_y[index8], c=colors[8], cmap='brg',marker='s',label='8')
plt.scatter(vis_x[index9], vis_y[index9], c=colors[9], cmap='brg',marker='o',label='9')

plt.title(u't-SNE')
plt.legend()
plt.show()




作者: matlab的旋律    时间: 2020-2-26 16:30
#####################################加载库#######################################
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams['font.sans-serif']=['Microsoft YaHei']#设置汉字
plt.rcParams['axes.unicode_minus']=False
##################################################################################

digits = load_digits()#加载数据
embeddings = TSNE(n_components=3).fit_transform(digits.data)#t-SNE降维,这里设置为三维
vis_x = embeddings[:, 0]#0维
vis_y = embeddings[:, 1]#1维
vis_z = embeddings[:, 2]#2维

############################提取不同label对应的坐标#################################
index0 = [i for i in range(len(digits.target)) if digits.target == 0]
index1 = [i for i in range(len(digits.target)) if digits.target == 1]
index2 = [i for i in range(len(digits.target)) if digits.target == 2]
index3 = [i for i in range(len(digits.target)) if digits.target == 3]
index4 = [i for i in range(len(digits.target)) if digits.target == 4]
index5 = [i for i in range(len(digits.target)) if digits.target == 5]
index6 = [i for i in range(len(digits.target)) if digits.target == 6]
index7 = [i for i in range(len(digits.target)) if digits.target == 7]
index8 = [i for i in range(len(digits.target)) if digits.target == 8]
index9 = [i for i in range(len(digits.target)) if digits.target == 9]
###################################################################################


#######################################绘图########################################
colors=['b', 'c', 'y', 'm', 'r', 'g', 'k','yellow','yellowgreen','wheat']
fig = plt.figure()
ax4 = Axes3D(fig)
ax4.scatter(vis_x[index0], vis_y[index0], vis_z[index0], c=colors[0], cmap='brg', marker='h',label='0')
ax4.scatter(vis_x[index1], vis_y[index1], vis_z[index1], c=colors[1], cmap='brg',marker='<',label='1')
ax4.scatter(vis_x[index2], vis_y[index2], vis_z[index2], c=colors[2], cmap='brg',marker='x',label='2')
ax4.scatter(vis_x[index3], vis_y[index3], vis_z[index3], c=colors[3], cmap='brg',marker='.',label='3')
ax4.scatter(vis_x[index4], vis_y[index4], vis_z[index4], c=colors[4], cmap='brg',marker='p',label='4')
ax4.scatter(vis_x[index5], vis_y[index5], vis_z[index5], c=colors[5], cmap='brg',marker='>',label='5')
ax4.scatter(vis_x[index6], vis_y[index6], vis_z[index6], c=colors[6], cmap='brg',marker='^',label='6')
ax4.scatter(vis_x[index7], vis_y[index7], vis_z[index7], c=colors[7], cmap='brg',marker='d',label='7')
ax4.scatter(vis_x[index8], vis_y[index8], vis_z[index8], c=colors[8], cmap='brg',marker='s',label='8')
ax4.scatter(vis_x[index9], vis_y[index9], vis_z[index9], c=colors[9], cmap='brg',marker='o',label='9')

ax4.grid(False)#去掉网格线
ax4.patch.set_facecolor(color="green")

plt.title('t-SNE降维分类效果图')
plt.legend()
plt.show()








欢迎光临 52matlab技术网站,matlab教程,matlab安装教程,matlab下载 (http://www.52matlab.com/) Powered by Discuz! X3.2