Skip to content

Commit

Permalink
Added view opt for tensor_split
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Oct 30, 2023
1 parent a5eb855 commit 52ebe2f
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/utils/dimensionality_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,37 @@
export squeeze, unsqueeze, wavelet_squeeze, wavelet_unsqueeze, Haar_squeeze, invHaar_unsqueeze
export tensor_split, tensor_cat
export cat_states, split_states
export ShuffleLayer, WaveletLayer, HaarLayer
export Squeezer, ShuffleLayer, WaveletLayer, HaarLayer


###############################################################################
# Custom type for squeezer functions

struct Squeezer
abstract type Squeezer end

struct ShuffleLayer<:Squeezer
pattern
forward::Function
inverse::Function
end

function ShuffleLayer(;pattern="checkerboard")
return Squeezer(x -> squeeze(x;pattern=pattern), x -> unsqueeze(x;pattern=pattern))
end
ShuffleLayer(; pattern="checkerboard") = ShuffleLayer(pattern, x -> squeeze(x; pattern=pattern), x -> unsqueeze(x;pattern=pattern))

function WaveletLayer(;type=WT.db1)
return Squeezer(x -> wavelet_squeeze(x;type=type), x -> wavelet_unsqueeze(x;type=type))
struct WaveletLayer<:Squeezer
type
forward::Function
inverse::Function
end

function HaarLayer()
return Squeezer(x -> Haar_squeeze(x), x -> invHaar_unsqueeze(x))
WaveletLayer(; type=WT.db1) = WaveletLayer(type, x -> wavelet_squeeze(x; type=type), x -> wavelet_unsqueeze(x; type=type))

struct HaarLayer<:Squeezer
forward::Function
inverse::Function
end

HaarLayer() = HaarLayer(x -> Haar_squeeze(x), x -> invHaar_unsqueeze(x))


####################################################################################################
# Squeeze and unsqueeze
Expand Down Expand Up @@ -410,7 +420,7 @@ invHaar_unsqueeze(x::AbstractArray{T, 3}) where T = invHaarLift(x, 1)
See also: [`tensor_cat`](@ref)
"""
function tensor_split(X::AbstractArray{T, N}; split_index=nothing) where {T, N}
function tensor_split(X::AbstractArray{T, N}; split_index=nothing, view::Bool=false) where {T, N}
d = max(1, N-1)
if isnothing(split_index)
k = Int(round(size(X, d)/2))
Expand All @@ -421,7 +431,8 @@ function tensor_split(X::AbstractArray{T, N}; split_index=nothing) where {T, N}
indsl = [i==d ? (1:k) : Colon() for i=1:N]
indsr = [i==d ? (k+1:size(X, d)) : Colon() for i=1:N]

return X[indsl...], X[indsr...]
view ? (Xl = Base.view(X, indsl...); Xr = Base.view(X, indsr...)) : (Xl = X[indsl...]; Xr = X[indsr...])
return Xl, Xr
end

"""
Expand Down

0 comments on commit 52ebe2f

Please sign in to comment.