<< Click to Display Table of Contents >> Navigation: Appendix M: Matlab and SMILE > matsmile.cpp |
// matsmile.cpp
// Simple MEX wrapper for use with existing discrete Bayesian networks
#include "mex.hpp"
#include "mexAdapter.hpp"
#include "smile_license.h"
#include "smile.h"
using namespace std;
class MexFunction : public matlab::mex::Function {
public:
void operator()(matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs) {
auto f = findFunction(outputs, inputs);
(this->*(f.ptr))(outputs, inputs);
}
private:
shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
matlab::data::ArrayFactory factory;
union uint64netptr { uint64_t integer; DSL_network* pointer; };
typedef void (MexFunction::* PTR)(matlab::mex::ArgumentList&, matlab::mex::ArgumentList&);
struct MatSmileFunction {
const char* name;
int inputSize;
PTR ptr;
};
void error(const string& msg) {
matlabPtr->feval(u"error",
0, vector<matlab::data::Array>({ factory.createScalar(msg.c_str()) }));
}
matlab::data::TypedArray<matlab::data::MATLABString> createStringArray(const vector<string>& vec) {
return factory.createArray({ vec.size() }, vec.begin(), vec.end());
}
const MatSmileFunction& findFunction(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
static const MatSmileFunction wrapFx[] = {
"newNetwork", 1, &MexFunction::newNetwork,
"deleteNetwork", 2, &MexFunction::deleteNetwork,
"getNodeCount", 2, &MexFunction::getNodeCount,
"readFile", 3, &MexFunction::readFile,
"updateBeliefs", 2, &MexFunction::updateBeliefs,
"setEvidence", 4, &MexFunction::setEvidence,
"getValue", 3, &MexFunction::getValue,
"getAllNodeIds", 2, &MexFunction::getAllNodeIds,
"getOutcomeIds", 3, &MexFunction::getOutcomeIds,
"isEvidence", 3, &MexFunction::isEvidence,
"getEvidence", 3, &MexFunction::getEvidence,
"getOutcomeCount", 3, &MexFunction::getOutcomeCount
};
matlab::data::CharArray charArray = inputs[0];
auto functionName = charArray.toAscii().c_str();
const auto iterator = find_if(begin(wrapFx), end(wrapFx), [&functionName](auto f) {
return !strcmp(f.name, functionName);
});
if (iterator == end(wrapFx)) {
string msg = "Cannot find function with name ";
msg += functionName;
error(msg);
}
int inputSize = iterator->inputSize;
if (inputs.size() != inputSize) {
string msg = "Invalid input size for function: ";
msg += functionName;
msg += ". Expected: ";
DSL_appendInt(msg, inputSize);
msg += ", given: ";
DSL_appendInt(msg, (int)inputs.size());
error(msg);
}
if (outputs.size() > 1) {
error("Outputs size cannot be greater than 1.");
}
return *iterator;
}
int validateNodeId(const DSL_network& net, const char* nodeId) {
int handle = net.FindNode(nodeId);
if (handle < 0) {
string msg = "Cannot find node with ID '";
msg += nodeId;
msg += '\'';
error(msg);
}
return handle;
}
int validateOutcomeId(const DSL_network& net, int nodeHandle, const char* outcomeId) {
const DSL_node* node = net.GetNode(nodeHandle);
const DSL_idArray* outcomeNames = node->Def()->GetOutcomeIds();
int outcomeIndex = outcomeNames->FindPosition(outcomeId);
if (outcomeIndex < 0) {
string msg = "Invalid outcome identifier '";
msg += outcomeId;
msg += "' for node '";
msg += node->GetId();
msg += '\'';
error(msg);
}
return outcomeIndex;
}
uint64_t encodePtr(DSL_network* pointer) {
uint64netptr ivp;
ivp.pointer = pointer;
return ivp.integer;
}
DSL_network* decodePtr(uint64_t integer) {
uint64netptr ivp;
ivp.integer = integer;
return ivp.pointer;
}
void newNetwork(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
auto net = new DSL_network();
outputs[0] = factory.createScalar<uint64_t>(encodePtr(net));
}
void deleteNetwork(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
delete net;
}
void getNodeCount(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
outputs[0] = factory.createScalar(net->GetNumberOfNodes());
}
void readFile(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
matlab::data::CharArray charArray = inputs[2];
auto filePath = charArray.toAscii();
int result = net->ReadFile(filePath.c_str());
if (result != DSL_OKAY) {
string msg = "Cannot read file from '";
msg += filePath;
msg += "' ErrNo ";
DSL_appendInt(msg, result);
error(msg);
}
}
void updateBeliefs(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
int result = net->UpdateBeliefs();
if (result != DSL_OKAY) {
string msg = "Update Beliefs failed. ErrNo ";
DSL_appendInt(msg, result);
error(msg);
}
}
void setEvidence(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
matlab::data::CharArray nodeIdCharArray = inputs[2];
matlab::data::CharArray outcomeIdCharArray = inputs[3];
int nodeHandle = validateNodeId(*net, nodeIdCharArray.toAscii().c_str());
int outcomeIndex = validateOutcomeId(*net, nodeHandle, outcomeIdCharArray.toAscii().c_str());
net->GetNode(nodeHandle)->Val()->SetEvidence(outcomeIndex);
}
void getValue(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = (DSL_network*)decodePtr(arr[0]);
matlab::data::CharArray nodeIdCharArray = inputs[2];
auto nodeId = nodeIdCharArray.toAscii();
int nodeHandle = validateNodeId(*net, nodeId.c_str());
const DSL_nodeVal* nodeValue = net->GetNode(nodeHandle)->Val();
if (!nodeValue->IsValueValid()) {
string msg = "Invalid node value for node ";
msg += nodeId;
error(msg);
}
const DSL_Dmatrix* m = nodeValue->GetMatrix();
const double* p = m->GetItems().Items();
size_t arraySize = m->GetSize();
matlab::data::TypedArray<double> resArr = factory.createArray({ arraySize }, p, p + arraySize);
outputs[0] = resArr;
}
void getAllNodeIds(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = (DSL_network*)decodePtr(arr[0]);
size_t count = net->GetNumberOfNodes();
vector<string> ids;
ids.resize(count);
DSL_intArray nodes;
net->GetAllNodes(nodes);
for (int i = 0; i < count; i++) {
ids[i] = string(net->GetNode(nodes[i])->GetId());
}
auto resArr = createStringArray(ids);
outputs[0] = resArr;
}
void getOutcomeCount(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
matlab::data::CharArray nodeIdCharArray = inputs[2];
int nodeHandle = validateNodeId(*net, nodeIdCharArray.toAscii().c_str());
outputs[0] = factory.createScalar(net->GetNode(nodeHandle)->Def()->GetNumberOfOutcomes());
}
void getOutcomeIds(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
matlab::data::CharArray nodeIdCharArray = inputs[2];
int nodeHandle = validateNodeId(*net, nodeIdCharArray.toAscii().c_str());
auto def = net->GetNode(nodeHandle)->Def();
const DSL_idArray* names = def->GetOutcomeIds();
size_t count = def->GetNumberOfOutcomes();
vector<string> namesVec(names->begin(), names->end());
auto resArr = createStringArray(namesVec);
outputs[0] = resArr;
}
void isEvidence(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
matlab::data::CharArray nodeIdCharArray = inputs[2];
int nodeHandle = validateNodeId(*net, nodeIdCharArray.toAscii().c_str());
outputs[0] = factory.createScalar(0 != net->GetNode(nodeHandle)->Val()->IsEvidence());
}
void getEvidence(matlab::mex::ArgumentList& outputs, matlab::mex::ArgumentList& inputs) {
matlab::data::TypedArray<uint64_t> arr = inputs[1];
auto net = decodePtr(arr[0]);
matlab::data::CharArray nodeIdCharArray = inputs[2];
int nodeHandle = validateNodeId(*net, nodeIdCharArray.toAscii().c_str());
DSL_node* node = net->GetNode(nodeHandle);
int evidence = node->Val()->GetEvidence();
if (evidence < 0) {
string msg = "Evidence for node ";
msg += node->GetId();
msg += " does not exist";
error(msg);
}
std::string outcomeId = node->Def()->GetOutcomeIds()->Subscript(evidence);
outputs[0] = factory.createCharArray(outcomeId);
}
};