Printing CPT of a given node

The engine.
Post Reply
rickyegeland
Posts: 4
Joined: Fri Mar 29, 2013 8:13 pm

Printing CPT of a given node

Post by rickyegeland »

I'm trying to learn the SMILE API and wanted to write a function that would print the CPT of a given node, without assuming any knowledge of the network. Would any experts be willing to look it over to see if I am accessing data in the "best" way? Getting the parent node's name and state names given the DSL_sysCoordinates was a bit tricky... maybe there's a shortcut?

Code: Select all

void printCPT(DSL_node* node) {                                                                                         
  DSL_network* net = node->Network(); // node network                                                                   
  int nid = node->Handle(); // node id                                                                                  
  const char* name = node->GetId(); // node name                                                                        
  DSL_idArray* nstates  = node->Definition()->GetOutcomesNames(); // names of node states                               
  DSL_intArray parents = net->GetParents(nid); // names of node's parents                                               
                                                                                                                        
  DSL_sysCoordinates coords(*(node->Definition())); // node coordinate navigator                                        
  int result = DSL_OKAY;                                                                                                
  coords.GoFirst();                                                                                                     
  while (result != DSL_OUT_OF_RANGE) {                                                                                  
    DSL_intArray cix = coords.Coordinates(); // array of state indexes for parents, node                                
    int six = cix[ cix.GetSize()-1 ]; // node's state index                                                             
    const char* state = (*nstates)[six]; // node's state name                                                           
    cout << "P(" << name << " = " << state << " | ";                                                                    
    for (int pix = 0; pix < cix.GetSize() - 1; pix++) { // iterate parent nodes                                         
      int pid = parents[pix]; // parent node handle                                                                     
      six = cix[pix]; // parent node state index                                                                        
      DSL_node* pnode = net->GetNode(pid); // parent node                                                               
      const char* pname = pnode->Info().Header().GetId(); // parent node name                                           
      DSL_idArray* pstates = pnode->Definition()->GetOutcomesNames(); // parent node state names                        
      cout << pname << " = " << (*pstates)[six];                                                                        
      if (pix + 1 < cix.GetSize() - 1) { cout << ", "; }                                                                
    }                                                                                                                   
    double prob = coords.UncheckedValue();                                                                              
    cout << ") = " << prob << endl;                                                                                     
    result = coords.Next();                                                                                             
  }                                                                                                                     
}
shooltz[BayesFusion]
Site Admin
Posts: 1417
Joined: Mon Nov 26, 2007 5:51 pm

Re: Printing CPT of a given node

Post by shooltz[BayesFusion] »

rickyegeland wrote:Getting the parent node's name and state names given the DSL_sysCoordinates was a bit tricky... maybe there's a shortcut?
You can avoid DSL_sysCoordinates to obtain a slightly shorter version of your function, but the general idea is the same (provide a set of integer coords for each CPT entry).

Code: Select all

void printCPT2(DSL_node *node)
{
	DSL_network* net = node->Network(); // node network                                                                   
	int handle = node->Handle();
	DSL_nodeDefinition *def = node->Definition();
	const DSL_Dmatrix &cpt = *def->GetMatrix();
	const DSL_idArray &outcomes = *def->GetOutcomesNames();
	const DSL_intArray &parents = net->GetParents(handle);
	int parentCount = parents.NumItems();

	DSL_intArray coords;
	for (int elemIdx = 0; elemIdx < cpt.GetSize(); elemIdx ++)
	{
		cpt.IndexToCoordinates(elemIdx, coords);
		cout << "P(" << node->GetId() << " = " << outcomes[coords[parentCount]] << " | ";
		for (int parentIdx = 0; parentIdx < parentCount; parentIdx ++)
		{
			if (parentIdx > 0) cout << ", ";
			DSL_node *parentNode = net->GetNode(parents[parentIdx]);
			const DSL_idArray &parentStates = *parentNode->Definition()->GetOutcomesNames();
			cout << parentNode->GetId() << " = " << parentStates[coords[parentIdx]];             
		}
		cout << ") = " << cpt[elemIdx] << endl;
	}
}
rickyegeland
Posts: 4
Joined: Fri Mar 29, 2013 8:13 pm

Re: Printing CPT of a given node

Post by rickyegeland »

Thanks, it is useful to see another way to do it!
Post Reply