Skip to content
Snippets Groups Projects
data.py 2.93 KiB
"""Utility functions for making batch generation easier."""
import numpy as np
from multiprocessing import Process, Queue, Manager, Pool
import time

def sequential(batch_size, num_samples, num_epochs=1, offset=0):
    """Generate sequence indices.
    """
    num_samples_ = int(batch_size*np.ceil(num_samples/float(batch_size)))
    for epoch in range(num_epochs):
        indices = np.arange(num_samples_)%num_samples + offset
        indices = indices.tolist()
        for i in range(0, num_samples_ - batch_size + 1, batch_size):
            yield indices[i:i+batch_size]


def random(batch_size, num_samples, num_epochs, offset=0):
    """Generate random indices.
    """
    num_samples_ = int(batch_size*np.ceil(num_samples/float(batch_size)))
    for epoch in range(num_epochs):
        indices = np.random.permutation(num_samples_)%num_samples + offset
        indices = indices.tolist()
        for i in range(0, num_samples_ - batch_size + 1, batch_size):
            yield indices[i:i+batch_size]


def batch_generator(data, index_generator, batch_function=None):
    """Generate batches of data.
    """
    for samples in index_generator:
        batch = data.get_parallel(samples)
        if batch_function:
            output = batch_function(batch)
        else:
            output = batch
        yield output


def async_batch_generator(data, index_generator, queue_maxsize, batch_function=None):
    """Create an asynchronous batch generator.
    """
    batcher = batch_generator(data, index_generator, batch_function)

    queue = Queue(maxsize=queue_maxsize)
    fetcher = BatchFetcher(queue, batcher)
    fetcher.start()

    # time.sleep(10)
    
    queue_batcher = queue_generator(queue)
    return queue_batcher


def queue_generator(queue, sentinel=None):
    """Create a generator from a queue.
    """
    while True:
        value = queue.get()
        if value is not sentinel:
            yield value
        else:
            return


class BatchFetcher(Process):
    def __init__(self, queue, batch_generator):
        super(BatchFetcher, self).__init__()
        self.queue = queue
        self.batch_generator = batch_generator

    def run(self):

        for batch in self.batch_generator:
            self.queue.put(batch)

        # Signal that fetcher is done.
        self.queue.put(None)


class NumpyData(object):
    def __init__(self, array):
        self.array = array

    def get(self, indices):
        if not isinstance(indices, np.ndarray):
            indices = np.array(indices)
        return self.array[indices]


# def write_to_queue((queue, samples, data)):
#     print samples
#     if samples:
#         batch = data.get(samples)
#         queue.put(batch)
#     else:
#         queue.put(None)


if __name__=='__main__':
    batch_size = 10
    num_samples = 20
    num_epochs = 3
    offset = 0
    index_generator = random(batch_size, num_samples, num_epochs, offset)
    for samples in index_generator:
        print samples