diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl index ab36e064..2c3b895e 100644 --- a/src/conditional_layers/conditional_layer_glow.jl +++ b/src/conditional_layers/conditional_layer_glow.jl @@ -78,7 +78,12 @@ function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze # 1x1 Convolution and residual block for invertible layers C = Conv1x1(n_in; freeze=freeze_conv) - RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims) + + split_num = Int(round(n_in/2)) + in_split = n_in-split_num + out_chan = 2*split_num + + RB = ResidualBlock(in_split+n_cond, n_hidden; n_out=out_chan, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims) return ConditionalLayerGlow(C, RB, logdet, activation) end @@ -143,7 +148,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA # Backpropagate RB ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C))) - ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2)) + ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=size(ΔY2)[N-1]) ΔX2 += ΔY2 # Backpropagate 1x1 conv diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index a2a0380d..946b39b0 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -9,6 +9,49 @@ device = InvertibleNetworks.CUDA.functional() ? gpu : cpu # Random seed Random.seed!(3); +# Define network +nx = 32; ny = 32; nz = 32 +n_in = 3 +n_cond = 3 +n_hidden = 4 +batchsize = 2 +L = 2 +K = 2 +split_scales = false +N = (nx,ny) + +########################################### Test with split_scales = false N = (nx,ny) ######################### +# Invertibility + +# Network and input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device +X = rand(Float32, N..., n_in, batchsize) |> device +Cond = rand(Float32, N..., n_cond, batchsize) |> device + +Y, Cond = G.forward(X,Cond) +X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes + +@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) + +# Test gradients are set and cleared +G.backward(Y, Y, Cond) + +P = get_params(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, L*K*10+2) + +clear_grad!(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, 0) + + +Random.seed!(3); # Define network nx = 32; ny = 32; nz = 32 n_in = 2