Validation

<< Click to Display Table of Contents >>

Navigation:  Using SMILE > Learning >

Validation

To evaluate the predictive quality of a network you can use the DSL_validator class. The DSL_validator constructor requires references to DSL_dataset and DSL_network objects to be specified. To properly match the network and data, the constructor also requires a vector of DSL_datasetMatch objects (as did DSL_em::Learn method). After the validator object is constructed, you need to specify which nodes in the network are considered class nodes by calling DSL_validator::AddClassNode method. Validation requires at least one class node.

During the validation, for each record in the data set, the variables matched to non-class nodes are used to set the evidence. The posterior probabilities are then calculated and for each class node the outcome with the highest probability is selected as the predicted outcome. The prediction is compared to the outcome variable (in the data set) that is associated with the class node. The number of matches and calculated posteriors are used to obtain the accuracy, confusion matrix, ROC (including the AUC) and calibration curves.

Validation can be performed without parameter learning, using DSL_validator::Test method, or with parameter learning using DSL_validator::KFold and LeaveOneOut methods. K-fold cross-validation divides the data set into K parts of equal size, trains the network on K-1 parts, and tests it on the last, Kth part. The process is repeated K times, with a different part of the data being selected for testing. Leave-one-out is an extreme case of K-fold, in which K is equal to the number of records in the data set.

The example below performs K-fold crossvalidation with five folds using one class node.

DSL_dataset ds;

DSL_network net;

vector<DSL_datasetMatch> matching;

// load network and dataset, create the matching here

DSL_validator validator(ds, net, matching);

int classNodeHandle = net.FindNode("someNodeIdentifier");

validator.AddClassNode(classNodeHandle);

DSL_em em;

// optionally tweak the EM options here

int res = validator.KFold(em, 5);

if (DSL_OKAY == res)

{

    double acc;

    validator.GetAccuracy(classNodeHandle, 0, acc);

    vector<pair<double, double> roc;

    vector<double> thresholds;
    double auc;

    validator.CreateROC(classNodeHandle, 0, roc, thresholds, auc);

    printf("Accuracy=%f Area under the curve=%f\n", acc, auc);

}

See the DSL_validator reference for more details.