diff --git a/n3fit/src/n3fit/layers/msr_normalization.py b/n3fit/src/n3fit/layers/msr_normalization.py index 0755cad0b4..719089cce3 100644 --- a/n3fit/src/n3fit/layers/msr_normalization.py +++ b/n3fit/src/n3fit/layers/msr_normalization.py @@ -40,6 +40,7 @@ def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs): else: raise ValueError(f"Mode {mode} not accepted for sum rules") + self.replicas = replicas indices = [] self.divisor_indices = [] if self._msr_enabled: @@ -83,6 +84,7 @@ def call(self, pdf_integrated, photon_integral): reshape = lambda x: op.transpose(x[0]) y = reshape(pdf_integrated) photon_integral = reshape(photon_integral) + numerators = [] if self._msr_enabled: @@ -96,8 +98,10 @@ def call(self, pdf_integrated, photon_integral): divisors = op.gather(y, self.divisor_indices, axis=0) # Fill in the rest of the flavours with 1 + # (Note: using y.shape in the output_shape below gives an error in Python 3.11) + num_flavours = y.shape[0] norm_constants = op.scatter_to_one( - numerators / divisors, indices=self.indices, output_shape=y.shape + numerators / divisors, indices=self.indices, output_shape=(num_flavours, self.replicas) ) return op.batchit(op.transpose(norm_constants), batch_dimension=1)