Tutorial3.java

<< Click to Display Table of Contents >>

Navigation:  Tutorials > Tutorial 3: Exploring the contents of a model >

Tutorial3.java

package tutorials;

 

import smile.*;

 

// Tutorial3 loads the XDSL file and prints the information

// about the structure (nodes and arcs) and the parameters 

// (conditional probabilities of the nodes) of the network.

 

public class Tutorial3 {

    public static void run() {

        System.out.println("Starting Tutorial3...");

        Network net = new Network();

        

        // load the network created by Tutorial1

        net.readFile("tutorial1.xdsl");

        

        for (int h = net.getFirstNode(); h >= 0; h = net.getNextNode(h)) {

            printNodeInfo(net, h);

        }

 

        System.out.println("Tutorial3 complete.");        

    }

    

    

    private static void printNodeInfo(Network net, int nodeHandle) {

        System.out.printf("Node id/name: %s/%s\n", 

            net.getNodeId(nodeHandle), 

            net.getNodeName(nodeHandle));

        

        System.out.print("  Outcomes:");

        for (String outcomeId: net.getOutcomeIds(nodeHandle)) {

            System.out.print(" " + outcomeId);

        }

        System.out.println();

 

        String[] parentIds = net.getParentIds(nodeHandle);

        if (parentIds.length > 0) {

            System.out.print("  Parents:");

            for (String parentId: parentIds) {

                System.out.print(" " + parentId);

            }

            System.out.println();

        }

 

        String[] childIds = net.getChildIds(nodeHandle);

        if (childIds.length > 0) {

            System.out.print("  Children:");

            for (String childId: childIds) {

                System.out.print(" " + childId);

            }

            System.out.println();

        }

        

        printCptMatrix(net, nodeHandle);

    }

 

    

    private static void printCptMatrix(Network net, int nodeHandle) {

        double[] cpt = net.getNodeDefinition(nodeHandle);

        int[] parents = net.getParents(nodeHandle);

        int dimCount = 1 + parents.length;

        

        int[] dimSizes = new int[dimCount];

        for (int i = 0; i < dimCount - 1; i ++) {

            dimSizes[i] = net.getOutcomeCount(parents[i]);

        }

        dimSizes[dimSizes.length - 1] = net.getOutcomeCount(nodeHandle);

        

        int[] coords = new int[dimCount];

        for (int elemIdx = 0; elemIdx < cpt.length; elemIdx ++) {

            indexToCoords(elemIdx, dimSizes, coords);

            

            String outcome = net.getOutcomeId(nodeHandle, coords[dimCount - 1]);

            System.out.printf("    P(%s", outcome);

            

            if (dimCount > 1) {

                System.out.print(" | ");

                for (int parentIdx = 0; parentIdx < parents.length; parentIdx++)

                {

                    if (parentIdx > 0) System.out.print(",");

                    int parentHandle = parents[parentIdx];

                    System.out.printf("%s=%s", 

                        net.getNodeId(parentHandle), 

                        net.getOutcomeId(parentHandle, coords[parentIdx]));

                }

            }

 

            double prob = cpt[elemIdx];

            System.out.printf(")=%f\n", prob);

        }

    }

 

    

    private static void indexToCoords(int index, int[] dimSizes, int[] coords) {

        int prod = 1;

        for (int i = dimSizes.length - 1; i >= 0; i --) {

            coords[i] = (index / prod) % dimSizes[i];

            prod *= dimSizes[i];

        }

    }

    

}