Skip to content

Commit

Permalink
Merge pull request #41 from dahong67/dahong67/issue34
Browse files Browse the repository at this point in the history
Faster Khatri-Rao
  • Loading branch information
dahong67 authored Mar 1, 2024
2 parents 8f57eef + 11c156f commit 20b2bfb
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 6 deletions.
1 change: 1 addition & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const SUITE_MODULES = Dict(
"gcp" => :BenchmarkGCP,
"mttkrp" => :BenchmarkMTTKRP,
"mttkrp-large" => :BenchmarkMTTKRPLarge,
"khatrirao" => :BenchmarkKhatriRao,
)

# Create top-level suite including only sub-suites
Expand Down
73 changes: 73 additions & 0 deletions benchmark/suites/khatrirao.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
module BenchmarkKhatriRao

using BenchmarkTools, GCPDecompositions
using Random

const SUITE = BenchmarkGroup()

# Collect setups
const SETUPS = []

## N=1 matrix
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 1) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=2 matrices (balanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 2) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=3 matrices (balanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 3) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=3 matrices (imbalanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for
sz in [Tuple(circshift([30, 100, 1000], c)) for c in 0:2], r in [5; 30:30:90]
],
)

## N=4 matrices (balanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 4) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=4 matrices (imbalanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for
sz in [Tuple(circshift([20, 40, 80, 500], c)) for c in 0:3], r in [5; 30:30:90]
],
)

# Generate random benchmarks
for SETUP in SETUPS
Random.seed!(0)
U = [randn(In, SETUP.rank) for In in SETUP.size]
SUITE["size=$(SETUP.size), rank=$(SETUP.rank)"] =
@benchmarkable(GCPDecompositions.khatrirao($U...), seconds = 2, samples = 5,)
end

end
15 changes: 9 additions & 6 deletions src/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,14 @@ function khatrirao(A::Vararg{T,N}) where {T<:AbstractMatrix,N}
return A[1]
end

# General case: N > 1
r = (only unique)(size.(A, 2))
K = similar(A[1], prod(size.(A, 1)), r)
for j in 1:r
K[:, j] = reduce(kron, [view(A[i], :, j) for i in 1:N])
# Base case: N = 2
if N == 2
r = (only unique)(size.(A, 2))
return reshape(reshape(A[2], :, 1, r) .* reshape(A[1], 1, :, r), :, r)
end
return K

# Recursive case: N > 2
I, r = size.(A, 1), (only unique)(size.(A, 2))
n = argmin(n -> I[n] * I[n+1], 1:N-1)
return khatrirao(A[1:n-1]..., khatrirao(A[n], A[n+1]), A[n+2:end]...)
end
17 changes: 17 additions & 0 deletions test/items/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ end
Mh = gcp(X, r) # test default (least-squares) loss
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end

# 4-way tensor to exercise recursive part of the Khatri-Rao code
@testset "size(X)=$sz, rank(X)=$r" for sz in [(50, 40, 30, 2)], r in 1:2
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]
Mh = gcp(X, r, LeastSquaresLoss())
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5

Xm = convert(Array{Union{Missing,eltype(X)}}, X)
Xm[1, 1, 1, 1] = missing
Mm = gcp(Xm, r, LeastSquaresLoss())
@test maximum(I -> abs(Mm[I] - X[I]), CartesianIndices(X)) <= 1e-5

Mh = gcp(X, r) # test default (least-squares) loss
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end
end

@testitem "NonnegativeLeastSquaresLoss" begin
Expand Down

0 comments on commit 20b2bfb

Please sign in to comment.