From e5e4003a66ce6d80aabc58d3e9ea874483361c9a Mon Sep 17 00:00:00 2001 From: am15h Date: Thu, 18 Jun 2020 21:06:53 +0530 Subject: [PATCH] [v0.4.0] Support Uint8List, Bytebuffer in run, bugfixes --- CHANGELOG.md | 5 ++++ example/pubspec.lock | 2 +- .../tflite_flutter_plugin_example_e2e.dart | 14 --------- lib/src/interpreter.dart | 30 ++++++++++++------- lib/src/tensor.dart | 20 +++++++++++++ pubspec.yaml | 2 +- 6 files changed, 47 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 148f0a2..cfadd06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.4.0 (Jun 18, 2020) +* run supports UintList8 and ByteBuffer objects +* Bug fix, resize input tensor +* Improved efficiency + ## 0.3.0 * New features * multi-dimensional reshape with type diff --git a/example/pubspec.lock b/example/pubspec.lock index 5dd327f..c5ecc3c 100644 --- a/example/pubspec.lock +++ b/example/pubspec.lock @@ -466,7 +466,7 @@ packages: path: ".." relative: true source: path - version: "0.3.0" + version: "0.4.0" typed_data: dependency: transitive description: diff --git a/example/test/tflite_flutter_plugin_example_e2e.dart b/example/test/tflite_flutter_plugin_example_e2e.dart index b08c501..e156044 100644 --- a/example/test/tflite_flutter_plugin_example_e2e.dart +++ b/example/test/tflite_flutter_plugin_example_e2e.dart @@ -81,15 +81,6 @@ void main() { interpreter.allocateTensors(); }); - test('allocate throws if already allocated', () { - interpreter.allocateTensors(); - expect(() => interpreter.allocateTensors(), throwsA(isStateError)); - }); - - test('invoke throws if not allocated', () { - expect(() => interpreter.invoke(), throwsA(isStateError)); - }); - test('invoke throws if not allocated after resized', () { interpreter.allocateTensors(); interpreter.resizeInputTensor(0, [1, 2, 4]); @@ -171,11 +162,6 @@ void main() { expect(tensors[0].data, hasLength(4)); }); - test('set throws if not allocated', () { - expect(() => tensors[0].data = Uint8List.fromList(const [0, 0, 0, 0]), - throwsA(isStateError)); - }); - test('set', () { interpreter.allocateTensors(); tensors[0].data = Uint8List.fromList(const [0, 0, 0, 0]); diff --git a/lib/src/interpreter.dart b/lib/src/interpreter.dart index 85e7230..eb04016 100644 --- a/lib/src/interpreter.dart +++ b/lib/src/interpreter.dart @@ -24,7 +24,10 @@ class Interpreter { int get lastNativeInferenceDurationMicroSeconds => _lastNativeInferenceDurationMicroSeconds; - Interpreter._(this._interpreter); + Interpreter._(this._interpreter) { + // Allocate tensors when interpreter is created + allocateTensors(); + } /// Creates interpreter from model or throws if unsuccessful. factory Interpreter._create(Model model, {InterpreterOptions options}) { @@ -87,7 +90,7 @@ class Interpreter { /// Updates allocations for all tensors. void allocateTensors() { - checkState(!_allocated, message: 'Interpreter already allocated.'); +// checkState(!_allocated, message: 'Interpreter already allocated.'); checkState( tfLiteInterpreterAllocateTensors(_interpreter) == TfLiteStatus.ok); _allocated = true; @@ -115,25 +118,32 @@ class Interpreter { throw ArgumentError('Input error: Outputs should not be null or empty.'); } + var inputTensors = getInputTensors(); + + for (int i = 0; i < inputs.length; i++) { + var tensor = inputTensors.elementAt(i); + final newShape = tensor.getInputShapeIfDifferent(inputs[i]); + if (newShape != null) { + resizeInputTensor(i, newShape); + } + } + if (!_allocated) { allocateTensors(); _allocated = true; } - var inputTensors = getInputTensors(); - for (var i = 0; i < inputs.length; i++) { - if (inputTensors[i].shape != (inputs[i] as List).shape) { - resizeInputTensor(i, (inputs[i] as List).shape); - allocateTensors(); - inputTensors = getInputTensors(); - } - inputTensors[i].setTo(inputs[i]); + inputTensors = getInputTensors(); + + for (int i = 0; i < inputs.length; i++) { + inputTensors.elementAt(i).setTo(inputs[i]); } var inferenceStartNanos = DateTime.now().microsecondsSinceEpoch; invoke(); _lastNativeInferenceDurationMicroSeconds = DateTime.now().microsecondsSinceEpoch - inferenceStartNanos; + var outputTensors = getOutputTensors(); for (var i = 0; i < outputTensors.length; i++) { outputTensors[i].copyTo(outputs[i]); diff --git a/lib/src/tensor.dart b/lib/src/tensor.dart index 9277b1a..573780d 100644 --- a/lib/src/tensor.dart +++ b/lib/src/tensor.dart @@ -166,6 +166,8 @@ class Tensor { var obj; if (dst is Uint8List) { obj = bytes; + } else if (dst is ByteBuffer) { + obj = bytes.buffer; } else { obj = _convertBytesToObject(bytes); } @@ -182,6 +184,9 @@ class Tensor { if (o is Uint8List) { return o; } + if (o is ByteBuffer) { + return o.asUint8List(); + } var bytes = []; if (o is List) { for (var e in o) { @@ -327,6 +332,21 @@ class Tensor { } } + List getInputShapeIfDifferent(Object input) { + if (input == null) { + return null; + } + if (input is ByteBuffer || input is Uint8List) { + return null; + } + + final inputShape = computeShapeOf(input); + if (inputShape == shape) { + return null; + } + return inputShape; + } + @override String toString() { return 'Tensor{_tensor: $_tensor, name: $name, type: $type, shape: $shape, data: ${data.length}'; diff --git a/pubspec.yaml b/pubspec.yaml index 1ab1243..5d88a1f 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: tflite_flutter description: TensorFlow Lite Flutter plugin provides easy, flexible and fast Dart API to integrate TFLite models in flutter apps. -version: 0.3.0 +version: 0.4.0 homepage: https://github.com/am15h/tflite_flutter_plugin environment: