diff --git a/CMakeLists.txt b/CMakeLists.txt index 490f7ff1..50a0f208 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ add_library(neural-fortran src/nf/nf_reshape_layer_submodule.f90 src/nf/io/nf_io_binary.f90 src/nf/io/nf_io_binary_submodule.f90 + src/nf/nf_dropout_layer.f90 + src/nf/nf_dropout_layer_submodule.f90 ) target_link_libraries(neural-fortran PRIVATE) diff --git a/README.md b/README.md index 7e3a4445..75a66491 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). |------------|------------------|------------------------|----------------------|--------------|---------------| | Input | `input` | n/a | 1, 3 | n/a | n/a | | Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 | ✅ | ✅ | +| Dropout | `dropout` | Any | 1 | ✅ | ✅ | | Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅(*) | | Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | | Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ | diff --git a/fpm.toml b/fpm.toml index 5f68f8f6..3df459fb 100644 --- a/fpm.toml +++ b/fpm.toml @@ -1,6 +1,9 @@ name = "neural-fortran" -version = "0.18.0" +version = "0.19.0" license = "MIT" author = "Milan Curcic" maintainer = "milancurcic@hey.com" -copyright = "Copyright 2018-2024, neural-fortran contributors" +copyright = "Copyright 2018-2025, neural-fortran contributors" + +[preprocess] +[preprocess.cpp] diff --git a/src/nf.f90 b/src/nf.f90 index b97d9e62..d477f1b5 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape + conv2d, dense, dropout, flatten, input, maxpool2d, reshape use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network diff --git a/src/nf/nf_dropout_layer.f90 b/src/nf/nf_dropout_layer.f90 new file mode 100644 index 00000000..bffca5f0 --- /dev/null +++ b/src/nf/nf_dropout_layer.f90 @@ -0,0 +1,84 @@ +module nf_dropout_layer + + !! This module provides the concrete dropout layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_base_layer, only: base_layer + + implicit none + + private + public :: dropout_layer + + type, extends(base_layer) :: dropout_layer + !! Concrete implementation of a dropout layer type + + integer :: input_size = 0 + + real, allocatable :: output(:) + real, allocatable :: gradient(:) + real, allocatable :: mask(:) ! binary mask for dropout + + real :: dropout_rate ! probability of dropping a neuron + real :: scale ! scale factor to preserve the input sum + logical :: training = .false. ! set to .true. in training mode + + contains + + procedure :: backward + procedure :: forward + procedure :: init + + end type dropout_layer + + interface dropout_layer + module function dropout_layer_cons(rate, training) & + result(res) + !! This function returns the `dropout_layer` instance. + real, intent(in) :: rate + !! Dropout rate + logical, intent(in), optional :: training + !! Training mode (default .false.) + type(dropout_layer) :: res + !! dropout_layer instance + end function dropout_layer_cons + end interface dropout_layer + + interface + + pure module subroutine backward(self, input, gradient) + !! Apply the backward gradient descent pass. + !! Only weight and bias gradients are updated in this subroutine, + !! while the weights and biases themselves are untouched. + class(dropout_layer), intent(in out) :: self + !! Dropout layer instance + real, intent(in) :: input(:) + !! Input from the previous layer + real, intent(in) :: gradient(:) + !! Gradient from the next layer + end subroutine backward + + module subroutine forward(self, input) + !! Propagate forward the layer. + !! Calling this subroutine updates the values of a few data components + !! of `dropout_layer` that are needed for the backward pass. + class(dropout_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: input(:) + !! Input from the previous layer + end subroutine forward + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(dropout_layer), intent(in out) :: self + !! Dropout layer instance + integer, intent(in) :: input_shape(:) + !! Shape of the input layer + end subroutine init + + end interface + +end module nf_dropout_layer diff --git a/src/nf/nf_dropout_layer_submodule.f90 b/src/nf/nf_dropout_layer_submodule.f90 new file mode 100644 index 00000000..6e7e35a0 --- /dev/null +++ b/src/nf/nf_dropout_layer_submodule.f90 @@ -0,0 +1,78 @@ +submodule (nf_dropout_layer) nf_dropout_layer_submodule + !! This submodule implements the procedures defined in the + !! nf_dropout_layer module. + +contains + + module function dropout_layer_cons(rate, training) result(res) + real, intent(in) :: rate + logical, intent(in), optional :: training + type(dropout_layer) :: res + res % dropout_rate = rate + if (present(training)) res % training = training + end function dropout_layer_cons + + + module subroutine init(self, input_shape) + class(dropout_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % input_size = input_shape(1) + + ! Allocate arrays + allocate(self % output(self % input_size)) + allocate(self % gradient(self % input_size)) + allocate(self % mask(self % input_size)) + + ! Initialize arrays + self % output = 0 + self % gradient = 0 + self % mask = 1 ! Default mask is all ones (no dropout) + + end subroutine init + + + module subroutine forward(self, input) + class(dropout_layer), intent(in out) :: self + real, intent(in) :: input(:) + + ! Generate random mask for dropout, training mode only + if (self % training) then + + call random_number(self % mask) + where (self % mask < self % dropout_rate) + self % mask = 0 + elsewhere + self % mask = 1 + end where + + ! Scale factor to preserve the input sum + self % scale = sum(input) / sum(input * self % mask) + + ! Apply dropout mask + self % output = input * self % mask * self % scale + + else + ! In inference mode, we don't apply dropout; simply pass through the input + self % output = input + + end if + + end subroutine forward + + + pure module subroutine backward(self, input, gradient) + class(dropout_layer), intent(in out) :: self + real, intent(in) :: input(:) + real, intent(in) :: gradient(:) + + if (self % training) then + ! Backpropagate gradient through dropout mask + self % gradient = gradient * self % mask * self % scale + else + ! In inference mode, pass through the gradient unchanged + self % gradient = gradient + end if + end subroutine backward + +end submodule nf_dropout_layer_submodule \ No newline at end of file diff --git a/src/nf/nf_layer.f90 b/src/nf/nf_layer.f90 index ca5e9606..18e8f76a 100644 --- a/src/nf/nf_layer.f90 +++ b/src/nf/nf_layer.f90 @@ -76,7 +76,7 @@ end subroutine backward_3d interface - pure module subroutine forward(self, input) + module subroutine forward(self, input) !! Apply a forward pass on the layer. !! This changes the internal state of the layer. !! This is normally called internally by the `network % forward` diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 309be6e4..fcc49342 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape + public :: conv2d, dense, flatten, input, maxpool2d, reshape, dropout interface input @@ -85,6 +85,26 @@ module function dense(layer_size, activation) result(res) !! Resulting layer instance end function dense + module function dropout(rate, training) result(res) + !! Create a dropout layer with a given dropout rate. + !! + !! This layer is for randomly disabling neurons during training. + !! + !! Example: + !! + !! ``` + !! use nf, only :: dropout, layer + !! type(layer) :: dropout_layer + !! dropout_layer = dropout(rate=0.5) + !! ``` + real, intent(in) :: rate + !! Dropout rate - fraction of neurons to randomly disable during training + logical, intent(in), optional :: training + !! Training mode (default .false.) + type(layer) :: res + !! Resulting layer instance + end function dropout + module function flatten() result(res) !! Flatten (3-d -> 1-d) layer constructor. !! @@ -166,6 +186,7 @@ module function reshape(output_shape) result(res) !! Resulting layer instance end function reshape + end interface end module nf_layer_constructors diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 234b20b1..5203497d 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -3,6 +3,7 @@ use nf_layer, only: layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_dropout_layer, only: dropout_layer use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer @@ -63,6 +64,15 @@ module function dense(layer_size, activation) result(res) end function dense + module function dropout(rate, training) result(res) + real, intent(in) :: rate + logical, intent(in), optional :: training + type(layer) :: res + res % name = 'dropout' + allocate(res % p, source=dropout_layer(rate, training)) + end function dropout + + module function flatten() result(res) type(layer) :: res res % name = 'flatten' @@ -91,6 +101,7 @@ module function input3d(layer_shape) result(res) res % initialized = .true. end function input3d + module function maxpool2d(pool_size, stride) result(res) integer, intent(in) :: pool_size integer, intent(in), optional :: stride @@ -119,6 +130,7 @@ module function maxpool2d(pool_size, stride) result(res) end function maxpool2d + module function reshape(output_shape) result(res) integer, intent(in) :: output_shape(:) type(layer) :: res diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index c672581a..8bf94ea5 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -3,6 +3,7 @@ use iso_fortran_env, only: stderr => error_unit use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_dropout_layer, only: dropout_layer use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer @@ -24,12 +25,14 @@ pure module subroutine backward_1d(self, previous, gradient) type is(dense_layer) - ! Upstream layers permitted: input1d, dense, flatten + ! Upstream layers permitted: input1d, dense, dropout, flatten select type(prev_layer => previous % p) type is(input1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(dense_layer) call this_layer % backward(prev_layer % output, gradient) + type is(dropout_layer) + call this_layer % backward(prev_layer % output, gradient) type is(flatten_layer) call this_layer % backward(prev_layer % output, gradient) end select @@ -106,7 +109,7 @@ pure module subroutine backward_3d(self, previous, gradient) end subroutine backward_3d - pure module subroutine forward(self, input) + module subroutine forward(self, input) implicit none class(layer), intent(in out) :: self class(layer), intent(in) :: input @@ -115,6 +118,20 @@ pure module subroutine forward(self, input) type is(dense_layer) + ! Upstream layers permitted: input1d, dense, dropout, flatten + select type(prev_layer => input % p) + type is(input1d_layer) + call this_layer % forward(prev_layer % output) + type is(dense_layer) + call this_layer % forward(prev_layer % output) + type is(dropout_layer) + call this_layer % forward(prev_layer % output) + type is(flatten_layer) + call this_layer % forward(prev_layer % output) + end select + + type is(dropout_layer) + ! Upstream layers permitted: input1d, dense, flatten select type(prev_layer => input % p) type is(input1d_layer) @@ -240,15 +257,17 @@ impure elemental module subroutine init(self, input) call this_layer % init(input % layer_shape) end select - ! The shape of conv2d, maxpool2d, or flatten layers is not known + ! The shape of conv2d, dropout, flatten, or maxpool2d layers is not known ! until we receive an input layer. select type(this_layer => self % p) type is(conv2d_layer) self % layer_shape = shape(this_layer % output) - type is(maxpool2d_layer) + type is(dropout_layer) self % layer_shape = shape(this_layer % output) type is(flatten_layer) self % layer_shape = shape(this_layer % output) + type is(maxpool2d_layer) + self % layer_shape = shape(this_layer % output) end select self % input_layer_shape = input % layer_shape @@ -284,6 +303,8 @@ elemental module function get_num_params(self) result(num_params) num_params = 0 type is (dense_layer) num_params = this_layer % get_num_params() + type is (dropout_layer) + num_params = 0 type is (conv2d_layer) num_params = this_layer % get_num_params() type is (maxpool2d_layer) @@ -309,6 +330,8 @@ module function get_params(self) result(params) ! No parameters to get. type is (dense_layer) params = this_layer % get_params() + type is (dropout_layer) + ! No parameters to get. type is (conv2d_layer) params = this_layer % get_params() type is (maxpool2d_layer) @@ -334,6 +357,8 @@ module function get_gradients(self) result(gradients) ! No gradients to get. type is (dense_layer) gradients = this_layer % get_gradients() + type is (dropout_layer) + ! No gradients to get. type is (conv2d_layer) gradients = this_layer % get_gradients() type is (maxpool2d_layer) @@ -381,6 +406,11 @@ module subroutine set_params(self, params) type is (dense_layer) call this_layer % set_params(params) + type is (dropout_layer) + ! No parameters to set. + write(stderr, '(a)') 'Warning: calling set_params() ' & + // 'on a zero-parameter layer; nothing to do.' + type is (conv2d_layer) call this_layer % set_params(params) diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 140c9226..f28a98e9 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -2,6 +2,7 @@ use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_dropout_layer, only: dropout_layer use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer @@ -134,6 +135,8 @@ module subroutine backward(self, output, loss) select type(next_layer => self % layers(n + 1) % p) type is(dense_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(dropout_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(conv2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(flatten_layer) @@ -218,15 +221,26 @@ module function predict_1d(self, input) result(res) class(network), intent(in out) :: self real, intent(in) :: input(:) real, allocatable :: res(:) - integer :: num_layers + integer :: n, num_layers num_layers = size(self % layers) + ! predict is run in inference mode only; + ! set all dropout layers' training mode to false. + do n = 2, num_layers + select type(this_layer => self % layers(n) % p) + type is(dropout_layer) + this_layer % training = .false. + end select + end do + call self % forward(input) select type(output_layer => self % layers(num_layers) % p) type is(dense_layer) res = output_layer % output + type is(dropout_layer) + res = output_layer % output type is(flatten_layer) res = output_layer % output class default @@ -240,10 +254,19 @@ module function predict_3d(self, input) result(res) class(network), intent(in out) :: self real, intent(in) :: input(:,:,:) real, allocatable :: res(:) - integer :: num_layers + integer :: n, num_layers num_layers = size(self % layers) + ! predict is run in inference mode only; + ! set all dropout layers' training mode to false. + do n = 2, num_layers + select type(this_layer => self % layers(n) % p) + type is(dropout_layer) + this_layer % training = .false. + end select + end do + call self % forward(input) select type(output_layer => self % layers(num_layers) % p) @@ -265,12 +288,21 @@ module function predict_batch_1d(self, input) result(res) class(network), intent(in out) :: self real, intent(in) :: input(:,:) real, allocatable :: res(:,:) - integer :: i, batch_size, num_layers, output_size + integer :: i, n, batch_size, num_layers, output_size num_layers = size(self % layers) batch_size = size(input, dim=rank(input)) output_size = product(self % layers(num_layers) % layer_shape) + ! predict is run in inference mode only; + ! set all dropout layers' training mode to false. + do n = 2, num_layers + select type(this_layer => self % layers(n) % p) + type is(dropout_layer) + this_layer % training = .false. + end select + end do + allocate(res(output_size, batch_size)) batch: do i = 1, size(res, dim=2) @@ -295,12 +327,21 @@ module function predict_batch_3d(self, input) result(res) class(network), intent(in out) :: self real, intent(in) :: input(:,:,:,:) real, allocatable :: res(:,:) - integer :: i, batch_size, num_layers, output_size + integer :: i, n, batch_size, num_layers, output_size num_layers = size(self % layers) batch_size = size(input, dim=rank(input)) output_size = product(self % layers(num_layers) % layer_shape) + ! predict is run in inference mode only; + ! set all dropout layers' training mode to false. + do n = 2, num_layers + select type(this_layer => self % layers(n) % p) + type is(dropout_layer) + this_layer % training = .false. + end select + end do + allocate(res(output_size, batch_size)) batch: do i = 1, batch_size @@ -434,6 +475,14 @@ module subroutine train(self, input_data, output_data, batch_size, & self % loss = quadratic() end if + ! Set all dropout layers' training mode to true. + do n = 2, size(self % layers) + select type(this_layer => self % layers(n) % p) + type is(dropout_layer) + this_layer % training = .true. + end select + end do + dataset_size = size(output_data, dim=2) epoch_loop: do n = 1, epochs diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bfd3538a..108dee66 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,7 @@ foreach(execid input1d_layer input3d_layer + dropout_layer parametric_activation dense_layer conv2d_layer diff --git a/test/test_conv2d_network.f90 b/test/test_conv2d_network.f90 index 47c9a819..28dce100 100644 --- a/test/test_conv2d_network.f90 +++ b/test/test_conv2d_network.f90 @@ -39,7 +39,7 @@ program test_conv2d_network type(network) :: cnn real :: y(1) - real :: tolerance = 1e-5 + real :: tolerance = 1e-4 integer :: n integer, parameter :: num_iterations = 1000 @@ -76,7 +76,7 @@ program test_conv2d_network type(network) :: cnn real :: x(1, 8, 8) real :: y(1) - real :: tolerance = 1e-5 + real :: tolerance = 1e-4 integer :: n integer, parameter :: num_iterations = 1000 @@ -111,7 +111,7 @@ program test_conv2d_network type(network) :: cnn real :: x(1, 12, 12) real :: y(9) - real :: tolerance = 1e-5 + real :: tolerance = 1e-4 integer :: n integer, parameter :: num_iterations = 5000 diff --git a/test/test_dropout_layer.f90 b/test/test_dropout_layer.f90 new file mode 100644 index 00000000..b0ad0664 --- /dev/null +++ b/test/test_dropout_layer.f90 @@ -0,0 +1,130 @@ +program test_dropout_layer + use iso_fortran_env, only: stderr => error_unit + use nf, only: dropout, input, layer, network + use nf_dropout_layer, only: dropout_layer + type(layer) :: layer1 + type(network) :: net + integer :: input_size + + logical :: ok = .true. + + layer1 = dropout(0.5) + + if (.not. layer1 % name == 'dropout') then + ok = .false. + write(stderr, '(a)') 'dropout layer has its name set correctly.. failed' + end if + + ! Dropout on its own is not initialized and its arrays not allocated. + select type(layer1_p => layer1 % p) + type is(dropout_layer) + + if (layer1_p % dropout_rate /= 0.5) then + ok = .false. + write(stderr, '(a)') 'dropout layer dropout rate should be 0.5.. failed' + end if + + if (layer1_p % training) then + ok = .false. + write(stderr, '(a)') 'dropout layer default training mode should be false.. failed' + end if + + if (layer1_p % input_size /= 0) then + print *, 'input_size: ', layer1_p % input_size + ok = .false. + write(stderr, '(a)') 'dropout layer size should be zero.. failed' + end if + + if (allocated(layer1_p % output)) then + ok = .false. + write(stderr, '(a)') 'dropout layer output array should not be allocated.. failed' + end if + + end select + + ! Test setting training mode explicitly. + layer1 = dropout(0.5, training=.true.) + select type(layer1_p => layer1 % p) + type is(dropout_layer) + if (.not. layer1_p % training) then + ok = .false. + write(stderr, '(a)') 'dropout layer training mode should be true.. failed' + end if + end select + + layer1 = dropout(0.5, training=.false.) + select type(layer1_p => layer1 % p) + type is(dropout_layer) + if (layer1_p % training) then + ok = .false. + write(stderr, '(a)') 'dropout layer training mode should be false.. failed' + end if + end select + + ! Now we're gonna initialize a minimal network with an input layer and a + ! dropout that follows and we'll check that the dropout layer has expected + ! state. + input_size = 10 + net = network([ & + input(input_size), & + dropout(0.5) & + ]) + + select type(layer1_p => net % layers(1) % p) + type is(dropout_layer) + if (layer1_p % input_size /= input_size) then + ok = .false. + write(stderr, '(a)') 'dropout layer input size should be the same as the input layer.. failed' + end if + + if (.not. allocated(layer1_p % output)) then + ok = .false. + write(stderr, '(a)') 'dropout layer output array should be allocated.. failed' + end if + + if (.not. allocated(layer1_p % gradient)) then + ok = .false. + write(stderr, '(a)') 'dropout layer gradient array should be allocated.. failed' + end if + + if (.not. allocated(layer1_p % mask)) then + ok = .false. + write(stderr, '(a)') 'dropout layer mask array should be allocated.. failed' + end if + + end select + + ! Now we're gonna run the forward pass and check that the dropout indeed + ! drops according to the requested dropout rate. + forward_pass: block + real :: input_data(5) + real :: output_data(size(input_data)) + integer :: n + + net = network([ & + input(size(input_data)), & + dropout(0.5) & + ]) + + call random_number(input_data) + do n = 1, 10000 + output_data = net % predict(input_data) + ! Check that sum of output matches sum of input within small tolerance + if (abs(sum(output_data) - sum(input_data)) > 1e-6) then + ok = .false. + exit + end if + end do + if (.not. ok) then + write(stderr, '(a)') 'dropout layer output sum should match input sum within tolerance.. failed' + end if + end block forward_pass + + if (ok) then + print '(a)', 'test_dropout_layer: All tests passed.' + else + write(stderr, '(a)') 'test_dropout_layer: One or more tests failed.' + stop 1 + end if + +end program test_dropout_layer