Skip to content

Commit

Permalink
(var, var, number) implementation and testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Oct 21, 2023
1 parent 84fb221 commit b472940
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 2 deletions.
217 changes: 215 additions & 2 deletions src/math/kepF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ llvm::Value *taylor_diff_kepF_impl(llvm_state &s, llvm::Type *fp_t, const std::v

auto &builder = s.builder();

// Fetch the index of the k and lam variable arguments.
// Fetch the indices of the k and lam variable arguments.
const auto k_idx = uname_to_index(var1.name());
const auto lam_idx = uname_to_index(var2.name());

Expand Down Expand Up @@ -533,7 +533,7 @@ taylor_diff_kepF_impl(llvm_state &s, llvm::Type *fp_t, const std::vector<std::ui

auto &builder = s.builder();

// Fetch the index of the k and lam variable arguments.
// Fetch the indices of the k and lam variable arguments.
const auto h_idx = uname_to_index(var1.name());
const auto lam_idx = uname_to_index(var2.name());

Expand Down Expand Up @@ -597,6 +597,86 @@ taylor_diff_kepF_impl(llvm_state &s, llvm::Type *fp_t, const std::vector<std::ui
return llvm_fdiv(s, dividend, divisor);
}

// Derivative of kepF(var, var, number).
template <typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
llvm::Value *
taylor_diff_kepF_impl(llvm_state &s, llvm::Type *fp_t, const std::vector<std::uint32_t> &deps, const variable &var1,
const variable &var2, const U &num, const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr,
// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
{
assert(deps.size() == 4u); // LCOV_EXCL_LINE

auto &builder = s.builder();

// Fetch the indices of the h and k variable arguments.
const auto h_idx = uname_to_index(var1.name());
const auto k_idx = uname_to_index(var2.name());

// Do the codegen for the number argument.
auto *lam = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size);

if (order == 0u) {
// Create/fetch the Kepler solver.
auto *fkep = llvm_add_inv_kep_F(s, fp_t, batch_size);

// Invoke and return.
return builder.CreateCall(
fkep, {taylor_fetch_diff(arr, h_idx, 0, n_uvars), taylor_fetch_diff(arr, k_idx, 0, n_uvars), lam});
}

// Splat the order.
auto *n = vector_splat(builder, llvm_constantfp(s, fp_t, static_cast<double>(order)), batch_size);

// Compute the divisor: n * (1 - c^[0] - d^[0]).
const auto c_idx = deps[0], d_idx = deps[1];
auto *one_fp = vector_splat(builder, llvm_constantfp(s, fp_t, 1.), batch_size);
auto *divisor = llvm_fsub(s, one_fp, taylor_fetch_diff(arr, c_idx, 0, n_uvars));
divisor = llvm_fsub(s, divisor, taylor_fetch_diff(arr, d_idx, 0, n_uvars));
divisor = llvm_fmul(s, n, divisor);

// Compute the first part of the dividend: n * (k^[n] * e^[0] - h^[n] * f^[0]) (the derivative of lam is zero
// because here lam is constant and the order is > 0).
const auto e_idx = deps[2], f_idx = deps[3];
auto *div1 = llvm_fmul(s, taylor_fetch_diff(arr, k_idx, order, n_uvars), taylor_fetch_diff(arr, e_idx, 0, n_uvars));
auto *div2 = llvm_fmul(s, taylor_fetch_diff(arr, h_idx, order, n_uvars), taylor_fetch_diff(arr, f_idx, 0, n_uvars));
auto *dividend = llvm_fsub(s, div1, div2);
dividend = llvm_fmul(s, n, dividend);

// Compute the second part of the dividend only for order > 1, in order to avoid
// an empty summation.
if (order > 1u) {
std::vector<llvm::Value *> sum;

// NOTE: iteration in the [1, order) range.
for (std::uint32_t j = 1; j < order; ++j) {
auto *fac = vector_splat(builder, llvm_constantfp(s, fp_t, static_cast<double>(j)), batch_size);

auto *cnj = taylor_fetch_diff(arr, c_idx, order - j, n_uvars);
auto *dnj = taylor_fetch_diff(arr, d_idx, order - j, n_uvars);
auto *enj = taylor_fetch_diff(arr, e_idx, order - j, n_uvars);
auto *fnj = taylor_fetch_diff(arr, f_idx, order - j, n_uvars);
auto *aj = taylor_fetch_diff(arr, idx, j, n_uvars);
auto *hj = taylor_fetch_diff(arr, h_idx, j, n_uvars);
auto *kj = taylor_fetch_diff(arr, k_idx, j, n_uvars);

auto *tmp1 = llvm_fadd(s, cnj, dnj);
auto *tmp2 = llvm_fmul(s, kj, enj);
auto *tmp3 = llvm_fmul(s, hj, fnj);
auto *tmp4 = llvm_fmul(s, aj, tmp1);
auto *tmp5 = llvm_fsub(s, tmp2, tmp3);
auto *tmp6 = llvm_fadd(s, tmp4, tmp5);
auto *tmp = llvm_fmul(s, fac, tmp6);
sum.push_back(tmp);
}

// Update the dividend.
dividend = llvm_fadd(s, dividend, pairwise_sum(s, sum));
}

return llvm_fdiv(s, dividend, divisor);
}

// LCOV_EXCL_START

// All the other cases.
Expand Down Expand Up @@ -1290,6 +1370,139 @@ llvm::Function *taylor_c_diff_func_kepF_impl(llvm_state &s, llvm::Type *fp_t, co
return f;
}

// Derivative of kepF(var, var, number).
template <typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
llvm::Function *taylor_c_diff_func_kepF_impl(llvm_state &s, llvm::Type *fp_t, const variable &var1,
const variable &var2, const U &n, std::uint32_t n_uvars,
std::uint32_t batch_size)
{
auto &md = s.module();
auto &builder = s.builder();
auto &context = s.context();

// Fetch the vector floating-point type.
auto *val_t = make_vector_type(fp_t, batch_size);

// Fetch the function name and arguments.
const auto na_pair = taylor_c_diff_func_name_args(context, fp_t, "kepF", n_uvars, batch_size, {var1, var2, n}, 4);
const auto &fname = na_pair.first;
const auto &fargs = na_pair.second;

// Try to see if we already created the function.
auto f = md.getFunction(fname);

if (f != nullptr) {
// The function was created already, return it.
return f;
}

// The function was not created before, do it now.

// Create/fetch the Kepler solver.
auto *fkep = llvm_add_inv_kep_F(s, fp_t, batch_size);

// Fetch the current insertion block.
auto *orig_bb = builder.GetInsertBlock();

// The return type is val_t.
auto *ft = llvm::FunctionType::get(val_t, fargs, false);
// Create the function
f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md);
assert(f != nullptr);

// Fetch the necessary function arguments.
auto ord = f->args().begin();
auto u_idx = f->args().begin() + 1;
auto diff_ptr = f->args().begin() + 2;
auto par_ptr = f->args().begin() + 3;
auto h_idx = f->args().begin() + 5;
auto k_idx = f->args().begin() + 6;
auto num_lam = f->args().begin() + 7;
auto c_idx = f->args().begin() + 8;
auto d_idx = f->args().begin() + 9;
auto e_idx = f->args().begin() + 10;
auto f_idx = f->args().begin() + 11;

// Create a new basic block to start insertion into.
builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));

// Create the return value.
auto *retval = builder.CreateAlloca(val_t);

// Create the accumulator.
auto *acc = builder.CreateAlloca(val_t);

llvm_if_then_else(
s, builder.CreateICmpEQ(ord, builder.getInt32(0)),
[&]() {
builder.CreateStore(
builder.CreateCall(fkep, {taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), h_idx),
taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), k_idx),
taylor_c_diff_numparam_codegen(s, fp_t, n, num_lam, par_ptr, batch_size)}),
retval);
},
[&]() {
// Create FP vector versions of the order.
auto ord_v = vector_splat(builder, llvm_ui_to_fp(s, ord, fp_t), batch_size);

// Compute the divisor: ord * (1 - c^[0] - d^[0]).
auto *one_fp = vector_splat(builder, llvm_constantfp(s, fp_t, 1.), batch_size);
auto divisor
= llvm_fsub(s, one_fp, taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), c_idx));
divisor
= llvm_fsub(s, divisor, taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), d_idx));
divisor = llvm_fmul(s, ord_v, divisor);

// Init the dividend: ord * (k^[n] * e^[0] - h^[n] * f^[0]) (lam is constant here).
auto div1 = llvm_fmul(s, taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, ord, k_idx),
taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), e_idx));
auto div2 = llvm_fmul(s, taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, ord, h_idx),
taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), f_idx));
auto dividend = llvm_fsub(s, div1, div2);
dividend = llvm_fmul(s, ord_v, dividend);

// Init the accumulator.
builder.CreateStore(vector_splat(builder, llvm_constantfp(s, fp_t, 0.), batch_size), acc);

// Run the loop.
llvm_loop_u32(s, builder.getInt32(1), ord, [&](llvm::Value *j) {
auto *j_v = vector_splat(builder, llvm_ui_to_fp(s, j, fp_t), batch_size);

auto c_nj = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.CreateSub(ord, j), c_idx);
auto d_nj = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.CreateSub(ord, j), d_idx);
auto e_nj = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.CreateSub(ord, j), e_idx);
auto f_nj = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.CreateSub(ord, j), f_idx);
auto aj = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, j, u_idx);
auto hj = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, j, h_idx);
auto kj = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, j, k_idx);

auto *tmp1 = llvm_fadd(s, c_nj, d_nj);
auto *tmp2 = llvm_fmul(s, kj, e_nj);
auto *tmp3 = llvm_fmul(s, hj, f_nj);
auto *tmp4 = llvm_fmul(s, aj, tmp1);
auto *tmp5 = llvm_fsub(s, tmp2, tmp3);
auto *tmp6 = llvm_fadd(s, tmp4, tmp5);
auto *tmp = llvm_fmul(s, j_v, tmp6);

builder.CreateStore(llvm_fadd(s, builder.CreateLoad(val_t, acc), tmp), acc);
});

// Write the result.
builder.CreateStore(llvm_fdiv(s, llvm_fadd(s, dividend, builder.CreateLoad(val_t, acc)), divisor), retval);
});

// Return the result.
builder.CreateRet(builder.CreateLoad(val_t, retval));

// Verify.
s.verify_function(f);

// Restore the original insertion block.
builder.SetInsertPoint(orig_bb);

return f;
}

// LCOV_EXCL_START

// All the other cases.
Expand Down
65 changes: 65 additions & 0 deletions test/taylor_kepF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,71 @@ TEST_CASE("taylor kepF")
REQUIRE(jet[14] == approximately((jet[8] * 2 + jet[10] * 2) / 6));
REQUIRE(jet[15] == approximately((jet[9] * 2 + jet[11] * 2) / 6));
}

// Var-var-number test.
{
llvm_state s{kw::opt_level = opt_level};

taylor_add_jet<fp_t>(s, "jet", {kepF(x, y, par[0]), x + y}, 3, 2, high_accuracy, compact_mode);

s.compile();

if (opt_level == 0u && compact_mode) {
REQUIRE(boost::contains(s.get_ir(), "@heyoka.taylor_c_diff.kepF.var_var_par"));
}

auto jptr = reinterpret_cast<void (*)(fp_t *, const fp_t *, const fp_t *)>(s.jit_lookup("jet"));

std::vector<fp_t> jet{fp_t{.5}, fp_t{-.125}, fp_t{0.1875}, fp_t{-0.3125}}, pars{fp_t(.2), fp_t(.2)};
jet.resize(16);

jptr(jet.data(), pars.data(), nullptr);

REQUIRE(jet[0] == .5);
REQUIRE(jet[1] == -.125);

REQUIRE(jet[2] == 0.1875);
REQUIRE(jet[3] == -0.3125);

REQUIRE(jet[4] == approximately(kepF_num(jet[0], jet[2], pars[0])));
REQUIRE(jet[5] == approximately(kepF_num(jet[1], jet[3], pars[1])));

REQUIRE(jet[6] == jet[0] + jet[2]);
REQUIRE(jet[7] == jet[1] + jet[3]);

auto den0 = 1 - jet[0] * sin(jet[4]) - jet[2] * cos(jet[4]);
auto den1 = 1 - jet[1] * sin(jet[5]) - jet[3] * cos(jet[5]);

REQUIRE(jet[8] == approximately((jet[6] * sin(jet[4]) - jet[4] * cos(jet[4])) / den0 / 2));
REQUIRE(jet[9] == approximately((jet[7] * sin(jet[5]) - jet[5] * cos(jet[5])) / den1 / 2));

REQUIRE(jet[10] == (jet[4] + jet[6]) / 2);
REQUIRE(jet[11] == (jet[5] + jet[7]) / 2);

auto Fp0 = jet[8] * 2;
auto Fp1 = jet[9] * 2;

auto tmp0 = -jet[4] * sin(jet[4]) - jet[0] * cos(jet[4]) * Fp0 - jet[6] * cos(jet[4])
+ jet[2] * sin(jet[4]) * Fp0;
auto tmp1 = -jet[5] * sin(jet[5]) - jet[1] * cos(jet[5]) * Fp1 - jet[7] * cos(jet[5])
+ jet[3] * sin(jet[5]) * Fp1;

REQUIRE(jet[12]
== approximately(((2 * jet[10] * sin(jet[4]) + jet[6] * cos(jet[4]) * Fp0 - Fp0 * cos(jet[4])
+ jet[4] * sin(jet[4]) * Fp0)
* den0
- (jet[6] * sin(jet[4]) - jet[4] * cos(jet[4])) * tmp0)
/ (den0 * den0) / 6));
REQUIRE(jet[13]
== approximately(((2 * jet[11] * sin(jet[5]) + jet[7] * cos(jet[5]) * Fp1 - Fp1 * cos(jet[5])
+ jet[5] * sin(jet[5]) * Fp1)
* den1
- (jet[7] * sin(jet[5]) - jet[5] * cos(jet[5])) * tmp1)
/ (den1 * den1) / 6));

REQUIRE(jet[14] == approximately((jet[8] * 2 + jet[10] * 2) / 6));
REQUIRE(jet[15] == approximately((jet[9] * 2 + jet[11] * 2) / 6));
}
};

for (auto cm : {false, true}) {
Expand Down

0 comments on commit b472940

Please sign in to comment.