Skip to content

Commit

Permalink
[python-package] fix bug in predict() function
Browse files Browse the repository at this point in the history
  • Loading branch information
fabsig committed Dec 5, 2022
1 parent 71b44ae commit caadd63
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.1
0.8.0.1
2 changes: 1 addition & 1 deletion python-package/gpboost/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.0
0.8.0.1
28 changes: 18 additions & 10 deletions python-package/gpboost/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4350,7 +4350,6 @@ def __init__(self, likelihood="gaussian",
self.cov_fct_taper_range = cov_fct_taper_range
self.vecchia_approx = vecchia_approx
self.vecchia_ordering = vecchia_ordering
self.vecchia_pred_type = vecchia_pred_type
self.num_neighbors = num_neighbors
self.num_neighbors_pred = num_neighbors_pred
if self.cov_function == "wendland":
Expand Down Expand Up @@ -4387,6 +4386,10 @@ def __init__(self, likelihood="gaussian",
["GP_rand_coef_" + gp_rand_coef_data_names[ii] + "_var",
"GP_rand_coef_" + gp_rand_coef_data_names[ii] + "_range"])
self.re_comp_names.append("GP_rand_coef_" + gp_rand_coef_data_names[ii])
# Prediction type for Vecchia approximation
if vecchia_pred_type is not None:
self.vecchia_pred_type = vecchia_pred_type
vecchia_pred_type_c = c_str(vecchia_pred_type)
# Set IDs for independent processes (cluster_ids)
if cluster_ids is not None:
cluster_ids = _format_check_1D_data(cluster_ids, data_name="cluster_ids", check_data_type=False,
Expand All @@ -4407,9 +4410,6 @@ def __init__(self, likelihood="gaussian",
cluster_ids = cluster_ids.astype(np.int32)
cluster_ids_c = cluster_ids.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))

if self.vecchia_pred_type is not None:
vecchia_pred_type_c = c_str(self.vecchia_pred_type)

self.__determine_num_cov_pars(likelihood=likelihood)

_safe_call(_LIB.GPB_CreateREModel(
Expand Down Expand Up @@ -5086,8 +5086,6 @@ def predict(self,
if predict_cov_mat and predict_var:
predict_cov_mat = True
predict_var = False
if vecchia_pred_type is not None:
self.vecchia_pred_type = vecchia_pred_type
if num_neighbors_pred is not None:
self.num_neighbors_pred = num_neighbors_pred
if cg_delta_conv_pred is not None:
Expand All @@ -5113,6 +5111,7 @@ def predict(self,
gp_rand_coef_data_pred_c = ctypes.c_void_p()
cluster_ids_pred_c = ctypes.c_void_p()
X_pred_c = ctypes.c_void_p()
vecchia_pred_type_c = ctypes.c_void_p()
num_data_pred = 0
if not use_saved_data:
# Set data for grouped random effects
Expand Down Expand Up @@ -5165,6 +5164,10 @@ def predict(self,
if gp_rand_coef_data_pred.shape[1] != self.num_gp_rand_coef:
raise ValueError("Incorrect number of covariates in gp_rand_coef_data_pred")
gp_rand_coef_data_pred_c, _, _ = c_float_array(gp_rand_coef_data_pred.flatten(order='F'))
# Prediction type for Vecchia approximation
if vecchia_pred_type is not None:
self.vecchia_pred_type = vecchia_pred_type
vecchia_pred_type_c = c_str(vecchia_pred_type)
# Set IDs for independent processes (cluster_ids)
if cluster_ids_pred is not None:
cluster_ids_pred = _format_check_1D_data(cluster_ids_pred, data_name="cluster_ids_pred",
Expand Down Expand Up @@ -5193,6 +5196,7 @@ def predict(self,
cluster_ids_pred = np.array([cluster_ids_pred_map_to_int[x] for x in cluster_ids_pred])
cluster_ids_pred = cluster_ids_pred.astype(np.int32)
cluster_ids_pred_c = cluster_ids_pred.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))

# Set data for linear fixed-effects
if self.has_covariates:
if X_pred is None:
Expand Down Expand Up @@ -5270,7 +5274,7 @@ def predict(self,
cov_pars_c,
X_pred_c,
ctypes.c_bool(use_saved_data),
c_str(self.vecchia_pred_type),
vecchia_pred_type_c,
ctypes.c_int(self.num_neighbors_pred),
ctypes.c_double(self.cg_delta_conv_pred),
fixed_effects_c,
Expand Down Expand Up @@ -5368,6 +5372,7 @@ def set_prediction_data(self,
gp_rand_coef_data_pred_c = ctypes.c_void_p()
cluster_ids_pred_c = ctypes.c_void_p()
X_pred_c = ctypes.c_void_p()
vecchia_pred_type_c = ctypes.c_void_p()
num_data_pred = 0
# Set data for grouped random effects
if group_data_pred is not None:
Expand Down Expand Up @@ -5419,6 +5424,11 @@ def set_prediction_data(self,
if gp_rand_coef_data_pred.shape[1] != self.num_gp_rand_coef:
raise ValueError("Incorrect number of covariates in gp_rand_coef_data_pred")
gp_rand_coef_data_pred_c, _, _ = c_float_array(gp_rand_coef_data_pred.flatten(order='F'))
# Prediction type for Vecchia approximation
# Prediction type for Vecchia approximation
if vecchia_pred_type is not None:
self.vecchia_pred_type = vecchia_pred_type
vecchia_pred_type_c = c_str(vecchia_pred_type)
# Set IDs for independent processes (cluster_ids)
if cluster_ids_pred is not None:
cluster_ids_pred = _format_check_1D_data(cluster_ids_pred, data_name="cluster_ids_pred",
Expand Down Expand Up @@ -5462,8 +5472,6 @@ def set_prediction_data(self,
raise ValueError("Incorrect number of covariates in X_pred")
X_pred_c, _, _ = c_float_array(X_pred.flatten(order='F'))
self.num_data_pred = num_data_pred
if vecchia_pred_type is not None:
self.vecchia_pred_type = vecchia_pred_type
if num_neighbors_pred is not None:
self.num_neighbors_pred = num_neighbors_pred
if cg_delta_conv_pred is not None:
Expand All @@ -5479,7 +5487,7 @@ def set_prediction_data(self,
gp_coords_pred_c,
gp_rand_coef_data_pred_c,
X_pred_c,
c_str(self.vecchia_pred_type),
vecchia_pred_type_c,
ctypes.c_int(self.num_neighbors_pred),
ctypes.c_double(self.cg_delta_conv_pred)))
return self
Expand Down

0 comments on commit caadd63

Please sign in to comment.