diff --git a/tf/net_to_coreml.py b/tf/net_to_coreml.py new file mode 100644 index 00000000..69def42e --- /dev/null +++ b/tf/net_to_coreml.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +import os +from net_to_model import convert +import coremltools as ct + +if __name__ == "__main__": + ############## + # NET TO MODEL + args, root_dir, tfp = convert(include_attn_wts_output=False, rescale_rule50=False) + + ################# + # MODEL TO COREML + input_shape = ct.Shape(shape=(1, 112, 8, 8)) + + # Set the compute precision + compute_precision = ct.precision.FLOAT16 + # compute_precision = ct.precision.FLOAT32 + + # Convert the model to CoreML + coreml_model = ct.convert( + tfp.model, + convert_to="mlprogram", + inputs=[ct.TensorType(shape=input_shape, name="input_1")], + compute_precision=compute_precision, + ) + + # Get the protobuf spec + spec = coreml_model._spec + + # Rename the input + ct.utils.rename_feature(spec, "input_1", "input_planes") + + # Get input names + input_names = [input.name for input in spec.description.input] + + # Print the input names + print(f"Renamed input: {input_names}") + + # Set output names + output_names = ["output_policy", "output_value"] + + if tfp.moves_left: + output_names.append("output_moves_left") + + # Rename output names + for i, name in enumerate(output_names): + # Rename the output + ct.utils.rename_feature(spec, spec.description.output[i].name, name) + + # Print the output names + print(f"Renamed output: {[output_i.name for output_i in spec.description.output]}") + + # Set model description + coreml_model.short_description = f"Lc0 converted from {args.net}" + + # Rebuild the model with the updated spec + print(f"Rebuilding model with updated spec ...") + rebuilt_mlmodel = ct.models.MLModel( + coreml_model._spec, weights_dir=coreml_model._weights_dir + ) + + # Save the CoreML model + print(f"Saving model ...") + coreml_model_path = os.path.join(root_dir, f"{args.net}.mlpackage") + coreml_model.save(coreml_model_path) + + print(f"CoreML model saved at {coreml_model_path}") diff --git a/tf/net_to_model.py b/tf/net_to_model.py index 02772f82..dea69182 100755 --- a/tf/net_to_model.py +++ b/tf/net_to_model.py @@ -4,33 +4,38 @@ import yaml import tfprocess -argparser = argparse.ArgumentParser(description='Convert net to model.') -argparser.add_argument('net', - type=str, - help='Net file to be converted to a model checkpoint.') -argparser.add_argument('--start', - type=int, - default=0, - help='Offset to set global_step to.') -argparser.add_argument('--cfg', - type=argparse.FileType('r'), - help='yaml configuration with training parameters') -argparser.add_argument('-e', - '--ignore-errors', - action='store_true', - help='Ignore missing and wrong sized values.') -args = argparser.parse_args() -cfg = yaml.safe_load(args.cfg.read()) -print(yaml.dump(cfg, default_flow_style=False)) -START_FROM = args.start +def convert(include_attn_wts_output=True, rescale_rule50=True): + argparser = argparse.ArgumentParser(description='Convert net to model.') + argparser.add_argument('net', + type=str, + help='Net file to be converted to a model checkpoint.') + argparser.add_argument('--start', + type=int, + default=0, + help='Offset to set global_step to.') + argparser.add_argument('--cfg', + type=argparse.FileType('r'), + help='yaml configuration with training parameters') + argparser.add_argument('-e', + '--ignore-errors', + action='store_true', + help='Ignore missing and wrong sized values.') + args = argparser.parse_args() + cfg = yaml.safe_load(args.cfg.read()) + print(yaml.dump(cfg, default_flow_style=False)) + START_FROM = args.start -tfp = tfprocess.TFProcess(cfg) -tfp.init_net() -tfp.replace_weights(args.net, args.ignore_errors) -tfp.global_step.assign(START_FROM) + tfp = tfprocess.TFProcess(cfg) + tfp.init_net(include_attn_wts_output) + tfp.replace_weights(args.net, args.ignore_errors, rescale_rule50) + tfp.global_step.assign(START_FROM) -root_dir = os.path.join(cfg['training']['path'], cfg['name']) -if not os.path.exists(root_dir): - os.makedirs(root_dir) -tfp.manager.save(checkpoint_number=START_FROM) -print("Wrote model to {}".format(tfp.manager.latest_checkpoint)) + root_dir = os.path.join(cfg['training']['path'], cfg['name']) + if not os.path.exists(root_dir): + os.makedirs(root_dir) + tfp.manager.save(checkpoint_number=START_FROM) + print("Wrote model to {}".format(tfp.manager.latest_checkpoint)) + return args, root_dir, tfp + +if __name__ == "__main__": + convert() diff --git a/tf/tfprocess.py b/tf/tfprocess.py index 17d198fe..52f3d7ea 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -377,10 +377,10 @@ def init(self, train_dataset, test_dataset, validation_dataset=None): else: self.init_net() - def init_net(self): + def init_net(self, include_attn_wts_output=True): self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001)) input_var = tf.keras.Input(shape=(112, 8, 8)) - outputs = self.construct_net(input_var) + outputs = self.construct_net(input_var, include_attn_wts_output=include_attn_wts_output) self.model = tf.keras.Model(inputs=input_var, outputs=outputs) # swa_count initialized regardless to make checkpoint code simpler. @@ -628,7 +628,7 @@ def accuracy(target, output): keep_checkpoint_every_n_hours=24, checkpoint_name=self.cfg['name']) - def replace_weights(self, proto_filename, ignore_errors=False): + def replace_weights(self, proto_filename, ignore_errors=False, rescale_rule50=True): self.net.parse_proto(proto_filename) filters, blocks = self.net.filters(), self.net.blocks() @@ -676,7 +676,7 @@ def replace_weights(self, proto_filename, ignore_errors=False): if weight.shape.ndims == 4: # Rescale rule50 related weights as clients do not normalize the input. - if weight.name == 'input/conv2d/kernel:0' and self.net.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES: + if rescale_rule50 and weight.name == 'input/conv2d/kernel:0' and self.net.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES: num_inputs = 112 # 50 move rule is the 110th input, or 109 starting from 0. rule50_input = 109 @@ -1520,7 +1520,7 @@ def apply_promotion_logits(self, queries, keys, attn_wts): h_fc1 = ApplyAttentionPolicyMap()(policy_attn_logits, promotion_logits) return h_fc1 - def construct_net(self, inputs, name=''): + def construct_net(self, inputs, name='', include_attn_wts_output=True): if self.encoder_layers > 0: flow, attn_wts = self.create_encoder_body(inputs, @@ -1665,9 +1665,11 @@ def construct_net(self, inputs, name=''): # attention weights added as optional output for analysis -- ignored by backend if self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION: if self.moves_left: - outputs = [h_fc1, h_fc3, h_fc5, attn_wts] + outputs = [h_fc1, h_fc3, h_fc5] else: - outputs = [h_fc1, h_fc3, attn_wts] + outputs = [h_fc1, h_fc3] + if include_attn_wts_output: + outputs.append(attn_wts) elif self.moves_left: outputs = [h_fc1, h_fc3, h_fc5] else: