Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sortcomps/sortcomps! #67

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/man/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ normalizecomps
normalizecomps!
permutecomps
permutecomps!
sortcomps
sortcomps!
GCPDecompositions.default_constraints
GCPDecompositions.default_algorithm
GCPDecompositions.default_init
Expand Down
7 changes: 5 additions & 2 deletions src/GCPDecompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ import Base: ndims, size, show, summary
import Base: getindex
import Base: AbstractArray, Array
import LinearAlgebra: norm
using Base.Order: Ordering, Reverse
using IntervalSets: Interval
using Random: default_rng

# Exports
export CPD
export ncomps, normalizecomps, normalizecomps!, permutecomps, permutecomps!
export CPD, CPDComp
export ncomps,
normalizecomps, normalizecomps!, permutecomps, permutecomps!, sortcomps, sortcomps!
export gcp
export GCPLosses, GCPConstraints, GCPAlgorithms

include("tensor-kernels.jl")
include("cpd.jl")
include("cpdcomp.jl")
include("gcp-losses.jl")
include("gcp-constraints.jl")
include("gcp-algorithms.jl")
Expand Down
44 changes: 42 additions & 2 deletions src/cpd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@
Permute the components of `M`.
`perm` is a vector or a tuple of length `ncomps(M)` specifying the permutation.

See also: `permutecomps!`.
See also: `permutecomps!`, `sortcomps`, `sortcomps!`.
"""
permutecomps(M::CPD, perm) = permutecomps!(deepcopy(M), perm)

Expand All @@ -219,7 +219,7 @@
Permute the components of `M` in-place.
`perm` is a vector or a tuple of length `ncomps(M)` specifying the permutation.

See also: `permutecomps`.
See also: `permutecomps`, `sortcomps`, `sortcomps!`.
"""
permutecomps!(M::CPD, perm) = permutecomps!(M, collect(perm))
function permutecomps!(M::CPD, perm::Vector)
Expand All @@ -236,3 +236,43 @@
# Return CPD with permuted components
return M
end

"""
sortcomps(M::CPD; dims=:λ, alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Reverse)

Sort the components of `M`. `dims` specifies what part to sort by;
it must be the symbol `:λ`, an integer in `1:ndims(M)`, or a collection of these.

For the remaining keyword arguments, see the documentation of `sort!`.

See also: `sortcomps!`, `sort`, `sort!`.
"""
sortcomps(M::CPD; dims = :λ, order::Ordering = Reverse, kwargs...) =

Check warning on line 250 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L250

Added line #L250 was not covered by tests
permutecomps(M, sortperm(_sortvals(M, dims); order, kwargs...))

"""
sortcomps!(M::CPD; dims=:λ, alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Reverse)

Sort the components of `M` in-place. `dims` specifies what part to sort by;
it must be the symbol `:λ`, an integer in `1:ndims(M)`, or a collection of these.

For the remaining keyword arguments, see the documentation of `sort!`.

See also: `sortcomps`, `sort`, `sort!`.
"""
sortcomps!(M::CPD; dims = :λ, order::Ordering = Reverse, kwargs...) =

Check warning on line 263 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L263

Added line #L263 was not covered by tests
permutecomps!(M, sortperm(_sortvals(M, dims); order, kwargs...))

function _sortvals(M::CPD, dims)

Check warning on line 266 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L266

Added line #L266 was not covered by tests
# Check dims
dims_iterable = dims isa Symbol ? (dims,) : dims
all(d -> d === :λ || (d isa Integer && d in 1:ndims(M)), dims_iterable) || throw(

Check warning on line 269 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L268-L269

Added lines #L268 - L269 were not covered by tests
ArgumentError(
"`dims` must be `:λ`, an integer specifying a mode, or a collection, got $dims",
),
)

# Return vector of values to sort by
return dims === :λ ? M.λ :
[map(d -> d === :λ ? M.λ[j] : view(M.U[d], :, j), dims) for j in 1:ncomps(M)]

Check warning on line 277 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L276-L277

Added lines #L276 - L277 were not covered by tests
end
62 changes: 62 additions & 0 deletions src/cpdcomp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
## CPD component type

"""
CPDComp

Type for a single component of a canonical polyadic decompositions (CPD).

If `M::CPDComp` is the component object,
the scalar weight `λ` and the factor vectors `u = (u[1],...,u[N])`
can be obtained via `M.λ` and `M.u`.
"""
struct CPDComp{T,N,Tu<:AbstractVector{T}}
λ::T
u::NTuple{N,Tu}
function CPDComp{T,N,Tu}(λ, u) where {T,N,Tu<:AbstractVector{T}}
Base.require_one_based_indexing(u...)
return new{T,N,Tu}(λ, u)
end
end
CPDComp(λ::T, u::NTuple{N,Tu}) where {T,N,Tu<:AbstractVector{T}} = CPDComp{T,N,Tu}(λ, u)

ndims(::CPDComp{T,N}) where {T,N} = N
size(M::CPDComp{T,N}, dim::Integer) where {T,N} = dim <= N ? length(M.u[dim]) : 1
size(M::CPDComp{T,N}) where {T,N} = ntuple(d -> size(M, d), N)

function show(io::IO, mime::MIME{Symbol("text/plain")}, M::CPDComp{T,N}) where {T,N}
# Compute displaysize for showing fields
LINES, COLUMNS = displaysize(io)
LINES_FIELD = max(LINES - 2 - N, 0) ÷ (1 + N)
io_field = IOContext(io, :displaysize => (LINES_FIELD, COLUMNS))

# Show summary and fields
summary(io, M)
println(io)
println(io, "λ weight:")
show(io_field, mime, M.λ)
for k in Base.OneTo(N)
println(io, "\nu[$k] factor vector:")
show(io_field, mime, M.u[k])
end
end

function summary(io::IO, M::CPDComp)
dimstring =
ndims(M) == 0 ? "0-dimensional" :
ndims(M) == 1 ? "$(size(M,1))-element" : join(map(string, size(M)), '×')
return print(io, dimstring, " ", typeof(M))
end

function getindex(M::CPDComp{T,N}, I::Vararg{Int,N}) where {T,N}
@boundscheck Base.checkbounds_indices(Bool, axes(M), I) || Base.throw_boundserror(M, I)
return M.λ * prod(M.u[k][I[k]] for k in Base.OneTo(ndims(M)))
end
getindex(M::CPDComp{T,N}, I::CartesianIndex{N}) where {T,N} = getindex(M, Tuple(I)...)

AbstractArray(A::CPDComp) =
reshape(TensorKernels.khatrirao(reverse(reshape.(A.u, :, 1))...) * A.λ, size(A))
Array(A::CPDComp) = Array(AbstractArray(A))

norm(M::CPDComp, p::Real = 2) =
p == 2 ? norm2(M) : norm((M[I] for I in CartesianIndices(size(M))), p)
norm2(M::CPDComp{T,N}) where {T,N} = sqrt(abs2(M.λ) * prod(sum(abs2, M.u[i]) for i in 1:N))
127 changes: 127 additions & 0 deletions test/items/cpdcomp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
## CPD component type

@testitem "constructors" begin
using OffsetArrays

@testset "T=$T" for T in [Float64, Float16]
λ = T(100)
u1, u2, u3 = T[1, 4], T[-1], T[2, 5, 8]

# Check type for various orders
@test CPDComp{T,0,Vector{T}}(λ, ()) isa CPDComp{T,0,Vector{T}}
@test CPDComp(λ, (u1,)) isa CPDComp{T,1,Vector{T}}
@test CPDComp(λ, (u1, u2)) isa CPDComp{T,2,Vector{T}}
@test CPDComp(λ, (u1, u2, u3)) isa CPDComp{T,3,Vector{T}}

# Check requirement of one-based indexing
O1, O2 = OffsetArray(u1, 0:1), OffsetArray(u2, 0:0)
@test_throws ArgumentError CPDComp(λ, (O1, O2))
end
end

@testitem "ndims" begin
λ = 100
u1, u2, u3 = [1, 4], [-1], [2, 5, 8]

@test ndims(CPDComp{Int,0,Vector{Int}}(λ, ())) == 0
@test ndims(CPDComp(λ, (u1,))) == 1
@test ndims(CPDComp(λ, (u1, u2))) == 2
@test ndims(CPDComp(λ, (u1, u2, u3))) == 3
end

@testitem "size" begin
λ = 100
u1, u2, u3 = [1, 4], [-1], [2, 5, 8]

@test size(CPDComp(λ, (u1,))) == (length(u1),)
@test size(CPDComp(λ, (u1, u2))) == (length(u1), length(u2))
@test size(CPDComp(λ, (u1, u2, u3))) == (length(u1), length(u2), length(u3))

M = CPDComp(λ, (u1, u2, u3))
@test size(M, 1) == 2
@test size(M, 2) == 1
@test size(M, 3) == 3
@test size(M, 4) == 1
end

@testitem "show / summary" begin
M = CPDComp(rand(), rand.((3, 4, 5)))
Mstring = sprint((t, s) -> show(t, "text/plain", s), M)
λstring = sprint((t, s) -> show(t, "text/plain", s), M.λ)
ustrings = sprint.((t, s) -> show(t, "text/plain", s), M.u)
@test Mstring == string(
"$(summary(M))\nλ weight:\n$λstring",
["\nu[$k] factor vector:\n$ustring" for (k, ustring) in enumerate(ustrings)]...,
)
end

@testitem "getindex" begin
T = Float64
λ = T(100)
u1, u2, u3 = T[1, 4], T[-1], T[2, 5, 8]

M = CPDComp(λ, (u1, u2, u3))
for i1 in axes(u1, 1), i2 in axes(u2, 1), i3 in axes(u3, 1)
Mi = λ * u1[i1] * u2[i2] * u3[i3]
@test Mi == M[i1, i2, i3]
@test Mi == M[CartesianIndex((i1, i2, i3))]
end
@test_throws BoundsError M[length(u1)+1, 1, 1]
@test_throws BoundsError M[1, length(u2)+1, 1]
@test_throws BoundsError M[1, 1, length(u3)+1]

M = CPDComp(λ, (u1, u2))
for i1 in axes(u1, 1), i2 in axes(u2, 1)
Mi = λ * u1[i1] * u2[i2]
@test Mi == M[i1, i2]
@test Mi == M[CartesianIndex((i1, i2))]
end
@test_throws BoundsError M[length(u1)+1, 1]
@test_throws BoundsError M[1, length(u2)+1]

M = CPDComp(λ, (u1,))
for i1 in axes(u1, 1)
Mi = λ * u1[i1]
@test Mi == M[i1]
@test Mi == M[CartesianIndex((i1,))]
end
@test_throws BoundsError M[length(u1)+1]
end

@testitem "Array" begin
@testset "N=$N" for N in 1:3
T = Float64
λ = T(100)
u1, u2, u3 = T[1, 4], T[-1], T[2, 5, 8]
M = CPDComp(λ, (u1, u2, u3))

X = Array(M)
@test all(I -> M[I] == X[I], CartesianIndices(X))
end
end

@testitem "norm" begin
using LinearAlgebra

T = Float64
λ = T(100)
u1, u2, u3 = T[1, 4], T[-1], T[2, 5, 8]

M = CPDComp(λ, (u1, u2, u3))
@test norm(M) == norm(M, 2) == sqrt(sum(abs2, M[I] for I in CartesianIndices(size(M))))
@test norm(M, 1) == sum(abs, M[I] for I in CartesianIndices(size(M)))
@test norm(M, 3) ==
(sum(m -> abs(m)^3, M[I] for I in CartesianIndices(size(M))))^(1 / 3)

M = CPDComp(λ, (u1, u2))
@test norm(M) == norm(M, 2) == sqrt(sum(abs2, M[I] for I in CartesianIndices(size(M))))
@test norm(M, 1) == sum(abs, M[I] for I in CartesianIndices(size(M)))
@test norm(M, 3) ==
(sum(m -> abs(m)^3, M[I] for I in CartesianIndices(size(M))))^(1 / 3)

M = CPDComp(λ, (u1,))
@test norm(M) == norm(M, 2) == sqrt(sum(abs2, M[I] for I in CartesianIndices(size(M))))
@test norm(M, 1) == sum(abs, M[I] for I in CartesianIndices(size(M)))
@test norm(M, 3) ==
(sum(m -> abs(m)^3, M[I] for I in CartesianIndices(size(M))))^(1 / 3)
end
Loading