Skip to content

Commit

Permalink
wrap procedure in a function
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes-intel committed Nov 19, 2024
1 parent 58ebb83 commit 4e241dc
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions cpp/oneapi/dal/backend/primitives/lapack/solve_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,30 @@ Float diagonal_minimum(sycl::queue& queue,
return ndview<Float, 1>::wrap(diag_min_holder).at_device(queue, 0, { diag_min_event });
}

template <mkl::uplo uplo, typename Float, ndorder xlayout, ndorder ylayout>
sycl::event solve_with_fallback(
sycl::queue& queue,
const ndview<Float, 2, xlayout>& xtx,
const ndview<Float, 2, ylayout>& xty,
ndview<Float, 2, ndorder::c>& nxtx,
ndview<Float, 2, ndorder::c>& nxty, /// solution will be written here
const event_vector& dependencies) {
const std::int64_t dim_xtx = xtx.get_dimension(0);
const std::int64_t nrhs = nxty.get_dimension(0);
/// Note: this templated version of 'copy' reuses a layout that should have been
/// specified in a previous copy before the fallback.
sycl::event xtx_event_new = copy(queue, nxtx, xtx, dependencies);
sycl::event xty_event_new = copy(queue, nxty, xty, dependencies);

return solve_spectral_decomposition<uplo, Float>(queue,
nxtx,
xtx_event_new,
nxty,
xty_event_new,
dim_xtx,
nrhs);
}

template <mkl::uplo uplo, bool beta, typename Float, ndorder xlayout, ndorder ylayout>
sycl::event solve_system(sycl::queue& queue,
const ndview<Float, 2, xlayout>& xtx,
Expand Down Expand Up @@ -239,28 +263,16 @@ sycl::event solve_system(sycl::queue& queue,
/// singular values that the fallback will discard, but there's no guaranteed match
/// between singular values and entries in the Cholesky diagonal. This is just a guess.
const Float threshold_diagonal_min = 1e-6;
if (diag_min <= threshold_diagonal_min)
goto fallback_solver;
solution_event = potrs_solution<uplo>(queue, nxtx, nxty, dummy, { potrf_event, xty_event });
if (diag_min > threshold_diagonal_min) {
solution_event =
potrs_solution<uplo>(queue, nxtx, nxty, dummy, { potrf_event, xty_event });
}
else {
solution_event = solve_with_fallback<uplo>(queue, xtx, xty, nxtx, nxty, dependencies);
}
}
catch (mkl::lapack::computation_error& ex) {
goto fallback_solver;
}
/// Note: this block is structured so that it will only be entered into through a 'goto'
if (false) {
fallback_solver:
const std::int64_t nrhs = nxty.get_dimension(0);
/// Note: this templated version of 'copy' reuses the layout that was specified in the previous copy
sycl::event xtx_event_new = copy(queue, nxtx, xtx, dependencies);
sycl::event xty_event_new = copy(queue, nxty, xty, dependencies);

solution_event = solve_spectral_decomposition<uplo, Float>(queue,
nxtx,
xtx_event_new,
nxty,
xty_event_new,
dim_xtx,
nrhs);
solution_event = solve_with_fallback<uplo>(queue, xtx, xty, nxtx, nxty, dependencies);
}

return beta_copy_transform<beta>(queue, nxty, final_xty, { solution_event });
Expand Down

0 comments on commit 4e241dc

Please sign in to comment.