train.py 13.51 KiB
#!/usr/bin/env python
from __future__ import print_function, division
import argparse
import time
import os
from six.moves import cPickle
import tensorflow as tf
from utils import TextLoader
from model import Model
import simulate
from tensorflow.python.client import timeline
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Data and model checkpoints directories
parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare',
help='data directory containing input.txt with training examples')
parser.add_argument('--save_dir', type=str, default='save',
help='directory to store checkpointed models')
parser.add_argument('--log_dir', type=str, default='logs',
help='directory to store tensorboard logs')
parser.add_argument('--save_every', type=int, default=1000,
help='Save frequency. Number of passes between checkpoints of the model.')
parser.add_argument('--init_from', type=str, default=None,
help="""continue training from saved model at this path (usually "save").
Path must contain files saved by previous training process:
'config.pkl' : configuration;
'chars_vocab.pkl' : vocabulary definitions;
'checkpoint' : paths to model file(s) (created by tf).
Note: this file contains absolute paths, be careful when moving files around;
'model.ckpt-*' : file(s) with model definition (created by tf)
Model params must be the same between multiple runs (model, rnn_size, num_layers and seq_length).
""")
# Model params
parser.add_argument('--model', type=str, default='lstm',
help='lstm, rnn, gru, or nas')
parser.add_argument('--rnn_size', type=int, default=128,
help='size of RNN hidden state')
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
# Optimization
parser.add_argument('--seq_length', type=int, default=50,
help='RNN sequence length. Number of timesteps to unroll for.')
parser.add_argument('--batch_size', type=int, default=50,
help="""minibatch size. Number of sequences propagated through the network in parallel.
Pick batch-sizes to fully leverage the GPU (e.g. until the memory is filled up)
commonly in the range 10-500.""")
parser.add_argument('--num_epochs', type=int, default=50,
help='number of epochs. Number of full passes through the training examples.')
parser.add_argument('--grad_clip', type=float, default=5.,
help='clip gradients at this value')
parser.add_argument('--learning_rate', type=float, default=0.002,
help='learning rate')
parser.add_argument('--decay_rate', type=float, default=0.97,
help='decay rate for rmsprop')
parser.add_argument('--output_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the hidden layer')
parser.add_argument('--input_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the input layer')
# distributed args
parser.add_argument('--distributed', help="Indicates running in distributed mode", action='store_true')
parser.add_argument('--ps_hosts', help="PS HOSTS", default=None)
parser.add_argument('--worker_hosts', help="WORKER HOSTS", default=None)
parser.add_argument('--job_name', help="Job name. Must be ps/worker", choices=['ps', 'worker'], default=None)
parser.add_argument('--task_index', help="Index of task for given job", type=int, default=None)
parser.add_argument("--tensor_file", help="Tensor file of the specific training data for given node.", default=None)
args = parser.parse_args()
##
args.job_name = "worker"
def train(args):
if args.distributed:
# PS nodes need not load data and take up ram.
if args.job_name == "worker":
# raise exception because having each node load a file and split it is
# wasteful. Instead make it a preprocessing step and have the data split beforehand.
# if args.tensor_file is None:
# raise Exception("Tensor file must be provided.")
print("running data loader ")
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length, tensor_file=args.tensor_file)
print(data_loader.vocab_size)
args.vocab_size = data_loader.vocab_size
else:
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
args.vocab_size = data_loader.vocab_size
# check compatibility if training is continued from previously saved model
if args.init_from is not None:
# check if all necessary files exist
assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
ckpt = tf.train.latest_checkpoint(args.init_from)
assert ckpt, "No checkpoint found"
# open old config and check if models are compatible
with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
saved_model_args = cPickle.load(f)
need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
for checkme in need_be_same:
assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
# open saved vocab/dict and check if vocabs/dicts are compatible
with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
saved_chars, saved_vocab = cPickle.load(f)
assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!"
assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"
if not args.distributed or args.job_name == "worker":
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
cPickle.dump(args, f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
cPickle.dump((data_loader.chars, data_loader.vocab), f)
if args.distributed:
# get the list of ps and worker hosts
##ps_hosts = args.ps_hosts.split(",")
worker_hosts = args.worker_hosts.split(",")
task_index = args.task_index # get current job and index.
# hide GPU from ps nodes, they dont need it.
# if job_name == "ps":
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# create a cluster.
cluster = tf.train.ClusterSpec({"worker": worker_hosts})
# create a tensorflow server
server = tf.train.Server(cluster, job_name="worker", task_index=task_index)
# if job_name == "ps":
is_chief=(task_index==0)
if not is_chief:
server.join()
# define our graph
with tf.device(tf.train.replica_device_setter(cluster=cluster)):
model = Model(args)
else:
model = Model(args)
# instrument for tensorboard, must be defined before monitored session
# If your model defined its own tf.Graph object then the
# following ops must be created in same graph.
summaries = tf.summary.merge_all()
writer = tf.summary.FileWriter(
os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
writer.add_graph(tf.get_default_graph())
# dont need this in distributed training scenario.
saver = tf.train.Saver(tf.global_variables())
# op to decay the learning rate
tf_epoch = tf.placeholder(tf.float32, shape=(), name="epoch_number")
learning_rate_modifier = tf.assign(model.lr, tf.constant(args.learning_rate) * (tf.constant(args.decay_rate) ** tf_epoch))
# If you're on a multiple gpu system, set the gpu environment vairable before launching
# and maybe remove the `per_process_memory_fraction`
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=1/cluster.num_tasks("worker")))
if args.distributed:
# create a training session
# ALERT: No more ops can be defined after this point.
#
graph = tf.get_default_graph()
simulate.simulate(graph, 'etf', True, 2)
#
# new
fileNameQueue = []
for file in os.listdir(args.log_dir):
fileNameQueue.append(file)
print
fileNameQueue
filename_queue = tf.train.string_input_producer(fileNameQueue, shuffle=False)
init_op = tf.global_variables_initializer()
sv = tf.train.Supervisor(is_chief=is_chief, init_op=init_op)
print("supervisor created")
with sv.managed_session(server.target) as mon_sess:
print("Start session")
tf.train.write_graph(tf.get_default_graph(), "model/", "model.pb", as_text=True)
# new
# mon_sess = tf.train.MonitoredTrainingSession(master=server.target,
# is_chief=(task_index==0),
# checkpoint_dir=args.save_dir,
# chief_only_hooks=[],
# save_checkpoint_steps=args.save_every,
# save_summaries_steps=100
# )
"""We can set checkpoint_{secs,steps} and summaries_{secs,steps} to None
if we want to manually handle those. But that is not the recommended way
since in principle only master should be in charge of those operations.
"""
# This is needed if running ops other than the main training step or ones that dont
# increment the global step. Eg. accuracy calculation. This is because running mon_sess
# might trigger the save summary step and that would fail since there are no related operations
# being ran in the current call.
# Rule of thumb: If it is linked to a op/summary op, that does NOT require training data feed,
# it goes with normal_sess
normal_sess = mon_sess._tf_sess()
# note that we dont need to call global_vairable_initializer anymore, monitored session does that for us.
# same goes for restoring models. It will automatically do it if possible.
else:
mon_sess = normal_sess = tf.Session()
normal_sess.run(tf.global_variables_initializer())
# restore model
if args.init_from is not None:
saver.restore(mon_sess, ckpt)
for e in range(args.num_epochs):
# this op goes in normal_sess since it is not a training step.
normal_sess.run(learning_rate_modifier, feed_dict={tf_epoch: e})
data_loader.reset_batch_pointer()
# again, not a training op.
state = normal_sess.run(model.initial_state)
##new
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
writer.add_run_metadata(run_metadata, 'epoch %03d' % e)
if e == 1:
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
with open('timeline.json', 'w') as f:
f.write(chrome_trace)
chrome_trace = fetched_timeline.generate_chrome_trace_format(show_memory=True)
with open('timeline_memory.json', 'w') as f:
f.write(chrome_trace)
##
for b in range(data_loader.num_batches):
start = time.time()
x, y = data_loader.next_batch()
feed = {model.input_data: x, model.targets: y}
for i, (c, h) in enumerate(model.initial_state):
feed[c] = state[i].c
feed[h] = state[i].h
summ, train_loss, state, _ = mon_sess.run([summaries, model.cost, model.final_state, model.train_op], feed)
if not args.distributed:
# do this only in single mode, distributed takes care of this.
writer.add_summary(summ, e * data_loader.num_batches + b)
end = time.time()
print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
.format(normal_sess.run(model.global_step), #e * data_loader.num_batches + b,
args.num_epochs * data_loader.num_batches,
e, train_loss, end - start))
# save stuff manually only if NOT in distributed mode.
if not args.distributed and ((e * data_loader.num_batches + b) % args.save_every == 0\
or (e == args.num_epochs-1 and
b == data_loader.num_batches-1)):
# save for the last result
checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
saver.save(mon_sess, checkpoint_path,
global_step=e * data_loader.num_batches + b)
print("model saved to {}".format(checkpoint_path))
mon_sess.close()
if __name__ == '__main__':
train(args)