Skip to content

Commit

Permalink
Add sortcomps/sortcomps!
Browse files Browse the repository at this point in the history
  • Loading branch information
dahong67 committed Aug 22, 2024
1 parent 4900abe commit 9b8d1bf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/GCPDecompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import Base: ndims, size, show, summary
import Base: getindex
import Base: AbstractArray, Array
import LinearAlgebra: norm
using Base.Order: Ordering, Reverse
using IntervalSets: Interval
using Random: default_rng

# Exports
export CPD
export ncomps, normalizecomps, normalizecomps!, permutecomps, permutecomps!
export ncomps,
normalizecomps, normalizecomps!, permutecomps, permutecomps!, sortcomps, sortcomps!
export gcp
export GCPLosses, GCPConstraints, GCPAlgorithms

Expand Down
19 changes: 19 additions & 0 deletions src/cpd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,22 @@ function permutecomps!(M::CPD, perm::Vector)
# Return CPD with permuted components
return M
end

sortcomps(M::CPD; dims = , order::Ordering = Reverse, kwargs...) =

Check warning on line 240 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L240

Added line #L240 was not covered by tests
permutecomps(M, sortperm(_sortvals(M, dims); order, kwargs...))
sortcomps!(M::CPD; dims = , order::Ordering = Reverse, kwargs...) =

Check warning on line 242 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L242

Added line #L242 was not covered by tests
permutecomps!(M, sortperm(_sortvals(M, dims); order, kwargs...))

function _sortvals(M::CPD, dims)

Check warning on line 245 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L245

Added line #L245 was not covered by tests
# Check dims
dims_iterable = dims isa Symbol ? (dims,) : dims
all(d -> d === || (d isa Integer && d in 1:ndims(M)), dims_iterable) || throw(

Check warning on line 248 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L247-L248

Added lines #L247 - L248 were not covered by tests
ArgumentError(
"`dims` must be `:λ`, an integer specifying a mode, or a collection, got $dims",
),
)

# Return vector of values to sort by
return dims === ? M.λ :
[map(d -> d === ? M.λ[j] : view(M.U[d], :, j), dims) for j in 1:ncomps(M)]

Check warning on line 256 in src/cpd.jl

View check run for this annotation

Codecov / codecov/patch

src/cpd.jl#L255-L256

Added lines #L255 - L256 were not covered by tests
end

0 comments on commit 9b8d1bf

Please sign in to comment.