diff --git a/src/gcp-losses.jl b/src/gcp-losses.jl index 1f2cf54..dd9674c 100644 --- a/src/gcp-losses.jl +++ b/src/gcp-losses.jl @@ -118,6 +118,7 @@ function grad_U!( for j in 1:K if sym_data mttkrp!(GU[j], Y, tuple([M.U[k] for k in M.S]...), findall(M.S .== j)[1]) + rmul!(GU[j], count(M.S .== j)) else for (index, mode) in enumerate(findall(M.S .== j)) if index == 1