I've build a simple network up the scretch and want to perform inference with it. If i save it to disc and load with GeNIe everything is fine, but when I use the following code the line
Code: Select all
DSLNetwork.GetNode(_queryNodesIdx2[i])->Value()->IsValueValid() != DSL_OKAY
Thanks in advance,
nordic
Code: Select all
DSL_Dmatrix cptObjectClass;
DSL_Dmatrix cptArea;
vector<string> stateNamesObjectclass;
vector<string> stateNamesArea;
DSL_network DSLNetwork;
// setup DSL reference network
cptObjectClass.AddDimension(2);
cptObjectClass[0] = 0.5;
cptObjectClass[1] = 0.5;
cptArea.AddDimension(5);
cptArea.AddDimension(2);
cptArea[0] = 0.4;
cptArea[1] = 0.4;
cptArea[2] = 0.2;
cptArea[3] = 0;
cptArea[4] = 0;
cptArea[5] = 0;
cptArea[6] = 0;
cptArea[7] = 0.2;
cptArea[8] = 0.4;
cptArea[9] = 0.4;
stateNamesObjectclass.resize(2);
stateNamesObjectclass[0] = "Pier";
stateNamesObjectclass[1] = "Wasser";
stateNamesArea.resize(5);
stateNamesArea[0] = "state_1_below_16";
stateNamesArea[1] = "state_2_16_44";
stateNamesArea[2] = "state_3_44_372";
stateNamesArea[3] = "state_4_372_31134";
stateNamesArea[4] = "state_5_31134_up";
int idxObjectclass = DSLNetwork.AddNode(DSL_CPT, "objectclass");
int idxArea = DSLNetwork.AddNode(DSL_CPT, "Area");
DSL_stringArray outcomesObjectclass;
if( outcomesObjectclass.Add( "Pier") < 0 )
std::perror( "Cannot add state names to node \"objectclass\"");
if( !outcomesObjectclass.Add( "Wasser") < 0)
std::perror( "Cannot add state names to node \"objectclass\"");
DSL_stringArray outcomesArea;
outcomesArea.Add(stateNamesArea.at(0).c_str());
outcomesArea.Add(stateNamesArea.at(1).c_str());
outcomesArea.Add(stateNamesArea.at(2).c_str());
outcomesArea.Add(stateNamesArea.at(3).c_str());
outcomesArea.Add(stateNamesArea.at(4).c_str());
DSLNetwork.GetNode( DSLNetwork.FindNode("objectclass") )->Definition()->SetNumberOfOutcomes(outcomesObjectclass);
DSLNetwork.GetNode( DSLNetwork.FindNode("Area") )->Definition()->SetNumberOfOutcomes(outcomesArea);
DSLNetwork.AddArc( idxObjectclass, idxArea);
DSLNetwork.GetNode( DSLNetwork.FindNode("objectclass") )->Definition()->SetDefinition(cptObjectClass);
DSLNetwork.GetNode( DSLNetwork.FindNode("Area") )->Definition()->SetDefinition(cptArea);
// setup inference parameter
vector<int> _evidenceNodesIdx2;
vector<int> _queryNodesIdx2;
vector<int> _evidence2;
vector<vector<double> > _beliefs2;
_evidenceNodesIdx2.push_back(idxObjectclass);
_queryNodesIdx2.push_back(idxArea);
_beliefs2.resize(_queryNodesIdx2.size());
_beliefs2.at(0).push_back(0.4);
_beliefs2.at(0).push_back(0.4);
_beliefs2.at(0).push_back(0.2);
_beliefs2.at(0).push_back(0);
_beliefs2.at(0).push_back(0);
_evidence2.push_back(0);
// set Evidence
for(int i=0; i< _evidenceNodesIdx2.size(); i++)
{
if(DSLNetwork.GetNode(_evidenceNodesIdx2[i])->Value()->SetEvidence(_evidence2[i]) != DSL_OKAY)
{
cout <<"\t\t\t" << ErrorH.GetLastErrorMessage() << endl;
return 1;
}
}
DSLNetwork.SetDefaultBNAlgorithm(DSL_ALG_BN_LAURITZEN);
// do inference
if( DSLNetwork.UpdateBeliefs() != DSL_OKAY)
{
cout << "\t\t\t" << ErrorH.GetLastErrorMessage() << endl;
return 1;
}
// get beliefs
vector< vector<double> > beliefs(_queryNodesIdx2.size()); // inner vector is still empty
for(int i=0; i<_queryNodesIdx2.size(); i++)
{
DSL_Dmatrix currentBeliefs( *(DSLNetwork.GetNode(_queryNodesIdx2[i])->Value()->GetMatrix()) );
if( DSLNetwork.GetNode(_queryNodesIdx2[i])->Value()->IsValueValid() != DSL_OKAY )
{
cout << "NodeValue of node with idx: " << _queryNodesIdx2[i] << " isn't valid " << endl;
return 1;
}
//convert into STL class
(beliefs[i]).resize( currentBeliefs.GetSize() ); // init inner vector
//iterate over all entries
for(int j=0; j<currentBeliefs.GetSize(); ++j)
{
beliefs[i][j] = currentBeliefs[j];
}
}
//prepare for next inference
if(DSLNetwork.ClearAllEvidence() != DSL_OKAY)
{
cout <<"\t\t" << ErrorH.GetLastErrorMessage() << endl;
return -1;
}
// print beliefs
cout << "current beliefs: " << endl;
for(int i=0; i< beliefs.size(); ++i)
for(int j=0; j < beliefs.at(i).size(); ++j)
cout << beliefs.at(i).at(j) << " " << endl;
// print reference
cout << "refernce beliefs: " << endl;
for(int i=0; i< _beliefs2.size(); ++i)
for(int j=0; j < _beliefs2.at(i).size(); ++j)
cout << _beliefs2.at(i).at(j) << " " << endl;
return 0;
}