Skip to content
Snippets Groups Projects
Commit ecffbdb0 authored by tgupta6's avatar tgupta6
Browse files

select_best_model computes top_k accuracy

parent 1a015d00
No related branches found
No related tags found
No related merge requests found
...@@ -48,7 +48,7 @@ def create_batch_generator(): ...@@ -48,7 +48,7 @@ def create_batch_generator():
constants.object_labels_json, constants.object_labels_json,
constants.attribute_labels_json, constants.attribute_labels_json,
constants.regions_json, constants.regions_json,
constants.genome_train_held_out_region_ids, constants.genome_val_region_ids,
constants.image_size, constants.image_size,
channels=3, channels=3,
resnet_feat_dim = constants.resnet_feat_dim, resnet_feat_dim = constants.resnet_feat_dim,
...@@ -93,9 +93,13 @@ class eval_mgr(): ...@@ -93,9 +93,13 @@ class eval_mgr():
labels): labels):
self.num_iter += 1.0 self.num_iter += 1.0
self.eval_object_accuracy( # self.eval_object_accuracy(
# eval_vars_dict['object_prob'],
# labels['objects'])
self.top_k_accuracy(
eval_vars_dict['object_prob'], eval_vars_dict['object_prob'],
labels['objects']) labels['objects'],
5)
self.eval_attribute_pr( self.eval_attribute_pr(
eval_vars_dict['attribute_prob'], eval_vars_dict['attribute_prob'],
...@@ -129,6 +133,25 @@ class eval_mgr(): ...@@ -129,6 +133,25 @@ class eval_mgr():
with open(filename, 'w') as file: with open(filename, 'w') as file:
ujson.dump(self.labels_dict[i], file, indent=4) ujson.dump(self.labels_dict[i], file, indent=4)
def top_k_accuracy(
self,
prob,
labels,
k):
num_samples, num_classes = prob.shape
ids = np.arange(num_classes)
accuracy = 0.0
for i in xrange(num_samples):
gt_ids = set(np.where(labels[i,:]>0.5)[0].tolist())
top_k = set(np.argsort(prob[i,:]).tolist()[-1:-1-k:-1])
count = 0.0
for idx in gt_ids:
if idx in top_k:
count += 1.0
accuracy += count/max(len(gt_ids),1)
self.object_accuracy += accuracy/num_samples
def eval_object_accuracy( def eval_object_accuracy(
self, self,
prob, prob,
...@@ -316,7 +339,8 @@ def model_path_generator(models_dir, start_model, step_size): ...@@ -316,7 +339,8 @@ def model_path_generator(models_dir, start_model, step_size):
if __name__=='__main__': if __name__=='__main__':
model_paths = model_path_generator( model_paths = model_path_generator(
constants.region_output_dir, constants.answer_output_dir,
# constants.region_output_dir,
constants.region_start_model, constants.region_start_model,
constants.region_step_size) constants.region_step_size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment