From f6cac43edce6c035e0d26b83e2a18bbd36240d21 Mon Sep 17 00:00:00 2001
From: Keith Battocchi <kebatt@microsoft.com>
Date: Wed, 27 Jan 2021 21:24:01 -0500
Subject: [PATCH] Relax tensorflow  version limit

---
 econml/iv/nnet/_deepiv.py   | 12 ++++++++----
 econml/tests/test_deepiv.py | 26 +++++++++++++++++++++++---
 setup.cfg                   |  8 ++++----
 3 files changed, 35 insertions(+), 11 deletions(-)

diff --git a/econml/iv/nnet/_deepiv.py b/econml/iv/nnet/_deepiv.py
index 18123b0d7..040614b40 100644
--- a/econml/iv/nnet/_deepiv.py
+++ b/econml/iv/nnet/_deepiv.py
@@ -93,9 +93,13 @@ def mog_loss_model(n_components, d_t):
     # LL = C - log(sum(pi_i/sig^d * exp(-d2/(2*sig^2))))
     # Use logsumexp for numeric stability:
     # LL = C - log(sum(exp(-d2/(2*sig^2) + log(pi_i/sig^d))))
-    # TODO: does the numeric stability actually make any difference?
     def make_logloss(d2, sig, pi):
-        return -K.logsumexp(-d2 / (2 * K.square(sig)) + K.log(pi / K.pow(sig, d_t)), axis=-1)
+        # logsumexp doesn't exist in keras 2.4; simulate it
+        values = - d2 / (2 * K.square(sig)) + K.log(pi / K.pow(sig, d_t))
+        # logsumexp(a,b,c) = log(exp(a)+exp(b)+exp(c)) = log((exp(a-k)+exp(b-k)+exp(c-k))*exp(k))
+        # = log((exp(a-k)+exp(b-k)+exp(c-k))) + k
+        mx = K.max(values, axis=-1)
+        return -K.log(K.sum(K.exp(values - L.Reshape((-1, 1))(mx)), axis=-1)) - mx
 
     ll = L.Lambda(lambda dsp: make_logloss(*dsp), output_shape=(1,))([d2, sig, pi])
 
@@ -350,7 +354,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
         model.add_loss(L.Lambda(K.mean)(ll))
         model.compile(self._optimizer)
         # TODO: do we need to give the user more control over other arguments to fit?
-        model.fit([Z, X, T], [], **self._first_stage_options)
+        model.fit([Z, X, T], **self._first_stage_options)
 
         lm = response_loss_model(lambda t, x: self._h(t, x),
                                  lambda z, x: Model([z_in, x_in],
@@ -365,7 +369,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
         response_model.add_loss(L.Lambda(K.mean)(rl))
         response_model.compile(self._optimizer)
         # TODO: do we need to give the user more control over other arguments to fit?
-        response_model.fit([Z, X, Y], [], **self._second_stage_options)
+        response_model.fit([Z, X, Y], **self._second_stage_options)
 
         self._effect_model = Model([t_in, x_in], [self._h(t_in, x_in)])
 
diff --git a/econml/tests/test_deepiv.py b/econml/tests/test_deepiv.py
index 4bc55e277..27c51b7ae 100644
--- a/econml/tests/test_deepiv.py
+++ b/econml/tests/test_deepiv.py
@@ -32,7 +32,27 @@ def test_stop_grad(self):
         model = keras.Model([x_input, y_input, z_input], [loss])
         model.add_loss(K.mean(loss))
         model.compile('nadam')
-        model.fit([np.array([[1]]), np.array([[2]]), np.array([[0]])], [])
+        model.fit([np.array([[1]]), np.array([[2]]), np.array([[0]])])
+
+    def test_mog_loss(self):
+        inputs = [keras.layers.Input(shape=s) for s in [(3,), (3, 2), (3,), (2,)]]
+        ll_model = keras.engine.Model(inputs, mog_loss_model(3, 2)(inputs))
+
+        for n in range(10):
+            ps = -np.log(np.random.uniform(size=(3,)))
+            pi = ps / np.sum(ps)
+            mu = np.random.normal(size=(3, 2))
+            sig = np.exp(np.random.normal(size=3,))
+            t = np.random.normal(size=(2,))
+
+            pred = ll_model.predict([pi.reshape(1, 3), mu.reshape(1, 3, 2), sig.reshape(1, 3), t.reshape(1, 2)])
+
+            # LL = C - log(sum(pi_i/sig^d * exp(-d2/(2*sig^2))))
+            d = mu - t.reshape(-1, 2)
+            d2 = np.sum(d * d, axis=-1)
+            ll = -np.log(np.sum(pi / (sig * sig) * np.exp(-d2 / (2 * sig * sig)), axis=0))
+
+            assert np.allclose(ll, pred[0])
 
     @pytest.mark.slow
     def test_deepiv_shape(self):
@@ -500,7 +520,7 @@ def norm(lr):
         model = keras.engine.Model([x_input, t_input], [ll])
         model.add_loss(K.mean(ll))
         model.compile('nadam')
-        model.fit([x, t], [], epochs=5)
+        model.fit([x, t], epochs=5)
 
         # For some reason this doesn't work at all when run against the CNTK backend...
         # model.compile('nadam', loss=lambda _,l:l)
@@ -559,7 +579,7 @@ def sample(n):
         model = keras.engine.Model([x_input, t_input], [ll])
         model.add_loss(K.mean(ll))
         model.compile('nadam')
-        model.fit([x, t], [], epochs=100)
+        model.fit([x, t], epochs=100)
 
         model2 = keras.engine.Model([x_input], [pi, mu, sig])
         import matplotlib
diff --git a/setup.cfg b/setup.cfg
index 1a628616e..57818bb49 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -60,14 +60,14 @@ automl =
     ; azureml-sdk[explain,automl] == 1.0.83
     azure-cli
 tf =
-    keras < 2.4
-    tensorflow > 1.10, < 2.3
+    keras
+    tensorflow > 1.10, < 2.4
 plt =
     matplotlib
 all =
     azure-cli
-    keras < 2.4
-    tensorflow > 1.10, < 2.3
+    keras
+    tensorflow > 1.10, < 2.4
     matplotlib
     
 [options.packages.find]