Tutorial10.java

<< Click to Display Table of Contents >>

Navigation:  Tutorials > Tutorial 10: Structure learning >

Tutorial10.java

package tutorials;

 

import smile.*;

import smile.learning.*;

 

// Tutorial9 loads Credit10k.csv file

// and runs multiple structure learning algorithms

// using the loaded dataset.

// Use the link below to download the Credit10k.csv file:

// https://support.bayesfusion.com/docs/Examples/Learning/Credit10K.csv

 

public class Tutorial10 {

    public static void run() {

        System.out.println("Starting Tutorial9...");

        DataSet ds = new DataSet();

        try

        {

            ds.readFile("Credit10k.csv");

        } catch (SMILEException ex)

        {

            System.out.println("Dataset load failed");

            return;

        }

        System.out.printf("Dataset has %d variables (columns) and %d records (rows)\n", 

            ds.getVariableCount(), ds.getRecordCount());

        BayesianSearch bayesSearch = new BayesianSearch();

        bayesSearch.setIterationCount(50);

        bayesSearch.setRandSeed(9876543);

        Network net1;

        try

        {

            net1 = bayesSearch.learn(ds);

        } catch (SMILEException ex)

        {

            System.out.println("Bayesian Search failed");

            return;

        }

        System.out.printf("1st Bayesian Search finished, structure score: %f\n", 

            bayesSearch.getLastScore());

        net1.writeFile("tutorial9-bs1.xdsl");

 

        Network net2;

        bayesSearch.setRandSeed(3456789);

        try

        {

            net2 = bayesSearch.learn(ds);

        }

        catch (SMILEException ex)

        {

            System.out.println("Bayesian Search failed");

            return;

        }

        System.out.printf("2nd Bayesian Search finished, structure score: %f\n", 

            bayesSearch.getLastScore());

        net2.writeFile("tutorial9-bs2.xdsl");

 

        int idxAge = ds.findVariable("Age");

        int idxProfession = ds.findVariable("Profession");

        int idxCreditWorthiness = ds.findVariable("CreditWorthiness");

        if (idxAge < 0 || idxProfession < 0 || idxCreditWorthiness < 0)

        {

            System.out.println("Can't find dataset variables for background knowledge");

            System.out.println("The loaded file may not be Credit10k.csv");

            return;

        }

 

        BkKnowledge backgroundKnowledge = new BkKnowledge();

         backgroundKnowledge.matchData(ds);

        backgroundKnowledge.addForbiddenArc(idxAge, idxCreditWorthiness);

        backgroundKnowledge.addForcedArc(idxAge, idxProfession);

        bayesSearch.setBkKnowledge(backgroundKnowledge);

        Network net3;

        try

        {

            net3 = bayesSearch.learn(ds);

        }

        catch (SMILEException ex)

        {

            System.out.println("Bayesian Search failed");

            return;

        }

        System.out.printf("3rd Bayesian Search finished, structure score: %f\n", 

            bayesSearch.getLastScore());

        net3.writeFile("tutorial9-bs3.xdsl");

 

        Network net4;

        TAN tan = new TAN();

        tan.setRandSeed(777999);

        tan.setClassVariableId("CreditWorthiness");

        try

        {

            net4 = tan.learn(ds);

        }

        catch (SMILEException ex)

        {

            System.out.println("TAN failed");

            return;

        }

        System.out.println("Tree-augmented Naive Bayes finished");

        net4.writeFile("tutorial9-tan.xdsl");

 

        PC pc = new PC();

        Pattern pattern;

        try

        {

            pattern = pc.learn(ds);

        } catch (SMILEException ex)

        {

            System.out.println("PC failed");

            return;

        }

        Network net5 = pattern.makeNetwork(ds);

        System.out.println("PC finished, proceeding to parameter learning");

        net5.writeFile("tutorial9-pc.xdsl");

        EM em = new EM();

        DataMatch[] matching;

        try

        {

            matching = ds.matchNetwork(net5);

        }

        catch (SMILEException ex)

        {

            System.out.println("Can't automatically match network with dataset");

            return;

        }

        em.setUniformizeParameters(false);

        em.setRandomizeParameters(false);

        em.setEqSampleSize(0);

        try

        {

            em.learn(ds, net5, matching);

        }

        catch (SMILEException ex)

        {

            System.out.println("EM failed");

            return;

        }

        System.out.println("EM finished");

        net5.writeFile("tutorial9-pc-em.xdsl");

 

        System.out.println("Tutorial10 complete");

    }

}