You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The forward function of the RGMS kernel is (relation related information are ignored for simplicity):
$$ Y = AXW $$
we already have its implementation written in SparseTIR using composable formats and tensor cores.
The backward function of the RGMS kernel needs to compute both the gradient of $X$ and $W$ : $$\nabla (XW) = A^T \nabla Y$$ $$\nabla X = \nabla (XW) W^T $$ $$\nabla W = X^T \nabla (XW) $$
The three formulas could be computed inside the same kernel, and $\nabla (XW)$ should be stored in shared memory. The same optimizations (composable formats + tensorization) could be applied to backward kernel as well.
The text was updated successfully, but these errors were encountered:
The forward function of the RGMS kernel is (relation related information are ignored for simplicity):
we already have its implementation written in SparseTIR using composable formats and tensor cores.
The backward function of the RGMS kernel needs to compute both the gradient of$X$ and $W$ :
$$\nabla (XW) = A^T \nabla Y$$
$$\nabla X = \nabla (XW) W^T $$
$$\nabla W = X^T \nabla (XW) $$
The three formulas could be computed inside the same kernel, and$\nabla (XW)$ should be stored in shared memory. The same optimizations (composable formats + tensorization) could be applied to backward kernel as well.
The text was updated successfully, but these errors were encountered: