tutorial6.cpp

<< Click to Display Table of Contents >>

Navigation:  Tutorials > Tutorial 6: A dynamic model >

tutorial6.cpp

// tutorial6.cpp

// Tutorial6 creates a dynamic Bayesian network (DBN),

// performs the inference, then saves the model to disk.

 

#include "smile.h"

#include <cstdio>

 

 

static int CreateCptNode(

    DSL_network& net, const char* id, const char* name,

    std::initializer_list<const char*> outcomes, int xPos, int yPos);

 

static void UpdateAndShowTemporalResults(DSL_network &net);

 

 

int Tutorial6()

{

    printf("Starting Tutorial6...\n");

 

    DSL_errorH().RedirectToFile(stdout);

    

    DSL_network net;

 

    int loc = CreateCptNode(net, "Location", "Location",

        { "Pittsburgh", "Sahara" }, 160, 360);

 

    int rain = CreateCptNode(net, "Rain", "Rain",

        { "true", "false" }, 380, 240);

 

    int umb = CreateCptNode(net, "Umbrella", "Umbrella",

        { "true", "false" }, 300, 100);

 

    net.SetTemporalType(rain, dsl_temporalType::dsl_plateNode);

    net.SetTemporalType(umb, dsl_temporalType::dsl_plateNode);

 

    net.AddArc(loc, rain);

    net.AddTemporalArc(rain, rain, 1);

    net.AddArc(rain, umb);

 

    DSL_nodeDef *rainDef = net.GetNode(rain)->Def();

    int res = rainDef->SetDefinition({

        0.7,  // P(Rain=true |Location=Pittsburgh)

        0.3,  // P(Rain=false|Location=Pittsburgh)

        0.01, // P(Rain=true |Location=Sahara)

        0.99  // P(Rain=false|Location=Sahara)

    });

    if (DSL_OKAY != res)

    {

        return res;

    }

 

    res = rainDef->SetTemporalDefinition(1, {

        0.7,   // P(Rain=true |Location=Pittsburgh,Rain[t-1]=true)

        0.3,   // P(Rain=false|Location=Pittsburgh,Rain[t-1]=true)

        0.3,   // P(Rain=true |Location=Pittsburgh,Rain[t-1]=false)

        0.7,   // P(Rain=false|Location=Pittsburgh,Rain[t-1]=false)

        0.001, // P(Rain=true |Location=Sahara,Rain[t-1]=true)

        0.999, // P(Rain=false|Location=Sahara,Rain[t-1]=true)

        0.01,  // P(Rain=true |Location=Sahara,Rain[t-1]=false)

        0.99   // P(Rain=false|Location=Sahara,Rain[t-1]=false)

    });

    if (DSL_OKAY != res)

    {

        return res;

    }

 

    res = net.GetNode(umb)->Def()->SetDefinition({

        0.9, // P(Umbrella=true |Rain=true)

        0.1, // P(Umbrella=false|Rain=true)

        0.2, // P(Umbrella=true |Rain=false)

        0.8  // P(Umbrella=false|Rain=false)

    });

    if (DSL_OKAY != res)

    {

        return res;

    }

 

    net.SetNumberOfSlices(5);

 

    printf("Performing update without evidence.\n");

    UpdateAndShowTemporalResults(net);

 

    printf("Setting Umbrella[t=1] to true and Umbrella[t=3] to false.\n");

    res = net.GetNode(umb)->Val()->SetTemporalEvidence(1, 0);

    if (DSL_OKAY != res)

    {

        return res;

    }

    res = net.GetNode(umb)->Val()->SetTemporalEvidence(3, 1);

    if (DSL_OKAY != res)

    {

        return res;

    }

    UpdateAndShowTemporalResults(net);

 

    res = net.WriteFile("tutorial6.xdsl");

    if (DSL_OKAY != res)

    {

        return res;

    }

    

    printf("Tutorial6 complete: Network written to tutorial6.xdsl\n");

    return DSL_OKAY;

}

 

 

static void UpdateAndShowTemporalResults(DSL_network &net)

{

    net.UpdateBeliefs();

    int sliceCount = net.GetNumberOfSlices();

    for (int h = net.GetFirstNode(); h >= 0; h = net.GetNextNode(h))

    {

        if (net.GetTemporalType(h) == dsl_temporalType::dsl_plateNode)

        {

            DSL_node *node = net.GetNode(h);

            int outcomeCount = node->Def()->GetNumberOfOutcomes();

            printf("Temporal beliefs for %s:\n", node->GetId());

            const DSL_Dmatrix& mtx = *node->Val()->GetMatrix();

            for (int sliceIdx = 0; sliceIdx < sliceCount; sliceIdx++)

            {

                printf("\tt=%d:", sliceIdx);

                for (int i = 0; i < outcomeCount; i++)

                {

                    printf(" %f", mtx[sliceIdx * outcomeCount + i]);

                }

                printf("\n");

            }

        }

    }

    printf("\n");

}

 

 

static int CreateCptNode(

    DSL_network& net, const char* id, const char* name,

    std::initializer_list<const char*> outcomes, int xPos, int yPos)

{

    int handle = net.AddNode(DSL_CPT, id);

    DSL_node* node = net.GetNode(handle);

    node->SetName(name);

    node->Def()->SetNumberOfOutcomes(outcomes);

    DSL_rectangle& position = node->Info().Screen().position;

    position.center_X = xPos;

    position.center_Y = yPos;

    position.width = 85;

    position.height = 55;

    return handle;

}