|
<< Click to Display Table of Contents >> Navigation: Using SMILE Wrappers > Learning > Validation |
To evaluate the predictive quality of a network, use the Validator class.
The Validator constructor requires references to both DataSet and Network objects. To correctly match network nodes to dataset variables, the constructor also requires an array of DataMatch objects, similar to the EM.learn method.
After constructing a validator object, specify which nodes in the network are considered class nodes by calling the Validator.add_class_node method. Validation requires at least one class node.
During validation, the variables matched to non-class nodes are used to set evidence for each record in the dataset. Posterior probabilities are then calculated, and for each class node, the outcome with the highest probability is selected as the predicted outcome. Predictions are compared with the corresponding outcomes in the dataset, and the results—including the number of matches and posterior probabilities—are used to calculate accuracy, the confusion matrix, ROC curves, and calibration curves.
Validation can be performed either without parameter learning using Validator.test, or with parameter learning using Validator.k_fold or Validator.leave_one_out. K-fold cross-validation divides the dataset into K parts of equal size, trains the network on K–1 parts, and tests it on the remaining part. This process is repeated K times, with each part used once for testing. Leave-one-out is an extreme case of K-fold, where K equals the number of records in the dataset.
The example below demonstrates 5-fold cross-validation using a single class node. Accuracy is obtained for the outcome with index zero, corresponding to the first outcome of the node.
Python
ds = pysmile.learning.DataSet()
net = pysmile.Network()
# load network and data here
matching = ds.match_network(net)
validator = pysmile.learning.Validator(ds, net, matching)
classNodehandle = net.getNode("someNodeId")
validator.addClassNode(classNodeHandle)
em = pysmile.learning.EM()
# optionally tweak EM options here
validator.k_fold(em, 5)
acc = validator.get_accuracy(classNodeHandle, 0)
Java
DataSet ds = new DataSet();
Network net = new Network();
// load network and data here
DataMatch[] matching = ds.matchNetwork();
Validator validator = new Validator(ds, net, matching);
int classNodehandle = net.getNode("someNodeId");
validator.addClassNode(classNodeHandle);
EM em = new EM();
// optionally tweak EM options here
validator.kFold(em, 5);
double acc = validator.getAccuracy(classNodeHandle, 0);
C#
DataSet ds = new DataSet();
Network net = new Network();
// load network and data here
DataMatch[] matching = ds.MatchNetwork();
Validator validator = new Validator(ds, net, matching);
int classNodehandle = net.GetNode("someNodeId");
validator.AddClassNode(classNodeHandle);
EM em = new EM();
// optionally tweak EM options here
validator.KFold(em, 5);
double acc = validator.GetAccuracy(classNodeHandle, 0);
R
ds <- DataSet()
net <- Network()
# load network and data here
matching <- ds$matchNetwork()
validator <- Validator(ds, net, matching)
classNodehandle <- net$getNode("someNodeId")
validator$addClassNode(classNodeHandle)
em <- EM()
# optionally tweak EM options here
validator$kFold(em, 5)
acc <- validator$getAccuracy(classNodeHandle, 0)