Skip to content

Commit

Permalink
Merge pull request #39 from alexmul1114/FastALS-Method
Browse files Browse the repository at this point in the history
FastALS method
  • Loading branch information
dahong67 authored Mar 20, 2024
2 parents 8bb9201 + 6c68d82 commit 24886d8
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 2 deletions.
1 change: 1 addition & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const SUITE_MODULES = Dict(
"mttkrp" => :BenchmarkMTTKRP,
"mttkrp-large" => :BenchmarkMTTKRPLarge,
"khatrirao" => :BenchmarkKhatriRao,
"leastsquares" => :BenchmarkLeastSquares,
)

# Create top-level suite including only sub-suites
Expand Down
37 changes: 37 additions & 0 deletions benchmark/suites/leastsquares.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module BenchmarkLeastSquares

using BenchmarkTools, GCPDecompositions
using Random, Distributions

const SUITE = BenchmarkGroup()

# More thorough benchmarks for least squares than gcp benchmarks

# Order-3 tensors
for sz in [(15, 20, 25), (30, 40, 50), (60, 70, 80)], r in [1, 10, 50]
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]
SUITE["size(X)=$sz, rank(X)=$r"] =
@benchmarkable gcp($X, $r; loss = GCPLosses.LeastSquaresLoss())
end

# Order-4 tensors
for sz in [(15, 20, 25, 30), (30, 40, 50, 60)], r in [1, 10, 50]
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]
SUITE["least-squares-size(X)=$sz, rank(X)=$r"] =
@benchmarkable gcp($X, $r; loss = GCPLosses.LeastSquaresLoss())
end

# Order-5 tensors
for sz in [(15, 20, 25, 30, 35), (30, 30, 30, 30, 30)], r in [1, 10, 50]
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]
SUITE["least-squares-size(X)=$sz, rank(X)=$r"] =
@benchmarkable gcp($X, $r; loss = GCPLosses.LeastSquaresLoss())
end

end
2 changes: 1 addition & 1 deletion src/GCPDecompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ default_algorithm(
r,
loss::GCPLosses.LeastSquaresLoss,
constraints::Tuple{},
) = GCPAlgorithms.ALS()
) = GCPAlgorithms.FastALS()
default_algorithm(X, r, loss, constraints) = GCPAlgorithms.LBFGSB()

"""
Expand Down
4 changes: 3 additions & 1 deletion src/gcp-algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ module GCPAlgorithms

using ..GCPDecompositions
using ..TensorKernels: create_mttkrp_buffer, mttkrp!
using ..TensorKernels: khatrirao!, khatrirao
using IntervalSets: Interval
using LinearAlgebra: lu!, norm, rdiv!
using LinearAlgebra: lu!, mul!, norm, rdiv!
using LBFGSB: lbfgsb

"""
Expand All @@ -34,5 +35,6 @@ function _gcp end

include("gcp-algorithms/lbfgsb.jl")
include("gcp-algorithms/als.jl")
include("gcp-algorithms/fastals.jl")

end
225 changes: 225 additions & 0 deletions src/gcp-algorithms/fastals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
## Algorithm: FastALS

"""
FastALS
**Fast** **A**lternating **L**east **S**quares.
Efficient ALS algorithm proposed in:
> **Fast Alternating LS Algorithms for High Order
> CANDECOMP/PARAFAC Tensor Factorizations**.
> Anh-Huy Phan, Petr Tichavský, Andrzej Cichocki.
> *IEEE Transactions on Signal Processing*, 2013.
> DOI: 10.1109/TSP.2013.2269903
Algorithm parameters:
- `maxiters::Int` : max number of iterations (default: `200`)
"""
Base.@kwdef struct FastALS <: AbstractAlgorithm
maxiters::Int = 200
end

function _gcp(
X::Array{TX,N},
r,
loss::GCPLosses.LeastSquaresLoss,
constraints::Tuple{},
algorithm::GCPAlgorithms.FastALS,
init,
) where {TX<:Real,N}
# Initialization
M = deepcopy(init)

# Determine order of modes of MTTKRP to compute
Jns = [prod(size(X)[1:n]) for n in 1:N]
Kns = [prod(size(X)[n+1:end]) for n in 1:N]
Kn_minus_ones = [prod(size(X)[n:end]) for n in 1:N]
n_star = findlast(n -> Jns[n] <= Kn_minus_ones[n], 1:N)
order = vcat([i for i in n_star:-1:1], [i for i in n_star+1:N])

buffers = create_FastALS_buffers(M.U, order, Jns, Kns)

for _ in 1:algorithm.maxiters
FastALS_iter!(X, M, order, Jns, Kns, buffers)
end

return M
end

"""
FastALS_iter!(X, U, λ)
Algorithm for computing MTTKRP sequences is from "Fast Alternating LS Algorithms
for High Order CANDECOMP/PARAFAC Tensor Factorizations" by Phan et al., specifically
section III-C.
"""
function FastALS_iter!(X, M, order, Jns, Kns, buffers)
N = ndims(X)
R = size(M.U[1])[2]

# Compute MTTKRPs recursively
n_star = order[1]

for n in order
if n == n_star
kr_right = khatrirao!(buffers.kr_buffer_descending, M.U[reverse(n_star+1:N)]...)
if n_star == 1
mul!(M.U[n], reshape(X, (Jns[n], Kns[n])), kr_right)
else
mul!(buffers.descending_buffers[1], reshape(X, (Jns[n], Kns[n])), kr_right)
_rl_outer_multiplication!(
buffers.descending_buffers[1],
M.U,
buffers.helper_buffers_descending[n_star-n+1],
n,
)
end
elseif n == n_star + 1
kr_left = khatrirao!(buffers.kr_buffer_ascending, M.U[reverse(1:n-1)]...)
if n == N
mul!(M.U[n], reshape(X, (Jns[n-1], Kns[n-1]))', kr_left)
else
mul!(
buffers.ascending_buffers[1],
(reshape(X, (Jns[n-1], Kns[n-1])))',
kr_left,
)
_lr_outer_multiplication!(
buffers.ascending_buffers[1],
M.U,
buffers.helper_buffers_ascending[n-n_star],
n,
)
end
elseif n < n_star
if n == 1
for r in 1:R
mul!(
view(M.U[n], :, r),
reshape(
view(buffers.descending_buffers[n_star-n], :, r),
(Jns[n], size(X)[n+1]),
),
view(M.U[n+1], :, r),
)
end
else
for r in 1:R
mul!(
view(buffers.descending_buffers[n_star-n+1], :, r),
reshape(
view(buffers.descending_buffers[n_star-n], :, r),
(Jns[n], size(X)[n+1]),
),
view(M.U[n+1], :, r),
)
end
_rl_outer_multiplication!(
buffers.descending_buffers[n_star-n+1],
M.U,
buffers.helper_buffers_descending[n_star-n+1],
n,
)
end
else
if n == N
for r in 1:R
mul!(
view(M.U[n], :, r),
reshape(
view(buffers.ascending_buffers[N-n_star-1], :, r),
(size(X)[n-1], Kns[n-1]),
)',
view(M.U[n-1], :, r),
)
end
else
for r in 1:R
mul!(
view(buffers.ascending_buffers[n-n_star], :, r),
reshape(
view(buffers.ascending_buffers[n-n_star-1], :, r),
(size(X)[n-1], Kns[n-1]),
)',
view(M.U[n-1], :, r),
)
end
_lr_outer_multiplication!(
buffers.ascending_buffers[n-n_star],
M.U,
buffers.helper_buffers_ascending[n-n_star],
n,
)
end
end
# Normalization, update weights
V = reduce(.*, M.U[i]'M.U[i] for i in setdiff(1:N, n))
rdiv!(M.U[n], lu!(V))
M.λ .= norm.(eachcol(M.U[n]))
M.U[n] ./= permutedims(M.λ)
end
end

# Helper function for right-to-left outer multiplications
function _rl_outer_multiplication!(Zn, U, kr_buffer, n)
khatrirao!(kr_buffer, U[reverse(1:n-1)]...)
for r in 1:size(U[n])[2]
mul!(
view(U[n], :, r),
reshape(view(Zn, :, r), (prod(size(U[i])[1] for i in 1:n-1), size(U[n])[1]))',
view(kr_buffer, :, r),
)
end
end

# Helper function for left-to-right outer multiplications
function _lr_outer_multiplication!(Zn, U, kr_buffer, n)
khatrirao!(kr_buffer, U[reverse(n+1:length(U))]...)
for r in 1:size(U[n])[2]
mul!(
view(U[n], :, r),
reshape(
view(Zn, :, r),
(size(U[n])[1], prod(size(U[i])[1] for i in n+1:length(U))),
),
view(kr_buffer, :, r),
)
end
end

function create_FastALS_buffers(
U::NTuple{N,TM},
order,
Jns,
Kns,
) where {TM<:AbstractMatrix,N}
n_star = order[1]
r = size(U[1])[2]
dims = [size(U[u])[1] for u in 1:length(U)]

# Allocate buffers
# Buffer for saved products between modes
descending_buffers =
n_star < 2 ? nothing : [similar(U[1], (Jns[n], r)) for n in n_star:-1:2]
ascending_buffers =
N - n_star - 1 < 1 ? nothing : [similar(U[1], (Kns[n], r)) for n in n_star:N]
# Buffers for khatri-rao products
kr_buffer_descending = similar(U[1], (Kns[n_star], r))
kr_buffer_ascending = similar(U[1], (Jns[n_star], r))
# Buffers for khatri-rao product in helper function
helper_buffers_descending =
n_star < 2 ? nothing : [similar(U[1], (prod(dims[1:n-1]), r)) for n in n_star:-1:2]
helper_buffers_ascending =
n_star >= N - 1 ? nothing :
[similar(U[1], (prod(dims[n+1:N]), r)) for n in n_star+1:N-1]
return (;
descending_buffers,
ascending_buffers,
kr_buffer_descending,
kr_buffer_ascending,
helper_buffers_descending,
helper_buffers_ascending,
)
end
29 changes: 29 additions & 0 deletions test/items/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,35 @@ end
Mh = gcp(X, r) # test default (least-squares) loss
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end

# 5 way tensor to exercise else case in FastALS
@testset "size(X)=$sz, rank(X)=$r" for sz in [(10, 15, 20, 25, 30), (30, 25, 5, 5, 5)],
r in [2]

r = 2
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]
Mh = gcp(X, r; loss = GCPLosses.LeastSquaresLoss())
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5

Xm = convert(Array{Union{Missing,eltype(X)}}, X)
Xm[1, 1, 1, 1, 1] = missing
Mm = gcp(Xm, r; loss = GCPLosses.LeastSquaresLoss())
@test maximum(I -> abs(Mm[I] - X[I]), CartesianIndices(X)) <= 1e-5

Mh = gcp(X, r) # test default (least-squares) loss
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end

# Test old ALS method
@testset "size(X)=$sz, rank(X)=$r" for sz in [(15, 20, 25)], r in [2]
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]
Mh = gcp(X, r; loss = GCPLosses.LeastSquaresLoss(), algorithm = GCPAlgorithms.ALS())
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end
end

@testitem "NonnegativeLeastSquaresLoss" begin
Expand Down

0 comments on commit 24886d8

Please sign in to comment.