Code: Select all
void EM(DSL_network &theNet, string NBfile_Learning, string DiscParameterFileName)
{
DSL_dataset dataset;
cout<<"File to be read: "<<DiscParameterFileName<<endl;
if(dataset.ReadFile(DiscParameterFileName)!=DSL_OKAY)
{
cout<<"Reading failed!"<<endl;
return;
}
else cout<<"Reading successful!"<<endl;
DSL_network result;
vector<DSL_datasetMatch> matchedData;
string error;
if(dataset.MatchNetwork(theNet, matchedData, error) !=DSL_OKAY)
{
cout<<"Matching failed!"<<endl;
return;
}
cout<<"Matching successful!"<<endl;
DSL_em em;
cout<<"Learning in progress...."<<endl;
if (em.Learn(dataset, theNet, matchedData)!=DSL_OKAY)
{
cout << "Learning failed!" << endl;
return;
}
cout<< "Learning successful!" << endl;
theNet.WriteFile(NBfile_Learning.c_str());
}