diff --git a/econml/iv/dml/_dml.py b/econml/iv/dml/_dml.py index 835bbb522..0c0b90655 100644 --- a/econml/iv/dml/_dml.py +++ b/econml/iv/dml/_dml.py @@ -761,7 +761,7 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None) TX_pred = np.tile(TX_pred.reshape(1, -1), (T.shape[0], 1)) Y_res = Y - Y_pred.reshape(Y.shape) T_res = TXZ_pred.reshape(T.shape) - TX_pred.reshape(T.shape) - if T_res.sum() == 0: + if not T_res.any(): raise ValueError( """ All values of the treatment residual are 0,