-
Notifications
You must be signed in to change notification settings - Fork 943
Conversion Guide
The process of converting an existing TensorFlow application is fairly simple.
We have included several sample TensorFlow applications in this repo to
illustrate the conversion steps.
We highlight the main points below.
Every TensorFlow application will have a file containing a main()
function and a call to tf.app.run()
. In that file, please add the following imports:
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from com.yahoo.ml.tf import TFCluster, TFNode
from datetime import datetime
The argv
parameter will contain a full copy of the arguments supplied at the PySpark command line, while the ctx
parameter will contain node metadata, like job_name
and task_id
. Also, make sure that the import tensorflow as tf
occurs within this function, since this will be executed/imported on the executors. And, if there are any functions used by the main function, ensure that they are defined or imported inside the main_fun
block.
# def main():
def main_fun(argv, ctx)
import tensorflow as tf
tf.app.run()
executes the TensorFlow main function. Replace it with the following code to set up PySpark and launch TensorFlow on the executors. Note that we're using argparse
here mostly because the tf.app.FLAGS
mechanism is currently not an officially supported TensorFlow API.
if __name__ == '__main__':
# tf.app.run()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
args, rem = parser.parse_known_args()
sc = SparkContext(conf=SparkConf().setAppName("your_app_name"))
num_executors = int(sc._conf.get("spark.executor.instances"))
num_ps = 1
tensorboard = True
cluster = TFCluster.run(sc, main_fun, sys.argv, num_executors, num_ps, tensorboard, TFCluster.InputMode.TENSORFLOW)
cluster.shutdown()
In distributed TensorFlow apps, there is typically code that:
- extracts the addresses for the
ps
andworker
nodes from the command line args - creates a cluster spec
- starts the TensorFlow server
These can all be replaced as follows.
# ps_hosts = FLAGS.ps_hosts.split(',')
# worker_hosts = FLAGS.worker_hosts.split(',')
# tf.logging.info('PS hosts are: %s' % ps_hosts)
# tf.logging.info('Worker hosts are: %s' % worker_hosts)
# cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
# server = tf.train.Server( {'ps': ps_hosts, 'worker': worker_hosts},
# job_name=FLAGS.job_name, task_index=FLAGS.task_id)
cluster_spec, server = TFNode.start_cluster_server(ctx, FLAGS.num_gpus, FLAGS.rdma)
# or use the following for default values of num_gpus=1 and rdma=False
# cluster_spec, server = TFNode.start_cluster_server(ctx)
Since most TensorFlow examples use the tf.app.FLAGS
mechanism, we leverage it here to parse our TensorFlowOnSpark-specific arguments (on the executor-side) for consistency. If your application uses another parsing mechanism, just add these two arguments accordingly.
tf.app.flags.DEFINE_integer('num_gpus', 1, 'Number of GPUs per node.')
tf.app.flags.DEFINE_boolean('rdma', False, 'Use RDMA between GPUs')
Note: while these are required for the TFNode.start_cluster_server()
function, your code must still be written specifically to leverage multiple GPUs (e.g. see the "tower" pattern in the CIFAR-10 example). And again, if using a single GPU per node with no RDMA, you can skip this step and just use TFNode.start_cluster_server(ctx)
.
Finally, if using TensorBoard, ensure that the summaries are saved to the local disk of the "chief" worker (by convention "worker:0"), since TensorBoard currently cannot read directly from HDFS. Locate the tf.train.Supervisor()
call, and add a custom summary_writer
as follows. Note: the tensorboard process will look in this specific directory by convention, so do not change the path.
summary_writer = tf.summary.FileWriter("tensorboard_%d" %(ctx.worker_num), graph=tf.get_default_graph())
sv = tf.train.Supervisor(is_chief=is_chief,
logdir=FLAGS.train_dir,
init_op=init_op,
summary_op=None,
global_step=global_step,
summary_writer=summary_writer,
saver=saver,
save_model_secs=FLAGS.save_interval_secs)