Skip to content
Snippets Groups Projects
Commit d6603f7c authored by Max Rapp's avatar Max Rapp
Browse files

Bayes Tests

parent 3e21a9f5
No related branches found
No related tags found
No related merge requests found
import sys
import os
import importlib.util
import csv
def makenode(name, parents, *probs):
probabilities = {}
for i in range(len(probs)//2):
combination = tuple({'f' : False, 't' : True}[e] for e in probs[2*i])
assert len(combination) == len(parents)
assert combination not in probabilities
probabilities[combination] = float(probs[2*i+1])
assert len(probabilities) == 2**len(parents)
return {'name' : name, 'parents' : parents, 'probabilities' : probabilities}
network1 = {
'Asia' : makenode('Asia', [], '', 0.05),
'Smoke' : makenode('Smoke', [], '', 0.3),
'TBC' : makenode('TBC', ['Asia'], 't', 0.01, 'f', 0.001),
'LC' : makenode('LC', ['Smoke'], 't', 0.2, 'f', 0.08),
'Bron' : makenode('Bron', ['Smoke'], 't', 0.4, 'f', 0.1),
'Xray' : makenode('Xray', ['TBC', 'LC'], 'tt', 0.98, 'tf', 0.94, 'ft', 0.92, 'ff', 0.02),
'Dysp' : makenode('Dysp', ['TBC', 'LC', 'Bron'],
'ttt', 0.99, 'ttf', 0.97, 'tft', 0.98, 'tff', 0.9,
'ftt', 0.98, 'ftf', 0.92, 'fft', 0.95, 'fff', 0.07),
}
network2 = {
'Burglary' : makenode('Burglary', [], '', 0.001),
'Earthquake' : makenode('Earthquake', [], '', 0.002),
'Alarm' : makenode('Alarm', ['Burglary', 'Earthquake'], 'tt', 0.95, 'tf', 0.94, 'ft', 0.29, 'ff', 0.001),
'JohnCalls' : makenode('JohnCalls', ['Alarm'], 't', 0.9, 'f', 0.05),
'MaryCalls' : makenode('MaryCalls', ['Alarm'], 't', 0.7, 'f', 0.01),
}
network3 = {
'Trivial' : makenode('Trivial', [], '', 0.08),
}
network4 = {
'Deterministic' : makenode('Deterministic', [], '', 1),
'Earthquake' : makenode('Earthquake', [], '', 0.002),
'Alarm' : makenode('Alarm', ['Deterministic', 'Earthquake'], 'tt', 0.95, 'tf', 0.94, 'ft', 0.29, 'ff', 0.001),
'JohnCalls' : makenode('JohnCalls', ['Alarm'], 't', 0.9, 'f', 0.05),
'MaryCalls' : makenode('MaryCalls', ['Alarm'], 't', 0.7, 'f', 0.01),
}
network5 = {
'Burglary' : makenode('Burglary', [], '', 0.001),
'Earthquake' : makenode('Earthquake', [], '', 0.002),
'Alarm' : makenode('Alarm', ['Burglary', 'Earthquake'], 'tt', 1, 'tf', 0.94, 'ft', 0.29, 'ff', 0.001),
'JohnCalls' : makenode('JohnCalls', ['Alarm'], 't', 0.9, 'f', 0.05),
'MaryCalls' : makenode('MaryCalls', ['Alarm'], 't', 0.7, 'f', 0.01),
}
def give_points(network):
if network == network1 :
points = 35
return points
elif network == network2 :
points = 35
return points
elif network == network3 :
points = 10
return points
elif network == network4 :
points = 10
return points
elif network == network5 :
points = 10
return points
def testquery(network, node, evidence, expectedresult):
try:
estr = ', '.join([e.lower() if evidence[e] else "¬"+e.lower() for e in evidence.keys()])
print(f'Query: P({node.lower()} | {estr})')
result = bayes.query(network, node, evidence)
print(f'Result: {result}')
if abs(result - expectedresult) < 1e-10:
print('SUCCESS!')
points = give_points(network)
print("Points: " + str(points))
return points
else:
print(f'I expected {expectedresult}')
points = 0
print("Points: " + str(points))
return points
except:
e = sys.exc_info()[0]
print( "<p>Error: %s</p>" % e )
print("Bug in solution! Skipped")
points = 0
print("Points: " + str(points))
return points
if __name__ == '__main__':
with open(str(os.pardir)+'\\grades.csv', 'a', newline = '') as csvfile:
gradeswriter = csv.writer(csvfile, dialect='excel')
sys.path.append(os.getcwd())
orig_stdout = sys.stdout
#Write feedback file:
f = open('correction.txt', 'w')
sys.stdout = f
try:
import bayes
points1=testquery(network1, 'TBC', {'Asia':False, 'Xray':True, 'Dysp':True}, 0.008321327685349975)
print()
points2=testquery(network2, 'Burglary', {'MaryCalls':True, 'JohnCalls':True}, 0.2841718353643929)
print()
points3=testquery(network3, 'Trivial', {}, 0.08)
print()
points4=testquery(network4, 'Alarm', {'Deterministic':True, 'Earthquake':True}, 0.9499999999999998)
print()
points5=testquery(network5, 'Alarm', {'Burglary':True, 'Earthquake':True}, 1)
print()
finalpoints=points1+points2+points3+points4+points5
print("points: "+ str(finalpoints)+"/100")
sys.stdout = orig_stdout
f.close()
# Submissions with non-vanilla dependencies are graded with 0 points.
except ModuleNotFoundError:
print("Not-included library used, grading not possible.")
finalpoints=0
gradeswriter.writerow([str(os.path.basename(os.getcwd())), str(finalpoints)])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment