Skip to content
Snippets Groups Projects
Commit a2e9808d authored by tgupta6's avatar tgupta6
Browse files

obj atr eval script now writes a json of obj predictions

parent 947cd847
No related branches found
No related tags found
No related merge requests found
......@@ -76,18 +76,21 @@ def create_batch_generator(region_ids_json):
class eval_mgr():
def __init__(
self,
scores_dirname,
attribute_scores_dirname,
object_scores_dirname,
inv_object_labels_dict,
vis_dirname,
genome_region_dir):
self.scores_dirname = scores_dirname
self.attribute_scores_dirname = attribute_scores_dirname
self.object_scores_dirname = object_scores_dirname
self.inv_object_labels_dict = inv_object_labels_dict
self.vis_dirname = vis_dirname
self.genome_region_dir = genome_region_dir
self.epsilon = 0.00001
self.num_iter = 0.0
self.object_accuracy = 0.0
self.obj_pred_json = dict()
self.precision = np.zeros([11], np.float32)
self.recall = np.zeros([11], np.float32)
self.fall_out = np.zeros([11], np.float32)
......@@ -133,20 +136,33 @@ class eval_mgr():
eval_vars_dict['attribute_prob'],
labels['attributes'])
pred_obj_labels = self.get_top_k_labels(eval_vars_dict['object_prob'],5)[0:10]
gt_obj_labels = self.get_gt_labels(labels['objects'])[0:10]
pred_obj_labels = self.get_top_k_labels(eval_vars_dict['object_prob'],5)
gt_obj_labels = self.get_gt_labels(labels['objects'])
region_paths = self.get_region_paths(image_ids, region_ids)
for i, region_id in enumerate(region_ids):
self.obj_pred_json[region_id] = {
'gt': gt_obj_labels[i],
'pred': pred_obj_labels[i],
}
if constants.visualize_object_predictions:
self.save_image_pred(
pred_obj_labels,
gt_obj_labels,
pred_obj_labels[0:10],
gt_obj_labels[0:10],
region_ids,
region_paths)
if iter%500 == 0:
self.write_scores()
filename = os.path.join(
self.object_scores_dirname,
'object_predictions.json')
with open(filename,'w') as file:
ujson.dump(self.obj_pred_json,file,indent=4)
def save_image_pred(
self,
pred_obj_labels,
......@@ -217,13 +233,13 @@ class eval_mgr():
def write_scores(self):
for i in xrange(10):
filename = os.path.join(
self.scores_dirname,
self.attribute_scores_dirname,
'scores_' + str(i) + '.json')
with open(filename, 'w') as file:
ujson.dump(self.scores_dict[i], file, indent=4)
filename = os.path.join(
self.scores_dirname,
self.attribute_scores_dirname,
'labels_' + str(i) + '.json')
with open(filename, 'w') as file:
ujson.dump(self.labels_dict[i], file, indent=4)
......@@ -434,6 +450,7 @@ if __name__=='__main__':
print 'Creating evaluator...'
evaluator = eval_mgr(
constants.region_attribute_scores_dirname,
constants.region_object_scores_dirname,
inv_object_labels_dict,
constants.region_pred_vis_dirname,
constants.image_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment