Skip to content
Snippets Groups Projects
Commit 36fec8aa authored by tgupta6's avatar tgupta6
Browse files

Merge branch 'use_detector_scores' of...

Merge branch 'use_detector_scores' of gitlab-beta.engr.illinois.edu:Vision/GenVQA into use_detector_scores
parents d3ced47a a2e9808d
No related branches found
No related tags found
No related merge requests found
...@@ -76,18 +76,21 @@ def create_batch_generator(region_ids_json): ...@@ -76,18 +76,21 @@ def create_batch_generator(region_ids_json):
class eval_mgr(): class eval_mgr():
def __init__( def __init__(
self, self,
scores_dirname, attribute_scores_dirname,
object_scores_dirname,
inv_object_labels_dict, inv_object_labels_dict,
vis_dirname, vis_dirname,
genome_region_dir): genome_region_dir):
self.scores_dirname = scores_dirname self.attribute_scores_dirname = attribute_scores_dirname
self.object_scores_dirname = object_scores_dirname
self.inv_object_labels_dict = inv_object_labels_dict self.inv_object_labels_dict = inv_object_labels_dict
self.vis_dirname = vis_dirname self.vis_dirname = vis_dirname
self.genome_region_dir = genome_region_dir self.genome_region_dir = genome_region_dir
self.epsilon = 0.00001 self.epsilon = 0.00001
self.num_iter = 0.0 self.num_iter = 0.0
self.object_accuracy = 0.0 self.object_accuracy = 0.0
self.obj_pred_json = dict()
self.precision = np.zeros([11], np.float32) self.precision = np.zeros([11], np.float32)
self.recall = np.zeros([11], np.float32) self.recall = np.zeros([11], np.float32)
self.fall_out = np.zeros([11], np.float32) self.fall_out = np.zeros([11], np.float32)
...@@ -133,20 +136,33 @@ class eval_mgr(): ...@@ -133,20 +136,33 @@ class eval_mgr():
eval_vars_dict['attribute_prob'], eval_vars_dict['attribute_prob'],
labels['attributes']) labels['attributes'])
pred_obj_labels = self.get_top_k_labels(eval_vars_dict['object_prob'],5)[0:10] pred_obj_labels = self.get_top_k_labels(eval_vars_dict['object_prob'],5)
gt_obj_labels = self.get_gt_labels(labels['objects'])[0:10] gt_obj_labels = self.get_gt_labels(labels['objects'])
region_paths = self.get_region_paths(image_ids, region_ids) region_paths = self.get_region_paths(image_ids, region_ids)
for i, region_id in enumerate(region_ids):
self.obj_pred_json[region_id] = {
'gt': gt_obj_labels[i],
'pred': pred_obj_labels[i],
}
if constants.visualize_object_predictions: if constants.visualize_object_predictions:
self.save_image_pred( self.save_image_pred(
pred_obj_labels, pred_obj_labels[0:10],
gt_obj_labels, gt_obj_labels[0:10],
region_ids, region_ids,
region_paths) region_paths)
if iter%500 == 0: if iter%500 == 0:
self.write_scores() self.write_scores()
filename = os.path.join(
self.object_scores_dirname,
'object_predictions.json')
with open(filename,'w') as file:
ujson.dump(self.obj_pred_json,file,indent=4)
def save_image_pred( def save_image_pred(
self, self,
pred_obj_labels, pred_obj_labels,
...@@ -217,13 +233,13 @@ class eval_mgr(): ...@@ -217,13 +233,13 @@ class eval_mgr():
def write_scores(self): def write_scores(self):
for i in xrange(10): for i in xrange(10):
filename = os.path.join( filename = os.path.join(
self.scores_dirname, self.attribute_scores_dirname,
'scores_' + str(i) + '.json') 'scores_' + str(i) + '.json')
with open(filename, 'w') as file: with open(filename, 'w') as file:
ujson.dump(self.scores_dict[i], file, indent=4) ujson.dump(self.scores_dict[i], file, indent=4)
filename = os.path.join( filename = os.path.join(
self.scores_dirname, self.attribute_scores_dirname,
'labels_' + str(i) + '.json') 'labels_' + str(i) + '.json')
with open(filename, 'w') as file: with open(filename, 'w') as file:
ujson.dump(self.labels_dict[i], file, indent=4) ujson.dump(self.labels_dict[i], file, indent=4)
...@@ -434,6 +450,7 @@ if __name__=='__main__': ...@@ -434,6 +450,7 @@ if __name__=='__main__':
print 'Creating evaluator...' print 'Creating evaluator...'
evaluator = eval_mgr( evaluator = eval_mgr(
constants.region_attribute_scores_dirname, constants.region_attribute_scores_dirname,
constants.region_object_scores_dirname,
inv_object_labels_dict, inv_object_labels_dict,
constants.region_pred_vis_dirname, constants.region_pred_vis_dirname,
constants.image_dir) constants.image_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment