From 22e54b1e127a5921a9a04a590506f7652fc5dce7 Mon Sep 17 00:00:00 2001 From: tgupta6 <tgupta6@illinois.edu> Date: Thu, 9 Jun 2016 15:19:36 -0500 Subject: [PATCH] Save word vectors for the vocabulary --- constants.py | 27 ++++ object_attribute_classifier/__init__.py | 0 object_attribute_classifier/inference.py | 94 ++++++++++++ resnet/inference.py | 6 +- tftools/batch_normalizer.py | 71 +++++++++ tftools/layers.py | 112 ++++++++++++++ tftools/placeholder_management.py | 182 +++++++++++++++++++++++ visual_genome_parser.py | 119 ++++++++++++++- word2vec/__init__.py | 0 word2vec/get_vocab_word_vectors.py | 45 ++++++ 10 files changed, 652 insertions(+), 4 deletions(-) create mode 100644 object_attribute_classifier/__init__.py create mode 100644 object_attribute_classifier/inference.py create mode 100644 tftools/batch_normalizer.py create mode 100644 tftools/layers.py create mode 100644 tftools/placeholder_management.py create mode 100644 word2vec/__init__.py create mode 100644 word2vec/get_vocab_word_vectors.py diff --git a/constants.py b/constants.py index adbc94a..8aac6ea 100644 --- a/constants.py +++ b/constants.py @@ -8,19 +8,30 @@ unknown_token = 'UNK' # Data paths data_absolute_path = '/home/tanmay/Data/VisualGenome' + image_dir = os.path.join(data_absolute_path, 'cropped_regions') + object_labels_json = os.path.join( data_absolute_path, 'restructured/object_labels.json') + attribute_labels_json = os.path.join( data_absolute_path, 'restructured/attribute_labels.json') + regions_json = os.path.join( data_absolute_path, 'restructured/region_with_labels.json') + mean_image_filename = os.path.join( data_absolute_path, 'restructured/mean_image.jpg') + +# Vocabulary +vocab_json = os.path.join( + data_absolute_path, + 'restructured/vocab_subset.json') + # Regions data partition # First 70% meant to be used for training # Next 10% is set aside for validation @@ -32,6 +43,22 @@ num_test_regions = num_total_regions \ - num_train_regions \ - num_val_regions +# Pretrained resnet ckpt +resnet_ckpt = '/home/tanmay/Downloads/pretrained_networks/' + \ + 'Resnet/tensorflow-resnet-pretrained-20160509/' + \ + 'ResNet-L50.ckpt' + +# Pretrained word vectors +word2vec_binary = '/home/tanmay/Code/word2vec/word2vec-api-master/' + \ + 'GoogleNews-vectors-negative300.bin' + +word_vector_size = 300 + +# Numpy matrix storing vocabulary word vectors +vocab_word_vectors_npy = os.path.join( + data_absolute_path, + 'restructured/vocab_word_vectors.npy') + diff --git a/object_attribute_classifier/__init__.py b/object_attribute_classifier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/object_attribute_classifier/inference.py b/object_attribute_classifier/inference.py new file mode 100644 index 0000000..2ce3910 --- /dev/null +++ b/object_attribute_classifier/inference.py @@ -0,0 +1,94 @@ +import pdb + +import resnet.inference as resnet_inference +from tftools import var_collect, placeholder_management, layers +import constants + +import tensorflow as tf + +class ObjectAttributeInference(): + def __init__( + self, + image_regions, + wordvecs, + training): + + self.image_regions = image_regions + self.wordvecs = wordvecs + self.training = training + avg_pool_feat = resnet_inference.inference( + self.image_regions, + self.training) + + object_feat = self.add_object_graph(avg_pool_feat) + attribute_feat = self.add_attribute_graph(avg_pool_feat) + pdb.set_trace() + + def add_object_graph(self, input): + with tf.variable_scope('object_graph') as object_graph: + with tf.variable_scope('fc1') as fc1: + in_dim = input.get_shape().as_list()[-1] + out_dim = in_dim/2 + fc1_out = layers.full( + input, + out_dim, + 'fc', + func = None) + fc1_out = layers.batch_norm( + fc1_out, + tf.constant(self.training)) + fc1_out = tf.nn.relu(fc1_out) + + with tf.variable_scope('fc2') as fc2: + in_dim = fc1_out.get_shape().as_list()[-1] + out_dim = in_dim/2 + fc2_out = layers.full( + fc1_out, + out_dim, + 'fc') + + return fc2_out + + def add_attribute_graph(self, input): + with tf.variable_scope('attribute_graph') as attribute_graph: + with tf.variable_scope('fc1') as fc1: + in_dim = input.get_shape().as_list()[-1] + out_dim = in_dim/2 + fc1_out = layers.full( + input, + out_dim, + 'fc', + func = None) + fc1_out = layers.batch_norm( + fc1_out, + tf.constant(self.training)) + fc1_out = tf.nn.relu(fc1_out) + + with tf.variable_scope('fc2') as fc2: + in_dim = fc1_out.get_shape().as_list()[-1] + out_dim = in_dim/2 + fc2_out = layers.full( + fc1_out, + out_dim, + 'fc') + + return fc2_out + + + + +if __name__=='__main__': + im_h, im_w = constants.image_size + plh = placeholder_management.PlaceholderManager() + plh.add_placeholder( + name = 'image_regions', + dtype = tf.float32, + shape = [None, im_h, im_w, 3]) + + training = False + ObjectAttributeInference( + plh['image_regions'], + [], + training) + + diff --git a/resnet/inference.py b/resnet/inference.py index 0d04bd4..012a6d9 100644 --- a/resnet/inference.py +++ b/resnet/inference.py @@ -79,9 +79,9 @@ def inference(x, is_training, # post-net x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool") - if num_classes != None: - with tf.variable_scope('fc'): - x = fc(x, c) + # if num_classes != None: + # with tf.variable_scope('fc'): + # x = fc(x, c) return x diff --git a/tftools/batch_normalizer.py b/tftools/batch_normalizer.py new file mode 100644 index 0000000..5eda8df --- /dev/null +++ b/tftools/batch_normalizer.py @@ -0,0 +1,71 @@ +import pdb +import numpy as np +import tensorflow as tf +from tensorflow.python import control_flow_ops + + +class BatchNorm(): + def __init__(self, + input, + training, + decay=0.95, + epsilon=1e-4, + name='bn', + reuse_vars=False): + + self.decay = decay + self.epsilon = epsilon + self.batchnorm(input, training, name, reuse_vars) + + def batchnorm(self, input, training, name, reuse_vars): + with tf.variable_scope(name, reuse=reuse_vars) as bn: + rank = len(input.get_shape().as_list()) + in_dim = input.get_shape().as_list()[-1] + + if rank == 2: + self.axes = [0] + elif rank == 4: + self.axes = [0, 1, 2] + else: + raise ValueError('Input tensor must have rank 2 or 4.') + + self.offset = tf.get_variable( + 'offset', + shape=[in_dim], + initializer=tf.constant_initializer(0.0)) + + self.scale = tf.get_variable( + 'scale', + shape=[in_dim], + initializer=tf.constant_initializer(1.0)) + + self.ema = tf.train.ExponentialMovingAverage(decay=self.decay) + + self.output = tf.cond(training, + lambda: self.get_normalizer(input, True), + lambda: self.get_normalizer(input, False)) + + def get_normalizer(self, input, train_flag): + if train_flag: + self.mean, self.variance = tf.nn.moments(input, self.axes) + ema_apply_op = self.ema.apply([self.mean, self.variance]) + with tf.control_dependencies([ema_apply_op]): + self.output_training = tf.nn.batch_normalization( + input, self.mean, self.variance, self.offset, self.scale, + self.epsilon, 'normalizer_train'), + return self.output_training + else: + self.output_test = tf.nn.batch_normalization( + input, self.ema.average(self.mean), + self.ema.average(self.variance), self.offset, self.scale, + self.epsilon, 'normalizer_test') + return self.output_test + + def get_batch_moments(self): + return self.mean, self.variance + + def get_ema_moments(self): + return self.ema.average(self.mean), self.ema.average(self.variance) + + def get_offset_scale(self): + return self.offset, self.scale diff --git a/tftools/layers.py b/tftools/layers.py new file mode 100644 index 0000000..67ee6d1 --- /dev/null +++ b/tftools/layers.py @@ -0,0 +1,112 @@ +import numpy as np +import tensorflow as tf +from batch_normalizer import BatchNorm + +def full(input, out_dim, name, gain=np.sqrt(2), func=tf.nn.relu, reuse_vars=False): + """ Fully connected layer helper. + Creates weights and bias parameters with good initial values. Then applies the matmul op and func. + Args: + input (tensor): Input to the layer. + Should have shape `[batch, in_dim]`. + Must be one of the following types: `float32`, `float64`. + out_dim (int): Number of output neurons. + name (string): Name used by the `tf.variable_scope`. + gain (float): Gain used when calculating stddev of weights. + Suggest values: sqrt(2) for relu, 1.0 for identity. + func (function): Function used to calculate neural activations. + If `None` uses identity. + reuse_vars (bool): Determine whether the layer should reuse variables + or construct new ones. Equivalent to setting reuse in variable_scope + Returns: + output (tensor): The neural activations for this layer. + Will have shape `[batch, out_dim]`. + """ + in_dim = input.get_shape().as_list()[-1] + stddev = 1.0 * gain / np.sqrt(in_dim) + with tf.variable_scope(name, reuse=reuse_vars): + w_init = tf.random_normal_initializer(stddev=stddev) + b_init = tf.constant_initializer() + w = tf.get_variable('w', shape=[in_dim, out_dim], initializer=w_init) + b = tf.get_variable('b', shape=[out_dim], initializer=b_init) + + output = tf.matmul(input, w) + b + if func is not None: + output = func(output) + + tf.add_to_collection('to_regularize', w) + return output + + +def conv2d(input, + filter_size, + out_dim, + name, + strides=[1, 1, 1, 1], + padding='SAME', + gain=np.sqrt(2), + func=tf.nn.relu, + reuse_vars=False): + """ Conv2d layer helper. + Creates filter and bias parameters with good initial values. Then applies the conv op and func. + Args: + input (tensor): Input to the layer. + Should have shape `[batch, in_height, in_width, in_dim]`. + Must be one of the following types: `float32`, `float64`. + filter_size (int): Width and height of square convolution filter. + out_dim (int): Number of output filters. + name (str): Name used by the `tf.variable_scope`. + strides (List[int]): The stride of the sliding window for each dimension + of `input`. Must be in the same order as the dimension specified with format. + padding (str): A `string` from: `'SAME', 'VALID'`. + The type of padding algorithm to use. + gain (float): Gain used when calculating stddev of weights. + Suggest values: sqrt(2) for relu, 1.0 for identity. + func (function): Function used to calculate neural activations. + If `None` uses identity. + reuse_vars (bool): Determine whether the layer should reuse variables + or construct new ones. Equivalent to setting reuse in variable_scope + Returns: + output (tensor): The neural activations for this layer. + Will have shape `[batch, in_weight, in_width, out_dim]`. + """ + in_dim = input.get_shape().as_list()[-1] + stddev = 1.0 * gain / np.sqrt(filter_size * filter_size * in_dim) + with tf.variable_scope(name, reuse=reuse_vars): + w_init = tf.random_normal_initializer(stddev=stddev) + b_init = tf.constant_initializer() + w = tf.get_variable('w', + shape=[filter_size, filter_size, in_dim, out_dim], + initializer=w_init) + b = tf.get_variable('b', shape=[out_dim], initializer=b_init) + + output = tf.nn.conv2d(input, w, strides=strides, padding=padding) + b + if func is not None: + output = func(output) + + tf.add_to_collection('to_regularize', w) + return output + + +def batch_norm(input, + training, + decay=0.95, + epsilon=1e-4, + name='bn', + reuse_vars=False): + """Adds a batch normalization layer. + Args: + input (tensor): Tensor to be batch normalized + training (bool tensor): Boolean tensor of shape [] + decay (float): Decay used for exponential moving average + epsilon (float): Small constant added to variance to prevent + division of the form 0/0 + name (string): variable scope name + reuse_vars (bool): Value passed to reuse keyword argument of + tf.variable_scope + Returns: + output (tensor): Batch normalized output tensor + """ + bn = BatchNorm(input, training, decay, epsilon, name) + output = bn.output + + return output diff --git a/tftools/placeholder_management.py b/tftools/placeholder_management.py new file mode 100644 index 0000000..fe40ab5 --- /dev/null +++ b/tftools/placeholder_management.py @@ -0,0 +1,182 @@ +"""This module defines a helper class for managing placeholders in Tensorflow. +The PlholderManager class allows easy management of placeholders including +adding placeholders to a Tensorflow graph, producing easy access to the added +placeholders using a dictionary with placeholder names as keys, create feed +dictionary from a given input dictionary, and print placeholders and feed dict +to standard output. +More importantly, the class allows sparse scipy matrices to +be passed into graphs (Tensorflow currently allows only dense matrices to be fed +into placeholders). +Usage: + pm = PlaceholderManager() + pm.add_placeholder('x1', tf.float64, [1,2]) + pm.add_placeholder('x2', tf.float64, [1,2]) + # Use placeholders in your graph + y = pm['x1'] + pm['x2'] + # Create feed dictionary + feed_dict = pm.get_feed_dict({'x1': np.array([3.0, 4.0]), + 'x2': np.array([5.0, 2.0])}) + y.eval(feed_dict) +""" +import tensorflow as tf +import numpy as np +import scipy.sparse as sps +import pdb + + +class PlaceholderManager(): + """Class for managing placeholders.""" + + def __init__(self): + self._placeholders = dict() + self.issparse = dict() + + def add_placeholder(self, name, dtype, shape=None, sparse=False): + """Add placeholder. + If the sparse is True then 3 placeholders are automatically created + corresponding to the indices and values of the non-zero entries, and + shape of the sparse matrix. The user does not need to keep track of + these and can directly pass a sparse scipy matrix as input and a + Tensorflow SparseTensor object is made available for use in the graph. + Args: + name (str): Name of the placeholder. + dtype (tf.Dtype): Data type for the placeholder. + shape (list of ints): Shape of the placeholder. + sparse (bool): Specifies if the placeholder takes sparse inputs. + """ + + self.issparse[name] = sparse + if not sparse: + self._placeholders[name] = tf.placeholder(dtype, shape, name) + + else: + name_indices = name + '_indices' + name_values = name + '_values' + name_shape = name + '_shape' + + self._placeholders[name_indices] = tf.placeholder(tf.int64, [None, 2], + name_indices) + self._placeholders[name_values] = tf.placeholder(dtype, [None], + name_values) + self._placeholders[name_shape] = tf.placeholder(tf.int64, [2], + name_shape) + + def __getitem__(self, name): + """Returns placeholder with the given name. + Usage: + plh_mgr = PlaceholderManager() + plh_mgr.add_placeholder('var_name', tf.int64, sparse=True) + placeholder = plh_mgr['var_name'] + """ + sparse = self.issparse[name] + if not sparse: + placeholder = self._placeholders[name] + else: + placeholder_indices = self._placeholders[name + '_indices'] + placeholder_values = self._placeholders[name + '_values'] + placeholder_shape = self._placeholders[name + '_shape'] + sparse_tensor = tf.SparseTensor( + placeholder_indices, placeholder_values, placeholder_shape) + placeholder = sparse_tensor + + return placeholder + + def get_placeholders(self): + """Returns a dictionary of placeholders with names as keys. + The returned dictionary provides an easy way of refering to the + placeholders and passing them to graph construction or evaluation + functions. + """ + placeholders = dict() + for name in self.issparse.keys(): + placeholders[name] = self[name] + + return placeholders + + def get_feed_dict(self, inputs): + """Returns a feed dictionary that can be passed to eval() or run(). + This method creates a feed dictionary from provided inputs that can be + passed directly into eval() or run() routines in Tensorflow. + Usage: + pm = PlaceholderManager() + pm.add_placeholder('x', tf.float64, [1,2]) + pm.add_placeholder('y', tf.float64, [1,2]) + z = pm['x'] + pm['y'] + inputs = { + 'x': np.array([3.0, 4.0]), + 'y': np.array([5.0, 2.0]) + } + feed_dict = pm.get_feed_dict(inputs) + z.eval(feed_dict) + Args: + inputs (dict): A dictionary with placeholder names as keys and the + inputs to be passed in as the values. For 'sparse' placeholders + only the sparse scipy matrix needs to be passed in instead of + 3 separate dense matrices of indices, values and shape. + """ + feed_dict = dict() + for name, input_value in inputs.items(): + try: + placeholder_sparsity = self.issparse[name] + input_sparsity = sps.issparse(input_value) + assert_str = 'Sparsity of placeholder and input do not match' + assert (placeholder_sparsity == input_sparsity), assert_str + if not input_sparsity: + placeholder = self._placeholders[name] + feed_dict[placeholder] = input_value + + else: + I, J, V = sps.find(input_value) + + placeholder_indices = self._placeholders[name + '_indices'] + placeholder_values = self._placeholders[name + '_values'] + placeholder_shape = self._placeholders[name + '_shape'] + + feed_dict[placeholder_indices] = \ + np.column_stack([I, J]).astype(np.int64) + feed_dict[placeholder_shape] = \ + np.array(input_value.shape).astype(np.int64) + if placeholder_values.dtype == tf.int64: + feed_dict[placeholder_values] = V.astype(np.int64) + else: + feed_dict[placeholder_values] = V + + except KeyError: + print "No placeholder with name '{}'".format(name) + raise + + return feed_dict + + def feed_dict_debug_string(self, feed_dict): + """Returns feed dictionary as a neat string. + Args: + feed_dict (dict): Output of get_feed_dict() or a dictionary with + placeholders as keys and the values to be fed into the graph as + values. + """ + + debug_str = 'feed_dict={\n' + for plh, value in feed_dict.items(): + debug_str += '{}: \n{}\n'.format(plh, value) + debug_str += '}' + return debug_str + + def placeholder_debug_string(self, placeholders=None): + """Returns placeholder information as a neat string. + Args: + placeholders (dict): Output of get_placeholders() or None in which + case self._placeholders is used. + """ + + if not placeholders: + placeholders = self._placeholders + + debug_str = 'placeholders={\n' + for name, plh in placeholders.items(): + debug_str += " '{}': {}\n".format(name, plh) + debug_str += '}' + return debug_str + + def __str__(self): + """Allows PlaceholderManager object to be used with print or str().""" + return self.placeholder_debug_string() diff --git a/visual_genome_parser.py b/visual_genome_parser.py index 5a57cf9..3a21864 100644 --- a/visual_genome_parser.py +++ b/visual_genome_parser.py @@ -17,6 +17,8 @@ _attributes = 'attributes.json' _objects_in_image = 'objects_in_image.json' _regions_in_image = 'regions_in_image.json' _regions_with_attributes = 'regions_with_attributes.json' +_region_descriptions = 'region_descriptions.json' +_question_answers = 'question_answers.json' _regions = 'regions.json' _raw_object_labels = 'raw_object_labels.json' _raw_attribute_labels = 'raw_attribute_labels.json' @@ -25,6 +27,10 @@ _attribute_labels = 'attribute_labels.json' _regions_with_labels = 'region_with_labels.json' _unknown_token = 'UNK' _unopenable_images = 'unopenable_images.json' +_vocab = 'vocab.json' +_answer_vocab = 'answer_vocab.json' +_vocab_subset = 'vocab_subset.json' +_answer_vocab_subset = 'answer_vocab_subset.json' _im_w = 224 _im_h = 224 _pool_size = 10 @@ -185,6 +191,7 @@ def normalized_labels(): print 'Number of attribute labels: {}'.format(attribute_count) + def normalize_region_object_attribute_labels(): regions_with_attributes_filename = os.path.join(_outdir, _regions_with_attributes) @@ -377,7 +384,114 @@ def crop_regions(): print region_id, region_data raise - + +def construct_vocabulary(): + question_answers_filename = os.path.join(_datadir, _question_answers) + with open(question_answers_filename) as file: + question_answers = json.load(file) + + vocab = dict() + answer_vocab = dict() + tokenizer = nltk.tokenize.RegexpTokenizer("([^\W\d]+'[^\W\d]+)|([^\W\d]+)") +# tokenizer = nltk.tokenize.RegexpTokenizer("[^-?.,:* \d\"]+") + for image_qas in question_answers: + for qa in image_qas['qas']: + answer_words = tokenizer.tokenize(qa['answer']) + question_words = tokenizer.tokenize(qa['question']) + for word in question_words + answer_words: + word_lower ="".join(word).lower() + if word_lower in vocab: + vocab[word_lower] += 1 + else: + vocab[word_lower] = 1 + + answer = [] + for word in answer_words: + answer.append("".join(word).lower()) + answer = " ".join(answer) + if answer in answer_vocab: + answer_vocab[answer] += 1 + else: + answer_vocab[answer] = 1 + + vocab_filename = os.path.join(_outdir, _vocab) + with open(vocab_filename, 'w') as outfile: + json.dump(vocab, outfile, sort_keys=True, indent=4) + + answer_vocab_filename = os.path.join(_outdir, _answer_vocab) + with open(answer_vocab_filename, 'w') as outfile: + json.dump(answer_vocab, outfile, sort_keys=True, indent=4) + + print "Vocab Size: {}".format(len(vocab)) + print "Answer Vocab Size: {}".format(len(answer_vocab)) + + +def select_vocab_subset(k): + vocab_filename = os.path.join(_outdir, _vocab) + with open(vocab_filename, 'r') as file: + vocab = json.load(file) + + sorted_vocab = \ + [key for key, value in sorted(vocab.items(), + key = operator.itemgetter(1), + reverse = True)] + + vocab_subset = dict() + vocab_subset_chars_only_size = min(k,len(sorted_vocab)) + + for i in xrange(vocab_subset_chars_only_size): + vocab_subset[sorted_vocab[i]] = i + vocab_subset[_unknown_token] = vocab_subset_chars_only_size + + + for i in xrange(10): + str_i = str(i) + vocab_subset[str_i] = vocab_subset_chars_only_size + i + + vocab_subset_filename = os.path.join(_outdir, _vocab_subset) + with open(vocab_subset_filename, 'w') as outfile: + json.dump(vocab_subset, outfile, sort_keys=True, indent=4) + + recalled = 0.0 + not_recalled = 0.0 + for word in vocab: + if word in vocab_subset: + recalled += vocab[word] + else: + not_recalled += vocab[word] + + print 'Recall: {}'.format(recalled/(recalled + not_recalled)) + +def select_answer_subset(k): + answer_vocab_filename = os.path.join(_outdir, _answer_vocab) + with open(answer_vocab_filename, 'r') as file: + answer_vocab = json.load(file) + + sorted_answer_vocab = \ + [key for key, value in sorted(answer_vocab.items(), + key = operator.itemgetter(1), + reverse = True)] + + answer_vocab_subset = dict() + for i in xrange(min(k,len(sorted_answer_vocab))): + answer_vocab_subset[sorted_answer_vocab[i]] = i + + answer_vocab_subset_filename = os.path.join(_outdir, _answer_vocab_subset) + with open(answer_vocab_subset_filename, 'w') as outfile: + json.dump(answer_vocab_subset, outfile, sort_keys=True, indent=4) + + recalled = 0.0 + not_recalled = 0.0 + for answer in answer_vocab: + if answer in answer_vocab_subset: + recalled += answer_vocab[answer] + else: + print answer + not_recalled += answer_vocab[answer] + + print 'Recall: {}'.format(recalled/(recalled + not_recalled)) + + if __name__=='__main__': # parse_objects() # parse_attributes() @@ -388,3 +502,6 @@ if __name__=='__main__': # top_k_object_labels(1000) # top_k_attribute_labels(1000) # crop_regions_parallel() + # construct_vocabulary() + select_vocab_subset(10000) + # select_answer_subset(5000) diff --git a/word2vec/__init__.py b/word2vec/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/word2vec/get_vocab_word_vectors.py b/word2vec/get_vocab_word_vectors.py new file mode 100644 index 0000000..1539b21 --- /dev/null +++ b/word2vec/get_vocab_word_vectors.py @@ -0,0 +1,45 @@ +from gensim.models import word2vec +import numpy as np +import json +import pdb +import constants + +def get_vocab_word_vectors( + model, + vocab): + + vocab_size = len(vocab) + + assert_str = 'predefined word vector size does not match model' + assert (model['test'].shape[0]==constants.word_vector_size), assert_str + + vocab_word_vectors = 2*np.random.rand( + vocab_size, + constants.word_vector_size) + vocab_word_vectors -= 0.5 + + found_word_vec = 0 + for word, index in vocab.items(): + if word in model: + found_word_vec += 1 + vocab_word_vectors[index,:] = model[word] + + np.save( + constants.vocab_word_vectors_npy, + vocab_word_vectors) + + print 'Found word vectors for {} out of {} words'.format( + found_word_vec, + vocab_size) + + +if __name__=='__main__': + model = word2vec.Word2Vec.load_word2vec_format( + constants.word2vec_binary, + binary=True) + + with open(constants.vocab_json, 'r') as file: + vocab = json.load(file) + + get_vocab_word_vectors(model, vocab) + -- GitLab