Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of the MSE loss function #173

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module nf
use nf_layer, only: layer
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: quadratic_derivative, mse_derivative
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
use nf_activation, only: activation_function, elu, exponential, &
Expand Down
39 changes: 39 additions & 0 deletions src/nf/nf_loss.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,23 @@ module nf_loss
implicit none

private
public :: loss_derivative_interface
public :: mse, mse_derivative
public :: quadratic, quadratic_derivative

interface

pure function loss_derivative_interface(true, predicted) result(res)
!! Interface for the first derivative of a loss function
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function loss_derivative_interface

pure module function quadratic(true, predicted) result(res)
!! Quadratic loss function:
!!
Expand All @@ -37,6 +50,32 @@ pure module function quadratic_derivative(true, predicted) result(res)
!! Resulting loss values
end function quadratic_derivative

pure module function mse(true, predicted) result(res)
!! Mean square error loss function:
!!
!! L = (predicted - true)**2 / n
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function mse

pure module function mse_derivative(true, predicted) result(res)
!! First derivative of the quadratic loss function:
!!
!! L' = 2 * (predicted - true) / n
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function mse_derivative

end interface

end module nf_loss
14 changes: 14 additions & 0 deletions src/nf/nf_loss_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,18 @@ pure module function quadratic_derivative(true, predicted) result(res)
res = predicted - true
end function quadratic_derivative

pure module function mse(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
res = (predicted - true)**2 / size(true)
end function mse

pure module function mse_derivative(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
res = 2 * (predicted - true) / size(true)
end function mse_derivative

end submodule nf_loss_submodule
12 changes: 10 additions & 2 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module nf_network
!! This module provides the network type to create new models.

use nf_layer, only: layer
use nf_loss, only: loss_derivative_interface
use nf_optimizers, only: optimizer_base_type

implicit none
Expand All @@ -14,6 +15,7 @@ module nf_network

type(layer), allocatable :: layers(:)
class(optimizer_base_type), allocatable :: optimizer
procedure(loss_derivative_interface), pointer, nopass :: loss_derivative => null()

contains

Expand Down Expand Up @@ -138,13 +140,16 @@ end function predict_batch_3d

interface

pure module subroutine backward(self, output)
pure module subroutine backward(self, output, loss_derivative)
!! Apply one backward pass through the network.
!! This changes the state of layers on the network.
!! Typically used only internally from the `train` method,
!! but can be invoked by the user when creating custom optimizers.
class(network), intent(in out) :: self
!! Network instance
procedure(loss_derivative_interface), optional :: loss_derivative
!! First derivative of the loss function to use.
!! If not provide the default is `quadratic_derivative(x, y)`.
real, intent(in) :: output(:)
!! Output data
end subroutine backward
Expand Down Expand Up @@ -185,7 +190,7 @@ module subroutine print_info(self)
end subroutine print_info

module subroutine train(self, input_data, output_data, batch_size, &
epochs, optimizer)
epochs, optimizer, loss_derivative)
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input_data(:,:)
Expand All @@ -204,6 +209,9 @@ module subroutine train(self, input_data, output_data, batch_size, &
!! Number of epochs to run
class(optimizer_base_type), intent(in), optional :: optimizer
!! Optimizer instance to use. If not provided, the default is sgd().
procedure(loss_derivative_interface), optional :: loss_derivative
!! First derivative of the loss function to use.
!! If not provide the default is `quadratic_derivative(x, y)`.
end subroutine train

module subroutine update(self, optimizer, batch_size)
Expand Down
24 changes: 20 additions & 4 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
use nf_keras, only: get_keras_h5_layers, keras_layer
use nf_layer, only: layer
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: quadratic_derivative
use nf_loss, only: quadratic_derivative, mse_derivative
use nf_optimizers, only: optimizer_base_type, sgd
use nf_parallel, only: tile_indices
use nf_activation, only: activation_function, &
Expand Down Expand Up @@ -280,11 +280,20 @@ pure function get_activation_by_name(activation_name) result(res)

end function get_activation_by_name

pure module subroutine backward(self, output)
pure module subroutine backward(self, output, loss_derivative)
class(network), intent(in out) :: self
procedure(loss_derivative_interface), optional :: loss_derivative
real, intent(in) :: output(:)
integer :: n, num_layers

if (.not.associated(self % loss_derivative)) then
if (present(loss_derivative)) then
self % loss_derivative => loss_derivative
else
self % loss_derivative => quadratic_derivative
end if
endif

num_layers = size(self % layers)

! Iterate backward over layers, from the output layer
Expand All @@ -297,7 +306,7 @@ pure module subroutine backward(self, output)
type is(dense_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
quadratic_derivative(output, this_layer % output) &
self % loss_derivative(output, this_layer % output) &
)
end select
else
Expand Down Expand Up @@ -540,13 +549,14 @@ end subroutine set_params


module subroutine train(self, input_data, output_data, batch_size, &
epochs, optimizer)
epochs, optimizer, loss_derivative)
class(network), intent(in out) :: self
real, intent(in) :: input_data(:,:)
real, intent(in) :: output_data(:,:)
integer, intent(in) :: batch_size
integer, intent(in) :: epochs
class(optimizer_base_type), intent(in), optional :: optimizer
procedure(loss_derivative_interface), optional :: loss_derivative
class(optimizer_base_type), allocatable :: optimizer_

real :: pos
Expand All @@ -565,6 +575,12 @@ module subroutine train(self, input_data, output_data, batch_size, &

call self % optimizer % init(self % get_num_params())

if (present(loss_derivative)) then
self % loss_derivative => loss_derivative
else
self % loss_derivative => quadratic_derivative
end if

dataset_size = size(output_data, dim=2)

epoch_loop: do n = 1, epochs
Expand Down
Loading