diff --git a/src/beanmachine/graph/operator/backward.cpp b/src/beanmachine/graph/operator/backward.cpp index 64743ba588..820d666f73 100644 --- a/src/beanmachine/graph/operator/backward.cpp +++ b/src/beanmachine/graph/operator/backward.cpp @@ -299,7 +299,51 @@ void BroadcastAdd::backward() { } void Cholesky::backward() { - // TODO: fill this in + assert(in_nodes.size() == 1); + // We compute the gradient in place on a copy of the upstream gradient + // according to the algorithm described in section 4.1 of + // https://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf + if (in_nodes[0]->needs_gradient()) { + uint n = in_nodes[0]->value.type.rows; + Eigen::MatrixXd L = value._matrix; + Eigen::MatrixXd dS = back_grad1.as_matrix().triangularView(); + for (int i = n - 1; i >= 0; i--) { + // update grad dS at lower-triangular col i, including (i,i) + Eigen::VectorXd L_c = L(Eigen::seq(i + 1, Eigen::last), i); + Eigen::VectorXd dS_c = dS(Eigen::seq(i + 1, Eigen::last), i); + dS(i, i) -= L_c.dot(dS_c) / L(i, i); + dS(Eigen::seq(i, Eigen::last), i) /= L(i, i); + + if (i > 0) { + // update grad dS at lower-triangular row i (excluding i,i) + Eigen::MatrixXd L_r = L(i, Eigen::seq(0, i - 1)); + Eigen::MatrixXd L_rB = + L(Eigen::seq(i, Eigen::last), Eigen::seq(0, i - 1)); + + dS(i, Eigen::seq(0, i - 1)) -= + dS(Eigen::seq(i, Eigen::last), i).transpose() * L_rB; + + // update grad dS at lower-triangular block left/below index (i, i) + dS(Eigen::seq(i + 1, Eigen::last), Eigen::seq(0, i - 1)) -= + dS(Eigen::seq(i + 1, Eigen::last), i) * L_r; + } + + dS(i, i) /= 2; + } + + // split gradient between upper and lower triangular parts of input, + // which are symmetric. This follows the convention used by Pytorch, + // while the Iain Murray description accumulates all gradients + // to the lower triangular part. + for (uint i = 0; i < n; i++) { + for (uint j = i + 1; j < n; j++) { + dS(j, i) /= 2; + dS(i, j) = dS(j, i); + } + } + + in_nodes[0]->back_grad1 += dS; + } } void MatrixExp::backward() { diff --git a/src/beanmachine/graph/operator/tests/gradient_test.cpp b/src/beanmachine/graph/operator/tests/gradient_test.cpp index 78e98e0280..81fc349107 100644 --- a/src/beanmachine/graph/operator/tests/gradient_test.cpp +++ b/src/beanmachine/graph/operator/tests/gradient_test.cpp @@ -916,6 +916,74 @@ TEST(testgradient, forward_cholesky) { _expect_near_matrix(second_grad1, expected_second_grad1); } +TEST(testgradient, backward_cholesky) { + /* + + PyTorch validation code: + + x = tensor([[1.0, 0.98, 3.2], [0.2, 0.98, 1.0], [0.98, 0.2, 2.1]], + requires_grad=True) + choleskySum = cholesky(x).sum() + log_p = ( + dist.Normal(choleskySum, tensor(1.0)).log_prob(tensor(1.7)) + + dist.Normal(tensor(0.0), tensor(1.0)).log_prob(x).sum() + ) + autograd.grad(log_p, x) + */ + Graph g; + auto zero = g.add_constant(0.0); + auto pos1 = g.add_constant_pos_real(1.0); + auto one = g.add_constant((natural_t)1); + auto three = g.add_constant((natural_t)3); + auto normal_dist = g.add_distribution( + DistributionType::NORMAL, + AtomicType::REAL, + std::vector{zero, pos1}); + + auto sigma_sample = g.add_operator( + OperatorType::IID_SAMPLE, std::vector{normal_dist, three, three}); + auto L = + g.add_operator(OperatorType::CHOLESKY, std::vector{sigma_sample}); + + Eigen::MatrixXd sigma(3, 3); + sigma << 1.0, 0.2, 0.98, 0.2, 0.98, 1.0, 0.98, 1.0, 2.1; + + g.observe(sigma_sample, sigma); + + // Uses two matrix multiplications to sum the result of Cholesky + auto col_sum_m = g.add_operator( + OperatorType::IID_SAMPLE, std::vector{normal_dist, three, one}); + auto row_sum_m = g.add_operator( + OperatorType::IID_SAMPLE, std::vector{normal_dist, one, three}); + Eigen::MatrixXd m1(3, 1); + m1 << 1.0, 1.0, 1.0; + Eigen::MatrixXd m2(1, 3); + m2 << 1.0, 1.0, 1.0; + g.observe(col_sum_m, m1); + g.observe(row_sum_m, m2); + auto sum_rows = g.add_operator( + OperatorType::MATRIX_MULTIPLY, std::vector{L, col_sum_m}); + auto sum_chol = g.add_operator( + OperatorType::MATRIX_MULTIPLY, std::vector{row_sum_m, sum_rows}); + auto sum_dist = g.add_distribution( + DistributionType::NORMAL, + AtomicType::REAL, + std::vector{sum_chol, pos1}); + + auto sum_sample = + g.add_operator(OperatorType::SAMPLE, std::vector{sum_dist}); + g.observe(sum_sample, 1.7); + + std::vector grad(4); + Eigen::MatrixXd expected_grad(3, 3); + expected_grad << -2.7761, -1.6587, -0.3756, -1.6587, -2.8059, -0.6445, + -0.3756, -0.6445, -4.2949; + + g.eval_and_grad(grad); + EXPECT_EQ(grad.size(), 4); + _expect_near_matrix(grad[0]->as_matrix(), expected_grad); +} + TEST(testgradient, matrix_exp_grad) { Graph g;