C/C++知识点之【EM】C++代码实现
小标 2018-12-03 来源 : 阅读 1062 评论 0

摘要:本文主要向大家介绍了 C/C++知识点之【EM】C++代码实现,通过具体的内容向大家展示,希望对大家学习C/C++知识点有所帮助。

本文主要向大家介绍了 C/C++知识点之【EM】C++代码实现,通过具体的内容向大家展示,希望对大家学习C/C++知识点有所帮助。

看了原理和比人的代码后,终于自己写了一个EM的实现。

我从网上找了一些身高性别的数据,用EM算法通过身高信息来识别性别。
实现的效果还行,正确率有84% (初始数据 男生170 女生160 方差都是10)
                                 79%  (初始数据 男生165 女生150 方差都是10)
正确率与初始值有关。

/*
试图用EM算法来根据输入的身高来区分性别
*/

#include
#include
#include
#include
using namespace std;

#define PI 3.14159
#define max(x,y) (x > y ? x : y)

typedef struct FLOAT2
{
    float f1;
    float f2;
}FLOAT2;
typedef struct Gaussian
{
    float mean;
    float var;
}Gaussian;

typedef struct EMData
{
    char sex;
    float fHeight;
}EMData;

//获取身高性别数据
int getdata(vector &Data)
{
    ifstream fin;
    fin.open("data.txt");
    if(!fin)
    {
        cout<<"error: can‘t open the file."<<endl;
        return -1;
    }

    while(!fin.eof())
    {
        char c[10];
        float height;
        fin >> c >> height;
        EMData data;
        data.sex = c[0];
        data.fHeight = height;
        Data.push_back(data);
    }

    return 0;
}

//根据身高数据区分性别, 返回正确率
float predict(vector Data)
{
    //设符合正态分布
    Gaussian sex[2];
    float a[2]; //男女生所占百分比
    float t = 1;
    float tlimit = 0.000001; //收敛条件

    //赋初值 下标0表示男生 1表示女生
    sex[0].mean = 180.0;
    sex[0].var = 10.0;
    sex[1].mean = 150.0;
    sex[1].var = 10.0;
    a[0] = 0.5;
    a[1] = 0.5;

    while(t > tlimit)
    {
        Gaussian sex_old[2];
        float a_old[2];
        sex_old[0] = sex[0];
        sex_old[1] = sex[1];
        a_old[0] = a[0];
        a_old[1] = a[1];

        //计算每个样本分别被两个模型抽中的概率
        vector px;
    
        vector::iterator it;
        for(it = Data.begin(); it < Data.end(); it++)
        {
            FLOAT2 p;
            p.f1 = 1/(sqrt(2 * PI * sex[0].var)) * exp(-(it->fHeight - sex[0].mean) * (it->fHeight - sex[0].mean) / (2 * sex[0].var));
            p.f2 = 1/(sqrt(2 * PI * sex[1].var)) * exp(-(it->fHeight - sex[1].mean) * (it->fHeight - sex[1].mean) / (2 * sex[1].var));
            px.push_back(p);
        }

        //E步
        //计算每个样本属于男生或女生的概率
        vector::iterator it2;
        for(it2 = px.begin(); it2 < px.end(); it2++)
        {
            float sum = 0.0;
            (*it2).f1 *= a[0];
            sum += (*it2).f1;
            (*it2).f2 *= a[1];
            sum += (*it2).f2;

            (*it2).f1 = (*it2).f1/sum;
            (*it2).f2 = (*it2).f2/sum;
        }

        //M步
        float sum_male = 0, sum_female = 0;
        float sum_mean_male = 0, sum_mean_female = 0;
        for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
        {
            sum_male += (*it2).f1;
            sum_female += (*it2).f2;
            sum_mean_male += (*it2).f1 * (it->fHeight);
            sum_mean_female += (*it2).f2 * (it->fHeight);
        }
        //更新a
        a[0] = sum_male/(sum_male + sum_female);
        a[1] = sum_female/(sum_male + sum_female);

        //更新均值
        sex[0].mean = sum_mean_male/ sum_male;
        sex[1].mean = sum_mean_female/ sum_female;

        //更新方差
        float sum_var_male = 0, sum_var_female = 0;
        for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
        {
            sum_var_male += (*it2).f1 * ((it->fHeight) - sex[0].mean) * ((it->fHeight) - sex[0].mean);
            sum_var_female += (*it2).f2 * ((it->fHeight) - sex[1].mean) * ((it->fHeight) - sex[1].mean);
        }
        sex[0].var = sum_var_male / sum_male;
        sex[1].var = sum_var_female / sum_female;

        //计算变化率
        t = max((a[0] - a_old[0])/a_old[0], (a[1] - a_old[1])/a_old[1]);
        t = max(t, (sex[0].mean - sex_old[0].mean)/sex_old[0].mean);
        t = max(t, (sex[1].mean - sex_old[1].mean)/sex_old[1].mean);
        t = max(t, (sex[0].var - sex_old[0].var)/sex_old[0].var);
        t = max(t, (sex[1].var - sex_old[1].var)/sex_old[1].var);
    }

    //计算正确率
    int correct_num = 0;
    float correct_rate = 0;
    vector::iterator it;
    for(it = Data.begin(); it < Data.end(); it++)
    {
        float p[2];
        char csex;
        for(int i = 0; i < 2; i++)
        {
            p[i] = 1/(sqrt(2 * PI * sex[i].var)) * exp(-(it->fHeight - sex[i].mean) * (it->fHeight - sex[i].mean) / (2 * sex[i].var));
        }

        csex = (p[0] > p[1]) ? ‘m‘ : ‘f‘;
        if(csex == it->sex)
            correct_num++;
    }

    correct_rate = (float)correct_num / Data.size();
    return correct_rate;
}

int main()
{
    vector Data;
    getdata(Data);
    float correct_rate = predict(Data);
    cout << "correct rate = "<< correct_rate << endl;
    return 0;
}

 
数据:data.txt内容

male    164
female    156
male    168
female    160
female    162
male    187
female    162
male    167
female    160.5
female    160
female    158
female    164
female    165
male    174
female    166
female    158
male     162
male    175
male    170
female    161
female    169
female    161
female    160
female    167
male    176
male    169
male    178
male    165
female    155
male    183
male    171
male    179
female    154
male    172
female    172
male    173
male    172
male    175
male    160
male    160
male    160
male    175
male    163
male    181
male    172
male    175
male    175
male    167
male    172
male    169
male    172
male    175
male    172
male    170
male    158
male    167
male    164
male    176
male    182
male    173
male    176
male    163
male    166
male    162
male    169
male    163
male    163
male    176
male    169
male    173
male    163
male    167
male    176
male    168
male    167
male    170
female    155
female    157
female    165
female    156
female    155
female    156
female    160
female    158
female    162
female    162
female    155
female    163
female    160
female    162
female    165
female    159
female    147
female    163
female    157
female    160
female    162
female    158
female    155
female    165
female    161
female    159
female    163
female    158
female    155
female    162
female    157
female    159
female    152
female    156
female    165
female    154
female    156
female    162

本文由职坐标整理并发布,希望对同学们有所帮助。了解更多详情请关注职坐标编程语言C/C+频道!

本文由 @小标 发布于职坐标。未经许可,禁止转载。
喜欢 | 0 不喜欢 | 0
看完这篇文章有何感觉?已经有0人表态,0%的人喜欢 快给朋友分享吧~
评论(0)
后参与评论

您输入的评论内容中包含违禁敏感词

我知道了

助您圆梦职场 匹配合适岗位
验证码手机号,获得海同独家IT培训资料
选择就业方向:
人工智能物联网
大数据开发/分析
人工智能Python
Java全栈开发
WEB前端+H5

请输入正确的手机号码

请输入正确的验证码

获取验证码

您今天的短信下发次数太多了,明天再试试吧!

提交

我们会在第一时间安排职业规划师联系您!

您也可以联系我们的职业规划师咨询:

小职老师的微信号:z_zhizuobiao
小职老师的微信号:z_zhizuobiao

版权所有 职坐标-一站式IT培训就业服务领导者 沪ICP备13042190号-4
上海海同信息科技有限公司 Copyright ©2015 www.zhizuobiao.com,All Rights Reserved.
 沪公网安备 31011502005948号    

©2015 www.zhizuobiao.com All Rights Reserved

208小时内训课程