diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a7211f4dfeb610e8887936be3ba4353e41a22fcc --- /dev/null +++ b/metrics.py @@ -0,0 +1,25 @@ +def fraction_in_top_k(y, y_pred): + y_list = tf.unpack(y) + y_pred_list = tf.unpack(y_pred) + accuracy = 0.0 + for y_, y_pred_ in zip(y_list, y_pred_list): + k = tf.reduce_sum(y_) + pos_label_ids = tf.nn.top_k(y_,k) + pos_label_ids_list = tf.unpack(pos_label_ids) + num_pos_labels_in_top_k = 0.0 + for pos_label_id in pos_label_ids_list: + is_present = tf.nn.in_top_k(y_pred_,pos_label_id,k) + num_pos_labels_in_top_k += tf.cast(is_present, tf.float32) + + frac_pos_labels_in_top_k = num_pos_labels_in_top_k/tf.maximum(k,1e-5) + accuracy += frac_pos_labels_in_top_k + + accuracy /= len(y_list) + return accuracy + +def presence_in_top_k(y, y_pred, k): + max_ids = tf.argmax(y,1) + is_present = tf.nn.in_top_k(y_pred, max_ids, k) + is_present = tf.cast(is_present, tf.float32) + return tf.reduce_mean(is_present) +