Skip to content
Snippets Groups Projects
hicopy.py 26.49 KiB
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 30 23:36:15 2017

@author: Xiaodan Du
"""
__author__ = 'Xiaodan Du'
__version__ = '1.0'

import json
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 
from pprint import pprint
import fileinput
import time
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.image as mpimg
from matplotlib import rcParams
from matplotlib.lines import Line2D
from matplotlib.collections import PatchCollection
import numpy as np
import copy
import itertools
#from . import mask as maskUtils
import os
from collections import defaultdict
import sys
'''
This is a function to read the JSON file of HICO-DET dataset
index/id of image and cat(hoi) starts from 1 instead of 0
index/ids must be integer and must be within the total number
Default location of the JSON file:'C:\\Users\\du_xi\\Dropbox\\Research Shuai Tang\\list.json'
Default location of the images:'C:\\Users\\du_xi\\.spyder-py3\\hico_20160224_det'

Functions:
    
    test_dim: returns the dimension of the input list
    
    hicopy Class:
        -checkMultiHoi: checks if an image has multiple hois
        -createIndex: a helper function of the constructor to create the class variables
        -getAllNname: returns a list of all object names
        -getAllVname: returns a list of all verb names
        -getCatIds: gets the hoi ids that contains given verbs and objects
        -getHoiIds: returns the ids that contain certain verbs and objects
        -getHoiIdsBasedOnVname: returns a list of all hoi ids that contain the given verb name
        -getHoiIdsBasedOnNname: returns a list of all hoi ids that contain the given objective name
        -getHoiNames: returns a list of hoi ids given image ids
        -getHoiNum: returns the total number of hois
        -getImgHoi: returns a list of hoi ids and a list of how many humans in each hoi of a given image id   
        -getImgId: a helper function of getImgIds
        -getImgIds: given a list of image ids and a list of hoi ids, returns a list of the ids of the images that belong to the given hois
        -getNnames: returns the names of given object ids 
        -getNnameIds: returns the ids of given object names
        -getNumOfNoInteraction: returns the number of "no_interaction" humans of a given image id
        -getObjectNum: returns the total number of objects
        -getVerbNum: returns the total number of verbs
        -getVnameIds: returns the ids of given verb names
        -getVnames: returns the names of given verb ids       
        -hicopy: Object that contains all information of the given JSON file      
        -loadCats: returns a list of strings in the form of "verb object" with given hoi ids
        -loadImgs: returns a list of images required
        -readImgs: returns a list of images required (images are in np array format)
        -visualize: visualize the bounding boxes and human-object interaction of a specific image
        -visualize_box_conn_one: a helper function of visualize to create bounding boxes and lines
'''
def test_dim(testlist, dim=0):
       """
       This is a function from the internect:"https://stackoverflow.com/questions/15985389/python-check-if-list-is-multidimensional-or-one-dimensional"
       tests if testlist is a list and how many dimensions it has
       returns -1 if it is no list at all, 0 if list is empty 
       and otherwise the dimensions of it
       """
       if isinstance(testlist, list):
          if testlist == []:
              return dim
          dim = dim + 1
          dim = test_dim(testlist[0], dim)
          return dim
       else:
          if dim == 0:
              return -1
          else:
              return dim
          
class hicopy:
    def __init__(self, annotation_file='C:\\Users\\du_xi\\Dropbox\\Research Shuai Tang\\list.json'):
        """
        The constructor of hicopy class
        input: 
            annotation_file: the location of the JSON file that contains all information. Default location only works for me. Please
                            change it to where you put the JSON file.
        """
        self.dataset=dict()
        self.anno_test=np.empty
        self.anno_train=np.empty
        self.bbox_test=list()
        self.bbox_train=list()
        self.list_action=list()
        self.list_test=list()
        self.list_train=list()
        if not annotation_file == None:
            print('loading annotations into memory...')
            tic = time.time()
            dataset = json.load(open(annotation_file, 'r'))
            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
            print('Done (t={:0.2f}s)'.format(time.time()- tic))
            self.dataset = dataset
            self.createIndex()
            

    def createIndex(self):
        """
        This function creates the seven class members of hicopy objects
        """
        # create index
        print('creating index...')
        anno_test=np.empty
        anno_train=np.empty
        bbox_test=list()
        bbox_train=list()
        list_action=list()
        list_test=list()
        list_train=list()
        if 'anno_test' in self.dataset:
            anno_test_list = self.dataset['anno_test']
            anno_test_list = [[x if x !='null' else np.nan for x in y] for y in anno_test_list]
            anno_test=np.array(anno_test_list)
        if 'anno_train' in self.dataset:
            anno_train_list = self.dataset['anno_train']
            anno_train_list = [[x if x !='null' else np.nan for x in y] for y in anno_train_list]
            anno_train=np.array(anno_train_list)
        if 'list_test' in self.dataset:
            list_test = self.dataset['list_test']
        if 'list_train' in self.dataset:
            list_train = self.dataset['list_train']
        if 'bbox_test' in self.dataset:
            bbox_test = self.dataset['bbox_test']
        if 'bbox_train' in self.dataset:
            bbox_train = self.dataset['bbox_train']
        if 'list_action' in self.dataset:
            list_action = self.dataset['list_action']
        print('index created!')

        # create class members
        self.anno_test = anno_test
        self.anno_train = anno_train
        self.list_test = list_test
        self.list_train = list_train
        self.bbox_test = bbox_test
        self.bbox_train = bbox_train
        self.list_action = list_action
        
    def getImgIds(self, imgIds=[], catIds=[]):
        """
        inputs:
            imgIds: a list of image ids
            catIds: a list of hoi ids
        outputs:
            list(ids): a list of image ids that contain given hois
        Note: If imgIds is empty, function will use all image ids as imgIds
              If catIds is empty, fucntion will use all hoi ids as catIds
              In case both imgIds and catIds are empty, function will return all hoi ids
        """
#        ids=list()
#        if dataType=='imgIds':
#            
#            for i in index:
#                ids.append(self.list_train[i])
#        elif dataYpe=='hoiIds':
#            for i in index:
        if len(imgIds) == len(catIds) == 0:
            ids = self.list_train
        else:
            ids = set(imgIds)
            for i, catId in enumerate(catIds):
                if i == 0 and len(ids) == 0:
                    ids = set(self.getImgId(catId,imgIds))
                else:
                    ids &= set(self.getImgId(catId,imgIds))
        return list(ids)
    def getImgId(self, catId,imgIds=[]):
        """
        inputs:
            catId: integer, id of hoi
            imgIds: a list of image ids
        outputs:
            ids: a list of image ids that contain the given hoi
        Note: If imgIds is empty, function will use all image ids as imgIds
        """
        ids=list()
        if len(imgIds)==0:
            for i in range(len(self.bbox_train)):
                hoiNum=self.checkMultiHoi('bbox_train',i)
                if hoiNum==0:
                    if self.bbox_train[i]['hoi']['id']==catId:
                        ids.append(i+1)
                else:
                    for j in range(hoiNum):
                        if self.bbox_train[i]['hoi'][j]['id']==catId:
                            ids.append(i+1)
                            break
        else:
            for i in imgIds:
                hoiNum=self.checkMultiHoi('bbox_train',i-1)
                if hoiNum==0:
                    if self.bbox_train[i-1]['hoi']['id']==catId:
                        ids.append(i)
                else:
                    for j in range(hoiNum):
                        if self.bbox_train[i-1]['hoi'][j]['id']==catId:
                            ids.append(i)
                            break
        return ids
                
    def checkMultiHoi(self,key,i):
        """
        inputs: 
            key: either 'bbox_train' or 'bbox_test'
            i: the index of an image in bbox_train or bbox_test. Note: i starts from 0, so if you have 
                the index of a certain image, i should be index-1
        outputs:
            0: if the image has only 1 hoi
            an integer n: if the image has n hois
            nothing: if key is incorrect
        """
        if key=='bbox_train':
            if isinstance(self.bbox_train[i]['hoi'],dict):
                return 0
            else:
                return len(self.bbox_train[i]['hoi'])
        elif key=='bbox_test':
            if isinstance(self.bbox_test[i]['hoi'],dict):
                return 0
            else:
                return len(self.bbox_test[i]['hoi'])
        else:
            print('Wrong Type of "key". "bbox_train" or "bbox_test" only!')
            return
            
    def loadImgs(self, ids=[],hicolocation='C:\\Users\\du_xi\\.spyder-py3\\hico_20160224_det'):
        """
        inputs:
            ids: a list of image ids that are to be loaded
            hicolocation: the location of all the images. Default location only works for my computer. Please change.
        outputs:
            imgs: a list of images
        Note: ids cannot be empty.
              This function uses Image.open() to load images. If you want a 3D np array, please chceck readImgs() in this class.
        """
        imgs=list()
        for i in ids:
            fileName=self.list_train[i-1]
            im_file=hicolocation+'\\images\\train2015\\'+fileName
            img = Image.open(im_file)
            imgs.append(img)
            #img.show()
        return imgs
    
    def readImgs(self,ids=[],hicolocation='C:\\Users\\du_xi\\.spyder-py3\\hico_20160224_det'):
        """
        inputs: 
            ids: a list of image ids that are to be loaded
            hicolocation: the location of all the images. Default location only works for my computer. Please change.
        outputs:
            imgs: a list of images (in np array format)
        Note:
            Dimension of the np array: (height,width,R-G-B)
        """
        imgs=list()
        for i in ids:
            fileName=self.list_train[i-1]
            im_file=hicolocation+'\\images\\train2015\\'+fileName
            img = mpimg.imread(im_file)
            imgs.append(img)
            #img.show()
        return imgs
    
    def getCatIds(self, catVNms=[],catNNms=[]):
        """
        inputs:
            catVNms: a list of string of verb names
            catNNms: a list of string of object names
        outputs:
            sorted(list(cats)): a sorted list of hoi ids that satisfy both given verbs and objects
        Note:
            If catVNms is empty, the function will use all verbs as catVNms
            If catNNms is empty, the function will use all objects as catNNms
            In case both inputs are empty, the function will directly return a list of all hoi ids
        """
        if len(catVNms) == len(catNNms) == 0:
            cats = list(range(1,len(self.list_action)+1))
        else:
            cats = list(range(1,len(self.list_action)+1))
            catsV = cats if len(catVNms) == 0 else [cat for cat in cats if self.list_action[cat-1]['vname']          in catVNms]
            catsN = cats if len(catNNms) == 0 else [cat for cat in cats if self.list_action[cat-1]['nname']          in catNNms]
            catsV=set(catsV)
            catsN=set(catsN)
            cats=catsV&catsN
        return sorted(list(cats))
    
    def loadCats(self, ids=[]):
        """
        Load cats with the specified ids.
        :param ids (int list)       : integer ids specifying hois
        :return: cats (string list) : "verb"+" "+"object"
        Note:
            If ids is empty, the function returns nothing.
        """
        if len(ids)==0:
            return
        else:
            hoi=list()
            for i in ids:
                hoi.append(self.list_action[i-1]['vname']+' '+self.list_action[i-1]['nname'])
            return hoi
        
    def getHoiNames(self, ids=[]):
        """
        inputs:
            ids: a list of hoi ids
        outputs:
            hoiIds: a list of string of names of given hoi ids
        Note:
            If ids is empty, the function will return nothing
            Elements in ids must be valid ids (integers larger than 0 and smaller or equal to 600)
        """
        if len(ids)==0:
            return
        else:
            hoiIds=list()
            for i in ids:
                hoiNum=self.checkMultiHoi('bbox_train',i-1)
                if hoiNum==0:
                    hoiIds.append(self.bbox_train[i-1]['hoi']['id'])
                else:
                    hoiIds.append([self.bbox_train[i-1]['hoi'][j]['id'] for j in range(hoiNum)])
            return hoiIds        
                                    
    def visualize(self,imageIndex,hoiIndex=[],hicolocation='C:\\Users\\du_xi\\.spyder-py3\\hico_20160224_det'):
        """
        inputs: 
            imageIndex: a list of an integer of image id
            hoiIndex: a list or an integer of hoi id
            hicolocation: the location of all the images. Default location only works for my computer. Please change.
        outputs: 
            visualization of annotations
        """
        if isinstance(imageIndex,list) and len(imageIndex)==0:
            return
        if isinstance(hoiIndex,int) and hoiIndex>0 and hoiIndex<=600:
            for i in imageIndex:
                fileName=self.bbox_train[i-1]['filename']
                im_file=hicolocation+'\\images\\train2015\\'+fileName
                
                if isinstance(self.bbox_train[i-1]['hoi'],dict):
                    hoi_invis=self.bbox_train[i-1]['hoi']['invis']
                    if self.bbox_train[i-1]['hoi']['id']==hoiIndex:
                        hoi_id = hoiIndex
                    else:
                        print('Image does not have such hoi.')
                        return
                    bboxhuman=self.bbox_train[i-1]['hoi']['bboxhuman']
                    bboxobject=self.bbox_train[i-1]['hoi']['bboxobject']
                    connection=self.bbox_train[i-1]['hoi']['connection']
                else:
                    imageHois=list()
                    for j, hoiId in enumerate(self.bbox_train[i-1]['hoi']):
                        imageHois.append(hoiId['id'])
                    if hoiIndex in imageHois:
                        k = imageHois.index(hoiIndex)
                        hoi_invis=self.bbox_train[i-1]['hoi'][k]['invis'] 
                        hoi_id = self.bbox_train[i-1]['hoi'][k]['id']
                        bboxhuman=self.bbox_train[i-1]['hoi'][k]['bboxhuman']
                        bboxobject=self.bbox_train[i-1]['hoi'][k]['bboxobject']
                        connection=self.bbox_train[i-1]['hoi'][k]['connection']
                    else:
                        print('Image does not have such hoi.')
                        return
                aname = self.list_action[hoi_id]['vname_ing']+' '+self.list_action[hoi_id]['nname']
                img=mpimg.imread(im_file)
                fig,ax = plt.subplots(1)
                imgplot = plt.imshow(img)
                plt.axis('off')
                plt.title(aname,fontSize=48,color='black')              
                print(self.bbox_train[i-1])
                if hoi_invis:
                    print('hoi not visible\n')
                else:
                    self.visualize_box_conn_one(ax,bboxhuman, bboxobject, connection, 'b','g','r')
                rcParams['figure.figsize'] = [20, 20]
                plt.show()
        else:
            for i in imageIndex:
                fileName=self.bbox_train[i-1]['filename']
                im_file=hicolocation+'\\images\\train2015\\'+fileName
                img = Image.open(im_file)
                img.show()
        
    def visualize_box_conn_one(self, ax,bboxhuman, bboxobject, connection, color1='b',color2='g',color3='r'):
        if isinstance(bboxhuman,list):
            for r in range(len(bboxhuman)):
                rt = [bboxhuman[r]['x1'], bboxhuman[r]['y1'], bboxhuman[r]['x2']-bboxhuman[r]['x1']+1, bboxhuman[r]['y2']-bboxhuman[r]['y1']+1]
                rect = patches.Rectangle((rt[0],rt[1]),rt[2],rt[3],linewidth=4,edgecolor=color1,facecolor='none')
                ax.add_patch(rect)
        else:
            rt = [bboxhuman['x1'], bboxhuman['y1'], bboxhuman['x2']-bboxhuman['x1']+1, bboxhuman['y2']-bboxhuman['y1']+1]
            rect=patches.Rectangle((rt[0],rt[1]),rt[2],rt[3],linewidth=4,edgecolor=color1,facecolor='none')
            ax.add_patch(rect)
        if isinstance(bboxobject,list):
            for r in range(len(bboxobject)):
                rt = [bboxobject[r]['x1'], bboxobject[r]['y1'], bboxobject[r]['x2']-bboxobject[r]['x1']+1, bboxobject[r]['y2']-bboxobject[r]['y1']+1]
                rect2 = patches.Rectangle((rt[0],rt[1]),rt[2],rt[3],linewidth=4,edgecolor=color2,facecolor='none')
                ax.add_patch(rect2)
        else:
            rt = [bboxobject['x1'], bboxobject['y1'], bboxobject['x2']-bboxobject['x1']+1, bboxobject['y2']-bboxobject['y1']+1]
            rect2=patches.Rectangle((rt[0],rt[1]),rt[2],rt[3],linewidth=4,edgecolor=color2,facecolor='none')
            ax.add_patch(rect2)    
        if test_dim(connection)==1:
            rt1=bboxhuman
            rt2=bboxobject
            ct1 = [(rt1['x1']+rt1['x2'])/2, (rt1['y1']+rt1['y2'])/2]
            ct2 = [(rt2['x1']+rt2['x2'])/2, (rt2['y1']+rt2['y2'])/2]
            markers_on = [0, -1]
            line = Line2D([ct1[0],ct2[0]],[ct1[1],ct2[1]],linewidth=4,color=color3,markevery=markers_on,marker='o',markersize=10,markerfacecolor=color3)
            ax.add_line(line)
        else:
            for r in range(len(connection)):
                rt1=bboxhuman[connection[r][0]-1]
                rt2=bboxobject[connection[r][1]-1]
                ct1 = [(rt1['x1']+rt1['x2'])/2, (rt1['y1']+rt1['y2'])/2]
                ct2 = [(rt2['x1']+rt2['x2'])/2, (rt2['y1']+rt2['y2'])/2]
                markers_on = [0, -1]
                line = Line2D([ct1[0],ct2[0]],[ct1[1],ct2[1]],linewidth=4,color=color3,markevery=markers_on,marker='o',markersize=10,markerfacecolor=color3)
                ax.add_line(line)
            
    def getAllNname(self):
        allNname=list()
        for i in range(len(self.list_action)):
            if self.list_action[i]['nname'] not in allNname:
                allNname.append(self.list_action[i]['nname'])
        return sorted(allNname)
    
    def getAllVname(self):
        allVname=list()
        for i in range(len(self.list_action)):
            if self.list_action[i]['vname'] not in allVname:
                allVname.append(self.list_action[i]['vname'])
        return sorted(allVname)        
    
    def getNnames(self,index):
        """
        inputs:
            index: a list or an integer of object ids
        outputs:
            a list of names of given objects
        Note: 
            If index is empty, the function will return all object names
        """
        allNname=self.getAllNname()
        if index==[]:
            return allNname
        return [allNname[i-1] for i in index] if isinstance(index,list) else allNname[index-1]
    
    def getVnames(self,index):
        """
        inputs:
            index: a list or an integer of verb ids
        outputs:
            a list of names of given verbs
        Note: 
            If index is empty, the function will return all verb names
        """
        allVname=self.getAllVname()
        if index==[]:
            return allVname
        return [allVname[i-1] for i in index] if isinstance(index,list) else allVname[index-1]
    
    def getNnameIds(self,nname):
        """
        inputs:
            nname: a list or an integer of object ids
        outputs:
            a list of ids of given object names
        Note:
            If nname is empty, the function will return a list of all object ids
        """
        allNname=self.getAllNname()
        if nname==[]:
            return sorted(range(1,self.getObjectNum()+1))
        return [allNname.index(i)+1 for i in nname] if isinstance(nname,list) else allNname.index(nname)+1
    
    def getVnameIds(self,vname):
        """
        inputs:
            vname: a list or an integer of verb ids
        outputs:
            a list of ids of given verb names
        Note:
            If vname is empty, the function will return a list of all verb ids
        """
        allVname=self.getAllVname()
        if vname==[]:
            return sorted(range(1,self.getVerbNum()+1))
        return [allVname.index(i)+1 for i in vname] if isinstance(vname,list) else allVname.index(vname)+1
    
    def getVerbNum(self):
        """
        outputs:
            an integer of the total types of verbs
        """
        return len(self.getAllVname())
    
    def getObjectNum(self):
        """
        outputs:
            an integer of the total types of objectives
        """
        return len(self.getAllNname())
    
    def getHoiNum(self):
        """
        outputs:
            an integer of the total types of hois
        """
        return len(self.list_action)
        
    def getHoiIds(self,**vn):
        """
        inputs:
            **vn: input keys can only be "vname" or "nname". 
                  "vname": a list of strings of verb names
                  "nname": a list of strings of object names
        outputs:
            sorted(HoiIds): a sorted list of hoi ids
        """
        if any(key=='vname' or 'nname' for key in vn)==0:
            return sorted(range(1,self.getHoiNum()+1))
        if all(key=='vname' or 'nname' for key in vn)==0:
            print('Inputs can only be "vname" or "nname"!')
            return
        if 'vname' in vn:
            vHoiIds=set(self.getHoiIdsBasedOnVname(vn['vname']))
        else:
            vHoiIds=set(range(1,self.getHoiNum()+1))
        if 'nname' in vn:
            nHoiIds=set(self.getHoiIdsBasedOnNname(vn['nname']))
        else:
            nHoiIds=set(range(1,self.getHoiNum()+1))
        HoiIds=vHoiIds&nHoiIds
        return sorted(HoiIds)
            
    def getHoiIdsBasedOnVname(self,vname):
        """
        inputs:
            vname: a list of verb names
        outputs:
            sorted(HoiIds): a sorted list of hoi ids
        """
        HoiIds=list()
        if vname==[]:
            vname = self.getVnameIds(vname)
        if isinstance(vname,list):
            for i in range(self.getHoiNum()):            
                if self.list_action[i]['vname'] in self.getVnames(vname):
                    HoiIds.append(i+1)
        else:
            for i in range(self.getHoiNum()):
                if self.list_action[i]['vname'] == self.getVnames(vname):
                    HoiIds.append(i+1)
        return sorted(HoiIds)
    
    def getHoiIdsBasedOnNname(self,nname):
        """
        inputs:
            nname: a list of object names
        outputs:
            sorted(HoiIds): a sorted list of hoi ids
        """
        HoiIds=list()
        if nname==[]:
            nname = self.getNnameIds(nname)
        if isinstance(nname,list):
            for i in range(self.getHoiNum()):            
                if self.list_action[i]['nname'] in self.getNnames(nname):
                    HoiIds.append(i+1)
        else:
            for i in range(self.getHoiNum()):
                if self.list_action[i]['nname'] == self.getNnames(nname):
                    HoiIds.append(i+1)
        return sorted(HoiIds)
    
    def getNnameIdsBasedOnHoiId(self,hoiId):
        """
        inputs: 
            hoiId: an integer of hoi id
        outputs:
            nnameId: an integer of the id of the object in the given hoi
        """
        nname=self.list_action[hoiId-1]['nname']
        nnameId=self.getNnameIds(nname)
        return nnameId
    
    def getVnameIdsBasedOnHoiId(self,hoiId):
        """
        inputs: 
            hoiId: an integer of hoi id
        outputs:
            vnameId: an integer of the id of the verb in the given hoi
        """
        vname=self.list_action[hoiId-1]['vname']
        vnameId=self.getVnameIds(vname)
        return vnameId
    
    def getImgHoi(self,index):
        """
        inputs: 
            index: an integer of the id of an image
        outputs:
            imgHoi: an integer of the hoi id of the given image
            personNum: the number of humans in the image
        """
        imgHoi=list()
        personNum=list()
        if isinstance(self.bbox_train[index-1]['hoi'],list):
            for currHoi in self.bbox_train[index-1]['hoi']:
                imgHoi.append(currHoi['id'])
                if isinstance(currHoi['bboxhuman'],dict):
                    personNum.append(1)
                elif isinstance(currHoi['bboxhuman'],list):
                    personNum.append(len(currHoi['bboxhuman']))
                else:
                    personNum.append(0)
        elif isinstance(self.bbox_train[index-1]['hoi'],dict):
            imgHoi.append(self.bbox_train[index-1]['hoi']['id'])
            if isinstance(self.bbox_train[index-1]['hoi']['bboxhuman'],dict):
                personNum.append(1)
            elif isinstance(self.bbox_train[index-1]['hoi']['bboxhuman'],list):
                personNum.append(len(self.bbox_train[index-1]['hoi']['bboxhuman']))
            else:
                personNum.append(0)
        return imgHoi,personNum
    
    def getNumOfNoInteraction(self,index):
        """
        inputs:
            index: an integer of the id of an image
        outputs:
            noInteraction: the number of "no_interaction" in a given image
        """
        imgHoi,personNum=self.getImgHoi(index)
        noInteraction=0
        noInteractionId=self.getVnameIds(['no_interaction'])[0]
        for i,hoi in enumerate(imgHoi):
            if self.getVnameIdsBasedOnHoiId(hoi)==noInteractionId:
                noInteraction=noInteraction+personNum[i]
        return noInteraction