Skip to content

Commit

Permalink
[v0.4.0] Support Uint8List, Bytebuffer in run, bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
am15h committed Jun 18, 2020
1 parent b8176dc commit e5e4003
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 26 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.4.0 (Jun 18, 2020)
* run supports UintList8 and ByteBuffer objects
* Bug fix, resize input tensor
* Improved efficiency

## 0.3.0
* New features
* multi-dimensional reshape with type
Expand Down
2 changes: 1 addition & 1 deletion example/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ packages:
path: ".."
relative: true
source: path
version: "0.3.0"
version: "0.4.0"
typed_data:
dependency: transitive
description:
Expand Down
14 changes: 0 additions & 14 deletions example/test/tflite_flutter_plugin_example_e2e.dart
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,6 @@ void main() {
interpreter.allocateTensors();
});

test('allocate throws if already allocated', () {
interpreter.allocateTensors();
expect(() => interpreter.allocateTensors(), throwsA(isStateError));
});

test('invoke throws if not allocated', () {
expect(() => interpreter.invoke(), throwsA(isStateError));
});

test('invoke throws if not allocated after resized', () {
interpreter.allocateTensors();
interpreter.resizeInputTensor(0, [1, 2, 4]);
Expand Down Expand Up @@ -171,11 +162,6 @@ void main() {
expect(tensors[0].data, hasLength(4));
});

test('set throws if not allocated', () {
expect(() => tensors[0].data = Uint8List.fromList(const [0, 0, 0, 0]),
throwsA(isStateError));
});

test('set', () {
interpreter.allocateTensors();
tensors[0].data = Uint8List.fromList(const [0, 0, 0, 0]);
Expand Down
30 changes: 20 additions & 10 deletions lib/src/interpreter.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class Interpreter {
int get lastNativeInferenceDurationMicroSeconds =>
_lastNativeInferenceDurationMicroSeconds;

Interpreter._(this._interpreter);
Interpreter._(this._interpreter) {
// Allocate tensors when interpreter is created
allocateTensors();
}

/// Creates interpreter from model or throws if unsuccessful.
factory Interpreter._create(Model model, {InterpreterOptions options}) {
Expand Down Expand Up @@ -87,7 +90,7 @@ class Interpreter {

/// Updates allocations for all tensors.
void allocateTensors() {
checkState(!_allocated, message: 'Interpreter already allocated.');
// checkState(!_allocated, message: 'Interpreter already allocated.');
checkState(
tfLiteInterpreterAllocateTensors(_interpreter) == TfLiteStatus.ok);
_allocated = true;
Expand Down Expand Up @@ -115,25 +118,32 @@ class Interpreter {
throw ArgumentError('Input error: Outputs should not be null or empty.');
}

var inputTensors = getInputTensors();

for (int i = 0; i < inputs.length; i++) {
var tensor = inputTensors.elementAt(i);
final newShape = tensor.getInputShapeIfDifferent(inputs[i]);
if (newShape != null) {
resizeInputTensor(i, newShape);
}
}

if (!_allocated) {
allocateTensors();
_allocated = true;
}

var inputTensors = getInputTensors();
for (var i = 0; i < inputs.length; i++) {
if (inputTensors[i].shape != (inputs[i] as List).shape) {
resizeInputTensor(i, (inputs[i] as List).shape);
allocateTensors();
inputTensors = getInputTensors();
}
inputTensors[i].setTo(inputs[i]);
inputTensors = getInputTensors();

for (int i = 0; i < inputs.length; i++) {
inputTensors.elementAt(i).setTo(inputs[i]);
}

var inferenceStartNanos = DateTime.now().microsecondsSinceEpoch;
invoke();
_lastNativeInferenceDurationMicroSeconds =
DateTime.now().microsecondsSinceEpoch - inferenceStartNanos;

var outputTensors = getOutputTensors();
for (var i = 0; i < outputTensors.length; i++) {
outputTensors[i].copyTo(outputs[i]);
Expand Down
20 changes: 20 additions & 0 deletions lib/src/tensor.dart
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class Tensor {
var obj;
if (dst is Uint8List) {
obj = bytes;
} else if (dst is ByteBuffer) {
obj = bytes.buffer;
} else {
obj = _convertBytesToObject(bytes);
}
Expand All @@ -182,6 +184,9 @@ class Tensor {
if (o is Uint8List) {
return o;
}
if (o is ByteBuffer) {
return o.asUint8List();
}
var bytes = <int>[];
if (o is List) {
for (var e in o) {
Expand Down Expand Up @@ -327,6 +332,21 @@ class Tensor {
}
}

List<int> getInputShapeIfDifferent(Object input) {
if (input == null) {
return null;
}
if (input is ByteBuffer || input is Uint8List) {
return null;
}

final inputShape = computeShapeOf(input);
if (inputShape == shape) {
return null;
}
return inputShape;
}

@override
String toString() {
return 'Tensor{_tensor: $_tensor, name: $name, type: $type, shape: $shape, data: ${data.length}';
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: tflite_flutter
description: TensorFlow Lite Flutter plugin provides easy, flexible and fast Dart API to integrate TFLite models in flutter apps.
version: 0.3.0
version: 0.4.0
homepage: https://github.com/am15h/tflite_flutter_plugin

environment:
Expand Down

0 comments on commit e5e4003

Please sign in to comment.