Skip to content
Snippets Groups Projects
Commit 319b70d3 authored by tgupta6's avatar tgupta6
Browse files

obj atr eval uses new genome split

parent b4e42ac1
No related branches found
No related tags found
No related merge requests found
......@@ -128,8 +128,8 @@ region_model_accuracies_txt = os.path.join(
'model_accuracies.txt')
# Object Attribute Classifier Evaluation Params
region_eval_on = 'val' # One of {'val','test','train'}
region_model_to_eval = region_model + '-' + '77500'
region_eval_on = 'test' # One of {'test','train_held_out'}
region_model_to_eval = region_model + '-' + '34000'
region_attribute_scores_dirname = os.path.join(
region_output_dir,
......
......@@ -41,23 +41,27 @@ def create_initializer(graph, sess, model_to_eval):
return initializer()
def create_batch_generator(num_samples, num_epochs, offset):
def create_batch_generator(region_ids_json):
data_mgr = cropped_regions.data(
constants.genome_resnet_feat_dir,
constants.image_dir,
constants.object_labels_json,
constants.attribute_labels_json,
constants.regions_json,
region_ids_json,
constants.image_size,
channels=3,
resnet_feat_dim = constants.resnet_feat_dim,
mean_image_filename=None)
num_regions = len(data_mgr.region_ids)
print num_regions
index_generator = tftools.data.random(
constants.region_batch_size,
num_samples,
num_epochs,
offset)
num_regions,
1,
0)
batch_generator = tftools.data.async_batch_generator(
data_mgr,
......@@ -253,24 +257,16 @@ def eval(
if __name__=='__main__':
num_epochs = 1
if constants.region_eval_on=='val':
num_samples = constants.num_val_regions
offset = constants.num_train_regions
if constants.region_eval_on=='train_held_out':
region_ids_json = constants.genome_train_held_out_region_ids
elif constants.region_eval_on=='test':
num_samples = constants.num_test_regions
offset = constants.num_train_regions + \
constants.num_val_regions
elif constants.region_eval_on=='train':
num_samples = constants.num_train_regions
offset = 0
region_ids_json = constants.genome_test_region_ids
else:
print "eval_on can only be either 'val' or 'test' or 'train'"
print "eval_on can only be 'test' or 'train_held_out'"
raise
print 'Creating batch generator...'
batch_generator = create_batch_generator(
num_samples,
num_epochs,
offset)
batch_generator = create_batch_generator(region_ids_json)
print 'Creating computation graph...'
graph = train.graph_creator(
......@@ -293,8 +289,8 @@ if __name__=='__main__':
initializer = create_initializer(
graph,
sess,
constants.answer_model_to_eval)
# constants.region_model_to_eval)
#constants.answer_model_to_eval)
constants.region_model_to_eval)
print 'Creating feed dict creator...'
feed_dict_creator = train.create_feed_dict_creator(graph.plh)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment