From 6f9dfb9d4c4257861c3354267a4701c3354370f0 Mon Sep 17 00:00:00 2001 From: zhai_pro Date: Sun, 27 Jan 2019 17:30:49 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9C=8B=E7=9C=8B=E6=B7=B1=E5=BA=A6=E5=AD=A6?= =?UTF-8?q?=E4=B9=A0=E5=90=8E=E7=9A=84=E6=A1=86=E6=9E=B6=E6=98=AF=E5=90=A6?= =?UTF-8?q?=E8=83=BD=E8=AF=86=E5=88=AB=E5=89=8D=E6=89=80=E6=9C=AA=E8=A7=81?= =?UTF-8?q?=E7=9A=84=E5=9B=BE=E7=89=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mlearn_for_image.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mlearn_for_image.py b/mlearn_for_image.py index 2fb483b..1b7b8f0 100644 --- a/mlearn_for_image.py +++ b/mlearn_for_image.py @@ -1,4 +1,6 @@ # coding: utf-8 +import sys + import cv2 import numpy as np from keras import models @@ -31,7 +33,7 @@ def load_data(): return (train_x, train_y, sample_weight), (test_x, test_y) -def main(): +def learn(): (train_x, train_y, sample_weight), (test_x, test_y) = load_data() datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True) @@ -62,5 +64,19 @@ def main(): model.save('12306.image.model.h5', include_optimizer=False) +def predict(fn): + imgs = cv2.imread(fn) + imgs = cv2.resize(imgs, (67, 67)) + imgs = imgs / 255.0 + imgs.shape = (-1, 67, 67, 3) + model = models.load_model('12306.image.model.h5') + labels = model.predict(imgs) + print(labels.max(axis=1)) + print(labels.argmax(axis=1)) + + if __name__ == '__main__': - main() + if len(sys.argv) >= 2: + predict(sys.argv[1]) + else: + learn()