Tutorial8.java

<< Click to Display Table of Contents >>

Navigation:  Tutorials > Tutorial 8: Hybrid model >

Tutorial8.java

package tutorials;

 

import smile.*;

 

// Tutorial8 loads continuous model from the XDSL file written by Tutorial7,

// then adds discrete nodes to create a hybrid model. Inference is performed

// and model is saved to disk.

 

public class Tutorial8 {

    public static void run() {

        System.out.println("Starting Tutorial8...");

        Network net = new Network();

        

        net.readFile("tutorial7.xdsl");

        

        createCptNode(

                net, "zone", "Climate Zone",

                new String[] { "Temperate", "Desert" },

            60, 20);

 

        int toaHandle = net.getNode("toa");

        net.setNodeEquation(toaHandle, "toa=If(zone=\"Desert\",Normal(22,5),Normal(11,10))");

 

        int perceivedHandle = createCptNode(

                net, "perceived", "Perceived Temperature", 

                new String[] { "Hot", "Warm", "Cold" },

            60, 300);

        net.addArc(toaHandle, perceivedHandle);

        

        double[] perceivedProbs = new double[] {

                    0,    // P(perceived=Hot |toa in -10..0)

            0.02, // P(perceived=Warm|toa in -10..0)

            0.98, // P(perceived=Cold|toa in -10..0)

            0.05, // P(perceived=Hot |toa in 0..10)

            0.15, // P(perceived=Warm|toa in 0..10)

            0.80, // P(perceived=Cold|toa in 0..10)

            0.10, // P(perceived=Hot |toa in 10..20)

            0.80, // P(perceived=Warm|toa in 10..20)

            0.10, // P(perceived=Cold|toa in 10..20)

            0.80, // P(perceived=Hot |toa in 20..30)

            0.15, // P(perceived=Warm|toa in 20..30)

            0.05, // P(perceived=Cold|toa in 20..30)

            0.98, // P(perceived=Hot |toa in 30..40)

            0.02, // P(perceived=Warm|toa in 30..40)

            0     // P(perceived=Cold|toa in 30..40)

        };

        net.setNodeDefinition(perceivedHandle, perceivedProbs);

 

        net.setEvidence("zone", "Temperate");

        System.out.println("Results in temperate zone:");

        updateAndShowStats(net);

 

        net.setEvidence("zone", "Desert");

        System.out.println("Results in desert zone:\n");

        updateAndShowStats(net);

        

        net.writeFile("tutorial8.xdsl");

        System.out.println("Tutorial8 complete: Network written to tutorial8.xdsl");

    }

 

    static void showStats(Network net, int nodeHandle) {

         String nodeId = net.getNodeId(nodeHandle);

         

         if (net.isEvidence(nodeHandle)) {

                 double v = net.getContEvidence(nodeHandle);

                 System.out.printf("%s has evidence set (%g)\n", nodeId, v);

                 return;

         

         

         if (net.isValueDiscretized(nodeHandle)) {

                 System.out.printf("%s is discretized.\n", nodeId);

                 DiscretizationInterval[] iv = 

                net.getNodeEquationDiscretization(nodeHandle);

                 double[] bounds = net.getNodeEquationBounds(nodeHandle);

                 double[] discBeliefs = net.getNodeValue(nodeHandle);

                 double lo = bounds[0];

             for (int i = 0; i < discBeliefs.length; i++) {

                 double hi = iv[i].boundary;

                 System.out.printf(

                    "\tP(%s in %g..%g)=%g\n", nodeId, lo, hi, discBeliefs[i]);

                 lo = hi;

             }

         } else {

                 double[] stats = net.getNodeSampleStats(nodeHandle);

         System.out.printf("%s: mean=%g stddev=%g min=%g max=%g\n",

             nodeId, stats[0], stats[1], stats[2], stats[3]);

         }

    }

   

    static void updateAndShowStats(Network net) {

        net.updateBeliefs();

        for (int h = net.getFirstNode(); h >= 0; h = net.getNextNode(h))

        {

            if (net.getNodeType(h) == Network.NodeType.EQUATION) {

                    showStats(net, h);

            }

        }

        System.out.println();

    }

    

    private static int createCptNode(

            Network net, String id, String name, 

            String[] outcomes, int xPos, int yPos) {

        int handle = net.addNode(Network.NodeType.CPT, id);

        

        net.setNodeName(handle, name);

        net.setNodePosition(handle, xPos, yPos, 85, 55);

        

        int initialOutcomeCount = net.getOutcomeCount(handle); 

        for (int i = 0; i < initialOutcomeCount; i ++) {

            net.setOutcomeId(handle, i, outcomes[i]);

        }

        

        for (int i = initialOutcomeCount; i < outcomes.length; i ++) {

            net.addOutcome(handle, outcomes[i]);

        }

        

        return handle;

    }

}