Tutorial6.java

<< Click to Display Table of Contents >>

Navigation:  Tutorials > Tutorial 6: Dynamic model >

Tutorial6.java

package tutorials;

 

import smile.*;

 

// Tutorial6 creates a dynamic Bayesian network (DBN),

// performs the inference, then saves the model to disk.

 

public class Tutorial6 {

    public static void run() {

        System.out.println("Starting Tutorial6...");

        Network net = new Network();

        

        int loc = createCptNode(

                net, "Location", "Location",

                new String[] { "Pittsburgh", "Sahara" }, 

                160, 360);

        

        int rain = createCptNode(

                net, "Rain", "Rain",

                new String[] { "true", "false" }, 

                380, 240);

        

        int umb = createCptNode(

                net, "Umbrella", "Umbrella",

                new String[] { "true", "false" },

                300, 100);

        

        net.setNodeTemporalType(rain, Network.NodeTemporalType.PLATE);

        net.setNodeTemporalType(umb, Network.NodeTemporalType.PLATE);

        

        net.addArc(loc, rain);

        net.addTemporalArc(rain, rain, 1);

        net.addArc(rain, umb);

    

        double[] rainDef = new double[] {

            0.7,  // P(Rain=true |Location=Pittsburgh)

            0.3,  // P(Rain=false|Location=Pittsburgh)

            0.01, // P(Rain=true |Location=Sahara)

            0.99  // P(Rain=false|Location=Sahara)

        };

        net.setNodeDefinition(rain, rainDef);

 

        double[] rainDefTemporal = new double[] {

            0.7,   // P(Rain=true |Location=Pittsburgh,Rain[t-1]=true)

            0.3,   // P(Rain=false|Location=Pittsburgh,Rain[t-1]=true)

            0.3,   // P(Rain=true |Location=Pittsburgh,Rain[t-1]=false)

            0.7,   // P(Rain=false|Location=Pittsburgh,Rain[t-1]=false)

            0.001, // P(Rain=true |Location=Sahara,Rain[t-1]=true)

            0.999, // P(Rain=false|Location=Sahara,Rain[t-1]=true)

            0.01,  // P(Rain=true |Location=Sahara,Rain[t-1]=false)

            0.99  // P(Rain=false|Location=Sahara,Rain[t-1]=false)

  };

        net.setNodeTemporalDefinition(rain, 1, rainDefTemporal);

        

        double[] umbDef = new double[] {

                    0.9, // P(Umbrella=true |Rain=true)

            0.1, // P(Umbrella=false|Rain=true)

            0.2, // P(Umbrella=true |Rain=false)

            0.8  // P(Umbrella=false|Rain=false)                

        };

        net.setNodeDefinition(umb, umbDef);

        

        net.setSliceCount(5);

 

        System.out.println("Performing update without evidence.");

        updateAndShowTemporalResults(net);

 

        System.out.println(

            "Setting Umbrella[t=1] to true and Umbrella[t=3] to false.");

        net.setTemporalEvidence(umb, 1, 0);

        net.setTemporalEvidence(umb, 3, 1);

        updateAndShowTemporalResults(net);

 

        net.writeFile("tutorial6.xdsl");

        System.out.println(

            "Tutorial6 complete: Network written to tutorial6.xdsl");

    }

    

    

    private static void updateAndShowTemporalResults(Network net) {

            net.updateBeliefs();

     int sliceCount = net.getSliceCount();

     for (int h = net.getFirstNode(); h >= 0; h = net.getNextNode(h)) {

             if (net.getNodeTemporalType(h) == Network.NodeTemporalType.PLATE) {

                     int outcomeCount = net.getOutcomeCount(h);

             System.out.printf(

                    "Temporal beliefs for %s:\n", net.getNodeId(h));

             double[] v = net.getNodeValue(h);

                         for (int sliceIdx = 0; sliceIdx < sliceCount; sliceIdx++) {

                                 System.out.printf("\tt=%d:", sliceIdx);

                                 for (int i = 0; i < outcomeCount; i++) {

                                     System.out.printf(" %f", v[sliceIdx*outcomeCount+i]);

                                 }

                                 System.out.println();

                         }

         }

     }

     System.out.println();

    }

    

    

    private static int createCptNode(

            Network net, String id, String name, 

            String[] outcomes, int xPos, int yPos) {

        int handle = net.addNode(Network.NodeType.CPT, id);

        

        net.setNodeName(handle, name);

        net.setNodePosition(handle, xPos, yPos, 85, 55);

        

        int initialOutcomeCount = net.getOutcomeCount(handle); 

        for (int i = 0; i < initialOutcomeCount; i ++) {

            net.setOutcomeId(handle, i, outcomes[i]);

        }

        

        for (int i = initialOutcomeCount; i < outcomes.length; i ++) {

            net.addOutcome(handle, outcomes[i]);

        }

        

        return handle;

    }

 

}