Skip to content

Commit

Permalink
Recreate initializer per replica to make sure seed is properly set
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jan 9, 2024
1 parent 2f3dec6 commit 1375daa
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions n3fit/src/n3fit/backends/keras_backend/multi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class MultiInitializer(Initializer):
"""

def __init__(self, single_initializer: Initializer, replica_seeds: List[int]):
self.single_initializer = single_initializer
self.initializer_class = type(single_initializer)
self.initializer_config = single_initializer.get_config()
self.base_seed = single_initializer.seed if hasattr(single_initializer, "seed") else None
self.replica_seeds = replica_seeds

Expand All @@ -170,8 +171,9 @@ def __call__(self, shape, dtype=None, **kwargs):
per_replica_weights = []
for replica_seed in self.replica_seeds:
if self.base_seed is not None:
self.single_initializer.seed = self.base_seed + replica_seed
self.initializer_config["seed"] = self.base_seed + replica_seed
single_initializer = self.initializer_class.from_config(self.initializer_config)

per_replica_weights.append(self.single_initializer(shape, dtype, **kwargs))
per_replica_weights.append(single_initializer(shape, dtype, **kwargs))

return tf.stack(per_replica_weights, axis=0)

0 comments on commit 1375daa

Please sign in to comment.