Skip to content

Commit

Permalink
Support of optional loss derivative function in the DT network
Browse files Browse the repository at this point in the history
  • Loading branch information
Vandenplas, Jeremie committed Apr 16, 2024
2 parents f5fc636 + e1cb7dd commit ecd2979
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module nf
use nf_layer, only: layer
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: quadratic_derivative, mse_derivative
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
use nf_activation, only: activation_function, elu, exponential, &
Expand Down
10 changes: 10 additions & 0 deletions src/nf/nf_loss.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@ module nf_loss
implicit none

private
public :: loss_derivative_interface
public :: mse, mse_derivative
public :: quadratic, quadratic_derivative

interface

pure function loss_derivative_interface(true, predicted) result(res)
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function loss_derivative_interface

pure module function quadratic(true, predicted) result(res)
!! Quadratic loss function:
!!
Expand Down
10 changes: 8 additions & 2 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module nf_network
!! This module provides the network type to create new models.

use nf_layer, only: layer
use nf_loss, only: loss_derivative_interface
use nf_optimizers, only: optimizer_base_type

implicit none
Expand All @@ -15,6 +16,8 @@ module nf_network
type(layer), allocatable :: layers(:)
class(optimizer_base_type), allocatable :: optimizer

procedure(loss_derivative_interface), pointer, nopass :: loss_derivative => null()

contains

procedure :: backward
Expand All @@ -26,6 +29,7 @@ module nf_network
procedure :: train
procedure :: update


procedure, private :: forward_1d
procedure, private :: forward_3d
procedure, private :: predict_1d
Expand Down Expand Up @@ -185,7 +189,7 @@ module subroutine print_info(self)
end subroutine print_info

module subroutine train(self, input_data, output_data, batch_size, &
epochs, optimizer)
epochs, optimizer, loss_derivative)
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input_data(:,:)
Expand All @@ -204,9 +208,10 @@ module subroutine train(self, input_data, output_data, batch_size, &
!! Number of epochs to run
class(optimizer_base_type), intent(in), optional :: optimizer
!! Optimizer instance to use. If not provided, the default is sgd().
procedure(loss_derivative_interface), optional :: loss_derivative
end subroutine train

module subroutine update(self, optimizer, batch_size)
module subroutine update(self, optimizer, batch_size, loss_derivative)
!! Update the weights and biases on all layers using the stored
!! gradients (from backward passes) on those layers, and flush those
!! same stored gradients to zero.
Expand All @@ -221,6 +226,7 @@ module subroutine update(self, optimizer, batch_size)
!! Batch size to use.
!! Set to 1 for a pure stochastic gradient descent (default).
!! Set to `size(input_data, dim=2)` for a batch gradient descent.
procedure(loss_derivative_interface), optional :: loss_derivative
end subroutine update

end interface
Expand Down
24 changes: 20 additions & 4 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
use nf_keras, only: get_keras_h5_layers, keras_layer
use nf_layer, only: layer
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: loss_derivative => mse_derivative
use nf_loss, only: quadratic_derivative, mse_derivative
use nf_optimizers, only: optimizer_base_type, sgd
use nf_parallel, only: tile_indices
use nf_activation, only: activation_function, &
Expand Down Expand Up @@ -297,7 +297,7 @@ pure module subroutine backward(self, output)
type is(dense_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
loss_derivative(output, this_layer % output) &
self % loss_derivative(output, this_layer % output) &
)
end select
else
Expand Down Expand Up @@ -540,13 +540,14 @@ end subroutine set_params


module subroutine train(self, input_data, output_data, batch_size, &
epochs, optimizer)
epochs, optimizer, loss_derivative)
class(network), intent(in out) :: self
real, intent(in) :: input_data(:,:)
real, intent(in) :: output_data(:,:)
integer, intent(in) :: batch_size
integer, intent(in) :: epochs
class(optimizer_base_type), intent(in), optional :: optimizer
procedure(loss_derivative_interface), optional :: loss_derivative
class(optimizer_base_type), allocatable :: optimizer_

real :: pos
Expand All @@ -565,6 +566,12 @@ module subroutine train(self, input_data, output_data, batch_size, &

call self % optimizer % init(self % get_num_params())

if (present(loss_derivative)) then
self % loss_derivative => loss_derivative
else
self % loss_derivative => quadratic_derivative
end if

dataset_size = size(output_data, dim=2)

epoch_loop: do n = 1, epochs
Expand Down Expand Up @@ -597,12 +604,13 @@ module subroutine train(self, input_data, output_data, batch_size, &
end subroutine train


module subroutine update(self, optimizer, batch_size)
module subroutine update(self, optimizer, batch_size, loss_derivative)
class(network), intent(in out) :: self
class(optimizer_base_type), intent(in), optional :: optimizer
integer, intent(in), optional :: batch_size
class(optimizer_base_type), allocatable :: optimizer_
integer :: batch_size_
procedure(loss_derivative_interface), optional :: loss_derivative
real, allocatable :: params(:)
integer :: n

Expand All @@ -622,6 +630,14 @@ module subroutine update(self, optimizer, batch_size)
call self % optimizer % init(self % get_num_params())
end if

if (.not.associated(self % loss_derivative)) then
if (present(loss_derivative)) then
self % loss_derivative => loss_derivative
else
self % loss_derivative => quadratic_derivative
end if
endif

if (present(batch_size)) then
batch_size_ = batch_size
else
Expand Down

0 comments on commit ecd2979

Please sign in to comment.