-
Notifications
You must be signed in to change notification settings - Fork 0
/
library.r
40 lines (33 loc) · 1.26 KB
/
library.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
library(repr)
options(repr.plot.width=6, repr.plot.height=4)
Progress <- R6::R6Class("Progress",
inherit = KerasCallback,
public = list(
num_batches = NULL,
num_epochs = NULL,
update_frequency = NULL,
epoch = NULL,
batch = NULL,
initialize = function(num_batches, num_epochs, update_frequency) {
self$num_batches <- num_batches
self$num_epochs <- num_epochs
self$update_frequency <- update_frequency
self$epoch <- 1
},
on_batch_end = function(batch, logs = list()) {
if ((batch + 1) %% self$update_frequency == 0) {
cat('Epoch ', self$epoch + 1, '/', self$num_epochs, ': ', batch + 1, '/', self$num_batches,
' - loss: ', logs[['loss']], '\r', sep = '')
flush.console()
}
self$batch <- batch
},
on_epoch_begin = function(epoch, logs = list()) {
self$epoch <- epoch
},
on_epoch_end = function(epoch, logs = list()) {
cat('Epoch ', self$epoch + 1, '/', self$num_epochs, ': ', self$batch + 1, '/', self$num_batches,
' - loss: ', logs[['loss']], ' - validation loss: ', logs[['val_loss']], '\n', sep = '')
flush.console()
}
))