Skip to content

Commit

Permalink
Reco speedup: fix data_batch and seed_array for minimizer function
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuen committed Oct 31, 2024
1 parent 50ac3b1 commit 2cf3461
Showing 1 changed file with 68 additions and 68 deletions.
136 changes: 68 additions & 68 deletions egenerator/manager/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,44 @@ def reconstruct_events(
msg.format(param_tensor.shape[1], len(fit_parameter_list))
)

# transform seed if minimization is performed in trafo space
if isinstance(seed, str):
seed_index = self.data_handler.tensors.get_index(seed)
seed_array = data_batch[seed_index]
else:
seed_array = seed
if minimize_in_trafo_space:

# transform bounds if provided
if "bounds" in kwargs:
bounds = self.data_trafo.transform(
data=np.array(kwargs["bounds"]).T,
tensor_name=parameter_tensor_name,
).T
for i, bound in enumerate(bounds):
for j in range(2):
if not np.isfinite(bound[j]):
bounds[i, j] = None
kwargs["bounds"] = bounds

seed_array_trafo = self.data_trafo.transform(
data=seed_array, tensor_name=parameter_tensor_name
)
else:
seed_array_trafo = seed_array

# get seed parameters
if np.all(fit_parameter_list):
x0 = seed_array_trafo
else:
# get seed parameters
x0 = seed_array_trafo[:, fit_parameter_list]

# define helper function
def func(x, data_batch, seed):
def func(x):
# reshape and convert to proper
x = np.reshape(x, param_shape).astype(param_dtype)
seed = np.reshape(seed, param_shape_full).astype(param_dtype)
seed = np.reshape(seed_array, param_shape_full).astype(param_dtype)
loss, grad = loss_and_gradients_function(x, data_batch, seed=seed)
loss = loss.numpy().astype("float64")
grad = grad.numpy().astype("float64")
Expand All @@ -810,10 +843,12 @@ def func(x, data_batch, seed):

if hessian_function is not None:

def get_hessian(x, data_batch, seed):
def get_hessian(x):
# reshape and convert to tensor
x = np.reshape(x, param_shape).astype(param_dtype)
seed = np.reshape(seed, param_shape_full).astype(param_dtype)
seed = np.reshape(seed_array, param_shape_full).astype(
param_dtype
)
hessian = hessian_function(x, data_batch, seed=seed)
hessian = hessian.numpy().astype("float64")
return hessian
Expand Down Expand Up @@ -844,47 +879,9 @@ def get_hessian(x, data_batch, seed):

# kwargs['callback'] = tolerance_func

# transform seed if minimization is performed in trafo space
if isinstance(seed, str):
seed_index = self.data_handler.tensors.get_index(seed)
seed_array = data_batch[seed_index]
else:
seed_array = seed
if minimize_in_trafo_space:

# transform bounds if provided
if "bounds" in kwargs:
bounds = self.data_trafo.transform(
data=np.array(kwargs["bounds"]).T,
tensor_name=parameter_tensor_name,
).T
for i, bound in enumerate(bounds):
for j in range(2):
if not np.isfinite(bound[j]):
bounds[i, j] = None
kwargs["bounds"] = bounds

seed_array_trafo = self.data_trafo.transform(
data=seed_array, tensor_name=parameter_tensor_name
)
else:
seed_array_trafo = seed_array

# get seed parameters
if np.all(fit_parameter_list):
x0 = seed_array_trafo
else:
# get seed parameters
x0 = seed_array_trafo[:, fit_parameter_list]

x0_flat = np.reshape(x0, [-1])
result = optimize.minimize(
fun=func,
x0=x0_flat,
jac=jac,
method=method,
args=(data_batch, seed_array),
**kwargs
fun=func, x0=x0_flat, jac=jac, method=method, **kwargs
)

best_fit = np.reshape(result.x, param_shape)
Expand Down Expand Up @@ -1121,11 +1118,31 @@ def scipy_global_reconstruct_events(
minimizer_kwargs["jac"] = jac
options["jac"] = jac

# get seed tensor
if isinstance(seed, str):
seed_index = self.data_handler.tensors.get_index(seed)
seed_array = data_batch[seed_index]
else:
seed_array = seed

# transform seed if minimization is performed in trafo space
if minimize_in_trafo_space:
seed_array_trafo = self.data_trafo.transform(
data=seed_array, tensor_name=parameter_tensor_name
)
else:
seed_array_trafo = seed_array

# For now: add +- 1 in trafo space
# ToDo: allow to pass proper boundaries and uncertainties
assert minimize_in_trafo_space, "currently only for trafo space"
bounds = np.concatenate((seed_array_trafo - 1, seed_array_trafo + 1)).T

# define helper function
def func(x, data_batch, seed):
def func(x):
# reshape and convert to tensor
x = np.reshape(x, param_shape).astype(param_dtype)
seed = np.reshape(seed, param_shape_full).astype(param_dtype)
seed = np.reshape(seed_array, param_shape_full).astype(param_dtype)
loss, grad = loss_and_gradients_function(x, data_batch, seed=seed)
loss = loss.numpy().astype("float64")
grad = grad.numpy().astype("float64")
Expand All @@ -1135,37 +1152,19 @@ def func(x, data_batch, seed):

if hessian_function is not None:

def get_hessian(x, data_batch, seed):
def get_hessian(x):
# reshape and convert to tensor
x = np.reshape(x, param_shape).astype(param_dtype)
seed = np.reshape(seed, param_shape_full).astype(param_dtype)
seed = np.reshape(seed_array, param_shape_full).astype(
param_dtype
)
hessian = hessian_function(x, data_batch, seed=seed)
hessian = hessian.numpy().astype("float64")
return hessian

minimizer_kwargs["hess"] = get_hessian
options["hess"] = get_hessian

# get seed tensor
if isinstance(seed, str):
seed_index = self.data_handler.tensors.get_index(seed)
seed_array = data_batch[seed_index]
else:
seed_array = seed

# transform seed if minimization is performed in trafo space
if minimize_in_trafo_space:
seed_array_trafo = self.data_trafo.transform(
data=seed_array, tensor_name=parameter_tensor_name
)
else:
seed_array_trafo = seed_array

# For now: add +- 1 in trafo space
# ToDo: allow to pass proper boundaries and uncertainties
assert minimize_in_trafo_space, "currently only for trafo space"
bounds = np.concatenate((seed_array_trafo - 1, seed_array_trafo + 1)).T

def callback(xk):
print(xk)

Expand All @@ -1175,7 +1174,6 @@ def callback(xk):
options=options,
minimizer_kwargs=minimizer_kwargs,
callback=callback,
args=(data_batch, seed_array),
**kwargs
)

Expand Down Expand Up @@ -1282,6 +1280,8 @@ def const_loss_and_gradients_function(x):
# convert to tensors
loss, grad = loss_and_gradients_function(x, data_batch, seed_array)
loss = tf.reshape(loss, [1])
loss = tf.cast(loss, param_tensor.dtype_tf)
grad = tf.cast(grad, param_tensor.dtype_tf)
return loss, grad

if hessian_function is not None:
Expand Down

0 comments on commit 2cf3461

Please sign in to comment.