diff --git a/js/react_native/android/CMakeLists.txt b/js/react_native/android/CMakeLists.txt new file mode 100644 index 0000000000000..98f30daac6372 --- /dev/null +++ b/js/react_native/android/CMakeLists.txt @@ -0,0 +1,37 @@ +project(OnnxruntimeJSIHelper) +cmake_minimum_required(VERSION 3.9.0) + +set (PACKAGE_NAME "onnxruntime-react-native") +set (BUILD_DIR ${CMAKE_SOURCE_DIR}/build) +set(CMAKE_VERBOSE_MAKEFILE ON) +set(CMAKE_CXX_STANDARD 17) + +file(TO_CMAKE_PATH "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath) + +include_directories( + "${NODE_MODULES_DIR}/react-native/React" + "${NODE_MODULES_DIR}/react-native/React/Base" + "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi" +) + +add_library(onnxruntimejsihelper + SHARED + ${libPath} + src/main/cpp/cpp-adapter.cpp +) + +# Configure C++ 17 +set_target_properties( + onnxruntimejsihelper PROPERTIES + CXX_STANDARD 17 + CXX_EXTENSIONS OFF + POSITION_INDEPENDENT_CODE ON +) + +find_library(log-lib log) + +target_link_libraries( + onnxruntimejsihelper + ${log-lib} # <-- Logcat logger + android # <-- Android JNI core +) diff --git a/js/react_native/android/build.gradle b/js/react_native/android/build.gradle index abf56a59a09ae..7a99a0a2671d5 100644 --- a/js/react_native/android/build.gradle +++ b/js/react_native/android/build.gradle @@ -1,3 +1,5 @@ +import java.nio.file.Paths + buildscript { repositories { google() @@ -20,6 +22,32 @@ def getExtOrIntegerDefault(name) { return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger() } +def reactNativeArchitectures() { + def value = project.getProperties().get("reactNativeArchitectures") + return value ? value.split(",") : ["armeabi-v7a", "x86", "x86_64", "arm64-v8a"] +} + +def resolveBuildType() { + Gradle gradle = getGradle() + String tskReqStr = gradle.getStartParameter().getTaskRequests()['args'].toString() + return tskReqStr.contains('Release') ? 'release' : 'debug' +} + +static def findNodeModules(baseDir) { + def basePath = baseDir.toPath().normalize() + while (basePath) { + def nodeModulesPath = Paths.get(basePath.toString(), "node_modules") + def reactNativePath = Paths.get(nodeModulesPath.toString(), "react-native") + if (nodeModulesPath.toFile().exists() && reactNativePath.toFile().exists()) { + return nodeModulesPath.toString() + } + basePath = basePath.getParent() + } + throw new GradleException("onnxruntime-react-native: Failed to find node_modules/ path!") +} + +def nodeModules = findNodeModules(projectDir); + def checkIfOrtExtensionsEnabled() { // locate user's project dir def reactnativeRootDir = project.rootDir.parentFile @@ -38,6 +66,9 @@ def checkIfOrtExtensionsEnabled() { boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled() +def REACT_NATIVE_VERSION = ['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim() +def REACT_NATIVE_MINOR_VERSION = REACT_NATIVE_VERSION.split("\\.")[1].toInteger() + android { compileSdkVersion getExtOrIntegerDefault('compileSdkVersion') buildToolsVersion getExtOrDefault('buildToolsVersion') @@ -47,6 +78,44 @@ android { versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + externalNativeBuild { + cmake { + cppFlags "-O2 -frtti -fexceptions -Wall -Wno-unused-variable -fstack-protector-all" + if (REACT_NATIVE_MINOR_VERSION >= 71) { + // fabricjni required c++_shared + arguments "-DANDROID_STL=c++_shared", "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}" + } else { + arguments "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}" + } + abiFilters (*reactNativeArchitectures()) + } + } + } + + if (rootProject.hasProperty("ndkPath")) { + ndkPath rootProject.ext.ndkPath + } + if (rootProject.hasProperty("ndkVersion")) { + ndkVersion rootProject.ext.ndkVersion + } + + buildFeatures { + prefab true + } + + externalNativeBuild { + cmake { + path "CMakeLists.txt" + } + } + + packagingOptions { + doNotStrip resolveBuildType() == 'debug' ? "**/**/*.so" : '' + excludes = [ + "META-INF", + "META-INF/**", + "**/libjsi.so", + ] } buildTypes { @@ -149,8 +218,6 @@ repositories { } } -def REACT_NATIVE_VERSION = new File(['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim()) - dependencies { api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION api "org.mockito:mockito-core:2.28.2" diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java new file mode 100644 index 0000000000000..82d063ad51e3f --- /dev/null +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/FakeBlobModule.java @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ai.onnxruntime.reactnative; + +import com.facebook.react.bridge.Arguments; +import com.facebook.react.bridge.JavaOnlyMap; +import com.facebook.react.bridge.ReactApplicationContext; +import com.facebook.react.bridge.ReadableMap; +import com.facebook.react.modules.blob.BlobModule; + +public class FakeBlobModule extends BlobModule { + + public FakeBlobModule(ReactApplicationContext context) { super(null); } + + @Override + public String getName() { + return "BlobModule"; + } + + public JavaOnlyMap testCreateData(byte[] bytes) { + String blobId = store(bytes); + JavaOnlyMap data = new JavaOnlyMap(); + data.putString("blobId", blobId); + data.putInt("offset", 0); + data.putInt("size", bytes.length); + return data; + } + + public byte[] testGetData(ReadableMap data) { + String blobId = data.getString("blobId"); + int offset = data.getInt("offset"); + int size = data.getInt("size"); + return resolve(blobId, offset, size); + } +} diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java index 112d1c98608ec..12b790444975b 100644 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/OnnxruntimeModuleTest.java @@ -10,11 +10,14 @@ import android.util.Base64; import androidx.test.platform.app.InstrumentationRegistry; import com.facebook.react.bridge.Arguments; +import com.facebook.react.bridge.CatalystInstance; import com.facebook.react.bridge.JavaOnlyArray; import com.facebook.react.bridge.JavaOnlyMap; import com.facebook.react.bridge.ReactApplicationContext; import com.facebook.react.bridge.ReadableArray; import com.facebook.react.bridge.ReadableMap; +import com.facebook.react.bridge.WritableMap; +import com.facebook.react.modules.blob.BlobModule; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.nio.ByteBuffer; @@ -29,12 +32,17 @@ public class OnnxruntimeModuleTest { private ReactApplicationContext reactContext = new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext()); + private FakeBlobModule blobModule; + @Before - public void setUp() {} + public void setUp() { + blobModule = new FakeBlobModule(reactContext); + } @Test public void getName() throws Exception { OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); + ortModule.blobModule = blobModule; String name = "Onnxruntime"; Assert.assertEquals(ortModule.getName(), name); } @@ -47,6 +55,7 @@ public void onnxruntime_module() throws Exception { when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray()); OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext); + ortModule.blobModule = blobModule; String sessionKey = ""; // test loadModel() @@ -104,8 +113,7 @@ public void onnxruntime_module() throws Exception { floatBuffer.put(value); } floatBuffer.rewind(); - String dataEncoded = Base64.encodeToString(buffer.array(), Base64.DEFAULT); - inputTensorMap.putString("data", dataEncoded); + inputTensorMap.putMap("data", blobModule.testCreateData(buffer.array())); inputDataMap.putMap("input", inputTensorMap); } @@ -124,10 +132,9 @@ public void onnxruntime_module() throws Exception { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat); - String dataEncoded = outputMap.getString("data"); - FloatBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)) - .order(ByteOrder.nativeOrder()) - .asFloatBuffer(); + ReadableMap data = outputMap.getMap("data"); + FloatBuffer buffer = + ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer(); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); } diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java index f8caae96bbf86..76fd608e4362b 100644 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java @@ -20,7 +20,9 @@ import com.facebook.react.bridge.Arguments; import com.facebook.react.bridge.JavaOnlyArray; import com.facebook.react.bridge.JavaOnlyMap; +import com.facebook.react.bridge.ReactApplicationContext; import com.facebook.react.bridge.ReadableMap; +import com.facebook.react.modules.blob.BlobModule; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.nio.ByteBuffer; @@ -39,11 +41,17 @@ @SmallTest public class TensorHelperTest { + private ReactApplicationContext reactContext = + new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext()); + private OrtEnvironment ortEnvironment; + private FakeBlobModule blobModule; + @Before public void setUp() { ortEnvironment = OrtEnvironment.getEnvironment("TensorHelperTest"); + blobModule = new FakeBlobModule(reactContext); } @Test @@ -64,10 +72,9 @@ public void createInputTensor_float32() throws Exception { dataFloatBuffer.put(Float.MIN_VALUE); dataFloatBuffer.put(2.0f); dataFloatBuffer.put(Float.MAX_VALUE); - String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); - inputTensorMap.putString("data", dataEncoded); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); @@ -94,10 +101,9 @@ public void createInputTensor_int8() throws Exception { dataByteBuffer.put(Byte.MIN_VALUE); dataByteBuffer.put((byte)2); dataByteBuffer.put(Byte.MAX_VALUE); - String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); - inputTensorMap.putString("data", dataEncoded); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8); Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8); @@ -125,10 +131,9 @@ public void createInputTensor_uint8() throws Exception { dataByteBuffer.put((byte)0); dataByteBuffer.put((byte)2); dataByteBuffer.put((byte)255); - String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); - inputTensorMap.putString("data", dataEncoded); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); @@ -157,10 +162,9 @@ public void createInputTensor_int32() throws Exception { dataIntBuffer.put(Integer.MIN_VALUE); dataIntBuffer.put(2); dataIntBuffer.put(Integer.MAX_VALUE); - String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); - inputTensorMap.putString("data", dataEncoded); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); @@ -189,10 +193,9 @@ public void createInputTensor_int64() throws Exception { dataLongBuffer.put(Long.MIN_VALUE); dataLongBuffer.put(15000000001L); dataLongBuffer.put(Long.MAX_VALUE); - String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); - inputTensorMap.putString("data", dataEncoded); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); @@ -221,10 +224,9 @@ public void createInputTensor_double() throws Exception { dataDoubleBuffer.put(Double.MIN_VALUE); dataDoubleBuffer.put(1.8e+30); dataDoubleBuffer.put(Double.MAX_VALUE); - String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT); - inputTensorMap.putString("data", dataEncoded); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); - OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment); + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); Assert.assertEquals(outputTensor.getInfo().onnxType, @@ -258,14 +260,14 @@ public void createOutputTensor_bool() throws Exception { OrtSession.Result result = session.run(container); - ReadableMap resultMap = TensorHelper.createOutputTensor(result); + ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); ReadableMap outputMap = resultMap.getMap("output"); for (int i = 0; i < 2; ++i) { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeBool); - String dataEncoded = outputMap.getString("data"); - ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)); + ReadableMap data = outputMap.getMap("data"); + ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i) == 1, inputData[i]); } @@ -298,15 +300,15 @@ public void createOutputTensor_double() throws Exception { OrtSession.Result result = session.run(container); - ReadableMap resultMap = TensorHelper.createOutputTensor(result); + ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); ReadableMap outputMap = resultMap.getMap("output"); for (int i = 0; i < 2; ++i) { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeDouble); - String dataEncoded = outputMap.getString("data"); + ReadableMap data = outputMap.getMap("data"); DoubleBuffer buffer = - ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asDoubleBuffer(); + ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asDoubleBuffer(); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); } @@ -339,15 +341,14 @@ public void createOutputTensor_float() throws Exception { OrtSession.Result result = session.run(container); - ReadableMap resultMap = TensorHelper.createOutputTensor(result); + ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); ReadableMap outputMap = resultMap.getMap("output"); for (int i = 0; i < 2; ++i) { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat); - String dataEncoded = outputMap.getString("data"); - FloatBuffer buffer = - ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ReadableMap data = outputMap.getMap("data"); + FloatBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer(); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f); } @@ -380,14 +381,14 @@ public void createOutputTensor_int8() throws Exception { OrtSession.Result result = session.run(container); - ReadableMap resultMap = TensorHelper.createOutputTensor(result); + ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); ReadableMap outputMap = resultMap.getMap("output"); for (int i = 0; i < 2; ++i) { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeByte); - String dataEncoded = outputMap.getString("data"); - ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)); + ReadableMap data = outputMap.getMap("data"); + ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i), inputData[i]); } @@ -420,15 +421,14 @@ public void createOutputTensor_int32() throws Exception { OrtSession.Result result = session.run(container); - ReadableMap resultMap = TensorHelper.createOutputTensor(result); + ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); ReadableMap outputMap = resultMap.getMap("output"); for (int i = 0; i < 2; ++i) { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeInt); - String dataEncoded = outputMap.getString("data"); - IntBuffer buffer = - ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asIntBuffer(); + ReadableMap data = outputMap.getMap("data"); + IntBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asIntBuffer(); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i), inputData[i]); } @@ -461,15 +461,14 @@ public void createOutputTensor_int64() throws Exception { OrtSession.Result result = session.run(container); - ReadableMap resultMap = TensorHelper.createOutputTensor(result); + ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); ReadableMap outputMap = resultMap.getMap("output"); for (int i = 0; i < 2; ++i) { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeLong); - String dataEncoded = outputMap.getString("data"); - LongBuffer buffer = - ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asLongBuffer(); + ReadableMap data = outputMap.getMap("data"); + LongBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asLongBuffer(); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i), inputData[i]); } @@ -502,14 +501,14 @@ public void createOutputTensor_uint8() throws Exception { OrtSession.Result result = session.run(container); - ReadableMap resultMap = TensorHelper.createOutputTensor(result); + ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); ReadableMap outputMap = resultMap.getMap("output"); for (int i = 0; i < 2; ++i) { Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]); } Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeUnsignedByte); - String dataEncoded = outputMap.getString("data"); - ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)); + ReadableMap data = outputMap.getMap("data"); + ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)); for (int i = 0; i < 5; ++i) { Assert.assertEquals(buffer.get(i), inputData[i]); } diff --git a/js/react_native/android/src/main/cpp/cpp-adapter.cpp b/js/react_native/android/src/main/cpp/cpp-adapter.cpp new file mode 100644 index 0000000000000..be1228bbfe959 --- /dev/null +++ b/js/react_native/android/src/main/cpp/cpp-adapter.cpp @@ -0,0 +1,127 @@ +#include +#include +#include + +using namespace facebook; + +typedef u_int8_t byte; + +std::string jstring2string(JNIEnv *env, jstring jStr) { + if (!jStr) return ""; + + jclass stringClass = env->GetObjectClass(jStr); + jmethodID getBytes = env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B"); + const auto stringJbytes = (jbyteArray) env->CallObjectMethod(jStr, getBytes, env->NewStringUTF("UTF-8")); + + auto length = (size_t) env->GetArrayLength(stringJbytes); + jbyte* pBytes = env->GetByteArrayElements(stringJbytes, nullptr); + + std::string ret = std::string((char *)pBytes, length); + env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT); + + env->DeleteLocalRef(stringJbytes); + env->DeleteLocalRef(stringClass); + return ret; +} + +byte* getBytesFromBlob(JNIEnv *env, jobject instanceGlobal, const std::string& blobId, int offset, int size) { + if (!env) throw std::runtime_error("JNI Environment is gone!"); + + // get java class + jclass clazz = env->GetObjectClass(instanceGlobal); + // get method in java class + jmethodID getBufferJava = env->GetMethodID(clazz, "getBlobBuffer", "(Ljava/lang/String;II)[B"); + // call method + auto jstring = env->NewStringUTF(blobId.c_str()); + auto boxedBytes = (jbyteArray) env->CallObjectMethod(instanceGlobal, + getBufferJava, + // arguments + jstring, + offset, + size); + env->DeleteLocalRef(jstring); + + jboolean isCopy = true; + jbyte* bytes = env->GetByteArrayElements(boxedBytes, &isCopy); + env->DeleteLocalRef(boxedBytes); + return reinterpret_cast(bytes); +}; + +std::string createBlob(JNIEnv *env, jobject instanceGlobal, byte* bytes, size_t size) { + if (!env) throw std::runtime_error("JNI Environment is gone!"); + + // get java class + jclass clazz = env->GetObjectClass(instanceGlobal); + // get method in java class + jmethodID getBufferJava = env->GetMethodID(clazz, "createBlob", "([B)Ljava/lang/String;"); + // call method + auto byteArray = env->NewByteArray(size); + env->SetByteArrayRegion(byteArray, 0, size, reinterpret_cast(bytes)); + auto blobId = (jstring) env->CallObjectMethod(instanceGlobal, getBufferJava, byteArray); + env->DeleteLocalRef(byteArray); + + return jstring2string(env, blobId); +}; + +extern "C" +JNIEXPORT void JNICALL +Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv *env, jclass _, jlong jsiPtr, jobject instance) { + auto jsiRuntime = reinterpret_cast(jsiPtr); + + auto& runtime = *jsiRuntime; + + auto instanceGlobal = env->NewGlobalRef(instance); + + auto resolveArrayBuffer = jsi::Function::createFromHostFunction(runtime, + jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeResolveArrayBuffer"), + 1, + [=](jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) -> jsi::Value { + if (count != 1) { + throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!"); + } + + jsi::Object data = arguments[0].asObject(runtime); + auto blobId = data.getProperty(runtime, "blobId").asString(runtime); + auto offset = data.getProperty(runtime, "offset").asNumber(); + auto size = data.getProperty(runtime, "size").asNumber(); + + auto bytes = getBytesFromBlob(env, instanceGlobal, blobId.utf8(runtime), offset, size); + + size_t totalSize = size - offset; + jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer"); + jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int) totalSize).getObject(runtime); + jsi::ArrayBuffer buf = o.getArrayBuffer(runtime); + memcpy(buf.data(runtime), reinterpret_cast(bytes), totalSize); + + return buf; + }); + runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", std::move(resolveArrayBuffer)); + + auto storeArrayBuffer = jsi::Function::createFromHostFunction(runtime, + jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeStoreArrayBuffer"), + 1, + [=](jsi::Runtime& runtime, + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) -> jsi::Value { + if (count != 1) { + throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!"); + } + + auto arrayBuffer = arguments[0].asObject(runtime).getArrayBuffer(runtime); + auto size = arrayBuffer.size(runtime); + + std::string blobId = createBlob(env, instanceGlobal, arrayBuffer.data(runtime), size); + + jsi::Object result(runtime); + auto blobIdString = jsi::String::createFromUtf8(runtime, blobId); + result.setProperty(runtime, "blobId", blobIdString); + result.setProperty(runtime, "offset", jsi::Value(0)); + result.setProperty(runtime, "size", jsi::Value(static_cast(size))); + return result; + }); + runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", std::move(storeArrayBuffer)); +} diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java new file mode 100644 index 0000000000000..93b37df0768b4 --- /dev/null +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeJSIHelper.java @@ -0,0 +1,70 @@ +package ai.onnxruntime.reactnative; + +import androidx.annotation.NonNull; +import com.facebook.react.bridge.JavaScriptContextHolder; +import com.facebook.react.bridge.ReactApplicationContext; +import com.facebook.react.bridge.ReactContextBaseJavaModule; +import com.facebook.react.bridge.ReactMethod; +import com.facebook.react.module.annotations.ReactModule; +import com.facebook.react.modules.blob.BlobModule; + +@ReactModule(name = OnnxruntimeJSIHelper.NAME) +public class OnnxruntimeJSIHelper extends ReactContextBaseJavaModule { + public static final String NAME = "OnnxruntimeJSIHelper"; + + private static ReactApplicationContext reactContext; + protected BlobModule blobModule; + + public OnnxruntimeJSIHelper(ReactApplicationContext context) { + super(context); + reactContext = context; + } + + @Override + @NonNull + public String getName() { + return NAME; + } + + public void checkBlobModule() { + if (blobModule == null) { + blobModule = getReactApplicationContext().getNativeModule(BlobModule.class); + if (blobModule == null) { + throw new RuntimeException("BlobModule is not initialized"); + } + } + } + + @ReactMethod(isBlockingSynchronousMethod = true) + public boolean install() { + try { + System.loadLibrary("onnxruntimejsihelper"); + JavaScriptContextHolder jsContext = getReactApplicationContext().getJavaScriptContextHolder(); + nativeInstall(jsContext.get(), this); + return true; + } catch (Exception exception) { + return false; + } + } + + public byte[] getBlobBuffer(String blobId, int offset, int size) { + checkBlobModule(); + byte[] bytes = blobModule.resolve(blobId, offset, size); + blobModule.remove(blobId); + if (bytes == null) { + throw new RuntimeException("Failed to resolve Blob #" + blobId + "! Not found."); + } + return bytes; + } + + public String createBlob(byte[] buffer) { + checkBlobModule(); + String blobId = blobModule.store(buffer); + if (blobId == null) { + throw new RuntimeException("Failed to create Blob!"); + } + return blobId; + } + + public static native void nativeInstall(long jsiPointer, OnnxruntimeJSIHelper instance); +} diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index 685c8d0643f28..6ef550cde23f6 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -13,7 +13,6 @@ import ai.onnxruntime.OrtSession.SessionOptions; import android.net.Uri; import android.os.Build; -import android.util.Base64; import android.util.Log; import androidx.annotation.NonNull; import androidx.annotation.RequiresApi; @@ -28,6 +27,7 @@ import com.facebook.react.bridge.ReadableType; import com.facebook.react.bridge.WritableArray; import com.facebook.react.bridge.WritableMap; +import com.facebook.react.modules.blob.BlobModule; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; @@ -56,6 +56,8 @@ private static String getNextSessionKey() { return key; } + protected BlobModule blobModule; + public OnnxruntimeModule(ReactApplicationContext context) { super(context); reactContext = context; @@ -67,6 +69,15 @@ public String getName() { return "Onnxruntime"; } + public void checkBlobModule() { + if (blobModule == null) { + blobModule = getReactApplicationContext().getNativeModule(BlobModule.class); + if (blobModule == null) { + throw new RuntimeException("BlobModule is not initialized"); + } + } + } + /** * React native binding API to load a model using given uri. * @@ -87,19 +98,22 @@ public void loadModel(String uri, ReadableMap options, Promise promise) { } /** - * React native binding API to load a model using the BASE64 encoded model data. + * React native binding API to load a model using blob object that data stored in BlobModule. * - * @param data the BASE64 encoded model data. + * @param data the blob object * @param options onnxruntime session options * @param promise output returning back to react native js * @note the value provided to `promise` includes a key representing the session. * when run() is called, the key must be passed into the first parameter. */ @ReactMethod - public void loadModelFromBase64EncodedBuffer(String data, ReadableMap options, Promise promise) { + public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise promise) { try { - byte[] modelData = Base64.decode(data, Base64.DEFAULT); - WritableMap resultMap = loadModel(modelData, options); + checkBlobModule(); + String blobId = data.getString("blobId"); + byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size")); + blobModule.remove(blobId); + WritableMap resultMap = loadModel(bytes, options); promise.resolve(resultMap); } catch (Exception e) { promise.reject("Failed to load model from buffer: " + e.getMessage(), e); @@ -242,6 +256,8 @@ public WritableMap run(String key, ReadableMap input, ReadableArray output, Read RunOptions runOptions = parseRunOptions(options); + checkBlobModule(); + long startTime = System.currentTimeMillis(); Map feed = new HashMap<>(); Iterator iterator = ortSession.getInputNames().iterator(); @@ -255,19 +271,7 @@ public WritableMap run(String key, ReadableMap input, ReadableArray output, Read throw new Exception("Can't find input: " + inputName); } - if (inputMap.getType("data") != ReadableType.String) { - // NOTE: - // - // tensor data should always be a BASE64 encoded string. - // This is because the current React Native bridge supports limited data type as arguments. - // In order to pass data from JS to Java, we have to encode them into string. - // - // see also: - // https://reactnative.dev/docs/native-modules-android#argument-types - throw new Exception("Non string type of a tensor data is not allowed"); - } - - OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment); + OnnxTensor onnxTensor = TensorHelper.createInputTensor(blobModule, inputMap, ortEnvironment); feed.put(inputName, onnxTensor); } @@ -292,7 +296,7 @@ public WritableMap run(String key, ReadableMap input, ReadableArray output, Read Log.d("Duration", "inference: " + duration); startTime = System.currentTimeMillis(); - WritableMap resultMap = TensorHelper.createOutputTensor(result); + WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); duration = System.currentTimeMillis() - startTime; Log.d("Duration", "createOutputTensor: " + duration); diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java index b2ccbf10c7f9a..bb4386a0953f3 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimePackage.java @@ -9,6 +9,7 @@ import com.facebook.react.ReactPackage; import com.facebook.react.bridge.NativeModule; import com.facebook.react.bridge.ReactApplicationContext; +import com.facebook.react.modules.blob.BlobModule; import com.facebook.react.uimanager.ViewManager; import java.util.ArrayList; import java.util.Collections; @@ -21,6 +22,7 @@ public class OnnxruntimePackage implements ReactPackage { public List createNativeModules(@NonNull ReactApplicationContext reactContext) { List modules = new ArrayList<>(); modules.add(new OnnxruntimeModule(reactContext)); + modules.add(new OnnxruntimeJSIHelper(reactContext)); return modules; } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java index 500141ab51c49..d9c2e3bac5d9b 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java @@ -16,6 +16,7 @@ import com.facebook.react.bridge.ReadableMap; import com.facebook.react.bridge.WritableArray; import com.facebook.react.bridge.WritableMap; +import com.facebook.react.modules.blob.BlobModule; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.DoubleBuffer; @@ -45,9 +46,10 @@ public class TensorHelper { /** * It creates an input tensor from a map passed by react native js. - * 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor. + * 'data' is blob object and the buffer is stored in BlobModule. It first resolve it and creates a tensor. */ - public static OnnxTensor createInputTensor(ReadableMap inputTensor, OrtEnvironment ortEnvironment) throws Exception { + public static OnnxTensor createInputTensor(BlobModule blobModule, ReadableMap inputTensor, + OrtEnvironment ortEnvironment) throws Exception { // shape ReadableArray dimsArray = inputTensor.getArray("dims"); long[] dims = new long[dimsArray.size()]; @@ -68,8 +70,11 @@ public static OnnxTensor createInputTensor(ReadableMap inputTensor, OrtEnvironme } onnxTensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims); } else { - String data = inputTensor.getString("data"); - ByteBuffer values = ByteBuffer.wrap(Base64.decode(data, Base64.DEFAULT)).order(ByteOrder.nativeOrder()); + ReadableMap data = inputTensor.getMap("data"); + String blobId = data.getString("blobId"); + byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size")); + blobModule.remove(blobId); + ByteBuffer values = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()); onnxTensor = createInputTensor(tensorType, dims, values, ortEnvironment); } @@ -78,9 +83,9 @@ public static OnnxTensor createInputTensor(ReadableMap inputTensor, OrtEnvironme /** * It creates an output map from an output tensor. - * a data array is encoded as base64 string. + * a data array is store in BlobModule. */ - public static WritableMap createOutputTensor(OrtSession.Result result) throws Exception { + public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception { WritableMap outputTensorMap = Arguments.createMap(); Iterator> iterator = result.iterator(); @@ -115,8 +120,13 @@ public static WritableMap createOutputTensor(OrtSession.Result result) throws Ex } outputTensor.putArray("data", dataArray); } else { - String data = createOutputTensor(onnxTensor); - outputTensor.putString("data", data); + // Store in BlobModule then create a blob object as data + byte[] bufferArray = createOutputTensor(onnxTensor); + WritableMap data = Arguments.createMap(); + data.putString("blobId", blobModule.store(bufferArray)); + data.putInt("offset", 0); + data.putInt("size", bufferArray.length); + outputTensor.putMap("data", data); } outputTensorMap.putMap(outputName, outputTensor); @@ -177,7 +187,7 @@ private static OnnxTensor createInputTensor(TensorInfo.OnnxTensorType tensorType return tensor; } - private static String createOutputTensor(OnnxTensor onnxTensor) throws Exception { + private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception { TensorInfo tensorInfo = onnxTensor.getInfo(); ByteBuffer buffer = null; @@ -224,8 +234,7 @@ private static String createOutputTensor(OnnxTensor onnxTensor) throws Exception throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString()); } - String data = Base64.encodeToString(buffer.array(), Base64.DEFAULT); - return data; + return buffer.array(); } private static final Map JsTensorTypeToOnnxTensorTypeMap = diff --git a/js/react_native/ios/OnnxruntimeJSIHelper.h b/js/react_native/ios/OnnxruntimeJSIHelper.h new file mode 100644 index 0000000000000..990a4d1879ece --- /dev/null +++ b/js/react_native/ios/OnnxruntimeJSIHelper.h @@ -0,0 +1,5 @@ +#import + +@interface OnnxruntimeJSIHelper : NSObject + +@end diff --git a/js/react_native/ios/OnnxruntimeJSIHelper.mm b/js/react_native/ios/OnnxruntimeJSIHelper.mm new file mode 100644 index 0000000000000..6fac00cefaedb --- /dev/null +++ b/js/react_native/ios/OnnxruntimeJSIHelper.mm @@ -0,0 +1,85 @@ +#import "OnnxruntimeJSIHelper.h" +#import +#import +#import + +@implementation OnnxruntimeJSIHelper + +RCT_EXPORT_MODULE() + +RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(install) { + RCTBridge *bridge = [RCTBridge currentBridge]; + RCTCxxBridge *cxxBridge = (RCTCxxBridge *)bridge; + if (cxxBridge == nil) { + return @false; + } + + using namespace facebook; + + auto jsiRuntime = (jsi::Runtime *)cxxBridge.runtime; + if (jsiRuntime == nil) { + return @false; + } + auto &runtime = *jsiRuntime; + + auto resolveArrayBuffer = jsi::Function::createFromHostFunction( + runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeResolveArrayBuffer"), 1, + [](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value { + if (count != 1) { + throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!"); + } + + auto data = args[0].asObject(runtime); + auto blobId = data.getProperty(runtime, "blobId").asString(runtime).utf8(runtime); + auto size = data.getProperty(runtime, "size").asNumber(); + auto offset = data.getProperty(runtime, "offset").asNumber(); + + RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; + if (blobManager == nil) { + throw jsi::JSError(runtime, "RCTBlobManager is not initialized"); + } + + NSString *blobIdStr = [NSString stringWithUTF8String:blobId.c_str()]; + auto blob = [blobManager resolve:blobIdStr offset:(long)offset size:(long)size]; + + jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer"); + jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int)blob.length).getObject(runtime); + jsi::ArrayBuffer buf = o.getArrayBuffer(runtime); + memcpy(buf.data(runtime), blob.bytes, blob.length); + [blobManager remove:blobIdStr]; + return buf; + }); + runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", resolveArrayBuffer); + + auto storeArrayBuffer = jsi::Function::createFromHostFunction( + runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeStoreArrayBuffer"), 1, + [](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value { + if (count != 1) { + throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!"); + } + + auto arrayBuffer = args[0].asObject(runtime).getArrayBuffer(runtime); + auto size = arrayBuffer.length(runtime); + NSData *data = [NSData dataWithBytesNoCopy:arrayBuffer.data(runtime) length:size freeWhenDone:NO]; + + RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; + if (blobManager == nil) { + throw jsi::JSError(runtime, "RCTBlobManager is not initialized"); + } + + NSString *blobId = [blobManager store:data]; + + jsi::Object result(runtime); + auto blobIdString = jsi::String::createFromUtf8(runtime, [blobId cStringUsingEncoding:NSUTF8StringEncoding]); + result.setProperty(runtime, "blobId", blobIdString); + result.setProperty(runtime, "offset", jsi::Value(0)); + result.setProperty(runtime, "size", jsi::Value(static_cast(size))); + return result; + }); + + runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", storeArrayBuffer); + + return @true; +} + +@end diff --git a/js/react_native/ios/OnnxruntimeModule.h b/js/react_native/ios/OnnxruntimeModule.h index 4ea3d21f2088f..24603cc648525 100644 --- a/js/react_native/ios/OnnxruntimeModule.h +++ b/js/react_native/ios/OnnxruntimeModule.h @@ -5,9 +5,12 @@ #define OnnxruntimeModule_h #import +#import @interface OnnxruntimeModule : NSObject +- (void)setBlobManager:(RCTBlobManager *)manager; + -(NSDictionary*)loadModel:(NSString*)modelPath options:(NSDictionary*)options; diff --git a/js/react_native/ios/OnnxruntimeModule.mm b/js/react_native/ios/OnnxruntimeModule.mm index fa34583cb0281..117b7379748f2 100644 --- a/js/react_native/ios/OnnxruntimeModule.mm +++ b/js/react_native/ios/OnnxruntimeModule.mm @@ -5,6 +5,8 @@ #import "TensorHelper.h" #import +#import +#import #import // Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened @@ -44,6 +46,21 @@ - (NSString *)getNextSessionKey { RCT_EXPORT_MODULE(Onnxruntime) +RCTBlobManager *blobManager = nil; + +- (void)checkBlobManager { + if (blobManager == nil) { + blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; + if (blobManager == nil) { + @throw @"RCTBlobManager is not initialized"; + } + } +} + +- (void)setBlobManager:(RCTBlobManager *)manager { + blobManager = manager; +} + /** * React native binding API to load a model using given uri. * @@ -68,22 +85,27 @@ - (NSString *)getNextSessionKey { } /** - * React native binding API to load a model using BASE64 encoded model data string. + * React native binding API to load a model using blob object that data stored in RCTBlobManager. * - * @param modelData the BASE64 encoded model data string + * @param modelDataBlob a model data blob object * @param options onnxruntime session options * @param resolve callback for returning output back to react native js * @param reject callback for returning an error back to react native js * @note when run() is called, the same modelPath must be passed into the first parameter. */ -RCT_EXPORT_METHOD(loadModelFromBase64EncodedBuffer - : (NSString *)modelDataBase64EncodedString options +RCT_EXPORT_METHOD(loadModelFromBlob + : (NSDictionary *)modelDataBlob options : (NSDictionary *)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) { @try { - NSData *modelDataDecoded = [[NSData alloc] initWithBase64EncodedString:modelDataBase64EncodedString options:0]; - NSDictionary *resultMap = [self loadModelFromBuffer:modelDataDecoded options:options]; + [self checkBlobManager]; + NSString *blobId = [modelDataBlob objectForKey:@"blobId"]; + long size = [[modelDataBlob objectForKey:@"size"] longValue]; + long offset = [[modelDataBlob objectForKey:@"offset"] longValue]; + auto modelData = [blobManager resolve:blobId offset:offset size:size]; + NSDictionary *resultMap = [self loadModelFromBuffer:modelData options:options]; + [blobManager remove:blobId]; resolve(resultMap); } @catch (...) { reject(@"onnxruntime", @"failed to load model from buffer", nil); @@ -255,6 +277,8 @@ - (NSDictionary *)run:(NSString *)url } SessionInfo *sessionInfo = (SessionInfo *)[value pointerValue]; + [self checkBlobManager]; + std::vector feeds; std::vector allocations; feeds.reserve(sessionInfo->inputNames.size()); @@ -265,7 +289,10 @@ - (NSDictionary *)run:(NSString *)url @throw exception; } - Ort::Value value = [TensorHelper createInputTensor:inputTensor ortAllocator:ortAllocator allocations:allocations]; + Ort::Value value = [TensorHelper createInputTensor:blobManager + input:inputTensor + ortAllocator:ortAllocator + allocations:allocations]; feeds.emplace_back(std::move(value)); } @@ -280,7 +307,7 @@ - (NSDictionary *)run:(NSString *)url sessionInfo->session->Run(runOptions, sessionInfo->inputNames.data(), feeds.data(), sessionInfo->inputNames.size(), requestedOutputs.data(), requestedOutputs.size()); - NSDictionary *resultMap = [TensorHelper createOutputTensor:requestedOutputs values:result]; + NSDictionary *resultMap = [TensorHelper createOutputTensor:blobManager outputNames:requestedOutputs values:result]; return resultMap; } @@ -378,6 +405,7 @@ - (void)dealloc { while (NSString *key = [iterator nextObject]) { [self dispose:key]; } + blobManager = nullptr; } @end diff --git a/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj b/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj index 23f33be0cdc26..2a093b2b89c95 100644 --- a/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj +++ b/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj @@ -7,6 +7,9 @@ objects = { /* Begin PBXBuildFile section */ + 0105483CF04B9471894F3EAA /* Pods_OnnxruntimeModuleTest.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */; }; + 7FD234672A1F221700734B71 /* FakeRCTBlobManager.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */; }; + C60033360456900E26D6F96F /* Pods_OnnxruntimeModule.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */; }; DB8FC9B525C2867800C72F26 /* OnnxruntimeModule.mm in Sources */ = {isa = PBXBuildFile; fileRef = DB8FC9B425C2867800C72F26 /* OnnxruntimeModule.mm */; }; DB8FC9B825C2868700C72F26 /* TensorHelper.mm in Sources */ = {isa = PBXBuildFile; fileRef = DB8FC9B725C2868700C72F26 /* TensorHelper.mm */; }; DBDB57DA2603211A004F16BE /* TensorHelperTest.mm in Sources */ = {isa = PBXBuildFile; fileRef = DBDB57D92603211A004F16BE /* TensorHelperTest.mm */; }; @@ -39,6 +42,14 @@ /* Begin PBXFileReference section */ 134814201AA4EA6300B7C361 /* libOnnxruntimeModule.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libOnnxruntimeModule.a; sourceTree = BUILT_PRODUCTS_DIR; }; + 38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_OnnxruntimeModuleTest.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + 49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_OnnxruntimeModule.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + 5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleTest.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest.debug.xcconfig"; sourceTree = ""; }; + 548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModule.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModule/Pods-OnnxruntimeModule.release.xcconfig"; sourceTree = ""; }; + 63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModule.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModule/Pods-OnnxruntimeModule.debug.xcconfig"; sourceTree = ""; }; + 7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = FakeRCTBlobManager.m; sourceTree = ""; }; + 7FD234682A1F234500734B71 /* FakeRCTBlobManager.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = FakeRCTBlobManager.h; sourceTree = ""; }; + 8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleTest.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest.release.xcconfig"; sourceTree = ""; }; DB8FC9B425C2867800C72F26 /* OnnxruntimeModule.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = OnnxruntimeModule.mm; sourceTree = SOURCE_ROOT; }; DB8FC9B725C2868700C72F26 /* TensorHelper.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = TensorHelper.mm; sourceTree = SOURCE_ROOT; }; DBDB57D72603211A004F16BE /* OnnxruntimeModuleTest.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = OnnxruntimeModuleTest.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -53,6 +64,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + C60033360456900E26D6F96F /* Pods_OnnxruntimeModule.framework in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -61,6 +73,7 @@ buildActionMask = 2147483647; files = ( DBDB57DC2603211A004F16BE /* libOnnxruntimeModule.a in Frameworks */, + 0105483CF04B9471894F3EAA /* Pods_OnnxruntimeModuleTest.framework in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -84,16 +97,30 @@ 134814211AA4EA7D00B7C361 /* Products */, 62ED2272D9F9CF7E3D0A8F87 /* Pods */, DBDB57D72603211A004F16BE /* OnnxruntimeModuleTest.xctest */, + 6FFDF1594C99DA125B013E34 /* Frameworks */, ); sourceTree = ""; }; 62ED2272D9F9CF7E3D0A8F87 /* Pods */ = { isa = PBXGroup; children = ( + 63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */, + 548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */, + 5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */, + 8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */, ); path = Pods; sourceTree = ""; }; + 6FFDF1594C99DA125B013E34 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */, + 38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; DB8FC9B325C2861300C72F26 /* OnnxruntimeModule */ = { isa = PBXGroup; children = ( @@ -109,6 +136,8 @@ DBDB57D92603211A004F16BE /* TensorHelperTest.mm */, DBDB57DB2603211A004F16BE /* Info.plist */, DBDB58AF262A92D6004F16BE /* OnnxruntimeModuleTest.mm */, + 7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */, + 7FD234682A1F234500734B71 /* FakeRCTBlobManager.h */, ); path = OnnxruntimeModuleTest; sourceTree = ""; @@ -120,6 +149,7 @@ isa = PBXNativeTarget; buildConfigurationList = 58B511EF1A9E6C8500147676 /* Build configuration list for PBXNativeTarget "OnnxruntimeModule" */; buildPhases = ( + FA8BD7B76BD8BD02A6DB750A /* [CP] Check Pods Manifest.lock */, 58B511D71A9E6C8500147676 /* Sources */, 58B511D81A9E6C8500147676 /* Frameworks */, 58B511D91A9E6C8500147676 /* CopyFiles */, @@ -137,9 +167,11 @@ isa = PBXNativeTarget; buildConfigurationList = DBDB57E12603211A004F16BE /* Build configuration list for PBXNativeTarget "OnnxruntimeModuleTest" */; buildPhases = ( + 896E89AEC864CBD0CC7E0AF1 /* [CP] Check Pods Manifest.lock */, DBDB57D32603211A004F16BE /* Sources */, DBDB57D42603211A004F16BE /* Frameworks */, DBDB57D52603211A004F16BE /* Resources */, + 015C75E59BC80D4507FB6E8A /* [CP] Embed Pods Frameworks */, ); buildRules = ( ); @@ -200,6 +232,119 @@ }; /* End PBXResourcesBuildPhase section */ +/* Begin PBXShellScriptBuildPhase section */ + 015C75E59BC80D4507FB6E8A /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + "${PODS_ROOT}/Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest-frameworks.sh", + "${BUILT_PRODUCTS_DIR}/DoubleConversion/DoubleConversion.framework", + "${BUILT_PRODUCTS_DIR}/RCT-Folly/folly.framework", + "${BUILT_PRODUCTS_DIR}/RCTTypeSafety/RCTTypeSafety.framework", + "${BUILT_PRODUCTS_DIR}/React-Codegen/React_Codegen.framework", + "${BUILT_PRODUCTS_DIR}/React-Core/React.framework", + "${BUILT_PRODUCTS_DIR}/React-CoreModules/CoreModules.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTAnimation/RCTAnimation.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTBlob/RCTBlob.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTImage/RCTImage.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTLinking/RCTLinking.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTNetwork/RCTNetwork.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTSettings/RCTSettings.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTText/RCTText.framework", + "${BUILT_PRODUCTS_DIR}/React-RCTVibration/RCTVibration.framework", + "${BUILT_PRODUCTS_DIR}/React-bridging/react_bridging.framework", + "${BUILT_PRODUCTS_DIR}/React-cxxreact/cxxreact.framework", + "${BUILT_PRODUCTS_DIR}/React-jsi/jsi.framework", + "${BUILT_PRODUCTS_DIR}/React-jsiexecutor/jsireact.framework", + "${BUILT_PRODUCTS_DIR}/React-jsinspector/jsinspector.framework", + "${BUILT_PRODUCTS_DIR}/React-logger/logger.framework", + "${BUILT_PRODUCTS_DIR}/React-perflogger/reactperflogger.framework", + "${BUILT_PRODUCTS_DIR}/ReactCommon/ReactCommon.framework", + "${BUILT_PRODUCTS_DIR}/Yoga/yoga.framework", + "${BUILT_PRODUCTS_DIR}/fmt/fmt.framework", + "${BUILT_PRODUCTS_DIR}/glog/glog.framework", + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/DoubleConversion.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/folly.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTTypeSafety.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/React_Codegen.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/React.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/CoreModules.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTAnimation.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTBlob.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTImage.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTLinking.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTNetwork.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTSettings.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTText.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTVibration.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/react_bridging.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/cxxreact.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsi.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsireact.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsinspector.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/logger.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/reactperflogger.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/ReactCommon.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/yoga.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/fmt.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/glog.framework", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + 896E89AEC864CBD0CC7E0AF1 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-OnnxruntimeModuleTest-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; + FA8BD7B76BD8BD02A6DB750A /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-OnnxruntimeModule-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + /* Begin PBXSourcesBuildPhase section */ 58B511D71A9E6C8500147676 /* Sources */ = { isa = PBXSourcesBuildPhase; @@ -216,6 +361,7 @@ files = ( DBDB57DA2603211A004F16BE /* TensorHelperTest.mm in Sources */, DBDB58B0262A92D7004F16BE /* OnnxruntimeModuleTest.mm in Sources */, + 7FD234672A1F221700734B71 /* FakeRCTBlobManager.m in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -329,6 +475,7 @@ }; 58B511F01A9E6C8500147676 /* Debug */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */; buildSettings = { HEADER_SEARCH_PATHS = ( "$(inherited)", @@ -352,6 +499,7 @@ }; 58B511F11A9E6C8500147676 /* Release */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */; buildSettings = { HEADER_SEARCH_PATHS = ( "$(inherited)", @@ -374,6 +522,7 @@ }; DBDB57DF2603211A004F16BE /* Debug */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */; buildSettings = { CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; @@ -446,6 +595,7 @@ }; DBDB57E02603211A004F16BE /* Release */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */; buildSettings = { CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; diff --git a/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h new file mode 100644 index 0000000000000..c6069b1a1d26d --- /dev/null +++ b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef FakeRCTBlobManager_h +#define FakeRCTBlobManager_h + +#import + +@interface FakeRCTBlobManager : RCTBlobManager + +@property (nonatomic, strong) NSMutableDictionary *blobs; + +- (NSString *)store:(NSData *)data; + +- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size; + +- (NSDictionary *)testCreateData:(NSData *)buffer; + +- (NSString *)testGetData:(NSDictionary *)data; + +@end + +#endif /* FakeRCTBlobManager_h */ diff --git a/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m new file mode 100644 index 0000000000000..5df902df03534 --- /dev/null +++ b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#import +#import "FakeRCTBlobManager.h" + +@implementation FakeRCTBlobManager + +- (instancetype)init { + if (self = [super init]) { + _blobs = [NSMutableDictionary new]; + } + return self; +} + +- (NSString *)store:(NSData *)data { + NSString *blobId = [[NSUUID UUID] UUIDString]; + _blobs[blobId] = data; + return blobId; +} + +- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size { + NSData *data = _blobs[blobId]; + if (data == nil) { + return nil; + } + return [data subdataWithRange:NSMakeRange(offset, size)]; +} + +- (NSDictionary *)testCreateData:(NSData *)buffer { + NSString* blobId = [self store:buffer]; + return @{ + @"blobId": blobId, + @"offset": @0, + @"size": @(buffer.length), + }; +} + +- (NSString *)testGetData:(NSDictionary *)data { + NSString *blobId = [data objectForKey:@"blobId"]; + long size = [[data objectForKey:@"size"] longValue]; + long offset = [[data objectForKey:@"offset"] longValue]; + [self resolve:blobId offset:offset size:size]; + return blobId; +} + +@end diff --git a/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm b/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm index 86bf35229da5a..03aa40bcc339d 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm +++ b/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm @@ -2,6 +2,7 @@ // Licensed under the MIT License. #import "OnnxruntimeModule.h" +#import "FakeRCTBlobManager.h" #import "TensorHelper.h" #import @@ -13,6 +14,14 @@ @interface OnnxruntimeModuleTest : XCTestCase @implementation OnnxruntimeModuleTest +FakeRCTBlobManager *fakeBlobManager = nil; + ++ (void)initialize { + if (self == [OnnxruntimeModuleTest class]) { + fakeBlobManager = [FakeRCTBlobManager new]; + } +} + - (void)testOnnxruntimeModule { NSBundle *bundle = [NSBundle bundleForClass:[OnnxruntimeModuleTest class]]; NSString *dataPath = [bundle pathForResource:@"test_types_float" ofType:@"ort"]; @@ -20,6 +29,7 @@ - (void)testOnnxruntimeModule { NSString *sessionKey2 = @""; OnnxruntimeModule *onnxruntimeModule = [OnnxruntimeModule new]; + [onnxruntimeModule setBlobManager:fakeBlobManager]; { // test loadModelFromBuffer() @@ -70,8 +80,8 @@ - (void)testOnnxruntimeModule { } floatPtr = (float *)[byteBufferRef bytes]; - NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0]; - inputTensorMap[@"data"] = dataEncoded; + XCTAssertNotNil(fakeBlobManager); + inputTensorMap[@"data"] = [fakeBlobManager testCreateData:byteBufferRef]; NSMutableDictionary *inputDataMap = [NSMutableDictionary dictionary]; inputDataMap[@"input"] = inputTensorMap; @@ -84,8 +94,18 @@ - (void)testOnnxruntimeModule { NSDictionary *resultMap = [onnxruntimeModule run:sessionKey input:inputDataMap output:output options:options]; NSDictionary *resultMap2 = [onnxruntimeModule run:sessionKey2 input:inputDataMap output:output options:options]; - XCTAssertTrue([[resultMap objectForKey:@"output"] isEqualToDictionary:inputTensorMap]); - XCTAssertTrue([[resultMap2 objectForKey:@"output"] isEqualToDictionary:inputTensorMap]); + // Compare output & input, but data.blobId is different + // dims + XCTAssertTrue([[resultMap objectForKey:@"output"][@"dims"] isEqualToArray:inputTensorMap[@"dims"]]); + XCTAssertTrue([[resultMap2 objectForKey:@"output"][@"dims"] isEqualToArray:inputTensorMap[@"dims"]]); + + // type + XCTAssertEqual([resultMap objectForKey:@"output"][@"type"], JsTensorTypeFloat); + XCTAssertEqual([resultMap2 objectForKey:@"output"][@"type"], JsTensorTypeFloat); + + // data ({ blobId, offset, size }) + XCTAssertEqual([[resultMap objectForKey:@"output"][@"data"][@"offset"] longValue], 0); + XCTAssertEqual([[resultMap2 objectForKey:@"output"][@"data"][@"size"] longValue], byteBufferSize); } // test dispose diff --git a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm index 3ed082e22237f..42e903a9b45e3 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm +++ b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm @@ -3,6 +3,7 @@ #import "TensorHelper.h" +#import "FakeRCTBlobManager.h" #import #import #include @@ -13,6 +14,14 @@ @interface TensorHelperTest : XCTestCase @implementation TensorHelperTest +FakeRCTBlobManager *testBlobManager = nil; + ++ (void)initialize { + if (self == [TensorHelperTest class]) { + testBlobManager = [FakeRCTBlobManager new]; + } +} + template static void testCreateInputTensorT(const std::array &outValues, std::function &convert, ONNXTensorElementDataType onnxType, NSString *jsTensorType) { @@ -34,12 +43,13 @@ static void testCreateInputTensorT(const std::array &outValues, std::funct typePtr[i] = outValues[i]; } - NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0]; - inputTensorMap[@"data"] = dataEncoded; + XCTAssertNotNil(testBlobManager); + inputTensorMap[@"data"] = [testBlobManager testCreateData:byteBufferRef]; Ort::AllocatorWithDefaultOptions ortAllocator; std::vector allocations; - Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap + Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager + input:inputTensorMap ortAllocator:ortAllocator allocations:allocations]; @@ -126,7 +136,8 @@ - (void)testCreateInputTensorString { Ort::AllocatorWithDefaultOptions ortAllocator; std::vector allocations; - Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap + Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager + input:inputTensorMap ortAllocator:ortAllocator allocations:allocations]; @@ -194,10 +205,11 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func typePtr[i] = outValues[i]; } - NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0]; - inputTensorMap[@"data"] = dataEncoded; + inputTensorMap[@"data"] = [testBlobManager testCreateData:byteBufferRef]; + ; std::vector allocations; - Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap + Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager + input:inputTensorMap ortAllocator:ortAllocator allocations:allocations]; @@ -208,9 +220,24 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func auto output = session.Run(runOptions, inputNames.data(), feeds.data(), inputNames.size(), outputNames.data(), outputNames.size()); - NSDictionary *resultMap = [TensorHelper createOutputTensor:outputNames values:output]; + NSDictionary *resultMap = [TensorHelper createOutputTensor:testBlobManager outputNames:outputNames values:output]; + + // Compare output & input, but data.blobId is different + + NSDictionary *outputMap = [resultMap objectForKey:@"output"]; + + // dims + XCTAssertTrue([outputMap[@"dims"] isEqualToArray:inputTensorMap[@"dims"]]); + + // type + XCTAssertEqual(outputMap[@"type"], jsTensorType); + + // data ({ blobId, offset, size }) + NSDictionary *data = outputMap[@"data"]; - XCTAssertTrue([[resultMap objectForKey:@"output"] isEqualToDictionary:inputTensorMap]); + XCTAssertNotNil(data[@"blobId"]); + XCTAssertEqual([data[@"offset"] longValue], 0); + XCTAssertEqual([data[@"size"] longValue], byteBufferSize); } - (void)testCreateOutputTensorFloat { diff --git a/js/react_native/ios/TensorHelper.h b/js/react_native/ios/TensorHelper.h index 73794921263f3..4d2aad5f4fcb8 100644 --- a/js/react_native/ios/TensorHelper.h +++ b/js/react_native/ios/TensorHelper.h @@ -5,6 +5,7 @@ #define TensorHelper_h #import +#import // Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened // in an expo react native ios app (a redefinition error happened with multiple object types defined within @@ -36,17 +37,19 @@ FOUNDATION_EXPORT NSString* const JsTensorTypeString; /** * It creates an input tensor from a map passed by react native js. - * 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor. + * 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor. */ -+(Ort::Value)createInputTensor:(NSDictionary*)input ++(Ort::Value)createInputTensor:(RCTBlobManager *)blobManager + input:(NSDictionary*)input ortAllocator:(OrtAllocator*)ortAllocator - allocations:(std::vector&)allocatons; + allocations:(std::vector&)allocations; /** * It creates an output map from an output tensor. - * a data array is encoded as base64 string. + * a data array is store in RCTBlobManager. */ -+(NSDictionary*)createOutputTensor:(const std::vector&)outputNames ++(NSDictionary*)createOutputTensor:(RCTBlobManager *)blobManager + outputNames:(const std::vector&)outputNames values:(const std::vector&)values; @end diff --git a/js/react_native/ios/TensorHelper.mm b/js/react_native/ios/TensorHelper.mm index 00c1c79defd88..1d6a3a3b79ed4 100644 --- a/js/react_native/ios/TensorHelper.mm +++ b/js/react_native/ios/TensorHelper.mm @@ -21,11 +21,12 @@ @implementation TensorHelper /** * It creates an input tensor from a map passed by react native js. - * 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor. + * 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor. */ -+ (Ort::Value)createInputTensor:(NSDictionary *)input ++ (Ort::Value)createInputTensor:(RCTBlobManager *)blobManager + input:(NSDictionary *)input ortAllocator:(OrtAllocator *)ortAllocator - allocations:(std::vector &)allocatons { + allocations:(std::vector &)allocations { // shape NSArray *dimsArray = [input objectForKey:@"dims"]; std::vector dims; @@ -48,22 +49,27 @@ @implementation TensorHelper } return inputTensor; } else { - NSString *data = [input objectForKey:@"data"]; - NSData *buffer = [[NSData alloc] initWithBase64EncodedString:data options:0]; + NSDictionary *data = [input objectForKey:@"data"]; + NSString *blobId = [data objectForKey:@"blobId"]; + long size = [[data objectForKey:@"size"] longValue]; + long offset = [[data objectForKey:@"offset"] longValue]; + auto buffer = [blobManager resolve:blobId offset:offset size:size]; Ort::Value inputTensor = [self createInputTensor:tensorType dims:dims buffer:buffer ortAllocator:ortAllocator - allocations:allocatons]; + allocations:allocations]; + [blobManager remove:blobId]; return inputTensor; } } /** * It creates an output map from an output tensor. - * a data array is encoded as base64 string. + * a data array is store in RCTBlobManager. */ -+ (NSDictionary *)createOutputTensor:(const std::vector &)outputNames ++ (NSDictionary *)createOutputTensor:(RCTBlobManager *)blobManager + outputNames:(const std::vector &)outputNames values:(const std::vector &)values { if (outputNames.size() != values.size()) { NSException *exception = [NSException exceptionWithName:@"create output tensor" @@ -109,8 +115,13 @@ + (NSDictionary *)createOutputTensor:(const std::vector &)outputNa } outputTensor[@"data"] = buffer; } else { - NSString *data = [self createOutputTensor:value]; - outputTensor[@"data"] = data; + NSData *data = [self createOutputTensor:value]; + NSString *blobId = [blobManager store:data]; + outputTensor[@"data"] = @{ + @"blobId" : blobId, + @"offset" : @0, + @"size" : @(data.length), + }; } outputTensorMap[[NSString stringWithUTF8String:outputName]] = outputTensor; @@ -170,15 +181,14 @@ + (NSDictionary *)createOutputTensor:(const std::vector &)outputNa } } -template static NSString *createOutputTensorT(const Ort::Value &tensor) { +template static NSData *createOutputTensorT(const Ort::Value &tensor) { const auto data = tensor.GetTensorData(); - NSData *buffer = [NSData dataWithBytesNoCopy:(void *)data - length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T) - freeWhenDone:false]; - return [buffer base64EncodedStringWithOptions:0]; + return [NSData dataWithBytesNoCopy:(void *)data + length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T) + freeWhenDone:false]; } -+ (NSString *)createOutputTensor:(const Ort::Value &)tensor { ++ (NSData *)createOutputTensor:(const Ort::Value &)tensor { ONNXTensorElementDataType tensorType = tensor.GetTensorTypeAndShapeInfo().GetElementType(); switch (tensorType) { diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index fa6d5ceaef994..b3f0c466308a5 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Buffer} from 'buffer'; import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common'; import {Platform} from 'react-native'; -import {binding, Binding} from './binding'; +import {binding, Binding, JSIBlob, jsiHelper} from './binding'; type SupportedTypedArray = Exclude; @@ -69,11 +68,11 @@ class OnnxruntimeSessionHandler implements SessionHandler { if (typeof this.#pathOrBuffer === 'string') { results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options); } else { - if (!this.#inferenceSession.loadModelFromBase64EncodedBuffer) { - throw new Error('Native module method "loadModelFromBase64EncodedBuffer" is not defined'); + if (!this.#inferenceSession.loadModelFromBlob) { + throw new Error('Native module method "loadModelFromBlob" is not defined'); } - const modelInBase64String = Buffer.from(this.#pathOrBuffer).toString('base64'); - results = await this.#inferenceSession.loadModelFromBase64EncodedBuffer(modelInBase64String, options); + const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer); + results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options); } // resolve promise if onnxruntime session is successfully created this.#key = results.key; @@ -113,18 +112,18 @@ class OnnxruntimeSessionHandler implements SessionHandler { return output; } + encodeFeedsType(feeds: SessionHandler.FeedsType): Binding.FeedsType { const returnValue: {[name: string]: Binding.EncodedTensorType} = {}; for (const key in feeds) { if (Object.hasOwnProperty.call(feeds, key)) { - let data: string|string[]; + let data: JSIBlob|string[]; if (Array.isArray(feeds[key].data)) { data = feeds[key].data as string[]; } else { - // Base64-encode tensor data const buffer = (feeds[key].data as SupportedTypedArray).buffer; - data = Buffer.from(buffer, 0, buffer.byteLength).toString('base64'); + data = jsiHelper.storeArrayBuffer(buffer); } returnValue[key] = { @@ -146,9 +145,9 @@ class OnnxruntimeSessionHandler implements SessionHandler { if (Array.isArray(results[key].data)) { tensorData = results[key].data as string[]; } else { - const buffer: Buffer = Buffer.from(results[key].data as string, 'base64'); + const buffer = jsiHelper.resolveArrayBuffer(results[key].data as JSIBlob) as SupportedTypedArray; const typedArray = tensorTypeToTypedArray(results[key].type as Tensor.Type); - tensorData = new typedArray(buffer.buffer, buffer.byteOffset, buffer.length / typedArray.BYTES_PER_ELEMENT); + tensorData = new typedArray(buffer, buffer.byteOffset, buffer.byteLength / typedArray.BYTES_PER_ELEMENT); } returnValue[key] = new Tensor(results[key].type as Tensor.Type, tensorData, results[key].dims); diff --git a/js/react_native/lib/binding.ts b/js/react_native/lib/binding.ts index 6d621fa3dc3df..5ecf85dcd25ab 100644 --- a/js/react_native/lib/binding.ts +++ b/js/react_native/lib/binding.ts @@ -26,7 +26,14 @@ interface ModelLoadInfo { } /** - * Tensor type for react native, which doesn't allow ArrayBuffer, so data will be encoded as Base64 string. + * JSIBlob is a blob object that exchange ArrayBuffer by OnnxruntimeJSIHelper. + */ +export type JSIBlob = { + blobId: string; offset: number; size: number; +}; + +/** + * Tensor type for react native, which doesn't allow ArrayBuffer in native bridge, so data will be stored as JSIBlob. */ interface EncodedTensor { /** @@ -38,10 +45,10 @@ interface EncodedTensor { */ readonly type: string; /** - * the Base64 encoded string of the buffer data of the tensor. - * if data is string array, it won't be encoded as Base64 string. + * the JSIBlob object of the buffer data of the tensor. + * if data is string array, it won't be stored as JSIBlob. */ - readonly data: string|string[]; + readonly data: JSIBlob|string[]; } /** @@ -64,12 +71,41 @@ export declare namespace Binding { interface InferenceSession { loadModel(modelPath: string, options: SessionOptions): Promise; - loadModelFromBase64EncodedBuffer?(buffer: string, options: SessionOptions): Promise; + loadModelFromBlob?(blob: JSIBlob, options: SessionOptions): Promise; dispose(key: string): Promise; run(key: string, feeds: FeedsType, fetches: FetchesType, options: RunOptions): Promise; } } // export native binding -const {Onnxruntime} = NativeModules; +const {Onnxruntime, OnnxruntimeJSIHelper} = NativeModules; export const binding = Onnxruntime as Binding.InferenceSession; + +// install JSI helper global functions +OnnxruntimeJSIHelper.install(); + +declare global { + // eslint-disable-next-line no-var + var jsiOnnxruntimeStoreArrayBuffer: ((buffer: ArrayBuffer) => JSIBlob)|undefined; + // eslint-disable-next-line no-var + var jsiOnnxruntimeResolveArrayBuffer: ((blob: JSIBlob) => ArrayBuffer)|undefined; +} + +export const jsiHelper = { + storeArrayBuffer: globalThis.jsiOnnxruntimeStoreArrayBuffer || (() => { + throw new Error( + 'jsiOnnxruntimeStoreArrayBuffer is not found, ' + + 'please make sure OnnxruntimeJSIHelper installation is successful.'); + }), + resolveArrayBuffer: globalThis.jsiOnnxruntimeResolveArrayBuffer || (() => { + throw new Error( + 'jsiOnnxruntimeResolveArrayBuffer is not found, ' + + 'please make sure OnnxruntimeJSIHelper installation is successful.'); + }), +}; + +// Remove global functions after installation +{ + delete globalThis.jsiOnnxruntimeStoreArrayBuffer; + delete globalThis.jsiOnnxruntimeResolveArrayBuffer; +}