From 12cd77d82546674e4951b900e6c8076d94e397b4 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 30 Jan 2025 16:50:42 -0800 Subject: [PATCH] Don't create nested thunks when accumulating (#1555) * Don't create nested thunks when accumulating Otherwise, it's too easy to create massive types that freeze compilation and blow the stack. * bump version --- Project.toml | 2 +- src/lib/lib.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 6a75ef54a..f1fc5a957 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.7.3" +version = "0.7.4" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 90e596d95..b209fb02e 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -43,9 +43,9 @@ accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x accum(x::Nothing, y::AbstractThunk) = y accum(x::AbstractThunk, y::Nothing) = x -accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y))) -accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y)) -accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y))) +accum(x, y::AbstractThunk) = accum(x, unthunk(y)) +accum(x::AbstractThunk, y) = accum(unthunk(x), y) +accum(x::AbstractThunk, y::AbstractThunk) = accum(unthunk(x), unthunk(y)) # Core functions @_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)