贝叶斯分类器实现Mnist手写数字识别(TensorFlow实现)

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
train_num = 5000
test_num = 100
class_num = 10
desimon = 784
#pca_desimon = 20


# def pca(X, target):
#     x_mean = np.subtract(X, np.mean(X, axis=0))
#     cov = np.cov(x_mean.T)
#     feature_vector = np.linalg.eig(cov)[1][:target].T
#     return np.dot(X, feature_vector)
#

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
prediction = []
for i in range(test_num):
    test = x_test[i]
    class_rate = []
    # 求每一个类别的概率,这里MNIST数据集共有10个类别
    for j in range(class_num):
        # 找到样本中类别是j的下标
        class_is_j_index = np.where(y_train[:train_num] == j)[0]
        # 类别是j的比率
        j_rate = len(class_is_j_index)/len(y_train)
        # 取出类别是j的样本
        class_is_j_x = np.array([x_train[x] for x in class_is_j_index])
        # 遍历每个维度
        for k in range(desimon):
            # 找到j类样本集中该维度下的值与测试样本中该维度的值的差小于0.8的样本,并求占j类样本的比率,与j_rate依次相乘
            # 这里我规定的界限是0.8,因为MNIST中样本数字在0到1之间,并且是两端分布,要么是0,要么接近1。
            j_rate *= len([item for item in class_is_j_x if np.fabs(item[k] - test[k]) < 0.8])*1.0 / len(class_is_j_x)
        class_rate.append(j_rate)
    # 找到贝叶斯预测值最大的类别,作为该测试的预测类别,放到结果集中
    prediction.append(np.argmax(class_rate))
    print(i, 'prediction:', prediction[-1], 'actual:', y_test[i])

accurancy = np.sum(np.equal(prediction, y_test[:test_num])) / test_num
print('accurancy:', accurancy)

实验结果:

0 prediction: 7 actual: 7
1 prediction: 2 actual: 2
2 prediction: 1 actual: 1
3 prediction: 0 actual: 0
4 prediction: 4 actual: 4
5 prediction: 1 actual: 1
6 prediction: 4 actual: 4
7 prediction: 9 actual: 9
8 prediction: 4 actual: 5
9 prediction: 7 actual: 9
10 prediction: 0 actual: 0
11 prediction: 6 actual: 6
12 prediction: 9 actual: 9
13 prediction: 0 actual: 0
14 prediction: 1 actual: 1
15 prediction: 3 actual: 5
16 prediction: 9 actual: 9
17 prediction: 7 actual: 7
18 prediction: 3 actual: 3
19 prediction: 4 actual: 4
20 prediction: 9 actual: 9
21 prediction: 6 actual: 6
22 prediction: 6 actual: 6
23 prediction: 5 actual: 5
24 prediction: 4 actual: 4
25 prediction: 0 actual: 0
26 prediction: 7 actual: 7
27 prediction: 4 actual: 4
28 prediction: 0 actual: 0
29 prediction: 1 actual: 1
30 prediction: 3 actual: 3
31 prediction: 1 actual: 1
32 prediction: 3 actual: 3
33 prediction: 0 actual: 4
34 prediction: 7 actual: 7
35 prediction: 2 actual: 2
36 prediction: 7 actual: 7
37 prediction: 1 actual: 1
38 prediction: 3 actual: 2
39 prediction: 1 actual: 1
40 prediction: 1 actual: 1
41 prediction: 7 actual: 7
42 prediction: 4 actual: 4
43 prediction: 2 actual: 2
44 prediction: 3 actual: 3
45 prediction: 3 actual: 5
46 prediction: 5 actual: 1
47 prediction: 2 actual: 2
48 prediction: 9 actual: 4
49 prediction: 4 actual: 4
50 prediction: 6 actual: 6
51 prediction: 3 actual: 3
52 prediction: 5 actual: 5
53 prediction: 5 actual: 5
54 prediction: 6 actual: 6
55 prediction: 8 actual: 0
56 prediction: 4 actual: 4
57 prediction: 1 actual: 1
58 prediction: 9 actual: 9
59 prediction: 5 actual: 5
60 prediction: 7 actual: 7
61 prediction: 8 actual: 8
62 prediction: 9 actual: 9
63 prediction: 2 actual: 3
64 prediction: 7 actual: 7
65 prediction: 5 actual: 4
66 prediction: 2 actual: 6
67 prediction: 4 actual: 4
68 prediction: 3 actual: 3
69 prediction: 0 actual: 0
70 prediction: 7 actual: 7
71 prediction: 0 actual: 0
72 prediction: 2 actual: 2
73 prediction: 8 actual: 9
74 prediction: 1 actual: 1
75 prediction: 7 actual: 7
76 prediction: 3 actual: 3
77 prediction: 9 actual: 2
78 prediction: 9 actual: 9
79 prediction: 7 actual: 7
80 prediction: 9 actual: 7
81 prediction: 6 actual: 6
82 prediction: 2 actual: 2
83 prediction: 7 actual: 7
84 prediction: 8 actual: 8
85 prediction: 4 actual: 4
86 prediction: 7 actual: 7
87 prediction: 3 actual: 3
88 prediction: 6 actual: 6
89 prediction: 1 actual: 1
90 prediction: 3 actual: 3
91 prediction: 6 actual: 6
92 prediction: 4 actual: 9
93 prediction: 3 actual: 3
94 prediction: 1 actual: 1
95 prediction: 4 actual: 4
96 prediction: 1 actual: 1
97 prediction: 8 actual: 7
98 prediction: 6 actual: 6
99 prediction: 9 actual: 9
accurancy: 0.83

Process finished with exit code 0
 

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: Age of Ai 设计师:meimeiellie 返回首页