机器学习朴素贝叶斯作业 3.0

机器学习朴素贝叶斯作业 3.0,第1张

机器学习朴素贝叶斯作业 3.0 书接上回

机器学习朴素贝叶斯作业

获取每类数据在每个属性集上的均值与方差 目的:获取数据特征,便于对数据进行分类 核心代码
def fit():  # 求均值&方差
    mean_var = [[np.mean(X[y == label], axis=0), np.var(X[y == label], axis=0)] for label in np.unique(y)]
    return mean_var
所有代码
import numpy as np


def fit():  # 求均值&方差
    mean_var = [[np.mean(X[y == label], axis=0), np.var(X[y == label], axis=0)] for label in np.unique(y)]
    return mean_var


if __name__ == '__main__':
    # 读取数据
    data = np.loadtxt("iris.csv",  # 数据源
                      dtype='str',  # 读取类型
                      delimiter=',',  # 分割符号
                      skiprows=1)

    # 数据预处理
    X = data[::, 0:-1].astype('float32')
    y = data[:, -1]
    print(fit())
输出结果
[[array([5.006, 3.428, 1.462, 0.246]), array([0.121764, 0.140816, 0.029556, 0.010884])], 
[array([5.936, 2.77 , 4.26 , 1.326]), array([0.261104, 0.0965  , 0.2164  , 0.038324])], 
[array([6.588, 2.974, 5.552, 2.026]), array([0.396256, 0.101924, 0.298496, 0.073924])]]
### array([5.006, 3.428, 1.462, 0.246]) 
### 为第一类(Setosa)在sepal.length sepal.width petal.length petal.width四个属性的均值
### array([0.121764, 0.140816, 0.029556, 0.010884]) 
### 为第一类(Setosa)在sepal.length sepal.width petal.length petal.width四个属性的方差
### array([5.936, 2.77 , 4.26 , 1.326])
### 为第二类(Versicolor)在sepal.length sepal.width petal.length petal.width四个属性的均值
### array([0.261104, 0.0965  , 0.2164  , 0.038324])
### 为第二类(Versicolor)在sepal.length sepal.width petal.length petal.width四个属性的方差
### array([6.588, 2.974, 5.552, 2.026])
### 为第三类(Virginica)在sepal.length sepal.width petal.length petal.width四个属性的均值
### array([0.396256, 0.101924, 0.298496, 0.073924])
### 为第三类(Virginica)在sepal.length sepal.width petal.length petal.width四个属性的方差
模型评估 目的:对模型准确率进行检测 核心代码
    def score(self, feature, label):  # 模型评估
        return np.sum(np.array([[np.argmax(np.array([np.array(
            self.GaussianProbability(sample, np.array([i[0] for i in self.fit()])[j],
                                     np.array([i[1] for i in self.fit()])[j])).prod() for j in
                                                     range(len(np.unique(label)))]))] for sample in
                                feature]).ravel() == label) / label.size
手撕核心代码 判断预测结果与真实标签是否一致
### 注意为核心代码块 由外及里 暂时先不考虑内容 
np.sum(np.array([<core> for sample in feature]).ravel() == label) / label.size
获取每类数据在每个属性上的均值与方差后求贝叶斯概率,由于朴素贝叶斯的属性条件独立性假设,每个属性独立的对分类结果产生影响, 因此需要对每个条件概率累乘后取最大值索引,即为预测结果。
### 代码
np.argmax(np.array([np.array(
            self.GaussianProbability(sample, np.array([i[0] for i in self.fit()])[j],
                                     np.array([i[1] for i in self.fit()])[j])).prod() for j in
                                                     range(len(np.unique(label)))]))
所有代码
import numpy as np


def GaussianProbability(x, mean, var):  # 高斯概率密度函数
    return np.array(
        [1 / (np.sqrt(2 * np.pi) * var[i]) *
         np.exp(-np.power(x[i] - mean[i], 2) / (2 * np.power(var[i], 2))) for i in range(len(x))])


def fit():  # 求均值&方差
    mean_var = [[np.mean(X[y == label], axis=0), np.var(X[y == label], axis=0)] for label in np.unique(y)]
    return mean_var


def score(feature, label):  # 模型评估
    return np.sum(np.array([[np.argmax(np.array([np.array(
        GaussianProbability(sample, np.array([i[0] for i in fit()])[j],
                            np.array([i[1] for i in fit()])[j])
    ).prod() for j in range(len(np.unique(label)))]))] for sample in
                            feature]).ravel() == label) / label.size


if __name__ == '__main__':
    # 读取数据
    data = np.loadtxt("iris.csv",  # 数据源
                      dtype='str',  # 读取类型
                      delimiter=',',  # 分割符号
                      skiprows=1)

    # 数据预处理
    X = data[::, 0:-1].astype('float32')
    y = data[:, -1]
    y = [0 if i == 'Setosa' else i for i in list(y)]
    y = [1 if i == 'Versicolor' else i for i in list(y)]
    y = [2 if i == 'Virginica' else i for i in list(y)]
    print(f'模型准确率为:{score(X, np.array(y)) * 100}%')

原创不易 转载请标明出处
如果对你有所帮助 别忘啦点赞支持哈

欢迎分享,转载请注明来源:内存溢出

原文地址: http://www.outofmemory.cn/langs/717256.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-25
下一篇 2022-04-25

发表评论

登录后才能评论

评论列表(0条)

保存