-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinit_code.py
40 lines (25 loc) · 899 Bytes
/
init_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# This defines the serialization/deserialization code for keras models when used
# in deepforge
import os
import time
import tarfile
import shutil
import tensorflow.keras as keras
import deepforge
def dump_model(model, outfile):
# Create the tmp directory
tmp_dir = outfile.name + '-tmp-' + str(time.time())
model.save(tmp_dir)
with tarfile.open(outfile.name, 'w:gz') as tar:
tar.add(tmp_dir, arcname='SavedModel')
shutil.rmtree(tmp_dir)
def load_model(infile):
tmp_dir = infile.name + '-tmp-' + str(time.time())
os.makedirs(tmp_dir)
with tarfile.open(infile.name) as tar:
tar.extractall(path=tmp_dir)
model = keras.models.load_model(os.path.join(tmp_dir, 'SavedModel'))
shutil.rmtree(tmp_dir)
return model
for subclass in keras.Model.__subclasses__():
deepforge.serialization.register(subclass, dump_model, load_model)