Skip to content

Commit

Permalink
fix mse not a builtin function while loading model
Browse files Browse the repository at this point in the history
  • Loading branch information
oaksharks committed Nov 15, 2024
1 parent 9d26725 commit 8f79332
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
22 changes: 16 additions & 6 deletions deeptables/models/deepmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import OrderedDict
from typing import List, Union

import keras
import tensorflow as tf
from keras import backend as K
from keras.api.layers import Dense, Concatenate, Flatten, Input, Add, BatchNormalization, Dropout
Expand Down Expand Up @@ -317,22 +318,31 @@ def __build_model(self, task, num_classes, nets, categorical_columns, continuous
return model

def __compile_model(self, model, task, num_classes, optimizer, loss, metrics):
import keras

if optimizer == 'auto':
optimizer = keras.optimizers.Adam(learning_rate=0.001)

loss_name = None
if loss == 'auto':
if task == consts.TASK_BINARY or task == consts.TASK_MULTILABEL:
loss = 'binary_crossentropy'
loss_name = 'binary_crossentropy'
loss = keras.losses.BinaryCrossentropy()
elif task == consts.TASK_REGRESSION:
loss = 'mse'
loss_name = 'mse'
loss = keras.losses.MeanSquaredError()
elif task == consts.TASK_MULTICLASS:
if num_classes == 2:
loss = 'binary_crossentropy'
loss_name = 'binary_crossentropy'
loss = keras.losses.BinaryCrossentropy()
else:
loss = 'categorical_crossentropy'
loss_name = 'categorical_crossentropy'
loss = keras.losses.CategoricalCrossentropy()
else:
raise RuntimeError(f'unseen task "{task}"')
assert loss_name
assert loss
self.model_desc.optimizer = optimizer
self.model_desc.loss = loss
self.model_desc.loss = loss_name
model.compile(optimizer, loss, metrics=metrics)
return model

Expand Down
8 changes: 5 additions & 3 deletions deeptables/models/hyper_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,12 @@ def save(self, model_path):

self.model.save(model_path)

stub = copy.copy(self)
stub.model = None
# Note: copy.copy(self) may cause self.model is None
# self.model = None # already use __getstate__ to avoid persist model
stub_path = model_path + 'dt_estimator.pkl'
with fs.open(stub_path, 'wb') as f:
pickle.dump(stub, f, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)


@staticmethod
def load(model_path):
Expand Down Expand Up @@ -245,6 +246,7 @@ def get_iteration_scores(self):
def __getstate__(self):
try:
state = super().__getstate__()
state = copy.copy(state) # if not make a copy, will delete model in self
except AttributeError:
state = self.__dict__.copy()

Expand Down

0 comments on commit 8f79332

Please sign in to comment.