Skip to content

Commit

Permalink
Version based on broadcasting
Browse files Browse the repository at this point in the history
Avoids allocating intermediate arrays, at a cost of additional multiplies.
  • Loading branch information
dahong67 committed Mar 1, 2024
1 parent a071fbd commit f9a78c9
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,11 @@ function khatrirao(A::Vararg{T,N}) where {T<:AbstractMatrix,N}
return A[1]
end

# Base case: N = 2
if N == 2
r = (only unique)(size.(A, 2))
return reshape(reshape(A[1], :, 1, r) .* reshape(A[2], 1, :, r), :, r)
# General case: N > 1
r = (only unique)(size.(A, 2))
R = ntuple(Val(N)) do k
dims = (ntuple(i -> 1, Val(N - k))..., :, ntuple(i -> 1, Val(k - 1))..., r)
return reshape(A[k], dims)
end

# Recursive case: N > 2
I, r = size.(A, 1), (only unique)(size.(A, 2))
n = argmin(n -> I[n] * I[n+1], 1:N-1)
return khatrirao(A[1:n-1]..., khatrirao(A[n], A[n+1]), A[n+2:end]...)
return reshape(broadcast(*, R...), :, r)
end

0 comments on commit f9a78c9

Please sign in to comment.