Something went wrong on our end
data.py 2.93 KiB
"""Utility functions for making batch generation easier."""
import numpy as np
from multiprocessing import Process, Queue, Manager, Pool
import time
def sequential(batch_size, num_samples, num_epochs=1, offset=0):
"""Generate sequence indices.
"""
num_samples_ = int(batch_size*np.ceil(num_samples/float(batch_size)))
for epoch in range(num_epochs):
indices = np.arange(num_samples_)%num_samples + offset
indices = indices.tolist()
for i in range(0, num_samples_ - batch_size + 1, batch_size):
yield indices[i:i+batch_size]
def random(batch_size, num_samples, num_epochs, offset=0):
"""Generate random indices.
"""
num_samples_ = int(batch_size*np.ceil(num_samples/float(batch_size)))
for epoch in range(num_epochs):
indices = np.random.permutation(num_samples_)%num_samples + offset
indices = indices.tolist()
for i in range(0, num_samples_ - batch_size + 1, batch_size):
yield indices[i:i+batch_size]
def batch_generator(data, index_generator, batch_function=None):
"""Generate batches of data.
"""
for samples in index_generator:
batch = data.get_parallel(samples)
if batch_function:
output = batch_function(batch)
else:
output = batch
yield output
def async_batch_generator(data, index_generator, queue_maxsize, batch_function=None):
"""Create an asynchronous batch generator.
"""
batcher = batch_generator(data, index_generator, batch_function)
queue = Queue(maxsize=queue_maxsize)
fetcher = BatchFetcher(queue, batcher)
fetcher.start()
# time.sleep(10)
queue_batcher = queue_generator(queue)
return queue_batcher
def queue_generator(queue, sentinel=None):
"""Create a generator from a queue.
"""
while True:
value = queue.get()
if value is not sentinel:
yield value
else:
return
class BatchFetcher(Process):
def __init__(self, queue, batch_generator):
super(BatchFetcher, self).__init__()
self.queue = queue
self.batch_generator = batch_generator
def run(self):
for batch in self.batch_generator:
self.queue.put(batch)
# Signal that fetcher is done.
self.queue.put(None)
class NumpyData(object):
def __init__(self, array):
self.array = array
def get(self, indices):
if not isinstance(indices, np.ndarray):
indices = np.array(indices)
return self.array[indices]
# def write_to_queue((queue, samples, data)):
# print samples
# if samples:
# batch = data.get(samples)
# queue.put(batch)
# else:
# queue.put(None)
if __name__=='__main__':
batch_size = 10
num_samples = 20
num_epochs = 3
offset = 0
index_generator = random(batch_size, num_samples, num_epochs, offset)
for samples in index_generator:
print samples