diff --git a/SS20/Bayes_Tests/test_internal.py b/SS20/Bayes_Tests/test_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..105ab09637c04daeb264f185db1d6d5efc7b9dca --- /dev/null +++ b/SS20/Bayes_Tests/test_internal.py @@ -0,0 +1,129 @@ +import sys +import os +import importlib.util +import csv + +def makenode(name, parents, *probs): + probabilities = {} + for i in range(len(probs)//2): + combination = tuple({'f' : False, 't' : True}[e] for e in probs[2*i]) + assert len(combination) == len(parents) + assert combination not in probabilities + probabilities[combination] = float(probs[2*i+1]) + assert len(probabilities) == 2**len(parents) + return {'name' : name, 'parents' : parents, 'probabilities' : probabilities} + +network1 = { + 'Asia' : makenode('Asia', [], '', 0.05), + 'Smoke' : makenode('Smoke', [], '', 0.3), + 'TBC' : makenode('TBC', ['Asia'], 't', 0.01, 'f', 0.001), + 'LC' : makenode('LC', ['Smoke'], 't', 0.2, 'f', 0.08), + 'Bron' : makenode('Bron', ['Smoke'], 't', 0.4, 'f', 0.1), + 'Xray' : makenode('Xray', ['TBC', 'LC'], 'tt', 0.98, 'tf', 0.94, 'ft', 0.92, 'ff', 0.02), + 'Dysp' : makenode('Dysp', ['TBC', 'LC', 'Bron'], + 'ttt', 0.99, 'ttf', 0.97, 'tft', 0.98, 'tff', 0.9, + 'ftt', 0.98, 'ftf', 0.92, 'fft', 0.95, 'fff', 0.07), + } + +network2 = { + 'Burglary' : makenode('Burglary', [], '', 0.001), + 'Earthquake' : makenode('Earthquake', [], '', 0.002), + 'Alarm' : makenode('Alarm', ['Burglary', 'Earthquake'], 'tt', 0.95, 'tf', 0.94, 'ft', 0.29, 'ff', 0.001), + 'JohnCalls' : makenode('JohnCalls', ['Alarm'], 't', 0.9, 'f', 0.05), + 'MaryCalls' : makenode('MaryCalls', ['Alarm'], 't', 0.7, 'f', 0.01), + } + +network3 = { + 'Trivial' : makenode('Trivial', [], '', 0.08), + } + +network4 = { + 'Deterministic' : makenode('Deterministic', [], '', 1), + 'Earthquake' : makenode('Earthquake', [], '', 0.002), + 'Alarm' : makenode('Alarm', ['Deterministic', 'Earthquake'], 'tt', 0.95, 'tf', 0.94, 'ft', 0.29, 'ff', 0.001), + 'JohnCalls' : makenode('JohnCalls', ['Alarm'], 't', 0.9, 'f', 0.05), + 'MaryCalls' : makenode('MaryCalls', ['Alarm'], 't', 0.7, 'f', 0.01), + } + +network5 = { + 'Burglary' : makenode('Burglary', [], '', 0.001), + 'Earthquake' : makenode('Earthquake', [], '', 0.002), + 'Alarm' : makenode('Alarm', ['Burglary', 'Earthquake'], 'tt', 1, 'tf', 0.94, 'ft', 0.29, 'ff', 0.001), + 'JohnCalls' : makenode('JohnCalls', ['Alarm'], 't', 0.9, 'f', 0.05), + 'MaryCalls' : makenode('MaryCalls', ['Alarm'], 't', 0.7, 'f', 0.01), + } + +def give_points(network): + if network == network1 : + points = 35 + return points + elif network == network2 : + points = 35 + return points + elif network == network3 : + points = 10 + return points + elif network == network4 : + points = 10 + return points + elif network == network5 : + points = 10 + return points + + +def testquery(network, node, evidence, expectedresult): + try: + estr = ', '.join([e.lower() if evidence[e] else "¬"+e.lower() for e in evidence.keys()]) + print(f'Query: P({node.lower()} | {estr})') + result = bayes.query(network, node, evidence) + print(f'Result: {result}') + if abs(result - expectedresult) < 1e-10: + print('SUCCESS!') + points = give_points(network) + print("Points: " + str(points)) + return points + else: + print(f'I expected {expectedresult}') + points = 0 + print("Points: " + str(points)) + return points + except: + e = sys.exc_info()[0] + print( "<p>Error: %s</p>" % e ) + print("Bug in solution! Skipped") + points = 0 + print("Points: " + str(points)) + return points + + +if __name__ == '__main__': + with open(str(os.pardir)+'\\grades.csv', 'a', newline = '') as csvfile: + gradeswriter = csv.writer(csvfile, dialect='excel') + sys.path.append(os.getcwd()) + orig_stdout = sys.stdout + #Write feedback file: + f = open('correction.txt', 'w') + sys.stdout = f + try: + import bayes + points1=testquery(network1, 'TBC', {'Asia':False, 'Xray':True, 'Dysp':True}, 0.008321327685349975) + print() + points2=testquery(network2, 'Burglary', {'MaryCalls':True, 'JohnCalls':True}, 0.2841718353643929) + print() + points3=testquery(network3, 'Trivial', {}, 0.08) + print() + points4=testquery(network4, 'Alarm', {'Deterministic':True, 'Earthquake':True}, 0.9499999999999998) + print() + points5=testquery(network5, 'Alarm', {'Burglary':True, 'Earthquake':True}, 1) + print() + finalpoints=points1+points2+points3+points4+points5 + print("points: "+ str(finalpoints)+"/100") + sys.stdout = orig_stdout + f.close() + # Submissions with non-vanilla dependencies are graded with 0 points. + except ModuleNotFoundError: + print("Not-included library used, grading not possible.") + finalpoints=0 + gradeswriter.writerow([str(os.path.basename(os.getcwd())), str(finalpoints)]) + +