diff --git a/src/gcp-opt.jl b/src/gcp-opt.jl index 7fc5412..e7725dc 100644 --- a/src/gcp-opt.jl +++ b/src/gcp-opt.jl @@ -248,14 +248,11 @@ function khatrirao(A::Vararg{T,N}) where {T<:AbstractMatrix,N} return A[1] end - # Base case: N = 2 - if N == 2 - r = (only ∘ unique)(size.(A, 2)) - return reshape(reshape(A[1], :, 1, r) .* reshape(A[2], 1, :, r), :, r) + # General case: N > 1 + r = (only ∘ unique)(size.(A, 2)) + R = ntuple(Val(N)) do k + dims = (ntuple(i -> 1, Val(N - k))..., :, ntuple(i -> 1, Val(k - 1))..., r) + return reshape(A[k], dims) end - - # Recursive case: N > 2 - I, r = size.(A, 1), (only ∘ unique)(size.(A, 2)) - n = argmin(n -> I[n] * I[n+1], 1:N-1) - return khatrirao(A[1:n-1]..., khatrirao(A[n], A[n+1]), A[n+2:end]...) + return reshape(broadcast(*, R...), :, r) end