From e40883b9e208201c97e5cba4f39c6bf14e715eab Mon Sep 17 00:00:00 2001 From: milancurcic Date: Wed, 22 Jan 2025 15:07:51 -0500 Subject: [PATCH 01/10] First stab at dropout; conflict with base type TODO --- src/nf/nf_dropout_layer.f90 | 87 ++++++++++++++++++++++ src/nf/nf_dropout_layer_submodule.f90 | 65 ++++++++++++++++ src/nf/nf_layer_constructors_submodule.f90 | 1 + test/test_dropout_layer.f90 | 20 +++++ 4 files changed, 173 insertions(+) create mode 100644 src/nf/nf_dropout_layer.f90 create mode 100644 src/nf/nf_dropout_layer_submodule.f90 create mode 100644 test/test_dropout_layer.f90 diff --git a/src/nf/nf_dropout_layer.f90 b/src/nf/nf_dropout_layer.f90 new file mode 100644 index 0000000..cab1ac3 --- /dev/null +++ b/src/nf/nf_dropout_layer.f90 @@ -0,0 +1,87 @@ +module nf_dropout_layer + + !! This module provides the concrete dropout layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_activation, only: activation_function + use nf_base_layer, only: base_layer + + implicit none + + private + public :: dropout_layer + + type, extends(base_layer) :: dropout_layer + + !! Concrete implementation of a dropout layer type + + integer :: input_size + integer :: output_size + + real, allocatable :: output(:) + real, allocatable :: gradient(:) + real :: dropout_rate ! probability of dropping a neuron + real, allocatable :: mask(:) ! binary mask for dropout + + class(activation_function), allocatable :: activation + + contains + + procedure :: backward + procedure :: forward + procedure :: init + + end type dropout_layer + + interface dropout_layer + module function dropout_layer_cons(rate) & + result(res) + !! This function returns the `dropout_layer` instance. + real, intent(in) :: rate + !! Dropout rate + type(dropout_layer) :: res + !! dropout_layer instance + end function dropout_layer_cons + end interface dropout_layer + + interface + + pure module subroutine backward(self, input, gradient) + !! Apply the backward gradient descent pass. + !! Only weight and bias gradients are updated in this subroutine, + !! while the weights and biases themselves are untouched. + class(dropout_layer), intent(in out) :: self + !! Dropout layer instance + real, intent(in) :: input(:) + !! Input from the previous layer + real, intent(in) :: gradient(:) + !! Gradient from the next layer + end subroutine backward + + pure module subroutine forward(self, input) + !! Propagate forward the layer. + !! Calling this subroutine updates the values of a few data components + !! of `dropout_layer` that are needed for the backward pass. + class(dropout_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: input(:) + !! Input from the previous layer + end subroutine forward + + module subroutine init(self, input_shape, training) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(dropout_layer), intent(in out) :: self + !! Dropout layer instance + integer, intent(in) :: input_shape(:) + !! Shape of the input layer + logical, intent(in) :: training + !! Whether the layer is in training mode (.true. == dropping out neurons) + !! or in inference mode (.false. == doing nothing) + end subroutine init + + end interface + +end module nf_dropout_layer diff --git a/src/nf/nf_dropout_layer_submodule.f90 b/src/nf/nf_dropout_layer_submodule.f90 new file mode 100644 index 0000000..02610a6 --- /dev/null +++ b/src/nf/nf_dropout_layer_submodule.f90 @@ -0,0 +1,65 @@ +submodule (nf_dropout_layer) nf_dropout_layer_submodule + !! This submodule implements the procedures defined in the + !! nf_dropout_layer module. + +contains + + module function dropout_layer_cons(rate) result(res) + real, intent(in) :: rate + type(dropout_layer) :: res + + ! Initialize dropout rate + res % dropout_rate = rate + end function dropout_layer_cons + + module subroutine init(self, input_shape, training) + class(dropout_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + logical, intent(in) :: training + + ! Set input and output sizes (dropout preserves dimensions) + self % input_size = input_shape(1) + self % output_size = input_shape(1) + + ! Allocate arrays + if (allocated(self % output)) deallocate(self % output) + if (allocated(self % gradient)) deallocate(self % gradient) + if (allocated(self % mask)) deallocate(self % mask) + + allocate(self % output(self % output_size)) + allocate(self % gradient(self % input_size)) + allocate(self % mask(self % input_size)) + + ! Initialize arrays to zero + self % output = 0.0 + self % gradient = 0.0 + self % mask = 1.0 ! Default mask is all ones (no dropout) + end subroutine init + + pure module subroutine forward(self, input) + class(dropout_layer), intent(in out) :: self + real, intent(in) :: input(:) + real :: rand_vals(size(input)) + + ! Generate random mask for dropout + call random_number(rand_vals) + where (rand_vals < self % dropout_rate) + self % mask = 0 + elsewhere + self % mask = 1 / (1 - self % dropout_rate) ! Scale to preserve expected value + end where + + ! Apply dropout mask + self % output = input * self % mask + end subroutine forward + + pure module subroutine backward(self, input, gradient) + class(dropout_layer), intent(in out) :: self + real, intent(in) :: input(:) + real, intent(in) :: gradient(:) + + ! Backpropagate gradient through dropout mask + self % gradient = gradient * self % mask + end subroutine backward + +end submodule nf_dropout_layer_submodule \ No newline at end of file diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 234b20b..86cef8a 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -3,6 +3,7 @@ use nf_layer, only: layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_dropout_layer, only: dropout_layer use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer diff --git a/test/test_dropout_layer.f90 b/test/test_dropout_layer.f90 new file mode 100644 index 0000000..c0f37d8 --- /dev/null +++ b/test/test_dropout_layer.f90 @@ -0,0 +1,20 @@ +program test_dropout_layer + use iso_fortran_env, only: stderr => error_unit + use nf, only: dropout, layer + type(layer) :: layer1 + + layer1 = dropout(0.5) + + if (.not. layer1 % name == 'dropout') then + ok = .false. + write(stderr, '(a)') 'dropout layer has its name set correctly.. failed' + end if + + if (ok) then + print '(a)', 'test_dropout_layer: All tests passed.' + else + write(stderr, '(a)') 'test_dropout_layer: One or more tests failed.' + stop 1 + end if + +end program test_dropout_layer From 37aa7a5db719c7cad1190653085469ab1ae700d5 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 23 Jan 2025 12:53:14 -0500 Subject: [PATCH 02/10] Partial dropout integration --- src/nf.f90 | 2 +- src/nf/nf_dropout_layer.f90 | 13 +++------- src/nf/nf_dropout_layer_submodule.f90 | 30 ++++++++++------------ src/nf/nf_layer_constructors.f90 | 21 ++++++++++++++- src/nf/nf_layer_constructors_submodule.f90 | 10 ++++++++ test/test_dropout_layer.f90 | 1 + 6 files changed, 49 insertions(+), 28 deletions(-) diff --git a/src/nf.f90 b/src/nf.f90 index b97d9e6..d477f1b 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape + conv2d, dense, dropout, flatten, input, maxpool2d, reshape use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network diff --git a/src/nf/nf_dropout_layer.f90 b/src/nf/nf_dropout_layer.f90 index cab1ac3..9489ad6 100644 --- a/src/nf/nf_dropout_layer.f90 +++ b/src/nf/nf_dropout_layer.f90 @@ -4,7 +4,6 @@ module nf_dropout_layer !! It is used internally by the layer type. !! It is not intended to be used directly by the user. - use nf_activation, only: activation_function use nf_base_layer, only: base_layer implicit none @@ -17,14 +16,13 @@ module nf_dropout_layer !! Concrete implementation of a dropout layer type integer :: input_size - integer :: output_size real, allocatable :: output(:) real, allocatable :: gradient(:) - real :: dropout_rate ! probability of dropping a neuron real, allocatable :: mask(:) ! binary mask for dropout - class(activation_function), allocatable :: activation + real :: dropout_rate ! probability of dropping a neuron + logical :: training = .true. contains @@ -59,7 +57,7 @@ pure module subroutine backward(self, input, gradient) !! Gradient from the next layer end subroutine backward - pure module subroutine forward(self, input) + module subroutine forward(self, input) !! Propagate forward the layer. !! Calling this subroutine updates the values of a few data components !! of `dropout_layer` that are needed for the backward pass. @@ -69,7 +67,7 @@ pure module subroutine forward(self, input) !! Input from the previous layer end subroutine forward - module subroutine init(self, input_shape, training) + module subroutine init(self, input_shape) !! Initialize the layer data structures. !! !! This is a deferred procedure from the `base_layer` abstract type. @@ -77,9 +75,6 @@ module subroutine init(self, input_shape, training) !! Dropout layer instance integer, intent(in) :: input_shape(:) !! Shape of the input layer - logical, intent(in) :: training - !! Whether the layer is in training mode (.true. == dropping out neurons) - !! or in inference mode (.false. == doing nothing) end subroutine init end interface diff --git a/src/nf/nf_dropout_layer_submodule.f90 b/src/nf/nf_dropout_layer_submodule.f90 index 02610a6..e3a3cf2 100644 --- a/src/nf/nf_dropout_layer_submodule.f90 +++ b/src/nf/nf_dropout_layer_submodule.f90 @@ -12,38 +12,33 @@ module function dropout_layer_cons(rate) result(res) res % dropout_rate = rate end function dropout_layer_cons - module subroutine init(self, input_shape, training) + + module subroutine init(self, input_shape) class(dropout_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) - logical, intent(in) :: training - ! Set input and output sizes (dropout preserves dimensions) self % input_size = input_shape(1) - self % output_size = input_shape(1) ! Allocate arrays - if (allocated(self % output)) deallocate(self % output) - if (allocated(self % gradient)) deallocate(self % gradient) - if (allocated(self % mask)) deallocate(self % mask) - - allocate(self % output(self % output_size)) + allocate(self % output(self % input_size)) allocate(self % gradient(self % input_size)) allocate(self % mask(self % input_size)) - ! Initialize arrays to zero - self % output = 0.0 - self % gradient = 0.0 - self % mask = 1.0 ! Default mask is all ones (no dropout) + ! Initialize arrays + self % output = 0 + self % gradient = 0 + self % mask = 1 ! Default mask is all ones (no dropout) + end subroutine init - pure module subroutine forward(self, input) + + module subroutine forward(self, input) class(dropout_layer), intent(in out) :: self real, intent(in) :: input(:) - real :: rand_vals(size(input)) ! Generate random mask for dropout - call random_number(rand_vals) - where (rand_vals < self % dropout_rate) + call random_number(self % mask) + where (self % mask < self % dropout_rate) self % mask = 0 elsewhere self % mask = 1 / (1 - self % dropout_rate) ! Scale to preserve expected value @@ -53,6 +48,7 @@ pure module subroutine forward(self, input) self % output = input * self % mask end subroutine forward + pure module subroutine backward(self, input, gradient) class(dropout_layer), intent(in out) :: self real, intent(in) :: input(:) diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 309be6e..24fc7e6 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape + public :: conv2d, dense, flatten, input, maxpool2d, reshape, dropout interface input @@ -85,6 +85,24 @@ module function dense(layer_size, activation) result(res) !! Resulting layer instance end function dense + module function dropout(rate) result(res) + !! Create a dropout layer with a given dropout rate. + !! + !! This layer is for randomly disabling neurons during training. + !! + !! Example: + !! + !! ``` + !! use nf, only :: dropout, layer + !! type(layer) :: dropout_layer + !! dropout_layer = dropout(rate=0.5) + !! ``` + real, intent(in) :: rate + !! Dropout rate - fraction of neurons to randomly disable during training + type(layer) :: res + !! Resulting layer instance + end function dropout + module function flatten() result(res) !! Flatten (3-d -> 1-d) layer constructor. !! @@ -166,6 +184,7 @@ module function reshape(output_shape) result(res) !! Resulting layer instance end function reshape + end interface end module nf_layer_constructors diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 86cef8a..09c79e9 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -64,6 +64,14 @@ module function dense(layer_size, activation) result(res) end function dense + module function dropout(rate) result(res) + real, intent(in) :: rate + type(layer) :: res + res % name = 'dropout' + allocate(res % p, source=dropout_layer(rate)) + end function dropout + + module function flatten() result(res) type(layer) :: res res % name = 'flatten' @@ -92,6 +100,7 @@ module function input3d(layer_shape) result(res) res % initialized = .true. end function input3d + module function maxpool2d(pool_size, stride) result(res) integer, intent(in) :: pool_size integer, intent(in), optional :: stride @@ -120,6 +129,7 @@ module function maxpool2d(pool_size, stride) result(res) end function maxpool2d + module function reshape(output_shape) result(res) integer, intent(in) :: output_shape(:) type(layer) :: res diff --git a/test/test_dropout_layer.f90 b/test/test_dropout_layer.f90 index c0f37d8..3424730 100644 --- a/test/test_dropout_layer.f90 +++ b/test/test_dropout_layer.f90 @@ -2,6 +2,7 @@ program test_dropout_layer use iso_fortran_env, only: stderr => error_unit use nf, only: dropout, layer type(layer) :: layer1 + logical :: ok = .true. layer1 = dropout(0.5) From 820b081cb8af5e2cfa078283863025bbeaee8574 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 23 Jan 2025 13:10:28 -0500 Subject: [PATCH 03/10] Test uninitialized dropout layer --- src/nf/nf_dropout_layer.f90 | 3 +-- src/nf/nf_layer_submodule.f90 | 7 +++++-- test/test_dropout_layer.f90 | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/nf/nf_dropout_layer.f90 b/src/nf/nf_dropout_layer.f90 index 9489ad6..6761327 100644 --- a/src/nf/nf_dropout_layer.f90 +++ b/src/nf/nf_dropout_layer.f90 @@ -12,10 +12,9 @@ module nf_dropout_layer public :: dropout_layer type, extends(base_layer) :: dropout_layer - !! Concrete implementation of a dropout layer type - integer :: input_size + integer :: input_size = 0 real, allocatable :: output(:) real, allocatable :: gradient(:) diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index c672581..8064797 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -3,6 +3,7 @@ use iso_fortran_env, only: stderr => error_unit use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_dropout_layer, only: dropout_layer use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer @@ -240,15 +241,17 @@ impure elemental module subroutine init(self, input) call this_layer % init(input % layer_shape) end select - ! The shape of conv2d, maxpool2d, or flatten layers is not known + ! The shape of conv2d, dropout, flatten, or maxpool2d layers is not known ! until we receive an input layer. select type(this_layer => self % p) type is(conv2d_layer) self % layer_shape = shape(this_layer % output) - type is(maxpool2d_layer) + type is(dropout_layer) self % layer_shape = shape(this_layer % output) type is(flatten_layer) self % layer_shape = shape(this_layer % output) + type is(maxpool2d_layer) + self % layer_shape = shape(this_layer % output) end select self % input_layer_shape = input % layer_shape diff --git a/test/test_dropout_layer.f90 b/test/test_dropout_layer.f90 index 3424730..b46bd30 100644 --- a/test/test_dropout_layer.f90 +++ b/test/test_dropout_layer.f90 @@ -1,6 +1,7 @@ program test_dropout_layer use iso_fortran_env, only: stderr => error_unit use nf, only: dropout, layer + use nf_dropout_layer, only: dropout_layer type(layer) :: layer1 logical :: ok = .true. @@ -11,6 +12,23 @@ program test_dropout_layer write(stderr, '(a)') 'dropout layer has its name set correctly.. failed' end if + ! Dropout on its own is not initialized and its arrays not allocated. + select type(layer1_p => layer1 % p) + type is(dropout_layer) + + if (layer1_p % input_size /= 0) then + print *, 'input_size: ', layer1_p % input_size + ok = .false. + write(stderr, '(a)') 'dropout layer size should be zero.. failed' + end if + + if (allocated(layer1_p % output)) then + ok = .false. + write(stderr, '(a)') 'dropout layer output array should not be allocated.. failed' + end if + + end select + if (ok) then print '(a)', 'test_dropout_layer: All tests passed.' else From 75ef184c73c7d659df72b8cfd063d608f673bb19 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 23 Jan 2025 13:16:05 -0500 Subject: [PATCH 04/10] Test dropout state that follows an input layer --- test/test_dropout_layer.f90 | 38 ++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/test/test_dropout_layer.f90 b/test/test_dropout_layer.f90 index b46bd30..5d092cb 100644 --- a/test/test_dropout_layer.f90 +++ b/test/test_dropout_layer.f90 @@ -1,8 +1,11 @@ program test_dropout_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: dropout, layer + use nf, only: dropout, input, layer, network use nf_dropout_layer, only: dropout_layer type(layer) :: layer1 + type(network) :: net + integer :: input_size + logical :: ok = .true. layer1 = dropout(0.5) @@ -29,6 +32,39 @@ program test_dropout_layer end select + ! Now we're gonna initialize a minimal network with an input layer and a + ! dropout that follows and we'll check that the dropout layer has expected + ! state. + input_size = 10 + net = network([ & + input(input_size), & + dropout(0.5) & + ]) + + select type(layer1_p => net % layers(1) % p) + type is(dropout_layer) + if (layer1_p % input_size /= input_size) then + ok = .false. + write(stderr, '(a)') 'dropout layer input size should be the same as the input layer.. failed' + end if + + if (.not. allocated(layer1_p % output)) then + ok = .false. + write(stderr, '(a)') 'dropout layer output array should be allocated.. failed' + end if + + if (.not. allocated(layer1_p % gradient)) then + ok = .false. + write(stderr, '(a)') 'dropout layer gradient array should be allocated.. failed' + end if + + if (.not. allocated(layer1_p % mask)) then + ok = .false. + write(stderr, '(a)') 'dropout layer mask array should be allocated.. failed' + end if + + end select + if (ok) then print '(a)', 'test_dropout_layer: All tests passed.' else From 796ae74bd308b74c4231772c9f0ab505862c006f Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 23 Jan 2025 14:29:21 -0500 Subject: [PATCH 05/10] Enable forward pass for dropout; backward pass TODO --- src/nf/nf_dropout_layer_submodule.f90 | 9 ++++++++- src/nf/nf_layer.f90 | 2 +- src/nf/nf_layer_submodule.f90 | 14 +++++++++++++- src/nf/nf_network_submodule.f90 | 3 +++ test/test_dropout_layer.f90 | 26 ++++++++++++++++++++++++++ 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/nf/nf_dropout_layer_submodule.f90 b/src/nf/nf_dropout_layer_submodule.f90 index e3a3cf2..5a022a2 100644 --- a/src/nf/nf_dropout_layer_submodule.f90 +++ b/src/nf/nf_dropout_layer_submodule.f90 @@ -35,17 +35,24 @@ end subroutine init module subroutine forward(self, input) class(dropout_layer), intent(in out) :: self real, intent(in) :: input(:) + real :: scale ! Generate random mask for dropout call random_number(self % mask) where (self % mask < self % dropout_rate) self % mask = 0 elsewhere - self % mask = 1 / (1 - self % dropout_rate) ! Scale to preserve expected value + self % mask = 1 end where ! Apply dropout mask self % output = input * self % mask + + ! Scale output and mask to preserve the input sum + scale = sum(input) / sum(self % output) + self % output = self % output * scale + self % mask = self % mask * scale + end subroutine forward diff --git a/src/nf/nf_layer.f90 b/src/nf/nf_layer.f90 index ca5e960..18e8f76 100644 --- a/src/nf/nf_layer.f90 +++ b/src/nf/nf_layer.f90 @@ -76,7 +76,7 @@ end subroutine backward_3d interface - pure module subroutine forward(self, input) + module subroutine forward(self, input) !! Apply a forward pass on the layer. !! This changes the internal state of the layer. !! This is normally called internally by the `network % forward` diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 8064797..d44ef17 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -107,7 +107,7 @@ pure module subroutine backward_3d(self, previous, gradient) end subroutine backward_3d - pure module subroutine forward(self, input) + module subroutine forward(self, input) implicit none class(layer), intent(in out) :: self class(layer), intent(in) :: input @@ -126,6 +126,18 @@ pure module subroutine forward(self, input) call this_layer % forward(prev_layer % output) end select + type is(dropout_layer) + + ! Upstream layers permitted: input1d, dense, flatten + select type(prev_layer => input % p) + type is(input1d_layer) + call this_layer % forward(prev_layer % output) + type is(dense_layer) + call this_layer % forward(prev_layer % output) + type is(flatten_layer) + call this_layer % forward(prev_layer % output) + end select + type is(conv2d_layer) ! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 140c922..6aaaec3 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -2,6 +2,7 @@ use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_dropout_layer, only: dropout_layer use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer @@ -227,6 +228,8 @@ module function predict_1d(self, input) result(res) select type(output_layer => self % layers(num_layers) % p) type is(dense_layer) res = output_layer % output + type is(dropout_layer) + res = output_layer % output type is(flatten_layer) res = output_layer % output class default diff --git a/test/test_dropout_layer.f90 b/test/test_dropout_layer.f90 index 5d092cb..b9b4b2a 100644 --- a/test/test_dropout_layer.f90 +++ b/test/test_dropout_layer.f90 @@ -65,6 +65,32 @@ program test_dropout_layer end select + ! Now we're gonna run the forward pass and check that the dropout indeed + ! drops according to the requested dropout rate. + forward_pass: block + real :: input_data(5) + real :: output_data(size(input_data)) + integer :: n + + net = network([ & + input(size(input_data)), & + dropout(0.5) & + ]) + + call random_number(input_data) + do n = 1, 10000 + output_data = net % predict(input_data) + ! Check that sum of output matches sum of input within small tolerance + if (abs(sum(output_data) - sum(input_data)) > 1e-5) then + ok = .false. + exit + end if + end do + if (.not. ok) then + write(stderr, '(a)') 'dropout layer output sum should match input sum within 1% tolerance.. failed' + end if + end block forward_pass + if (ok) then print '(a)', 'test_dropout_layer: All tests passed.' else From b04d44725a329158b24ebe4363302583308dc77b Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 23 Jan 2025 14:31:20 -0500 Subject: [PATCH 06/10] Version bump and add dropout to the features table --- README.md | 1 + fpm.toml | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7e3a444..75a6649 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). |------------|------------------|------------------------|----------------------|--------------|---------------| | Input | `input` | n/a | 1, 3 | n/a | n/a | | Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 | ✅ | ✅ | +| Dropout | `dropout` | Any | 1 | ✅ | ✅ | | Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅(*) | | Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | | Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ | diff --git a/fpm.toml b/fpm.toml index 5f68f8f..368812c 100644 --- a/fpm.toml +++ b/fpm.toml @@ -1,6 +1,6 @@ name = "neural-fortran" -version = "0.18.0" +version = "0.19.0" license = "MIT" author = "Milan Curcic" maintainer = "milancurcic@hey.com" -copyright = "Copyright 2018-2024, neural-fortran contributors" +copyright = "Copyright 2018-2025, neural-fortran contributors" From 544b23a2911cdaccae87f15fed75a6b9cf2037d8 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 23 Jan 2025 17:11:18 -0500 Subject: [PATCH 07/10] Add dropout to CMake --- CMakeLists.txt | 2 ++ test/CMakeLists.txt | 1 + test/test_dropout_layer.f90 | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 490f7ff..50a0f20 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ add_library(neural-fortran src/nf/nf_reshape_layer_submodule.f90 src/nf/io/nf_io_binary.f90 src/nf/io/nf_io_binary_submodule.f90 + src/nf/nf_dropout_layer.f90 + src/nf/nf_dropout_layer_submodule.f90 ) target_link_libraries(neural-fortran PRIVATE) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bfd3538..108dee6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,7 @@ foreach(execid input1d_layer input3d_layer + dropout_layer parametric_activation dense_layer conv2d_layer diff --git a/test/test_dropout_layer.f90 b/test/test_dropout_layer.f90 index b9b4b2a..9ed7b86 100644 --- a/test/test_dropout_layer.f90 +++ b/test/test_dropout_layer.f90 @@ -87,7 +87,7 @@ program test_dropout_layer end if end do if (.not. ok) then - write(stderr, '(a)') 'dropout layer output sum should match input sum within 1% tolerance.. failed' + write(stderr, '(a)') 'dropout layer output sum should match input sum within tolerance.. failed' end if end block forward_pass From 56dbd52377b96622c0caf53fe2a9e79d14c7ef84 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Fri, 24 Jan 2025 10:49:00 -0500 Subject: [PATCH 08/10] Enable preprocessing in fpm.toml (needed with recent versions of fpm) --- fpm.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fpm.toml b/fpm.toml index 368812c..3df459f 100644 --- a/fpm.toml +++ b/fpm.toml @@ -4,3 +4,6 @@ license = "MIT" author = "Milan Curcic" maintainer = "milancurcic@hey.com" copyright = "Copyright 2018-2025, neural-fortran contributors" + +[preprocess] +[preprocess.cpp] From 3b5cc27f04867e24f64aa3df9aa0bbf494b1e85e Mon Sep 17 00:00:00 2001 From: milancurcic Date: Fri, 24 Jan 2025 10:57:27 -0500 Subject: [PATCH 09/10] Small change in scale implementation --- src/nf/nf_dropout_layer.f90 | 1 + src/nf/nf_dropout_layer_submodule.f90 | 13 +++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/nf/nf_dropout_layer.f90 b/src/nf/nf_dropout_layer.f90 index 6761327..0f557d6 100644 --- a/src/nf/nf_dropout_layer.f90 +++ b/src/nf/nf_dropout_layer.f90 @@ -21,6 +21,7 @@ module nf_dropout_layer real, allocatable :: mask(:) ! binary mask for dropout real :: dropout_rate ! probability of dropping a neuron + real :: scale ! scale factor to preserve the input sum logical :: training = .true. contains diff --git a/src/nf/nf_dropout_layer_submodule.f90 b/src/nf/nf_dropout_layer_submodule.f90 index 5a022a2..568cbf2 100644 --- a/src/nf/nf_dropout_layer_submodule.f90 +++ b/src/nf/nf_dropout_layer_submodule.f90 @@ -35,7 +35,6 @@ end subroutine init module subroutine forward(self, input) class(dropout_layer), intent(in out) :: self real, intent(in) :: input(:) - real :: scale ! Generate random mask for dropout call random_number(self % mask) @@ -45,13 +44,11 @@ module subroutine forward(self, input) self % mask = 1 end where - ! Apply dropout mask - self % output = input * self % mask + ! Scale factor to preserve the input sum + self % scale = sum(input) / sum(self % output) ! scale == 1/P(keep) - ! Scale output and mask to preserve the input sum - scale = sum(input) / sum(self % output) - self % output = self % output * scale - self % mask = self % mask * scale + ! Apply dropout mask + self % output = input * self % mask * self % scale end subroutine forward @@ -62,7 +59,7 @@ pure module subroutine backward(self, input, gradient) real, intent(in) :: gradient(:) ! Backpropagate gradient through dropout mask - self % gradient = gradient * self % mask + self % gradient = gradient * self % mask * self % scale end subroutine backward end submodule nf_dropout_layer_submodule \ No newline at end of file From 703f8023a175a584005105c3be9a6051a695edad Mon Sep 17 00:00:00 2001 From: milancurcic Date: Fri, 24 Jan 2025 11:14:42 -0500 Subject: [PATCH 10/10] Integration of backward pass for dropout --- src/nf/nf_layer_submodule.f90 | 19 +++++++++++++++++-- src/nf/nf_network_submodule.f90 | 2 ++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index d44ef17..69b40d2 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -25,12 +25,14 @@ pure module subroutine backward_1d(self, previous, gradient) type is(dense_layer) - ! Upstream layers permitted: input1d, dense, flatten + ! Upstream layers permitted: input1d, dense, dropout, flatten select type(prev_layer => previous % p) type is(input1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(dense_layer) call this_layer % backward(prev_layer % output, gradient) + type is(dropout_layer) + call this_layer % backward(prev_layer % output, gradient) type is(flatten_layer) call this_layer % backward(prev_layer % output, gradient) end select @@ -116,12 +118,14 @@ module subroutine forward(self, input) type is(dense_layer) - ! Upstream layers permitted: input1d, dense, flatten + ! Upstream layers permitted: input1d, dense, dropout, flatten select type(prev_layer => input % p) type is(input1d_layer) call this_layer % forward(prev_layer % output) type is(dense_layer) call this_layer % forward(prev_layer % output) + type is(dropout_layer) + call this_layer % forward(prev_layer % output) type is(flatten_layer) call this_layer % forward(prev_layer % output) end select @@ -299,6 +303,8 @@ elemental module function get_num_params(self) result(num_params) num_params = 0 type is (dense_layer) num_params = this_layer % get_num_params() + type is (dropout_layer) + num_params = size(this_layer % mask) type is (conv2d_layer) num_params = this_layer % get_num_params() type is (maxpool2d_layer) @@ -324,6 +330,8 @@ module function get_params(self) result(params) ! No parameters to get. type is (dense_layer) params = this_layer % get_params() + type is (dropout_layer) + ! No parameters to get. type is (conv2d_layer) params = this_layer % get_params() type is (maxpool2d_layer) @@ -349,6 +357,8 @@ module function get_gradients(self) result(gradients) ! No gradients to get. type is (dense_layer) gradients = this_layer % get_gradients() + type is (dropout_layer) + ! No gradients to get. type is (conv2d_layer) gradients = this_layer % get_gradients() type is (maxpool2d_layer) @@ -396,6 +406,11 @@ module subroutine set_params(self, params) type is (dense_layer) call this_layer % set_params(params) + type is (dropout_layer) + ! No parameters to set. + write(stderr, '(a)') 'Warning: calling set_params() ' & + // 'on a zero-parameter layer; nothing to do.' + type is (conv2d_layer) call this_layer % set_params(params) diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 6aaaec3..0b076b9 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -135,6 +135,8 @@ module subroutine backward(self, output, loss) select type(next_layer => self % layers(n + 1) % p) type is(dense_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(dropout_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(conv2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(flatten_layer)