|
沙发
楼主 |
发表于 2020-2-26 16:30:16
|
只看该作者
#####################################加载库#######################################
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams['font.sans-serif']=['Microsoft YaHei']#设置汉字
plt.rcParams['axes.unicode_minus']=False
##################################################################################
digits = load_digits()#加载数据
embeddings = TSNE(n_components=3).fit_transform(digits.data)#t-SNE降维,这里设置为三维
vis_x = embeddings[:, 0]#0维
vis_y = embeddings[:, 1]#1维
vis_z = embeddings[:, 2]#2维
############################提取不同label对应的坐标#################################
index0 = [i for i in range(len(digits.target)) if digits.target == 0]
index1 = [i for i in range(len(digits.target)) if digits.target == 1]
index2 = [i for i in range(len(digits.target)) if digits.target == 2]
index3 = [i for i in range(len(digits.target)) if digits.target == 3]
index4 = [i for i in range(len(digits.target)) if digits.target == 4]
index5 = [i for i in range(len(digits.target)) if digits.target == 5]
index6 = [i for i in range(len(digits.target)) if digits.target == 6]
index7 = [i for i in range(len(digits.target)) if digits.target == 7]
index8 = [i for i in range(len(digits.target)) if digits.target == 8]
index9 = [i for i in range(len(digits.target)) if digits.target == 9]
###################################################################################
#######################################绘图########################################
colors=['b', 'c', 'y', 'm', 'r', 'g', 'k','yellow','yellowgreen','wheat']
fig = plt.figure()
ax4 = Axes3D(fig)
ax4.scatter(vis_x[index0], vis_y[index0], vis_z[index0], c=colors[0], cmap='brg', marker='h',label='0')
ax4.scatter(vis_x[index1], vis_y[index1], vis_z[index1], c=colors[1], cmap='brg',marker='<',label='1')
ax4.scatter(vis_x[index2], vis_y[index2], vis_z[index2], c=colors[2], cmap='brg',marker='x',label='2')
ax4.scatter(vis_x[index3], vis_y[index3], vis_z[index3], c=colors[3], cmap='brg',marker='.',label='3')
ax4.scatter(vis_x[index4], vis_y[index4], vis_z[index4], c=colors[4], cmap='brg',marker='p',label='4')
ax4.scatter(vis_x[index5], vis_y[index5], vis_z[index5], c=colors[5], cmap='brg',marker='>',label='5')
ax4.scatter(vis_x[index6], vis_y[index6], vis_z[index6], c=colors[6], cmap='brg',marker='^',label='6')
ax4.scatter(vis_x[index7], vis_y[index7], vis_z[index7], c=colors[7], cmap='brg',marker='d',label='7')
ax4.scatter(vis_x[index8], vis_y[index8], vis_z[index8], c=colors[8], cmap='brg',marker='s',label='8')
ax4.scatter(vis_x[index9], vis_y[index9], vis_z[index9], c=colors[9], cmap='brg',marker='o',label='9')
ax4.grid(False)#去掉网格线
ax4.patch.set_facecolor(color="green")
plt.title('t-SNE降维分类效果图')
plt.legend()
plt.show()
|
本帖子中包含更多资源
您需要 登录 才可以下载或查看,没有帐号?立即注册
x
|