diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index d646ddbf6..aeb85ad6f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -184,7 +184,7 @@ function automatic_sensealg_choice( # QuadratureAdjoint skips all p calculations until the end # So it's the fastest when there are no parameters QuadratureAdjoint(autodiff = false, autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autodiff = false, autojacvec = vjp) else InterpolatingAdjoint(autodiff = false, autojacvec = vjp) @@ -194,7 +194,7 @@ function automatic_sensealg_choice( # QuadratureAdjoint skips all p calculations until the end # So it's the fastest when there are no parameters QuadratureAdjoint(autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autojacvec = vjp) else InterpolatingAdjoint(autojacvec = vjp) @@ -209,7 +209,7 @@ function automatic_sensealg_choice( # If reverse-mode isn't working, just fallback to numerical vjps if p === nothing || p === SciMLBase.NullParameters() QuadratureAdjoint(autodiff = false, autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autodiff = false, autojacvec = vjp) else InterpolatingAdjoint(autodiff = false, autojacvec = vjp) @@ -217,7 +217,7 @@ function automatic_sensealg_choice( else if p === nothing || p === SciMLBase.NullParameters() QuadratureAdjoint(autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autojacvec = vjp) else InterpolatingAdjoint(autojacvec = vjp)