Skip to content

Commit

Permalink
Replace old dense layer everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jan 11, 2024
1 parent c23a59f commit 472b674
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 15 deletions.
7 changes: 4 additions & 3 deletions n3fit/src/n3fit/backends/keras_backend/base_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def dense_per_flavour(basis_size=8, kernel_initializer="glorot_normal", **dense_

# Need to generate a list of dense layers
dense_basis = [
base_layer_selector("dense", kernel_initializer=initializer, **dense_kwargs)
base_layer_selector("single_dense", kernel_initializer=initializer, **dense_kwargs)
for initializer in kernel_initializer
]

Expand Down Expand Up @@ -122,7 +122,7 @@ def apply_dense(xinput):


layers = {
"multi_dense": (
"dense": (
MultiDense,
{
"input_shape": (1,),
Expand All @@ -134,7 +134,8 @@ def apply_dense(xinput):
"replica_input": True,
},
),
"dense": (
# This one is only used inside dense_per_flavour
"single_dense": (
Dense,
{
"input_shape": (1,),
Expand Down
7 changes: 3 additions & 4 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def check_initializer(initializer):
def check_layer_type_implemented(parameters):
"""Checks whether the layer_type is implemented"""
layer_type = parameters.get("layer_type")
implemented_types = ["dense", "multi_dense", "dense_per_flavour"]
implemented_types = ["dense", "dense_per_flavour"]
if layer_type not in implemented_types:
raise CheckError(
f"Layer type {layer_type} not implemented, must be one of {implemented_types}"
Expand Down Expand Up @@ -427,10 +427,9 @@ def check_fiatlux_pdfs_id(replicas, fiatlux):
f"Cannot generate a photon replica with id larger than the number of replicas of the PDFs set {luxset.name}:\nreplica id={max_id}, replicas of {luxset.name} = {pdfs_ids}"
)


@make_argcheck
def check_multireplica_qed(replicas, fiatlux):
if fiatlux is not None:
if len(replicas) > 1:
raise CheckError(
"At the moment, running a multireplica QED fits is not allowed."
)
raise CheckError("At the moment, running a multireplica QED fits is not allowed.")
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def initializer_generator(seed, i_layer):
# list_of_pdf_layers[d][r] is the layer at depth d for replica r
list_of_pdf_layers = []
for i_layer, (nodes_out, activation) in enumerate(zip(nodes_list, activations)):
if layer_type == "multi_dense":
if layer_type == "dense":
layers = base_layer_selector(
layer_type,
replica_seeds=replica_seeds,
Expand Down Expand Up @@ -777,7 +777,7 @@ def initializer_generator(seed, i_layer):
list_of_pdf_layers[-1] = [lambda x: concat(layer(x)) for layer in list_of_pdf_layers[-1]]

# Apply all layers to the input to create the models
if layer_type == "multi_dense":
if layer_type == "dense":
pdfs = x_input
for layer in list_of_pdf_layers:
pdfs = layer(pdfs)
Expand Down
Binary file modified n3fit/src/n3fit/tests/regressions/weights_1.h5
Binary file not shown.
Binary file modified n3fit/src/n3fit/tests/regressions/weights_2.h5
Binary file not shown.
13 changes: 7 additions & 6 deletions n3fit/src/n3fit/tests/test_modelgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@


def test_generate_dense_network():
nn = generate_nn("dense", **COMMON_ARGS).get_layer("NN_0")
nn = generate_nn("dense", **COMMON_ARGS)

# The number of layers should be input layer + len(OUT_SIZES)
assert len(nn.layers) == len(OUT_SIZES) + 1
# Check that the number of parameters is as expected
# We expect 4 weights where the two first ones are
# (INSIZE, OUT_SIZE[0]) (OUT_SIZE[0],)
# and the second one
# (OUT_SIZE[0], OUT_SIZE[1]) (OUT_SIZE[1],)
expected_sizes = [(INSIZE, OUT_SIZES[0]), (OUT_SIZES[0],), OUT_SIZES, (OUT_SIZES[1],)]
expected_sizes = [
(1, INSIZE, OUT_SIZES[0]),
(1, 1, OUT_SIZES[0]),
(1, *OUT_SIZES),
(1, 1, OUT_SIZES[1]),
]
for weight, esize in zip(nn.weights, expected_sizes):
assert weight.shape == esize

Expand Down

0 comments on commit 472b674

Please sign in to comment.