Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help in generating .tfilte for CNN-Text-Classificationa #10

Open
skmalviya opened this issue Mar 6, 2020 · 0 comments
Open

Help in generating .tfilte for CNN-Text-Classificationa #10

skmalviya opened this issue Mar 6, 2020 · 0 comments

Comments

@skmalviya
Copy link

skmalviya commented Mar 6, 2020

Can you help me out with the code at [(https://github.com/dennybritz/cnn-text-classification-tf)] ?
I am new to tensorflow. I want to create .tflite file for the model in train.py, As you mentioned in the video It starts with making a checkpoint, save its graph file as .pbtxt, freeze it with creating .pb file and then finally converting it to .tflite which I want to obtain at last.
I run it in CPU mode with tensorflow=1.13.1.
I am able to generate both .pbtxt and .pb file successfully for the very first checkpoint, but getting error message at the tf.lite.TocoConverter.from_frozen_graph() line of my code.

# Training loop. For each batch...
            for batch in batches:
                x_batch, y_batch = zip(*batch)
                train_step(x_batch, y_batch)
                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("\nEvaluation:")
                    dev_step(x_dev, y_dev, writer=dev_summary_writer)
                    print("")
                if current_step % FLAGS.checkpoint_every == 0:
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    tf.train.write_graph(sess.graph_def, checkpoint_dir, 'savegraph.pbtxt') #saving the model's tensorflow graph definition
                    freez_grph(checkpoint_dir)
                    inp_node = ['input_x']
                    out_node = ['output']
                    #nodes = [e.name + '=>' +  e.op for e in tf.get_default_graph().as_graph_def().node if e.op in  (( 'Softmax','Placeholder'))]
                    #print (nodes)
                    #converter = tf.lite.TFLiteConverter.from_session(sess, [cnn.embedded_chars_expanded], [cnn.input_y])
                    converter = tf.lite.TocoConverter.from_frozen_graph(checkpoint_dir+'/frozen_model_TextCNN Model.pb',inp_node, out_node)
                    tflite_model = converter.convert()
                    open("TextCNN.tflite", "wb").write(tflite_model)
                    exit()
                    print("Saved model checkpoint to {}\n".format(path))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant