Tutorial9.py

<< Click to Display Table of Contents >>

Navigation:  Tutorials > Tutorial 9: Diagnosis >

Tutorial9.py

import pysmile

 

# Tutorial9 loads HeparII.xdsl file

# and runs multiple diagnostic algorithms

# Use the link below to download the HeparII.xdsl file:

# https://support.bayesfusion.com/docs/Examples/Discrete%20Bayesian%20Networks/HeparII.xdsl

 

class Tutorial9:

    def __init__(self):

        print("Starting tutorial 9...")

        net = pysmile.Network()

        net.read_file("HeparII.xdsl")

        print("Hepar model loaded")

        print_diag_types(net)

 

        print("Creating diagnostic network object")

        diag = pysmile.DiagNetwork(net)

        print_fault_indices(diag)

 

        pursued_fault_idx = diag.get_pursued_fault()

        print(f"The default (most likely) pursued fault is at index {pursued_fault_idx}: " 

              + f"{diag.get_fault_node_id(pursued_fault_idx)}" 

              + f"={diag.get_fault_outcome_id(pursued_fault_idx)}")

        

        print("Running diagnosis with no instantiated observations")

        diag_results = diag.update() 

        print_diag_results(net, diag_results)

 

        print("Running diagnosis with three observations")

        diag.instantiate_observation("jaundice", "present")

        diag.instantiate_observation("nausea", "absent")

        diag.instantiate_observation("obesity", "present")

        diag_results = diag.update() 

        print_diag_results(net, diag_results)

 

        print("Running diagnosis with two observations and focusing on Hyperbilirubinemia")

        diag.release_observation("nausea")

        diag.instantiate_observation("obesity", "absent")

        hyperbilirubinemia_idx = diag.get_fault_index("Hyperbilirubinemia", "present")

        diag.set_pursued_fault(hyperbilirubinemia_idx)

        diag_results = diag.update() 

        print_diag_results(net, diag_results)

 

        print("Switching algorithm to cross-entropy, observations and pursued fault unchanged")

        diag.set_single_fault_algorithm(pysmile.SingleFaultAlgorithmType.CROSSENTROPY)

        diag_results = diag.update() 

        print_diag_results(net, diag_results)

    

        print("Running diagnosis with two observations and "

            + "focusing on both Hyperbilirubinemia and Steatosis")

        steatosis_idx = diag.get_fault_index("Steatosis", "present")

        diag.set_pursued_faults([steatosis_idx, hyperbilirubinemia_idx])

        diag_results = diag.update() 

        print_diag_results(net, diag_results)

        

        print("Swtiching algorithm to L2 distance, observations and pursued faults unchanged")

        diag.set_multi_fault_algorithm(pysmile.MultiFaultAlgorithmType.L2_NORMALIZED_DISTANCE)

        diag_results = diag.update() 

        print_diag_results(net, diag_results)

 

        print("Tutorial 9 complete.")            

 

def print_diag_types(net):

    print(f"Network has {net.get_node_count()} nodes")

    print("{:<20} {:<30} {:<30}".format("Node Id", "Diagnostic Type", "Fault Outcomes"))

    for node_id in net.get_all_node_ids():

        diag_type = net.get_node_diag_type(node_id)

        fault_outcomes = []

        if diag_type == pysmile.NodeDiagType.FAULT:

            for outcome_id in net.get_outcome_ids(node_id):

                if net.is_fault_outcome(node_id, outcome_id):

                    fault_outcomes.append(outcome_id)

        print("{:<20} {:<30} {:<30}".format(node_id, str(net.get_node_diag_type(node_id)), 

            str(fault_outcomes) if diag_type == pysmile.NodeDiagType.FAULT else "n/a"))

 

def print_fault_indices(diag):

    fault_count = diag.get_fault_count()

    print(f"Diagnostic network has {fault_count} faults (node/outcome pairs)")

    print("{:<12} {:<20} {:<20}".format("Fault Index", "Fault Node Id", "Fault Outcome Id"))

    for fidx in range(fault_count):

        print("{:<12} {:<20} {:<20}".format(

            fidx, diag.get_fault_node_id(fidx), diag.get_fault_outcome_id(fidx)))

 

def print_fault_info(net, fault):

    print("{:<20} {:<20} {:<24} {:<8}".format(

        net.get_node_id(fault.node), net.get_outcome_id(fault.node, fault.outcome),

        fault.probability, "Yes" if fault.is_pursued else "No"))

 

def print_observation_info(net, observation):

    print("{:<20} {:<24} {:<8} {:<20}".format(net.get_node_id(observation.node), 

        observation.measure, observation.cost, observation.info_gain))

 

def print_diag_results(net, diag_results):

    print("Diag results start\n\nFaults:")

    print("{:<20} {:<20} {:<24} {:<8}".format(

        "Node Id", "Outcome Id", "Probability", "Is Pursued"))

    for fault_info in diag_results.faults:

        print_fault_info(net, fault_info)

    print("\nObservations:")

    print("{:<20} {:<24} {:<8} {:<20}".format("Node Id", "Measure", "Cost", "InfoGain"))

    for observation_info in diag_results.observations:

        print_observation_info(net, observation_info)

    print("\nDiag results end\n")