Skip to content

Commit

Permalink
refined InferenceConfig.h
Browse files Browse the repository at this point in the history
  • Loading branch information
faressc committed Oct 23, 2023
1 parent f6d51e5 commit 312dcd1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@ enum InferenceBackend {
TFLite
};

#define MODEL_TFLITE "model_0/model_0.tflite"
#define MODEL_LIBTORCH "model_0/model_0.pt"
#define MODELS_PATH_ONNX MODELS_PATH_TENSORFLOW
#define MODEL_ONNX "model_0/model_0-tflite.onnx"


#define MODEL_INPUT_SIZE 1
#define MODEL_INPUT_SIZE_BACKEND 150 // Same as MODEL_INPUT_SIZE, but for streamable models
#define MODEL_INPUT_SHAPE {1, MODEL_INPUT_SIZE_BACKEND, 1}

#define MODEL_OUTPUT_SIZE_BACKEND 1


#define MAX_INFERENCE_TIME 128
#define MODEL_LATENCY 0

Expand Down
2 changes: 1 addition & 1 deletion source/dsp/inference/InferenceThread.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <JuceHeader.h>

#include "../utils/RingBuffer.h"
#include "utils/InferenceConfig.h"
#include "InferenceConfig.h"
#include "backends/OnnxRuntimeProcessor.h"
// #include "backends/LibtorchProcessor.h"
// #include "processors/WindowingProcessor.h"
Expand Down
2 changes: 1 addition & 1 deletion source/dsp/inference/backends/OnnxRuntimeProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ OnnxRuntimeProcessor::~OnnxRuntimeProcessor()

void OnnxRuntimeProcessor::prepareToPlay() {
// Define the shape of input tensor
inputShape = {1, MODEL_INPUT_SIZE_BACKEND, 1};
inputShape = MODEL_INPUT_SHAPE;
}

void OnnxRuntimeProcessor::processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output) {
Expand Down
6 changes: 3 additions & 3 deletions source/dsp/inference/backends/OnnxRuntimeProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define NN_INFERENCE_TEMPLATE_ONNXRUNTIMEPROCESSOR_H

#include <JuceHeader.h>
#include "../utils/InferenceConfig.h"
#include "../InferenceConfig.h"
#include "onnxruntime_cxx_api.h"

class OnnxRuntimeProcessor {
Expand All @@ -14,8 +14,8 @@ class OnnxRuntimeProcessor {
void processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output);

private:
std::string filepath = MODELS_PATH_TENSORFLOW;
std::string modelname = "model_0/model_0-tflite.onnx";
std::string filepath = MODELS_PATH_ONNX;
std::string modelname = MODEL_ONNX;
#ifdef _WIN32
std::string modelpathStr = filepath + modelname;
std::wstring modelpath = std::wstring(modelpathStr.begin(), modelpathStr.end());
Expand Down

0 comments on commit 312dcd1

Please sign in to comment.