Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
castelao committed Nov 14, 2023
1 parent 98d12b8 commit a9111b3
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 36 deletions.
9 changes: 5 additions & 4 deletions src/nf/nf_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ module nf_layer
procedure :: get_params
procedure :: get_gradients
procedure :: set_params
procedure :: set_state
procedure :: init
procedure :: print_info
procedure :: reset

! Specific subroutines for different array ranks
procedure, private :: backward_1d
Expand Down Expand Up @@ -154,9 +154,10 @@ module subroutine set_params(self, params)
!! Parameters of this layer
end subroutine set_params

module subroutine reset(self)
class(layer), intent(in out) :: self
end subroutine reset
module subroutine set_state(self, state)
class(layer), intent(inout) :: self
real, intent(in), optional :: state(:)
end subroutine set_state

end interface

Expand Down
1 change: 0 additions & 1 deletion src/nf/nf_layer_constructors_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ pure module function input1d(layer_size) result(res)
res % initialized = .true.
end function input1d


pure module function input3d(layer_shape) result(res)
integer, intent(in) :: layer_shape(3)
type(layer) :: res
Expand Down
16 changes: 10 additions & 6 deletions src/nf/nf_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -442,14 +442,18 @@ module subroutine set_params(self, params)

end subroutine set_params

module subroutine reset(self)
class(layer), intent(in out) :: self
subroutine set_state(self, state)
class(layer), intent(inout) :: self
real, intent(in), optional :: state(:)

select type (this_layer => self % p)
type is (rnn_layer)
call this_layer % reset()
end select

end subroutine reset
if (present(state)) then
this_layer % state = state
else
this_layer % state = 0
end if
end select
end subroutine set_state

end submodule nf_layer_submodule
8 changes: 0 additions & 8 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ module nf_network
procedure :: get_params
procedure :: print_info
procedure :: set_params
procedure :: reset
procedure :: train
procedure :: update

Expand Down Expand Up @@ -224,13 +223,6 @@ module subroutine update(self, optimizer, batch_size)
!! Set to `size(input_data, dim=2)` for a batch gradient descent.
end subroutine update

module subroutine reset(self)
!! Reset network state
!!
!! Currently only affect RNN layer type
class(network), intent(in out) :: self
end subroutine reset

end interface

end module nf_network
17 changes: 3 additions & 14 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -676,23 +676,12 @@ module subroutine update(self, optimizer, batch_size)
type is(conv2d_layer)
this_layer % dw = 0
this_layer % db = 0
end select
end do

end subroutine update

module subroutine reset(self)
class(network), intent(in out) :: self
integer :: n, num_layers

num_layers = size(self % layers)
do n = 2, num_layers
select type(this_layer => self % layers(n) % p)
type is(rnn_layer)
call self % layers(n) % reset()
this_layer % dw = 0
this_layer % db = 0
end select
end do

end subroutine reset
end subroutine update

end submodule nf_network_submodule
16 changes: 13 additions & 3 deletions src/nf/nf_rnn_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module nf_rnn_layer

type, extends(base_layer) :: rnn_layer

!! Concrete implementation of a dense (fully-connected) layer type
!! Concrete implementation of an RNN (fully-connected) layer type

integer :: input_size
integer :: output_size
Expand All @@ -40,7 +40,7 @@ module nf_rnn_layer
procedure :: get_params
procedure :: init
procedure :: set_params
procedure :: reset
procedure :: set_state

end type rnn_layer

Expand Down Expand Up @@ -94,7 +94,7 @@ pure module function get_params(self) result(params)
!! Return the parameters (weights and biases) of this layer.
!! The parameters are ordered as weights first, biases second.
class(rnn_layer), intent(in) :: self
!! Dense layer instance
!! RNN layer instance
real, allocatable :: params(:)
!! Parameters of this layer
end function get_params
Expand Down Expand Up @@ -137,4 +137,14 @@ end subroutine reset

end interface

subroutine set_state(self, state)
type(rnn_layer), intent(inout) :: self
real, intent(in), optional :: state(:)
if (present(state)) then
self % state = state
else
self % state = 0
end if
end subroutine set_state

end module nf_rnn_layer

0 comments on commit a9111b3

Please sign in to comment.