Skip to content
Snippets Groups Projects
cache_resnet_features_vqa.py 4.74 KiB
import tftools.var_collect as var_collect
from tftools.placeholder_management import PlaceholderManager
import tftools
import resnet.inference as resnet_inference
import image_io
import os
import pdb
import glob
import numpy as np
import tensorflow as tf


class graph_creator():
    def __init__(
            self,
            image_size):
        self.im_h, self.im_w = image_size
        self.tf_graph = tf.Graph()
        with self.tf_graph.as_default():
            self.create_placeholders()
            self.avg_pool_feat = resnet_inference.inference(
                self.plh['image_regions'],
                False,
                num_classes=None)

    def create_placeholders(self):
        self.plh = PlaceholderManager()
        self.plh.add_placeholder(
            'image_regions',
            tf.float32,
            shape=[None, self.im_h, self.im_w, 3])


class graph_initializer():
    def __init__(self, graph, ckpt_path):
        with graph.tf_graph.as_default():
            self.vars_to_restore = []
            for s in xrange(5):
                self.vars_to_restore += var_collect.collect_scope('scale'+str(s+1))
            self.saver = tf.train.Saver(self.vars_to_restore)

    def initialize(self):
        sess = tf.get_default_session()
        self.saver.restore(sess, ckpt_path)


def create_feed_dict_creator(plh):
    def feed_dict_creator(image_regions):
        inputs = {
            'image_regions': image_regions,
        }
        return plh.get_feed_dict(inputs)
        
    return feed_dict_creator


def read_image_regions(
        image_sub_dir,
        num_regions,
        image_size,
        mean_image):
    im_h, im_w = image_size
#    image_names = glob.glob(os.path.join(image_sub_dir,'*.jpg'))
    image_regions = np.zeros([num_regions, im_h, im_w, 3], np.float32)
    for i in xrange(num_regions):
        image_name = os.path.join(image_sub_dir, str(i+1) + '.jpg')
        
        if not os.path.exists(image_name):
            return image_regions

        image_regions[i,:,:,:] = \
            image_io.imread(image_name)[:,:,[2,1,0]] - mean_image

    return image_regions


class cache_features():
    def __init__(
            self,
            image_dir,
            feat_dir,
            num_regions,
            image_size,
            mean_image,
            ckpt_path):
        self.image_dir = image_dir
        self.feat_dir = feat_dir
        self.num_regions = num_regions
        self.image_size = image_size
        self.mean_image = mean_image

        self.graph = graph_creator(self.image_size)
        self.feed_dict_creator = create_feed_dict_creator(self.graph.plh)
        self.initializer = graph_initializer(self.graph, ckpt_path)
        
        self.sess = self.create_session()

        with self.sess.as_default():
            self.initializer.initialize()

            sub_dirs = self.get_list_image_sub_dirs()
            for sub_dir in sub_dirs:
                print sub_dir
                self.compute_and_save_features(sub_dir)

    def create_session(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.8
        sess = tf.Session(graph=self.graph.tf_graph, config=config)
        return sess

    def get_list_image_sub_dirs(self):
        sub_dirs = os.listdir(self.image_dir)
        print 'Found {} image sub directories'.format(len(sub_dirs))
        sub_dirs = [sub_dir for sub_dir in sub_dirs if 
                    os.path.isdir(os.path.join(self.image_dir, sub_dir))]
        return sub_dirs

    def compute_and_save_features(self, image_name):
        image_sub_dir = os.path.join(
            self.image_dir,
            image_name)
            
        image_regions = read_image_regions(
            image_sub_dir,
            self.num_regions,
            self.image_size,
            self.mean_image)

        feed_dict = self.feed_dict_creator(image_regions)
        
        resnet_feat = self.graph.avg_pool_feat.eval(feed_dict)
        
        feat_path = os.path.join(feat_dir, image_name)
        np.save(feat_path, resnet_feat)


if __name__=='__main__':
    basedir = '/home/ssd/VQA'
    image_dir = os.path.join(basedir, 'val2014_cropped_large')
    feat_dir = image_dir + '_resnet_features'
    num_regions = 100
    image_size = (224, 224)
    ckpt_path = '/home/tanmay/Downloads/pretrained_networks/' + \
                'Resnet/tensorflow-resnet-pretrained-20160509/' +\
                'ResNet-L50.ckpt'
    imagenet_mean = [103.062623801, 115.902882574, 123.151630838, ] #BGR

    if not os.path.exists(feat_dir):
        os.mkdir(feat_dir)

    cache_features(
        image_dir,
        feat_dir,
        num_regions,
        image_size,
        imagenet_mean,
        ckpt_path)