Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Training Strategies to NNDAE #876

Closed
wants to merge 0 commits into from

Conversation

hippyhippohops
Copy link

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Add any other context about the problem here.

src/ode_solve.jl Outdated Show resolved Hide resolved
src/ode_solve.jl Outdated Show resolved Hide resolved
src/ode_solve.jl Outdated Show resolved Hide resolved
src/dae_solve.jl Outdated Show resolved Hide resolved
Copy link
Member

@sathvikbhagavan sathvikbhagavan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, now that you have written the strategies for DAEs, next step is to refactor commonalities with this and NNODE such that we don't repeat code. Next is to actually try out the problem in #721 using what is implemented here.

src/dae_solve.jl Outdated
@@ -47,6 +47,25 @@ function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff =
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
end


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

src/dae_solve.jl Outdated
end
return loss
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

src/ode_solve.jl Outdated
@@ -304,6 +304,7 @@ function generate_loss(
return loss
end


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

src/dae_solve.jl Outdated
end

function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ,
autodiff::Bool,differential_vars::AbstractVector) where {C, T, U <: AbstractVector}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
autodiff::Bool,differential_vars::AbstractVector) where {C, T, U <: AbstractVector}
autodiff::Bool, differential_vars::AbstractVector) where {C, T, U <: AbstractVector}

src/dae_solve.jl Outdated
if autodiff
ForwardDiff.jacobian(t -> phi(t, θ), t)
else
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this use only differential variables? See other methods of dfdx

@sathvikbhagavan sathvikbhagavan changed the title Adding Training Strategies to dae_solvers - 2.jl Adding Training Strategies to NNDAE Jul 16, 2024
@hippyhippohops
Copy link
Author

Hi, I am trying to refactor the dae_solve.jl & ode_solve code.jl and I am currently in the process of testing it.

But for the following code:

@testset "WeightedIntervalTraining" begin
    function example2(du, u, p, t)
        du[1] = u[1] - t
        du[2] = u[2] - t
        nothing
    end
    M = [0.0 0.0
         0.0 1.0]
    u₀ = [0.0, 0.0]
    du₀ = [0.0, 0.0]
    tspan = (0.0, pi / 2.0)
    f = ODEFunction(example2, mass_matrix = M)
    prob_mm = ODEProblem(f, u₀, tspan)
    ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)
    example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
    differential_vars = [false, true]
    prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
    chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
    opt = OptimizationOptimisers.Adam(0.1)
    weights = [0.7, 0.2, 0.1]
    points = 200
    alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1),
        strategy = WeightedIntervalTraining(weights, points); autodiff = false)
    sol = solve(prob,
        alg, verbose = false, dt = 1 / 100.0,
        maxiters = 3000, abstol = 1e-10)
    @test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2
end
I am getting this error:
WeightedIntervalTraining: Error During Test at /Users/aravinthkrishnan/Desktop/NeuralPDE_Mycopy.jl/test/NNDAE_tests.jl:71
  Got exception outside of a @test
  Default algorithm choices require DifferentialEquations.jl.
  Please specify an algorithm (e.g., `solve(prob, Tsit5())` or
  `init(prob, Tsit5())` for an ODE) or import DifferentialEquations
  directly.
  You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/
  and its associated pages.
  Some of the types have been truncated in the stacktrace for improved reading. To emit complete information
  in the stack trace, evaluate `TruncatedStacktraces.VERBOSE[] = true` and re-run the code.
  Stacktrace:
    [1] __solve(::DAEProblem{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, DAEFunction{false, SciMLBase.FullSpecialize, var"#8#10", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, Vector{Bool}}, ::Nothing, ::Vararg{Any}; default_set::Bool, second_time::Bool, kwargs::@Kwargs{verbose::Bool, dt::Float64, maxiters::Int64, abstol::Float64})
      @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1390
    [2] __solve(prob::DAEProblem{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, DAEFunction{false, SciMLBase.FullSpecialize, var"#8#10", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, Vector{Bool}}, args::NNDAE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Optimisers.Adam, Nothing, @Kwargs{}, WeightedIntervalTraining{Float64}}; default_set::Bool, second_time::Bool, kwargs::@Kwargs{verbose::Bool, dt::Float64, maxiters::Int64, abstol::Float64})
      @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1394
    [3] solve_call(_prob::DAEProblem{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, DAEFunction{false, SciMLBase.FullSpecialize, var"#8#10", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, Vector{Bool}}, args::NNDAE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Optimisers.Adam, Nothing, @Kwargs{}, WeightedIntervalTraining{Float64}}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{verbose::Bool, dt::Float64, maxiters::Int64, abstol::Float64})
      @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:612
    [4] solve_up(prob::DAEProblem{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, DAEFunction{false, SciMLBase.FullSpecialize, var"#8#10", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, Vector{Bool}}, sensealg::Nothing, u0::Vector{Float64}, p::SciMLBase.NullParameters, args::NNDAE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Optimisers.Adam, Nothing, @Kwargs{}, WeightedIntervalTraining{Float64}}; kwargs::@Kwargs{verbose::Bool, dt::Float64, maxiters::Int64, abstol::Float64})
      @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1080
    [5] solve(prob::DAEProblem{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, DAEFunction{false, SciMLBase.FullSpecialize, var"#8#10", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, Vector{Bool}}, args::NNDAE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Optimisers.Adam, Nothing, @Kwargs{}, WeightedIntervalTraining{Float64}}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::@Kwargs{verbose::Bool, dt::Float64, maxiters::Int64, abstol::Float64})
      @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003
    [6] macro expansion
      @ ~/Desktop/NeuralPDE_Mycopy.jl/test/NNDAE_tests.jl:96 [inlined]
    [7] macro expansion
      @ ~/.julia/juliaup/julia-1.10.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
    [8] top-level scope
      @ ~/Desktop/NeuralPDE_Mycopy.jl/test/NNDAE_tests.jl:72
    [9] eval
      @ ./boot.jl:385 [inlined]
   [10] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String)
      @ Base ./loading.jl:2076
   [11] include_string(m::Module, txt::String, fname::String)
      @ Base ./loading.jl:2086
   [12] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
      @ Base ./essentials.jl:892
   [13] invokelatest(::Any, ::Any, ::Vararg{Any})
      @ Base ./essentials.jl:889
   [14] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:263
   [15] (::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:181
   [16] withpath(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/repl.jl:274
   [17] (::VSCodeServer.var"#66#71"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:179
   [18] hideprompt(f::VSCodeServer.var"#66#71"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/repl.jl:38
   [19] (::VSCodeServer.var"#65#70"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:150
   [20] with_logstate(f::Function, logstate::Any)
      @ Base.CoreLogging ./logging.jl:515
   [21] with_logger
      @ ./logging.jl:627 [inlined]
   [22] (::VSCodeServer.var"#64#69"{VSCodeServer.ReplRunCodeRequestParams})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:255
   [23] #invokelatest#2
      @ ./essentials.jl:892 [inlined]
   [24] invokelatest(::Any)
      @ Base ./essentials.jl:889
   [25] (::VSCodeServer.var"#62#63")()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:34
Test Summary:            | Error  Total  Time
WeightedIntervalTraining |     1      1  1.3s
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken. 

I tried resolving this issue by addressing the prompt from the above error message:

Got exception outside of a @test
  Default algorithm choices require DifferentialEquations.jl.
  Please specify an algorithm (e.g., `solve(prob, Tsit5())` or
  `init(prob, Tsit5())` for an ODE) or import DifferentialEquations
  directly.

But I still get the following error:

 WeightedIntervalTraining: Error During Test at /Users/aravinthkrishnan/Desktop/NeuralPDE_Mycopy.jl/test/NNDAE_tests.jl:71
  Got exception outside of a @test
  Incompatible problem+solver pairing.
  For example, this can occur if an ODE solver is passed with an SDEProblem.
  Solvers are only capable of handling specific problem types. Please double
  check that the chosen pairing is capable for handling the given problems.
  
  Problem type: DAEProblem
  Solver type: Tsit5
  Problem types compatible with the chosen solver: ODEProblem
  
  
  Some of the types have been truncated in the stacktrace for improved reading. To emit complete information
  in the stack trace, evaluate `TruncatedStacktraces.VERBOSE[] = true` and re-run the code.

So, I am unsure of how to proceed.

@hippyhippohops
Copy link
Author

I tried a different strategy by directly changing the code in ode_solve.jl and dae_solve.jl instead of creating a new script with the refactored code to overcome the last issue. But I am getting the following error:
ERROR: invalid redefinition of constant Main.NeuralPDE
Stacktrace:
[1] eval
@ ./boot.jl:385 [inlined]
[2] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String)
@ Base ./loading.jl:2076
[3] include_string(m::Module, txt::String, fname::String)
@ Base ./loading.jl:2086
[4] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@kwargs{})
@ Base ./essentials.jl:892
[5] invokelatest(::Any, ::Any, ::Vararg{Any})
@ Base ./essentials.jl:889
[6] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:263
[7] (::VSCodeServer.var"#67#72"{…})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:181
[8] withpath(f::VSCodeServer.var"#67#72"{…}, path::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/repl.jl:274
[9] (::VSCodeServer.var"#66#71"{…})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:179
[10] hideprompt(f::VSCodeServer.var"#66#71"{…})
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/repl.jl:38
[11] (::VSCodeServer.var"#65#70"{…})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:150
[12] with_logstate(f::Function, logstate::Any)
@ Base.CoreLogging ./logging.jl:515
[13] with_logger
@ ./logging.jl:627 [inlined]
[14] (::VSCodeServer.var"#64#69"{VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:255
[15] #invokelatest#2
@ ./essentials.jl:892 [inlined]
[16] invokelatest(::Any)
@ Base ./essentials.jl:889
[17] (::VSCodeServer.var"#62#63")()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.73.2/scripts/packages/VSCodeServer/src/eval.jl:34
Some type information was truncated. Use show(err) to see complete types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants