Skip to content
Snippets Groups Projects
Commit a5fdc5d0 authored by Kevin Shih's avatar Kevin Shih
Browse files

adding eval code for shapes dataset

parent f5c5a4c3
No related branches found
No related tags found
No related merge requests found
import json
import sys
import re
if __name__== "__main__":
f_res = open(sys.argv[1], 'r') # read the res file
f_anno = open(sys.argv[2], 'r') # read the test annotation file
res_data = json.load(f_res);
anno_data = json.load(f_anno);
assert(len(res_data) == len(anno_data))
res_dict = dict()
# convert to map with qid as key
for q_ind in range(len(res_data)):
res_dict[res_data[q_ind]['question_id']] = q_ind;
correct = 0
existence = 0;
existence_total = 0;
relation = 0;
relation_total = 0;
color = 0;
color_total = 0;
counting = 0;
counting_total = 0;
held_out = 0; # simple held out subset. Doesn't include color questions
held_out_total = 0;
held_out_existence = 0;
held_out_existence_total = 0;
held_out_relation = 0;
held_out_relation_total = 0;
held_out_counting = 0;
held_out_counting_total = 0;
held_out_no = 0;
held_out_relation_no = 0;
held_out_existence_no = 0;
for ann in anno_data:
is_correct = False
ans_is_no = False
if ann['answer'] == 'no':
ans_is_no = True
if ann['answer'] == res_data[res_dict[ann['question_id']]]['answer']:
correct = correct +1
is_correct = True
if "Is there a" in ann['question']:
if "below a" in ann['question']:
relation_total = relation_total +1
if is_correct:
relation = relation +1
else:
existence_total = existence_total +1
if is_correct:
existence = existence+1
elif "What color" in ann['question']:
color_total = color_total +1
if is_correct:
color = color +1
elif "How many" in ann['question']:
counting_total+=1
if is_correct:
counting+=1
if re.search('red squares?|blue circles?|green triangles?', ann['question']) != None:
held_out_total +=1
if "Is there a" in ann['question']:
if "below a" in ann['question']:
held_out_relation_total +=1
if is_correct:
held_out_relation += 1
if ans_is_no:
held_out_relation_no +=1
else:
held_out_existence_total +=1
if is_correct:
held_out_existence +=1
if ans_is_no:
held_out_existence_no +=1
elif "How many" in ann['question']:
held_out_counting_total+=1
if is_correct:
held_out_counting+=1
if is_correct:
held_out +=1
if ans_is_no:
held_out_no += 1
print 'Overall Accuracy:\t %f' %(correct/float(len(anno_data)))
print 'Existence:\t\t %f' %(existence/float(existence_total))
print 'Spatial relation:\t %f' %(relation/float(relation_total))
print 'What color:\t\t %f' %(color/float(color_total))
print 'How many:\t\t %f' %(counting/float(counting_total))
print 'Held out:\t\t %f' %(held_out/float(held_out_total))
print 'Held out (no):\t\t %f' %(held_out_no/float(held_out_total))
print '\tExistence:\t\t %f' %(held_out_existence/float(held_out_existence_total))
print '\tExistence (no):\t\t %f' %(held_out_existence_no/float(held_out_existence_total))
print '\tSpatial relation:\t %f' %(held_out_relation/float(held_out_relation_total))
print '\tSpatial relation (no):\t %f' %(held_out_relation_no/float(held_out_relation_total))
print '\tHow many:\t\t %f' %(held_out_counting/float(held_out_counting_total))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment