Skip to content
Snippets Groups Projects
eval.py 9.98 KiB
import pdb
import os
import ujson
import numpy as np

import data.cropped_regions_cached_features as cropped_regions
import tftools.data
from tftools import var_collect, placeholder_management
from object_attribute_classifier_cached_features import inference
from word2vec.word_vector_management import word_vector_manager
import train
import losses
import constants

import tensorflow as tf


def create_initializer(graph, sess, model_to_eval):
    class initializer():
        def __init__(self):
            with graph.tf_graph.as_default():
                model_restorer = tf.train.Saver(graph.vars_to_save)
                print 'Restoring model {}'.format(model_to_eval)
                model_restorer.restore(sess, model_to_eval)    
                all_vars = tf.all_variables()
                other_vars = [
                    var for var in all_vars 
                    if var not in graph.vars_to_save]

            var_collect.print_var_list(
                graph.vars_to_save,
                'Restored Variables')

            var_collect.print_var_list(
                other_vars,
                'Unrestored Variables')

        def initialize(self):
            pass
    
    return initializer()


def create_batch_generator(num_samples, num_epochs, offset):
    data_mgr = cropped_regions.data(
        constants.genome_resnet_feat_dir,
        constants.image_dir,
        constants.object_labels_json,
        constants.attribute_labels_json,
        constants.regions_json,
        constants.image_size,
        channels=3,
        resnet_feat_dim = constants.resnet_feat_dim,
        mean_image_filename=None)

    index_generator = tftools.data.random(
        constants.region_batch_size, 
        num_samples, 
        num_epochs, 
        offset)
    
    batch_generator = tftools.data.async_batch_generator(
        data_mgr, 
        index_generator, 
        constants.region_queue_size)
    
    return batch_generator


class eval_mgr():
    def __init__(self, scores_dirname):
        self.epsilon = 0.00001
        self.num_iter = 0.0
        self.object_accuracy = 0.0
        self.precision = np.zeros([11], np.float32)
        self.recall = np.zeros([11], np.float32)
        self.fall_out = np.zeros([11], np.float32)
        self.scores_dict = dict()
        self.labels_dict = dict()
        self.attribute_ids = np.arange(10)*10
        for i in xrange(10):
            self.scores_dict[i] = []
            self.labels_dict[i] = []
        self.scores_dirname = scores_dirname

    def eval(self,
             iter,
             eval_vars_dict,
             labels):
        self.num_iter += 1.0
       
        self.eval_object_accuracy(
            eval_vars_dict['object_prob'],
            labels['objects'])

        self.eval_attribute_pr(
            eval_vars_dict['attribute_prob'],
            labels['attributes'])

        self.append_to_scores_labels_list(
            eval_vars_dict['attribute_prob'],
            labels['attributes'])

        if iter%500 == 0:
            self.write_scores()

    def append_to_scores_labels_list(self, prob, labels):
        for i in xrange(10):
            self.scores_dict[i].append(
                prob[:,self.attribute_ids[i]].tolist())
            self.labels_dict[i].append(
                labels[:,self.attribute_ids[i]].tolist())

    def write_scores(self):
        for i in xrange(10):
            filename = os.path.join(
                self.scores_dirname,
                'scores_' + str(i) + '.json')
            with open(filename, 'w') as file:
                ujson.dump(self.scores_dict[i], file, indent=4)

            filename = os.path.join(
                self.scores_dirname,
                'labels_' + str(i) + '.json')
            with open(filename, 'w') as file:
                ujson.dump(self.labels_dict[i], file, indent=4)
                
    def eval_object_accuracy(
            self,
            prob,
            labels):

        matches = np.equal(
            np.argmax(prob, 1),
            np.argmax(labels, 1)).astype(np.float32)

        current_object_accuracy = np.sum(matches)/matches.shape[0]
        self.object_accuracy += current_object_accuracy

    def eval_attribute_pr(
            self,
            prob,
            labels):

        thresholds = np.arange(0.0,1.1,0.1).tolist()
        current_recall = np.zeros([11], dtype=np.float32)
        current_precision = np.zeros([11], dtype=np.float32)
        current_fall_out = np.zeros([11], dtype=np.float32)
        for i, threshold in enumerate(thresholds):
            matches = np.equal(
                prob > threshold,
                labels == 1).astype(np.int32)
            
            thresholded = (prob > threshold).astype(np.int32)

            correct_attributes = np.sum(matches, 0)
            tp_attributes = np.sum(labels*thresholded, 0)
            tn_attributes = np.sum((1-labels)*(1-thresholded), 0)
            fp_attributes = np.sum((1-labels)*thresholded, 0)
            positive_attributes = np.sum(labels, 0)
            negative_attributes = np.sum(1-labels, 0)
            num_attribute_samples = matches.shape[0]
            
            current_recall[i] = np.mean(
                (tp_attributes + self.epsilon) / \
                (positive_attributes + self.epsilon))
            
            current_precision[i] = np.mean(
                (tp_attributes + self.epsilon) / \
                (tp_attributes + fp_attributes + self.epsilon))

            current_fall_out[i] = np.mean(
                (fp_attributes + self.epsilon) / \
                (tn_attributes + fp_attributes + self.epsilon))

        self.recall += current_recall
        self.precision += current_precision
        self.fall_out += current_fall_out

    def get_object_accuracy(self):
        return self.object_accuracy/self.num_iter

    def get_precision(self):
        return self.precision/self.num_iter

    def get_recall(self):
        return self.recall/self.num_iter

    def get_fall_out(self):
        return self.fall_out/self.num_iter

    def get_ap(self):
        precision = self.get_precision()
        recall = self.get_recall()
        slots = precision.size-1
        ap = 0.0
        for i in xrange(slots):
            area = (precision[i+1] + precision[i]) * \
                   (recall[i] - recall[i+1]) / 2
            ap += area

        area = (1.0 + precision[slots]) * \
               (recall[slots] - 0.0) / 2
        ap += area

        return ap
        
def eval(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        evaluator):

    vars_to_eval_names = []
    vars_to_eval = []
    for var_name, var in vars_to_eval_dict.items():
        vars_to_eval_names += [var_name]
        vars_to_eval += [var]

    with sess.as_default():
        initializer.initialize()

        iter = 0
        for batch in batch_generator:
            print iter
            feed_dict = feed_dict_creator(batch)
            eval_vars = sess.run(
                vars_to_eval,
                feed_dict = feed_dict)
            eval_vars_dict = {
                var_name: eval_var for var_name, eval_var in
                zip(vars_to_eval_names, eval_vars)}
            # print batch['region_ids']
            labels = dict()
            labels['objects'] = batch['object_labels']
            labels['attributes'] = batch['attribute_labels']
            evaluator.eval(iter, eval_vars_dict, labels)
            print 'Object accuracy: {}'.format(evaluator.get_object_accuracy())
            # print 'Recall: {}'.format(evaluator.get_recall())
            # print 'Precision: {}'.format(evaluator.get_precision())
            # print 'Fall_out: {}'.format(evaluator.get_fall_out())
            # print 'AP: {}'.format(evaluator.get_ap())
            iter+=1

        # print 'Object accuracy: {}'.format(evaluator.get_object_accuracy())
        # print 'Recall: {}'.format(evaluator.get_recall())
        # print 'Precision: {}'.format(evaluator.get_precision())
        # print 'AP: {}'.format(evaluator.get_ap())


if __name__=='__main__':
    num_epochs = 1

    if constants.region_eval_on=='val':
        num_samples = constants.num_val_regions
        offset = constants.num_train_regions
    elif constants.region_eval_on=='test':
        num_samples = constants.num_test_regions
        offset = constants.num_train_regions + \
                 constants.num_val_regions
    elif constants.region_eval_on=='train':
        num_samples = constants.num_train_regions
        offset = 0
    else:
        print "eval_on can only be either 'val' or 'test' or 'train'"

    print 'Creating batch generator...'
    batch_generator = create_batch_generator(
        num_samples,
        num_epochs,
        offset)

    print 'Creating computation graph...'
    graph = train.graph_creator(
        constants.region_batch_size,
        constants.tb_log_dir,
        constants.image_size,
        constants.num_object_labels,
        constants.num_attribute_labels,
        constants.region_regularization_coeff,
        resnet_feat_dim=constants.resnet_feat_dim,
        training=False)

    print 'Starting a session...'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    sess = tf.Session(config=config, graph=graph.tf_graph)

    print 'Creating initializer...'
    initializer = create_initializer(
        graph, 
        sess, 
        constants.answer_model_to_eval)
        # constants.region_model_to_eval)

    print 'Creating feed dict creator...'
    feed_dict_creator = train.create_feed_dict_creator(graph.plh)

    print 'Creating dict of vars to be evaluated...'
    vars_to_eval_dict = {
        'object_prob': graph.obj_atr_inference.object_prob,
        'attribute_prob': graph.obj_atr_inference.attribute_prob,
    }

    print 'Creating evaluator...'
    evaluator = eval_mgr(constants.region_attribute_scores_dirname)

    print 'Start evaluating...'
    eval(
        batch_generator, 
        sess, 
        initializer,
        vars_to_eval_dict,
        feed_dict_creator,
        evaluator)