Skip to content

Commit

Permalink
Merge pull request #1881 from NNPDF/parallel-prefactor
Browse files Browse the repository at this point in the history
Merge prefactors into single layer
  • Loading branch information
scarlehoff authored Jan 26, 2024
2 parents f706be2 + 8810e11 commit c3f896a
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 161 deletions.
Binary file modified n3fit/runcards/examples/developing_weights.h5
Binary file not shown.
25 changes: 14 additions & 11 deletions n3fit/src/n3fit/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from n3fit.backends.keras_backend.internal_state import (
set_initial_state,
clear_backend_state,
set_eager
)
from n3fit.backends.keras_backend import callbacks, constraints, operations
from n3fit.backends.keras_backend.MetaLayer import MetaLayer
from n3fit.backends.keras_backend.MetaModel import MetaModel
from n3fit.backends.keras_backend.MetaModel import (
NN_LAYER_ALL_REPLICAS,
NN_PREFIX,
PREPROCESSING_LAYER_ALL_REPLICAS,
MetaModel,
)
from n3fit.backends.keras_backend.base_layers import (
Concatenate,
Input,
concatenate,
Lambda,
base_layer_selector,
concatenate,
regularizer_selector,
Concatenate,
)
from n3fit.backends.keras_backend import operations
from n3fit.backends.keras_backend import constraints
from n3fit.backends.keras_backend import callbacks
from n3fit.backends.keras_backend.internal_state import (
clear_backend_state,
set_eager,
set_initial_state,
)

print("Using Keras backend")
140 changes: 85 additions & 55 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
}

NN_PREFIX = "NN"
PREPROCESSING_PREFIX = "preprocessing_factor"
NN_LAYER_ALL_REPLICAS = "all_NNs"
PREPROCESSING_LAYER_ALL_REPLICAS = "preprocessing_factor"

# Some keys need to work for everyone
for k, v in optimizers.items():
Expand Down Expand Up @@ -156,7 +157,7 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
of the model (the loss functions) to the partial losses.
If the model was compiled with input and output data, they will not be passed through.
In this case by default the number of `epochs` will be set to 1
In this case by default the number of ``epochs`` will be set to 1
ex:
{'loss': [100], 'dataset_a_loss1' : [67], 'dataset_2_loss': [33]}
Expand Down Expand Up @@ -228,7 +229,7 @@ def compile(
):
"""
Compile the model given an optimizer and a list of loss functions.
The optimizer must be one of those implemented in the `optimizer` attribute of this class.
The optimizer must be one of those implemented in the ``optimizer`` attribute of this class.
Options:
- A learning rate and a list of target outpout can be defined.
Expand Down Expand Up @@ -353,14 +354,10 @@ def get_replica_weights(self, i_replica):
dict
dictionary with the weights of the replica
"""
NN_weights = [
tf.Variable(w, name=w.name) for w in self.get_layer(f"{NN_PREFIX}_{i_replica}").weights
]
prepro_weights = [
tf.Variable(w, name=w.name)
for w in self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").weights
]
weights = {NN_PREFIX: NN_weights, PREPROCESSING_PREFIX: prepro_weights}
weights = {}
for layer_type in [NN_LAYER_ALL_REPLICAS, PREPROCESSING_LAYER_ALL_REPLICAS]:
layer = self.get_layer(layer_type)
weights[layer_type] = get_layer_replica_weights(layer, i_replica)

return weights

Expand All @@ -378,10 +375,9 @@ def set_replica_weights(self, weights, i_replica=0):
i_replica: int
the replica number to set, defaulting to 0
"""
self.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights[NN_PREFIX])
self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").set_weights(
weights[PREPROCESSING_PREFIX]
)
for layer_type in [NN_LAYER_ALL_REPLICAS, PREPROCESSING_LAYER_ALL_REPLICAS]:
layer = self.get_layer(layer_type)
set_layer_replica_weights(layer=layer, weights=weights[layer_type], i_replica=i_replica)

def split_replicas(self):
"""
Expand Down Expand Up @@ -411,51 +407,85 @@ def load_identical_replicas(self, model_file):
"""
From a single replica model, load the same weights into all replicas.
"""
weights = self._format_weights_from_file(model_file)
single_replica = self.single_replica_generator()
single_replica.load_weights(model_file)
weights = single_replica.get_replica_weights(0)

for i_replica in range(self.num_replicas):
self.set_replica_weights(weights, i_replica)

def _format_weights_from_file(self, model_file):
"""Read weights from a .h5 file and format into a dictionary of tf.Variables"""
weights = {}

with h5py.File(model_file, 'r') as f:
# look at layers of the form NN_i and take the lowest i
i_replica = 0
while f"{NN_PREFIX}_{i_replica}" not in f:
i_replica += 1
def is_stacked_single_replicas(layer):
"""
Check if the layer consists of stacked single replicas (Only happens for NN layers),
to determine how to extract single replica weights.
weights[NN_PREFIX] = self._extract_weights(
f[f"{NN_PREFIX}_{i_replica}"], NN_PREFIX, i_replica
)
weights[PREPROCESSING_PREFIX] = self._extract_weights(
f[f"{PREPROCESSING_PREFIX}_{i_replica}"], PREPROCESSING_PREFIX, i_replica
)
Parameters
----------
layer: MetaLayer
the layer to check
return weights
Returns
-------
bool
True if the layer consists of stacked single replicas
"""
if not isinstance(layer, MetaModel):
return False
return f"{NN_PREFIX}_0" in [sublayer.name for sublayer in layer.layers]


def get_layer_replica_weights(layer, i_replica: int):
"""
Get the weights for the given single replica ``i_replica``,
from a ``layer`` that contains the weights of all the replicas.
Note that the layer could be a complete NN with many separated sub_layers
each of which containing weights for all replicas together.
This functions separates the per-replica weights and returns the list of weight as if the
input ``layer`` were made of _only_ replica ``i_replica``.
Parameters
----------
layer: MetaLayer
the layer to get the weights from
i_replica: int
the replica number
Returns
-------
weights: list
list of weights for the replica
"""
if is_stacked_single_replicas(layer):
weights = layer.get_layer(f"{NN_PREFIX}_{i_replica}").weights
else:
weights = [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights]

return weights


def set_layer_replica_weights(layer, weights, i_replica: int):
"""
Set the weights for the given single replica ``i_replica``.
When the input ``layer`` contains weights for many replicas, ensures that
only those corresponding to replica ``i_replica`` are updated.
Parameters
----------
layer: MetaLayer
the layer to set the weights for
weights: list
list of weights for the replica
i_replica: int
the replica number
"""
if is_stacked_single_replicas(layer):
layer.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights)
return

full_weights = [w.numpy() for w in layer.weights]
for w_old, w_new in zip(full_weights, weights):
w_old[i_replica : i_replica + 1] = w_new

def _extract_weights(self, h5_group, weights_key, i_replica):
"""Extract weights from a h5py group, turning them into Tensorflow variables"""
weights = []

def append_weights(name, node):
if isinstance(node, h5py.Dataset):
weight_name = node.name.split("/", 2)[-1]
weight_name = weight_name.replace(f"{NN_PREFIX}_{i_replica}", f"{NN_PREFIX}_0")
weight_name = weight_name.replace(
f"{PREPROCESSING_PREFIX}_{i_replica}", f"{PREPROCESSING_PREFIX}_0"
)
weights.append(tf.Variable(node[()], name=weight_name))

h5_group.visititems(append_weights)

# have to put them in the same order
weights_ordered = []
weights_model_order = [w.name for w in self.get_replica_weights(0)[weights_key]]
for w in weights_model_order:
for w_h5 in weights:
if w_h5.name == w:
weights_ordered.append(w_h5)

return weights_ordered
layer.set_weights(full_weights)
5 changes: 4 additions & 1 deletion n3fit/src/n3fit/layers/msr_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs):
else:
raise ValueError(f"Mode {mode} not accepted for sum rules")

self.replicas = replicas
indices = []
self.divisor_indices = []
if self._msr_enabled:
Expand Down Expand Up @@ -83,6 +84,7 @@ def call(self, pdf_integrated, photon_integral):
reshape = lambda x: op.transpose(x[0])
y = reshape(pdf_integrated)
photon_integral = reshape(photon_integral)

numerators = []

if self._msr_enabled:
Expand All @@ -96,8 +98,9 @@ def call(self, pdf_integrated, photon_integral):
divisors = op.gather(y, self.divisor_indices, axis=0)

# Fill in the rest of the flavours with 1
num_flavours = y.shape[0]
norm_constants = op.scatter_to_one(
numerators / divisors, indices=self.indices, output_shape=y.shape
numerators / divisors, indices=self.indices, output_shape=(num_flavours, self.replicas)
)

return op.batchit(op.transpose(norm_constants), batch_dimension=1)
16 changes: 12 additions & 4 deletions n3fit/src/n3fit/layers/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class Preprocessing(MetaLayer):
Whether large x preprocessing factor should be active
seed: int
seed for the initializer of the random alpha and beta values
num_replicas: int (default 1)
The number of replicas
"""

def __init__(
self,
flav_info: Optional[list] = None,
seed: int = 0,
large_x: bool = True,
num_replicas: int = 1,
**kwargs,
):
if flav_info is None:
Expand All @@ -49,6 +52,8 @@ def __init__(
self.flav_info = flav_info
self.seed = seed
self.large_x = large_x
self.num_replicas = num_replicas

self.alphas = []
self.betas = []
super().__init__(**kwargs)
Expand Down Expand Up @@ -87,7 +92,7 @@ def generate_weight(self, name: str, kind: str, dictionary: dict, set_to_zero: b
# Generate the new trainable (or not) parameter
newpar = self.builder_helper(
name=name,
kernel_shape=(1,),
kernel_shape=(self.num_replicas, 1),
initializer=initializer,
trainable=trainable,
constraint=constraint,
Expand Down Expand Up @@ -117,9 +122,12 @@ def call(self, x):
Returns
-------
prefactor: tensor(shape=[1,N,F])
prefactor: tensor(shape=[1,R,N,F])
"""
alphas = op.stack(self.alphas, axis=1)
betas = op.stack(self.betas, axis=1)
# weight tensors of shape (R, 1, F)
alphas = op.stack(self.alphas, axis=-1)
betas = op.stack(self.betas, axis=-1)

x = op.batchit(x, batch_dimension=0)

return x ** (1 - alphas) * (1 - x) ** betas
Loading

0 comments on commit c3f896a

Please sign in to comment.