Skip to content

Commit

Permalink
Merge pull request #65 from modichirag/code-style
Browse files Browse the repository at this point in the history
Reformatting code as pep8 + 2 space tab
  • Loading branch information
EiffL authored Mar 3, 2021
2 parents 2910bd9 + 5f8e45a commit 57c346e
Show file tree
Hide file tree
Showing 35 changed files with 3,694 additions and 2,901 deletions.
3 changes: 3 additions & 0 deletions .style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[style]
based_on_style = pep8
spaces_before_comment = 2
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# flowpm [![Build Status](https://travis-ci.org/modichirag/flowpm.svg?branch=master)](https://travis-ci.org/modichirag/flowpm) [![PyPI version](https://badge.fury.io/py/flowpm.svg)](https://badge.fury.io/py/flowpm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/modichirag/flowpm/blob/master/notebooks/flowpm_tutorial.ipynb) [![arXiv:2010.11847](https://img.shields.io/badge/astro--ph.IM-arXiv%3A2010.11847-B31B1B.svg)](https://arxiv.org/abs/2010.11847) [![youtube](https://img.shields.io/badge/-youtube-red?logo=youtube&labelColor=grey)](https://youtu.be/DHOaHTU61hM)
# flowpm [![Build Status](https://travis-ci.org/modichirag/flowpm.svg?branch=master)](https://travis-ci.org/modichirag/flowpm) [![PyPI version](https://badge.fury.io/py/flowpm.svg)](https://badge.fury.io/py/flowpm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/modichirag/flowpm/blob/master/notebooks/flowpm_tutorial.ipynb) [![arXiv:2010.11847](https://img.shields.io/badge/astro--ph.IM-arXiv%3A2010.11847-B31B1B.svg)](https://arxiv.org/abs/2010.11847) [![youtube](https://img.shields.io/badge/-youtube-red?logo=youtube&labelColor=grey)](https://youtu.be/DHOaHTU61hM) [![PEP8](https://img.shields.io/badge/code%20style-pep8-blue.svg)](https://www.python.org/dev/peps/pep-0008/)


Particle Mesh Simulation in TensorFlow, based on [fastpm-python](https://github.com/rainwoodman/fastpm-python) simulations
Expand Down
164 changes: 96 additions & 68 deletions examples/mesh_lpt_TPU.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@

# Cloud TPU Cluster Resolver flags
tf.flags.DEFINE_string(
"tpu", default="flowpm",
"tpu",
default="flowpm",
help="The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", default="europe-west4-a",
"tpu_zone",
default="europe-west4-a",
help="[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"gcp_project", default="flowpm",
"gcp_project",
default="flowpm",
help="[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
Expand All @@ -45,15 +48,19 @@
tf.flags.DEFINE_float("a0", 0.1, "Scale factor of linear field.")
tf.flags.DEFINE_integer("pm_steps", 10, "Number of PM steps.")

tf.flags.DEFINE_integer("batch_size", 128,
"Mini-batch size for the training. Note that this "
"is the global batch size and not the per-shard batch.")
tf.flags.DEFINE_integer(
"batch_size", 128, "Mini-batch size for the training. Note that this "
"is the global batch size and not the per-shard batch.")

tf.flags.DEFINE_string("mesh_shape", "b1:8,b2:4", "mesh shape")
tf.flags.DEFINE_string("layout", "nx:b1,ny:b2,nx_lr:b1,ny_lr:b2,ty:b1,tz:b2,ty_lr:b1,tz_lr:b2,nx_block:b1,ny_block:b2", "layout rules")
tf.flags.DEFINE_string(
"layout",
"nx:b1,ny:b2,nx_lr:b1,ny_lr:b2,ty:b1,tz:b2,ty_lr:b1,tz_lr:b2,nx_block:b1,ny_block:b2",
"layout rules")

FLAGS = tf.flags.FLAGS


def nbody_model(mesh):
"""
Initializes a 3D volume with random noise, and execute a forward FFT
Expand All @@ -76,10 +83,11 @@ def nbody_model(mesh):
stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True)

# Generate a batch of 3D initial conditions
initial_conditions = flowpm.linear_field(nc, # size of the cube
bs, # Physical size of the cube
ipklin, # Initial power spectrum
batch_size=batch_size)
initial_conditions = flowpm.linear_field(
nc, # size of the cube
bs, # Physical size of the cube
ipklin, # Initial power spectrum
batch_size=batch_size)

# Compute necessary Fourier kernels
kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
Expand All @@ -94,7 +102,7 @@ def nbody_model(mesh):
n_block_x = 8
n_block_y = 4
n_block_z = 1
halo_size = 4
halo_size = 4

# Parameters of the large scales decomposition
downsampling_factor = 2
Expand All @@ -112,78 +120,95 @@ def nbody_model(mesh):
tx_dim = mtf.Dimension("tx_lr", nc)
ty_dim = mtf.Dimension("ty_lr", nc)
tz_dim = mtf.Dimension("tz_lr", nc)

nx_dim = mtf.Dimension('nx_block', n_block_x)
ny_dim = mtf.Dimension('ny_block', n_block_y)
nz_dim = mtf.Dimension('nz_block', n_block_z)

sx_dim = mtf.Dimension('sx_block', nc//n_block_x)
sy_dim = mtf.Dimension('sy_block', nc//n_block_y)
sz_dim = mtf.Dimension('sz_block', nc//n_block_z)
sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

batch_dim = mtf.Dimension("batch", batch_size)
pk_dim = mtf.Dimension("npk", len(plin))
pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])



# Compute necessary Fourier kernels
kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim])
ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim])
kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim])
kx = mtf.import_tf_tensor(mesh,
kvec[0].squeeze().astype('float32'),
shape=[tfx_dim])
ky = mtf.import_tf_tensor(mesh,
kvec[1].squeeze().astype('float32'),
shape=[tfy_dim])
kz = mtf.import_tf_tensor(mesh,
kvec[2].squeeze().astype('float32'),
shape=[tfz_dim])
kv = [ky, kz, kx]


kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim])
ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim])
kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim])
kx_lr = mtf.import_tf_tensor(mesh,
kvec_lr[0].squeeze().astype('float32'),
shape=[tx_dim])
ky_lr = mtf.import_tf_tensor(mesh,
kvec_lr[1].squeeze().astype('float32'),
shape=[ty_dim])
kz_lr = mtf.import_tf_tensor(mesh,
kvec_lr[2].squeeze().astype('float32'),
shape=[tz_dim])
kv_lr = [ky_lr, kz_lr, kx_lr]

# kvec for high resolution blocks
shape = [batch_dim, fx_dim, fy_dim, fz_dim]
lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]




initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)
#initc = mtf.import_tf_tensor(mesh, ic, shape=shape)
#return initc
state = mtfpm.lpt_init_single(initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True,)
state = mtfpm.lpt_init_single(
initc,
a0,
kv_lr,
halo_size,
lr_shape,
hr_shape,
part_shape[1:],
antialias=True,
)
#state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape,
# part_shape[1:], downsampling_factor=downsampling_factor, antialias=True,)

# Here we can run our nbody
final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)
final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

# paint the field
final_field = mtf.zeros(mesh, shape=hr_shape)
for block_size_dim in hr_shape[-3:]:
final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name)
final_field = mtf.pad(final_field, [halo_size, halo_size],
block_size_dim.name)
final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
# Halo exchange
for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size)
final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
halo_size)
# Remove borders
for block_size_dim in hr_shape[-3:]:
final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name)
final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
block_size_dim.name)

#final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim])
# Hack usisng custom reshape because mesh is pretty dumb
final_field = mtf.slicewise(lambda x: x[:,0,0,0],
[final_field],
output_dtype=tf.float32,
output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
name='my_dumb_reshape',
splittable_dims=part_shape[:-1]+hr_shape[:4])
# Hack usisng custom reshape because mesh is pretty dumb
final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field],
output_dtype=tf.float32,
output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
name='my_dumb_reshape',
splittable_dims=part_shape[:-1] + hr_shape[:4])

return final_field




def model_fn(features, labels, mode, params):
"""A model is called by TpuEstimator."""
del labels
Expand All @@ -196,11 +221,12 @@ def model_fn(features, labels, mode, params):
num_hosts = ctx.num_hosts
host_placement_fn = ctx.tpu_host_placement_function
device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
tf.logging.info('device_list = %s' % device_list,)
tf.logging.info('device_list = %s' % device_list, )

mesh_devices = [''] * mesh_shape.size
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
mesh_devices,
ctx.device_assignment)

graph = mtf.Graph()
mesh = mtf.Mesh(graph, "fft_mesh")
Expand All @@ -212,15 +238,19 @@ def model_fn(features, labels, mode, params):
y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size)

# Until we implement distributed outputs, we only return one example
field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size-1])
field_slice = mtf.reshape(field_slice, [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim])
field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size - 1])
field_slice = mtf.reshape(
field_slice,
[mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim])
#field_slice = field

lowering = mtf.Lowering(graph, {mesh: mesh_impl})
tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice))

with mtf.utils.outside_all_rewrites():
return tpu_estimator.TPUEstimatorSpec(mode, predictions={'field': tf_field})
return tpu_estimator.TPUEstimatorSpec(mode,
predictions={'field': tf_field})


def main(_):

Expand All @@ -229,10 +259,7 @@ def main(_):

# Resolve the TPU environment
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu,
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project
)
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
Expand All @@ -241,29 +268,30 @@ def main(_):
save_checkpoints_secs=None, # Disable the default saver
log_step_count_steps=100,
save_summary_steps=100,
tpu_config=tpu_config.TPUConfig(
num_shards=mesh_shape.size,
iterations_per_loop=100,
num_cores_per_replica=1,
per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))

model = tpu_estimator.TPUEstimator(
use_tpu=True,
model_fn=model_fn,
config=run_config,
predict_batch_size=1,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size)
tpu_config=tpu_config.TPUConfig(num_shards=mesh_shape.size,
iterations_per_loop=100,
num_cores_per_replica=1,
per_host_input_for_training=tpu_config.
InputPipelineConfig.BROADCAST))

model = tpu_estimator.TPUEstimator(use_tpu=True,
model_fn=model_fn,
config=run_config,
predict_batch_size=1,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size)

def dummy_input_fn(params):
dset = tf.data.Dataset.from_tensor_slices(tf.zeros(shape=[params['batch_size'],1],
dtype=tf.float32))
dset = tf.data.Dataset.from_tensor_slices(
tf.zeros(shape=[params['batch_size'], 1], dtype=tf.float32))
return dset

# Run evaluate loop for ever, we will be connecting to this process using a profiler
for i, f in enumerate(model.predict(input_fn=dummy_input_fn)):
print(i)
np.save(file_io.FileIO(FLAGS.output_dir+'/field_%d.npy'%i, 'w'), f['field'])
np.save(file_io.FileIO(FLAGS.output_dir + '/field_%d.npy' % i, 'w'),
f['field'])


if __name__ == "__main__":
tf.disable_v2_behavior()
Expand Down
Loading

0 comments on commit 57c346e

Please sign in to comment.