-
Notifications
You must be signed in to change notification settings - Fork 0
/
06-helpers.R
48 lines (40 loc) · 1.49 KB
/
06-helpers.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
use_multi_cpu <- function(threads) {
library(tensorflow)
library(keras)
k_clear_session()
config <- tf$ConfigProto(intra_op_parallelism_threads = threads, inter_op_parallelism_threads = threads)
session = tf$Session(config=config)
k_set_session(session)
}
split_dataset <- function(x_data, y_data, fraction = 0.2) {
train_start_index <- 1
train_end_index <- train_start_index + floor((1-fraction) * dim(x_data)[1]) - 1
test_start_index <- train_end_index + 1
test_end_index <- dim(x_data)[1]
x_train <- x_data[train_start_index:train_end_index,,]
x_test <- x_data[test_start_index:test_end_index,,]
y_train <- y_data[train_start_index:train_end_index]
y_test <- y_data[test_start_index:test_end_index]
list(x_train = x_train,
y_train = y_train,
x_test = x_test,
y_test = y_test)
}
Progress <- R6::R6Class("Progress",
inherit = KerasCallback,
public = list(
num_epochs = NULL,
update_frequency = NULL,
epoch = NULL,
batch = NULL,
initialize = function() {
self$epoch <- 1
},
on_epoch_end = function(epoch, logs = list()) {
validation_info <- ''
if ('val_loss' %in% names(logs))
validation_info <- paste(', val loss: ', logs[['val_loss']], ', val acc.: ', logs[['val_acc']], sep = '')
cat('Epoch ', epoch + 1, ' - loss: ', logs[['loss']], ', acc.: ', logs[['acc']], validation_info, ' \r', sep = '')
flush.console()
}
))