From fecc37cd1075d4f1ec580e58616ca7fb88ccf8a1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 11 Feb 2023 13:56:56 -0500 Subject: [PATCH] give meaning to batchsize=0 --- src/eachobs.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/eachobs.jl b/src/eachobs.jl index 99c96c4..4932df7 100644 --- a/src/eachobs.jl +++ b/src/eachobs.jl @@ -56,7 +56,8 @@ The original data is preserved in the `data` field of the DataLoader. - `data`: The data to be iterated over. The data type has to be supported by [`numobs`](@ref) and [`getobs`](@ref). - `batchsize`: If less than 0, iterates over individual observations. - Otherwise, each iteration (except possibly the last) yields a mini-batch + If 0, then one mini-batch containing all `numobs(x)` observations. + If larger than 0, each iteration (except possibly the last) yields a mini-batch containing `batchsize` observations. Default `1`. - `buffer`: If `buffer=true` and supported by the type of `data`, a buffer will be allocated and reused for memory efficiency. @@ -149,6 +150,7 @@ function DataLoader( if !(collate ∈ (Val(nothing), Val(true), Val(false))) throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`.")) end + batchsize = batchsize == 0 ? numobs(data) : batchsize return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng) end