Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed Jul 15, 2024
1 parent 3552110 commit c361ed1
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 72 deletions.
1 change: 0 additions & 1 deletion examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
*
**************************************************************************************************/

#include "cutlass/gemm/device/gemm.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp"
#include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp"
Expand Down
4 changes: 2 additions & 2 deletions examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
*
**************************************************************************************************/

#include "cutlass/gemm/device/gemm.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp"
#include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp"
Expand Down Expand Up @@ -210,7 +209,8 @@ struct ExampleRunner {
syclcompat::wait();

using TensorView = cutlass::TensorView<ElementOutput, LayoutD>;
cutlass::reference::device::TensorReLu(TensorView(block_ref_D.get(), LayoutD::packed({M, N}), cutlass::make_Coord(M, N)));
cutlass::reference::device::TensorReLu(TensorView(block_ref_D.get(), LayoutD::packed({M, N}),
cutlass::make_Coord(M, N)));

syclcompat::wait();

Expand Down
5 changes: 0 additions & 5 deletions include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ class CollectiveEpilogue<

cst_callbacks.begin();

if (is_C_load_needed) {
copy(params.xe_load_c, tOuti(_,_,_,l_coord), trC);
}

auto acc_frag = recast<Array<ElementOutput, FragmentSize>>(accumulators);
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);

Expand All @@ -321,7 +317,6 @@ class CollectiveEpilogue<
for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) {
trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
}

copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n));
}
}
Expand Down
122 changes: 59 additions & 63 deletions include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,36 @@
#include "cute/tensor.hpp"

#include "cutlass/epilogue/dispatch_policy.hpp"
// #include "cutlass/epilogue/fusion/callbacks.hpp"
// #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
// #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
// #include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp"
// #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp"
// #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/callbacks.hpp"
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass::epilogue::fusion {

/////////////////////////////////////////////////////////////////////////////////////////////////

// template <class NodeOp, class... ChildOps>
// using Sm90EVT = Sm90TreeVisitor<NodeOp, ChildOps...>;

template <
class ElementOutput_,
class ElementCompute_,
class ElementSource_,
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
class ElementOutput_,
class ElementCompute_,
class ElementSource_,
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
>
struct FusionCallbacks<
epilogue::IntelPVCEpilogue,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
CtaTileShapeMNK_,
EpilogueTile_,
void, void
> {//: Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle_> {

// using Impl = Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput>::type, ElementCompute, ElementSource, ElementScalar, RoundStyle_>;
epilogue::IntelPVCEpilogue,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
CtaTileShapeMNK_,
EpilogueTile_
> : Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput_>::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {

using Impl = Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput_>::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>;
using ElementOutput = ElementOutput_;
using ElementCompute = ElementCompute_;
using ElementSource = ElementSource_;
Expand All @@ -86,41 +82,41 @@ struct FusionCallbacks<
ElementScalar const* alpha_ptr = nullptr;
ElementScalar const* beta_ptr = nullptr;

// operator typename Impl::Arguments() const {
// return
// { // ternary op : beta * C + (alpha * acc)
// {{beta}, {beta_ptr}}, // leaf args : beta
// {}, // leaf args : C
// { // binary op : alpha * acc
// {{alpha}, {alpha_ptr}}, // leaf args : alpha
// {}, // leaf args : acc
// {} // binary args : multiplies
// }, // end binary op
// {} // ternary args : multiply_add
// }; // end ternary op
// }
operator typename Impl::Arguments() const {
return
{ // ternary op : beta * C + (alpha * acc)
{{beta}, {beta_ptr}}, // leaf args : beta
{}, // leaf args : C
{ // binary op : alpha * acc
{{alpha}, {alpha_ptr}}, // leaf args : alpha
{}, // leaf args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
}; // end ternary op
}
};

// Ctor inheritance
// using Impl::Impl;
using Impl::Impl;
};


template <
template <class> class ActivationFn_,
class ElementOutput_,
class ElementCompute_,
class ElementSource_,
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
template <class> class ActivationFn_,
class ElementOutput_,
class ElementCompute_,
class ElementSource_,
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
>
struct FusionCallbacks<
epilogue::IntelPVCEpilogue,
fusion::LinCombEltAct<ActivationFn_, ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
CtaTileShapeMNK_,
EpilogueTile_
epilogue::IntelPVCEpilogue,
fusion::LinCombEltAct<ActivationFn_, ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
CtaTileShapeMNK_,
EpilogueTile_
> : Sm90LinCombEltAct<ActivationFn_, ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {

using Impl = Sm90LinCombEltAct<ActivationFn_, typename cutlass::detail::get_unpacked_element_type<ElementOutput_>::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>;
Expand All @@ -141,19 +137,19 @@ struct FusionCallbacks<

operator typename Impl::Arguments() const {
return
{ // unary op: activation(beta * C + (alpha * acc))
{ // ternary op : beta * C + (alpha * acc)
{{beta}, {beta_ptr}}, // leaf args : beta
{}, // leaf args : C
{ // binary op : alpha * acc
{{alpha}, {alpha_ptr}}, // leaf args : alpha
{}, // leaf args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
}, // end ternary op
activation // unary args: activation
}; // end unary op
{ // unary op: activation(beta * C + (alpha * acc))
{ // ternary op : beta * C + (alpha * acc)
{{beta}, {beta_ptr}}, // leaf args : beta
{}, // leaf args : C
{ // binary op : alpha * acc
{{alpha}, {alpha_ptr}}, // leaf args : alpha
{}, // leaf args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
}, // end ternary op
activation // unary args: activation
}; // end unary op
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ __global__ void
Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);

if (!relatively_equal(a, b, epsilon, nonzero_floor)) {
printf("idx :%lu | a: %f | b: %f\n", idx, a, b);
*equal = 0;
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ void TensorReLu(
} // namespace reference
} // namespace cutlass

#if (CUTLASS_ENABLE_SYCL)
namespace sycl {
template <>
struct is_device_copyable <
cutlass::reference::device::detail::TensorReLuFunc<float,
cutlass::layout::RowMajor>::Params> : std::true_type {};
}
#endif

0 comments on commit c361ed1

Please sign in to comment.