Skip to content

Commit

Permalink
Add speculative model serialisation fix
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Jun 1, 2017
1 parent b772e06 commit ceb9993
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions thinc/neural/_classes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,44 +278,47 @@ def to_bytes(self):
for layer in queue:
if hasattr(layer, '_mem'):
weights.append({
'dims': normalize_string_keys(getattr(layer, '_dims', {})),
'params': []})
b'dims': normalize_string_keys(getattr(layer, '_dims', {})),
b'params': []})
if hasattr(layer, 'seed'):
weights[-1]['seed'] = layer.seed
weights[-1][b'seed'] = layer.seed

for (id_, name), (start, row, shape) in layer._mem._offsets.items():
if row == 1:
continue
param = layer._mem.get((id_, name))
if not isinstance(layer._mem.weights, numpy.ndarray):
param = param.get()
weights[-1]['params'].append(
weights[-1][b'params'].append(
{
'name': name,
'offset': start,
'shape': shape,
'value': param,
b'name': name,
b'offset': start,
b'shape': shape,
b'value': param,
}
)
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)
return msgpack.dumps({'weights': weights})
return msgpack.dumps({b'weights': weights})

def from_bytes(self, bytes_data):
data = msgpack.loads(bytes_data)
weights = data['weights']
weights = data[b'weights']
queue = [self]
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
if 'seed' in weights[i]:
layer.seed = weights[i]['seed']
for dim, value in weights[i]['dims'].items():
if hasattr(layer, b'_mem'):
if b'seed' in weights[i]:
layer.seed = weights[i][b'seed']
for dim, value in weights[i][b'dims'].items():
setattr(layer, dim, value)
for param in weights[i]['params']:
dest = getattr(layer, param['name'])
copy_array(dest, param['value'])
for param in weights[i][b'params']:
name = param[b'name']
if isinstance(name, bytes):
name = name.decode('utf8')
dest = getattr(layer, name)
copy_array(dest, param[b'value'])
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)

0 comments on commit ceb9993

Please sign in to comment.