Skip to content

Conversion Guide

leewyang edited this page Jul 24, 2017 · 5 revisions

The process of converting an existing TensorFlow application is fairly simple.
We have included several sample TensorFlow applications in this repo to illustrate the conversion steps.

We highlight the main points below.

1. Add PySpark and TensorFlowOnSpark imports

Every TensorFlow application will have a file containing a main() function and a call to In that file, please add the following imports:

from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from import TFCluster, TFNode
from datetime import datetime

2. Replace the main() function with main_fun(argv, ctx)

The argv parameter will contain a full copy of the arguments supplied at the PySpark command line, while the ctx parameter will contain node metadata, like job_name and task_id. Also, make sure that the import tensorflow as tf occurs within this function, since this will be executed/imported on the executors. And, if there are any functions used by the main function, ensure that they are defined or imported inside the main_fun block.

# def main():
def main_fun(argv, ctx)
  import tensorflow as tf

3. Replace the method to launch TensorFlowOnSpark cluster executes the TensorFlow main function. Replace it with the following code to set up PySpark and launch TensorFlow on the executors. Note that we're using argparse here mostly because the mechanism is currently not an officially supported TensorFlow API.

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
    args, rem = parser.parse_known_args()

    sc = SparkContext(conf=SparkConf().setAppName("your_app_name"))
    num_executors = int(sc._conf.get("spark.executor.instances"))
    num_ps = 1
    tensorboard = True

    cluster =, main_fun, sys.argv, num_executors, num_ps, tensorboard, TFCluster.InputMode.TENSORFLOW)

4. Replace the tf.train.Server() with TFNode.start_cluster_server()

In distributed TensorFlow apps, there is typically code that:

  1. extracts the addresses for the ps and worker nodes from the command line args
  2. creates a cluster spec
  3. starts the TensorFlow server

These can all be replaced as follows.

    # ps_hosts = FLAGS.ps_hosts.split(',')
    # worker_hosts = FLAGS.worker_hosts.split(',')
    #'PS hosts are: %s' % ps_hosts)
    #'Worker hosts are: %s' % worker_hosts)
    # cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
    # server = tf.train.Server( {'ps': ps_hosts, 'worker': worker_hosts},
    #    job_name=FLAGS.job_name, task_index=FLAGS.task_id)
    cluster_spec, server = TFNode.start_cluster_server(ctx, FLAGS.num_gpus, FLAGS.rdma)
    # or use the following for default values of num_gpus=1 and rdma=False
    # cluster_spec, server = TFNode.start_cluster_server(ctx)

5. Add TensorFlowOnSpark-specific arguments

Since most TensorFlow examples use the mechanism, we leverage it here to parse our TensorFlowOnSpark-specific arguments (on the executor-side) for consistency. If your application uses another parsing mechanism, just add these two arguments accordingly.'num_gpus', 1, 'Number of GPUs per node.')'rdma', False, 'Use RDMA between GPUs')

Note: while these are required for the TFNode.start_cluster_server() function, your code must still be written specifically to leverage multiple GPUs (e.g. see the "tower" pattern in the CIFAR-10 example). And again, if using a single GPU per node with no RDMA, you can skip this step and just use TFNode.start_cluster_server(ctx).

6. Enable TensorBoard

Finally, if using TensorBoard, ensure that the summaries are saved to the local disk of the "chief" worker (by convention "worker:0"), since TensorBoard currently cannot read directly from HDFS. Locate the tf.train.Supervisor() call, and add a custom summary_writer as follows. Note: the tensorboard process will look in this specific directory by convention, so do not change the path.

  summary_writer = tf.summary.FileWriter("tensorboard_%d" %(ctx.worker_num), graph=tf.get_default_graph())
  sv = tf.train.Supervisor(is_chief=is_chief,
Clone this wiki locally