diff --git a/app/build.gradle b/app/build.gradle
index 0f17218cd..ab769cfb9 100644
--- a/app/build.gradle
+++ b/app/build.gradle
@@ -64,7 +64,7 @@ android {
applicationIdSuffix ".debug"
}
release {
-// debuggable true
+ debuggable false
shrinkResources false
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
diff --git a/app/proguard-rules.pro b/app/proguard-rules.pro
index 6ccb2012b..0e3606353 100644
--- a/app/proguard-rules.pro
+++ b/app/proguard-rules.pro
@@ -28,11 +28,14 @@
-dontwarn java.lang.invoke.*
-dontwarn **$$Lambda$*
+-keepnames class com.stardust.autojs.onnx.YoloV8Predictor
+
-keep class org.mozilla.javascript.** { *; }
-keep class com.jecelyin.editor.** { *; }
-keep class com.stardust.automator.** { *; }
-keep class com.stardust.autojs.** { *; }
-keep class org.greenrobot.eventbus.** { *; }
+-keep class org.autojs.autojs.** {*;}
-keep class * extends c
-keepattributes *Annotation*
# Event bus
@@ -138,4 +141,16 @@
# Bugly
-dontwarn com.tencent.bugly.**
--keep public class com.tencent.bugly.**{*;}
\ No newline at end of file
+-keep public class com.tencent.bugly.**{*;}
+
+# OnnxRuntime
+-dontwarn ai.onnxruntime.**
+-keep public class ai.onnxruntime.**{*;}
+
+# okHttp3
+-dontwarn okhttp3.**
+-keep public class okhttp3.**{*;}
+
+# paddleocr
+-dontwarn com.baidu.paddle.lite.ocr.**
+-keep public class com.baidu.paddle.lite.ocr.**{*;}
\ No newline at end of file
diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml
index 662d32c5b..7b707473c 100644
--- a/app/src/main/AndroidManifest.xml
+++ b/app/src/main/AndroidManifest.xml
@@ -45,7 +45,9 @@
android:requestLegacyExternalStorage="true"
tools:replace="android:label, android:icon, android:allowBackup"
tools:targetApi="m">
-
+
+
diff --git a/autojs-aar/paddleocr/src/main/cpp/ocr_db_post_process.cpp b/autojs-aar/paddleocr/src/main/cpp/ocr_db_post_process.cpp
index 9816ea4ac..a2cf342f9 100644
--- a/autojs-aar/paddleocr/src/main/cpp/ocr_db_post_process.cpp
+++ b/autojs-aar/paddleocr/src/main/cpp/ocr_db_post_process.cpp
@@ -64,6 +64,13 @@ static cv::RotatedRect unclip(float **box) {
return res;
}
+// 用于动态释放内存 避免内存泄露
+static void free_array(float **array, int rows) {
+ for (int i = 0; i < rows; ++i) {
+ delete[] array[i];
+ }
+ delete[] array;
+}
static float **Mat2Vec(cv::Mat mat) {
auto **array = new float *[mat.rows];
@@ -267,6 +274,7 @@ boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap) {
// end get_mini_box
if (ssid < min_size) {
+ free_array(array, 4);
continue;
}
@@ -274,6 +282,7 @@ boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap) {
score = box_score_fast(array, pred);
// end box_score_fast
if (score < box_thresh) {
+ free_array(array, 4);
continue;
}
@@ -284,8 +293,11 @@ boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap) {
cv::RotatedRect clipbox = points;
auto cliparray = get_mini_boxes(clipbox, ssid);
- if (ssid < min_size + 2)
+ if (ssid < min_size + 2) {
+ free_array(array, 4);
+ free_array(cliparray, 4);
continue;
+ }
int dest_width = pred.cols;
int dest_height = pred.rows;
@@ -301,6 +313,8 @@ boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap) {
intcliparray.emplace_back(std::move(a));
}
boxes.emplace_back(std::move(intcliparray));
+ free_array(array, 4);
+ free_array(cliparray, 4);
} // end for
return boxes;
diff --git a/autojs/build.gradle b/autojs/build.gradle
index 7b516be01..b2e51b0bf 100644
--- a/autojs/build.gradle
+++ b/autojs/build.gradle
@@ -15,7 +15,7 @@ android {
}
buildTypes {
release {
- minifyEnabled false
+ minifyEnabled true
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
diff --git a/autojs/proguard-rules.pro b/autojs/proguard-rules.pro
index 27b6816cb..e588db103 100644
--- a/autojs/proguard-rules.pro
+++ b/autojs/proguard-rules.pro
@@ -23,3 +23,5 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
+#-keep public class com.stardust.autojs.onnx.YoloV8Predictor
+#-keepnames class com.stardust.autojs.onnx.YoloV8Predictor
diff --git a/autojs/src/main/assets/modules/__$ocr__.js b/autojs/src/main/assets/modules/__$ocr__.js
index 055deaa08..de0890308 100644
--- a/autojs/src/main/assets/modules/__$ocr__.js
+++ b/autojs/src/main/assets/modules/__$ocr__.js
@@ -31,11 +31,13 @@ module.exports = function(runtime, global) {
let region = options.region
if (region) {
let r = buildRegion(region, img)
- let o = img
img = images.clip(img, r.x, r.y, r.width, r.height)
- o.recycle()
}
let text = javaOcr.recognizeText(img, options.cpuThreadNum || 4, options.useSlim || true)
+ if (region) {
+ // 进行过区域截取,需要回收截取的图片 原始图片由外部管理
+ img.recycle()
+ }
if (text) {
return JSON.parse(JSON.stringify(text))
}
@@ -47,14 +49,16 @@ module.exports = function(runtime, global) {
let region = options.region
if (region) {
let r = buildRegion(region, img)
- let o = img
img = images.clip(img, r.x, r.y, r.width, r.height)
- o.recycle()
}
let resultList = runtime.bridges.bridges.toArray(javaOcr.detect(img, options.cpuThreadNum || 4, options.useSlim || true))
if (region && region.length > 1 && resultList && resultList.length > 0) {
resultList.forEach(r => r.bounds.offset(region[0], region[1]))
}
+ if (region) {
+ // 进行过区域截取,需要回收截取的图片 原始图片由外部管理
+ img.recycle()
+ }
return resultList
}
diff --git a/autojs/src/main/assets/modules/__$yolo__.js b/autojs/src/main/assets/modules/__$yolo__.js
index ac14e37dd..9e6aaaf80 100644
--- a/autojs/src/main/assets/modules/__$yolo__.js
+++ b/autojs/src/main/assets/modules/__$yolo__.js
@@ -1,5 +1,6 @@
module.exports = function (runtime, global) {
let yoloCreator = runtime.yolo
+ let instanceList = []
let $yolo = function () {
this.enabled = false
}
@@ -29,6 +30,7 @@ module.exports = function (runtime, global) {
}
this.yoloInstance = yoloCreator.createNcnn(options.paramPath, options.binPath, convertToList(options.labels), options.imageSize, !!options.useGpu);
this.enabled = this.yoloInstance.isInit();
+ instanceList.push(this.yoloInstance)
} else {
this.type = 'onnx';
@@ -39,6 +41,7 @@ module.exports = function (runtime, global) {
}
this.yoloInstance = yoloCreator.createOnnx(options.modelPath, convertToList(options.labels), options.imageSize);
this.enabled = this.yoloInstance != null;
+ instanceList.push(this.yoloInstance)
}
if (this.enabled) {
@@ -60,9 +63,7 @@ module.exports = function (runtime, global) {
}
if (region) {
let r = buildRegion(region, img)
- let o = img
img = images.clip(img, r.x, r.y, r.width, r.height)
- o.recycle()
}
let resultList = util.java.toJsArray(this.yoloInstance.predictYolo(img.mat))
if (region) {
@@ -94,6 +95,12 @@ module.exports = function (runtime, global) {
return $yolo;
+ events.on('exit', function () {
+ if (instanceList.length > 0) {
+ instanceList.forEach(instance => instance.release())
+ }
+ })
+
function wrapForward (yoloInstance) {
return {
forward: function (img, filterOption, region) {
diff --git a/autojs/src/main/assets/modules/__engines__.js b/autojs/src/main/assets/modules/__engines__.js
index a8104e42c..1346c47d5 100644
--- a/autojs/src/main/assets/modules/__engines__.js
+++ b/autojs/src/main/assets/modules/__engines__.js
@@ -1,47 +1,45 @@
-module.exports = function(__runtime__, scope){
+module.exports = function (__runtime__, scope) {
var rtEngines = __runtime__.engines;
- var execArgv = null;
-
var engines = {};
- engines.execScript = function(name, script, config){
+ engines.execScript = function (name, script, config) {
return rtEngines.execScript(name, script, fillConfig(config));
}
- engines.execScriptFile = function(path, config){
+ engines.execScriptFile = function (path, config) {
return rtEngines.execScriptFile(path, fillConfig(config));
}
- engines.execAutoFile = function(path, config){
+ engines.execAutoFile = function (path, config) {
return rtEngines.execAutoFile(path, fillConfig(config));
}
- engines.myEngine = function(){
+ engines.myEngine = function () {
return rtEngines.myEngine();
}
- engines.all = function(){
+ engines.all = function () {
return rtEngines.all();
}
engines.stopAll = rtEngines.stopAll.bind(rtEngines);
engines.stopAllAndToast = rtEngines.stopAllAndToast.bind(rtEngines);
- function fillConfig(c){
+ function fillConfig (c) {
var config = new com.stardust.autojs.execution.ExecutionConfig();
c = c || {};
c.path = c.path || files.cwd();
- if(c.path){
- config.workingDirectory = c.path;
+ if (c.path) {
+ config.workingDirectory = c.path;
}
config.delay = c.delay || 0;
config.interval = c.interval || 0;
- config.loopTimes = (c.loopTimes === undefined)? 1 : c.loopTimes;
- if(c.arguments){
+ config.loopTimes = (c.loopTimes === undefined) ? 1 : c.loopTimes;
+ if (c.arguments) {
var arguments = c.arguments;
- for(var key in arguments){
- if(arguments.hasOwnProperty(key)){
+ for (var key in arguments) {
+ if (arguments.hasOwnProperty(key)) {
config.setArgument(key, arguments[key]);
}
}
@@ -49,14 +47,15 @@ module.exports = function(__runtime__, scope){
return config;
}
- var engine = engines.myEngine();
- var execArgv = {};
- var iterator = engine.getTag("execution.config").arguments.entrySet().iterator();
- while(iterator.hasNext()){
- var entry = iterator.next();
- execArgv[entry.getKey()] = entry.getValue();
- }
- engine.execArgv = execArgv;
+ ((engine) => {
+ let execArgv = {}
+ let iterator = engine.getTag("execution.config").arguments.entrySet().iterator();
+ while (iterator.hasNext()) {
+ let entry = iterator.next();
+ execArgv[entry.getKey()] = entry.getValue();
+ }
+ engine.execArgv = execArgv
+ })(engines.myEngine())
return engines;
}
\ No newline at end of file
diff --git a/autojs/src/main/java/com/stardust/autojs/core/image/ImageWrapper.java b/autojs/src/main/java/com/stardust/autojs/core/image/ImageWrapper.java
index 846850a76..8b5f9da3a 100644
--- a/autojs/src/main/java/com/stardust/autojs/core/image/ImageWrapper.java
+++ b/autojs/src/main/java/com/stardust/autojs/core/image/ImageWrapper.java
@@ -108,49 +108,49 @@ public static Bitmap toBitmap(Image image) {
}
public static ImageWrapper ofImageByMat(Image image, int cvType) {
-
long start = System.currentTimeMillis();
// 获取Image的平面
Image.Plane[] planes = image.getPlanes();
- Log.d("ImageWrapper", "ofImageByMat: planes.length: " + planes.length);
+// Log.d("ImageWrapper", "ofImageByMat: planes.length: " + planes.length);
// 获取Image的宽高
int width = image.getWidth();
int height = image.getHeight();
+// Log.d("ImageWrapper", "ofImageByMat: width:" + width + " height:" + height);
- Log.d("ImageWrapper", "ofImageByMat: width:" + width + " height:" + height);
Image.Plane plane = planes[0];
// 获取平面的数据缓冲区
- ByteBuffer buffer = planes[0].getBuffer();
- // 创建一个byte数组来存储缓冲区的数据
- byte[] bytes = new byte[buffer.remaining()];
- buffer.position(0);
+ ByteBuffer buffer = plane.getBuffer();
int pixelStride = plane.getPixelStride();
- int rowPadding = plane.getRowStride() - pixelStride * image.getWidth();
- // 将数据从缓冲区拷贝到byte数组
- buffer.get(bytes);
+ int rowStride = plane.getRowStride();
+ int rowPadding = rowStride - pixelStride * width;
- // 创建一个Mat对象
+ long s2 = System.currentTimeMillis();
+ // 尽量避免使用临时数组
Mat mat = new Mat(height, width + rowPadding / pixelStride, CvType.CV_8UC4);
- // 将byte数组拷贝到Mat对象
- mat.put(0, 0, bytes);
+ byte[] rowData = new byte[rowStride];
+ for (int i = 0; i < height; i++) {
+ buffer.get(rowData, 0, rowStride);
+ mat.put(i, 0, rowData);
+ }
+// Log.d("ImageWrapper", "ofImageByMat: create mat by bytes cost: " + (System.currentTimeMillis() - s2) + "ms");
if (width != mat.width()) {
- Log.d("ImageWrapper", "ofImageByMat: mat width is not valid: " + mat.width() + " => " + width);
+// Log.d("ImageWrapper", "ofImageByMat: mat width is not valid: " + mat.width() + " => " + width);
// 定义裁切区域
Rect rect = new Rect(0, 0, width, height);
- // 裁切图像
Mat croppedImage = new Mat(mat, rect);
mat.release();
mat = croppedImage;
}
+
if (cvType != CvType.CV_8UC4) {
long convertStart = System.currentTimeMillis();
Imgproc.cvtColor(mat, mat, Imgproc.COLOR_RGBA2RGB);
- Log.d("ImageWrapper", "ofImageByMat: convert channel: " + (System.currentTimeMillis() - convertStart) + "ms");
+// Log.d("ImageWrapper", "ofImageByMat: convert channel: " + (System.currentTimeMillis() - convertStart) + "ms");
}
- Log.d("ImageWrapper", "ofImageByMat: create by mat cost: " + (System.currentTimeMillis() - start) + "ms");
- return new ImageWrapper(mat);
+// Log.d("ImageWrapper", "ofImageByMat: create by mat cost: " + (System.currentTimeMillis() - start) + "ms");
+ return new ImageWrapper(mat);
}
public int getWidth() {
diff --git a/autojs/src/main/java/com/stardust/autojs/execution/RunnableScriptExecution.java b/autojs/src/main/java/com/stardust/autojs/execution/RunnableScriptExecution.java
index 19d9d2bf3..c8775e1b0 100644
--- a/autojs/src/main/java/com/stardust/autojs/execution/RunnableScriptExecution.java
+++ b/autojs/src/main/java/com/stardust/autojs/execution/RunnableScriptExecution.java
@@ -2,7 +2,6 @@
import android.util.Log;
-import com.stardust.autojs.AutoJs;
import com.stardust.autojs.ScriptEngineService;
import com.stardust.autojs.engine.ScriptEngine;
import com.stardust.autojs.engine.ScriptEngineManager;
diff --git a/autojs/src/main/java/com/stardust/autojs/onnx/YoloV8Predictor.java b/autojs/src/main/java/com/stardust/autojs/onnx/YoloV8Predictor.java
index f063744f4..538003aab 100644
--- a/autojs/src/main/java/com/stardust/autojs/onnx/YoloV8Predictor.java
+++ b/autojs/src/main/java/com/stardust/autojs/onnx/YoloV8Predictor.java
@@ -9,6 +9,7 @@
import com.stardust.autojs.onnx.util.Letterbox;
import com.stardust.autojs.runtime.api.YoloPredictor;
+import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
@@ -17,11 +18,14 @@
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import java.util.stream.Collectors;
import ai.onnxruntime.OnnxTensor;
@@ -39,6 +43,8 @@
@RequiresApi(api = Build.VERSION_CODES.N)
public class YoloV8Predictor extends YoloPredictor {
private static final String TAG = "YoloV8Predictor";
+ private static final Pattern IMG_SIZE_PATTERN = Pattern.compile("\\[(\\d+), \\d+]");
+ private static final Pattern LABEL_PATTERN = Pattern.compile("'([^']*)'");
private final String modelPath;
@@ -95,8 +101,53 @@ private void prepareSession() throws OrtException {
throw new RuntimeException(e);
}
});
+ // 如果入参labels无效或未定义,使用模型内置labels
+ if (labels == null || labels.size() == 0) {
+ labels = initLabels(session);
+ }
+ initShapeSize(session);
+ }
+
+ private List initLabels(OrtSession session) throws OrtException {
+ String meteStr = session.getMetadata().getCustomMetadata().get("names");
+ if (meteStr == null) {
+ Log.d(TAG, "initLabels: 读取names失败 无法自动修正labels");
+ return Collections.emptyList();
+ }
+ String[] labels = new String[meteStr.split(",").length];
+
+ Matcher matcher = LABEL_PATTERN.matcher(meteStr);
+
+ int h = 0;
+ while (matcher.find()) {
+ labels[h] = matcher.group(1);
+ h++;
+ }
+ return Arrays.asList(labels);
}
+ private void initShapeSize(OrtSession session) throws OrtException {
+ String meteStr = session.getMetadata().getCustomMetadata().get("imgsz");
+ Log.d(TAG, "initShapeSize: " + meteStr);
+ if (meteStr == null) {
+ Log.d(TAG, "initShapeSize: 读取imgsz失败 无法自动修正输入大小");
+ return;
+ }
+ Matcher matcher = IMG_SIZE_PATTERN.matcher(meteStr);
+ if (matcher.find()) {
+ String shapeSize = matcher.group(1);
+ if (shapeSize == null) {
+ Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
+ return;
+ }
+ this.shapeSize = new Size(Double.parseDouble(shapeSize), Double.parseDouble(shapeSize));
+ Log.d(TAG, "set shape size: " + shapeSize);
+ } else {
+ Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
+ }
+ }
+
+
private void addNNApiProvider(OrtSession.SessionOptions sessionOptions) {
if (!tryNpu) {
return;
@@ -140,19 +191,26 @@ private HashMap preprocessImage(Mat img) throws OrtException
int rows = letterbox.getHeight();
int cols = letterbox.getWidth();
int channels = image.channels();
+ // 转换Mat对象的数据类型为CV_64F,即64位浮点型
+ Mat convertedImage = new Mat();
+ image.convertTo(convertedImage, CvType.CV_64F);
+
+ // 获取整个像素数据
+ double[] pixelData = new double[rows * cols * channels];
+ convertedImage.get(0, 0, pixelData);
- // 将Mat对象的像素值赋值给Float[]对象
float[] pixels = new float[channels * rows * cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
- double[] pixel = image.get(j, i);
for (int k = 0; k < channels; k++) {
// 这样设置相当于同时做了image.transpose((2, 0, 1))操作
- pixels[rows * cols * k + j * cols + i] = (float) pixel[k] / 255.0f;
+ // 重新组织内存访问模式,提高缓存效率
+ pixels[k * rows * cols + i * cols + j] = (float) (pixelData[(i * cols + j) * channels + k] / 255.0);
}
}
}
image.release();
+ convertedImage.release();
// 创建OnnxTensor对象
long[] shape = {1L, (long) channels, (long) rows, (long) cols};
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
@@ -168,9 +226,8 @@ private List postProcessOutput(OrtSession.Result output) throws OrtEx
Map> class2Bbox = new HashMap<>();
for (float[] bbox : outputData) {
- float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, outputData.length);
- int label = argmax(conditionalProbabilities);
- float conf = conditionalProbabilities[label];
+ int label = argmax(bbox, 4); // 直接在原数组上进行操作
+ float conf = bbox[label + 4];
if (conf < confThreshold) {
continue;
}
@@ -185,8 +242,7 @@ private List postProcessOutput(OrtSession.Result output) throws OrtEx
continue;
}
- class2Bbox.putIfAbsent(label, new ArrayList<>());
- class2Bbox.get(label).add(bbox);
+ class2Bbox.computeIfAbsent(label, k -> new ArrayList<>()).add(bbox);
}
List detections = new ArrayList<>();
@@ -214,12 +270,18 @@ public List predictYolo(String imagePath) throws OrtException {
public List predictYolo(Mat image) throws OrtException {
prepareSession();
long start_time = System.currentTimeMillis();
+ Map inputMap = preprocessImage(image);
// 运行推理
- OrtSession.Result output = session.run(preprocessImage(image));
- List detections = postProcessOutput(output);
- Log.d("YoloV8Predictor", String.format("onnx predict cost: %d ms", (System.currentTimeMillis() - start_time)));
- return detections.stream().map(detection -> new DetectResult(detection, letterbox))
- .collect(Collectors.toList());
+ try (OrtSession.Result output = session.run(inputMap)) {
+ Log.d(TAG, "predictYolo: onnx run cost " + (System.currentTimeMillis() - start_time) + "ms");
+ List detections = postProcessOutput(output);
+ Log.d("YoloV8Predictor", String.format("onnx predict cost: %d ms", (System.currentTimeMillis() - start_time)));
+ return detections.stream().map(detection -> new DetectResult(detection, letterbox))
+ .collect(Collectors.toList());
+ } finally {
+ // 释放资源
+ inputMap.values().forEach(OnnxTensor::close);
+ }
}
public static void xywh2xyxy(float[] bbox) {
@@ -245,7 +307,7 @@ public static float[][] transposeMatrix(float[][] m) {
}
public static List nonMaxSuppression(List bboxes, float iouThreshold) {
-
+ long start = System.currentTimeMillis();
List bestBboxes = new ArrayList<>();
bboxes.sort(Comparator.comparing(a -> a[4]));
@@ -253,9 +315,9 @@ public static List nonMaxSuppression(List bboxes, float iouThr
while (!bboxes.isEmpty()) {
float[] bestBbox = bboxes.remove(bboxes.size() - 1);
bestBboxes.add(bestBbox);
- bboxes = bboxes.stream().filter(a -> computeIOU(a, bestBbox) < iouThreshold).collect(Collectors.toList());
+ bboxes.removeIf(bbox -> computeIOU(bbox, bestBbox) >= iouThreshold);
}
-
+ Log.d(TAG, "nonMaxSuppression: cost " + (System.currentTimeMillis() - start) + "ms");
return bestBboxes;
}
@@ -269,22 +331,44 @@ public static float computeIOU(float[] box1, float[] box2) {
float right = Math.min(box1[2], box2[2]);
float bottom = Math.min(box1[3], box2[3]);
- float interArea = Math.max(right - left, 0) * Math.max(bottom - top, 0);
+ // 计算交集区域的宽度和高度
+ float width = Math.max(right - left, 0);
+ float height = Math.max(bottom - top, 0);
+
+ // 计算交集面积和并集面积
+ float interArea = width * height;
float unionArea = area1 + area2 - interArea;
- return Math.max(interArea / unionArea, 1e-8f);
+ // 计算交并比
+ return Math.max(interArea / unionArea, 1e-8f);
}
+
//返回最大值的索引
- public static int argmax(float[] a) {
+ // 优化后的 argmax 函数
+ public static int argmax(float[] a, int start) {
float re = -Float.MAX_VALUE;
int arg = -1;
- for (int i = 0; i < a.length; i++) {
+ for (int i = start; i < a.length; i++) {
if (a[i] >= re) {
re = a[i];
- arg = i;
+ arg = i - start;
}
}
return arg;
}
+
+ @Override
+ public void release() {
+ if (session != null) {
+ try {
+ session.close();
+ session = null;
+ } catch (OrtException e) {
+ Log.e(TAG, "close session failed" + e);
+ }
+ environment.close();
+ environment = null;
+ }
+ }
}
diff --git a/autojs/src/main/java/com/stardust/autojs/runtime/api/Images.java b/autojs/src/main/java/com/stardust/autojs/runtime/api/Images.java
index 0f1945b65..68fa24c3e 100644
--- a/autojs/src/main/java/com/stardust/autojs/runtime/api/Images.java
+++ b/autojs/src/main/java/com/stardust/autojs/runtime/api/Images.java
@@ -118,6 +118,7 @@ public synchronized ImageWrapper captureScreen() {
}
long start = System.currentTimeMillis();
mPreCaptureImage = ImageWrapper.ofImage(capture);
+ capture.close();
Log.d(TAG, "captureScreen: convert image cost: " + (System.currentTimeMillis() - start) + "ms");
return mPreCaptureImage;
}
diff --git a/autojs/src/main/java/com/stardust/autojs/runtime/api/Yolo.java b/autojs/src/main/java/com/stardust/autojs/runtime/api/Yolo.java
index ad8ffd8f8..ee70d19e2 100644
--- a/autojs/src/main/java/com/stardust/autojs/runtime/api/Yolo.java
+++ b/autojs/src/main/java/com/stardust/autojs/runtime/api/Yolo.java
@@ -103,7 +103,8 @@ public List captureAndPredict(ScriptRuntime runtime, Rect rect) {
Images images = (Images)runtime.getImages();
Image image = images.captureScreenRaw();
if (image != null) {
- ImageWrapper imageWrapper = ImageWrapper.ofImageByMat(image, CvType.CV_8UC3);
+ ImageWrapper imageWrapper = ImageWrapper.ofImageByMat(image, CvType.CV_8UC4);
+ image.close();
Mat mat = imageWrapper.getMat();
if (rect != null) {
// 裁切图像
diff --git a/autojs/src/main/java/com/stardust/autojs/runtime/api/YoloPredictor.java b/autojs/src/main/java/com/stardust/autojs/runtime/api/YoloPredictor.java
index bf514c88d..c4b62723b 100644
--- a/autojs/src/main/java/com/stardust/autojs/runtime/api/YoloPredictor.java
+++ b/autojs/src/main/java/com/stardust/autojs/runtime/api/YoloPredictor.java
@@ -57,4 +57,10 @@ public boolean isInit() {
public void release() {
}
+
+ @Override
+ protected void finalize() throws Throwable {
+ super.finalize();
+ release();
+ }
}