Skip to content

Commit

Permalink
implemented TFLite inference processor
Browse files Browse the repository at this point in the history
  • Loading branch information
faressc committed Oct 25, 2023
1 parent 677a1d7 commit ae732d6
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 14 deletions.
2 changes: 1 addition & 1 deletion source/dsp/inference/InferenceConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
enum InferenceBackend {
LIBTORCH,
ONNX,
TFLite
TFLITE
};

#define MODEL_TFLITE "model_0/model_0-streaming.tflite"
Expand Down
3 changes: 3 additions & 0 deletions source/dsp/inference/InferenceThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void InferenceThread::prepareToPlay(const juce::dsp::ProcessSpec &spec) {

onnxProcessor.prepareToPlay();
torchProcessor.prepareToPlay();
tfliteProcessor.prepareToPlay();
}

void InferenceThread::run() {
Expand Down Expand Up @@ -73,6 +74,8 @@ void InferenceThread::processModel() {
onnxProcessor.processBlock(processedModelInput, rawModelOutputBuffer);
} else if (currentBackend == LIBTORCH) {
torchProcessor.processBlock(processedModelInput, rawModelOutputBuffer);
} else if (currentBackend == TFLITE) {
tfliteProcessor.processBlock(processedModelInput, rawModelOutputBuffer);
}
}

Expand Down
7 changes: 4 additions & 3 deletions source/dsp/inference/InferenceThread.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "InferenceConfig.h"
#include "backends/OnnxRuntimeProcessor.h"
#include "backends/LibtorchProcessor.h"
#include "backends/TFLiteProcessor.h"
// #include "processors/WindowingProcessor.h"

class InferenceThread : public juce::Thread {
Expand All @@ -19,7 +20,7 @@ class InferenceThread : public juce::Thread {
RingBuffer& getModelInputBuffer();
RingBuffer& getModelOutputBuffer();
void testInference(InferenceBackend backend);

private:
void run() override;
void inference();
Expand All @@ -31,14 +32,14 @@ class InferenceThread : public juce::Thread {

OnnxRuntimeProcessor onnxProcessor;
LibtorchProcessor torchProcessor;

TFLiteProcessor tfliteProcessor;

RingBuffer rawModelInput;
RingBuffer processedModelOutput;
std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND> rawModelOutputBuffer;
std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND> processedModelInput;

std::atomic<InferenceBackend> currentBackend {LIBTORCH};
std::atomic<InferenceBackend> currentBackend {TFLITE};

juce::ListenerList<Listener> listeners;
};
Expand Down
30 changes: 30 additions & 0 deletions source/dsp/inference/backends/TFLiteProcessor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "TFLiteProcessor.h"

TFLiteProcessor::TFLiteProcessor()
{
model = TfLiteModelCreateFromFile(modelpath.c_str());
options = TfLiteInterpreterOptionsCreate();
interpreter = TfLiteInterpreterCreate(model, options);
}

TFLiteProcessor::~TFLiteProcessor()
{
TfLiteInterpreterDelete(interpreter);
TfLiteInterpreterOptionsDelete(options);
TfLiteModelDelete(model);
}

void TFLiteProcessor::prepareToPlay() {
TfLiteInterpreterAllocateTensors(interpreter);
inputTensor = TfLiteInterpreterGetInputTensor(interpreter, 0);
outputTensor = TfLiteInterpreterGetOutputTensor(interpreter, 0);
std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND> input;
std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND> output;
processBlock(input, output);
}

void TFLiteProcessor::processBlock(std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND>& output) {
TfLiteTensorCopyFromBuffer(inputTensor, input.data(), input.size() * sizeof(float));
TfLiteInterpreterInvoke(interpreter);
TfLiteTensorCopyToBuffer(outputTensor, output.data(), output.size() * sizeof(float));
}
34 changes: 34 additions & 0 deletions source/dsp/inference/backends/TFLiteProcessor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef NN_INFERENCE_TEMPLATE_TFLITEPROCESSOR_H
#define NN_INFERENCE_TEMPLATE_TFLITEPROCESSOR_H

#include <JuceHeader.h>
#include "../InferenceConfig.h"
#include "tensorflow/lite/c_api.h"

class TFLiteProcessor {
public:
TFLiteProcessor();
~TFLiteProcessor();

void prepareToPlay();
void processBlock(std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND>& output);

private:
std::string filepath = MODELS_PATH_TENSORFLOW;
std::string modelname = MODEL_TFLITE;
#ifdef _WIN32
std::string modelpathStr = filepath + modelname;
std::wstring modelpath = std::wstring(modelpathStr.begin(), modelpathStr.end());
#else
std::string modelpath = filepath + modelname;
#endif

TfLiteModel* model;
TfLiteInterpreterOptions* options;
TfLiteInterpreter* interpreter;

TfLiteTensor* inputTensor;
const TfLiteTensor* outputTensor;
};

#endif //NN_INFERENCE_TEMPLATE_TFLITEPROCESSOR_H
23 changes: 13 additions & 10 deletions test/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,27 @@ class InferenceFixture : public benchmark::Fixture
}
};

BENCHMARK_DEFINE_F(InferenceFixture, BM_ONNX_INFERENCE)(benchmark::State& st)
{
for (auto _ : st)
{
BENCHMARK_DEFINE_F(InferenceFixture, BM_ONNX_INFERENCE)(benchmark::State& st) {
for (auto _ : st) {
plugin->getInferenceThread().testInference(ONNX);
}
}

BENCHMARK_DEFINE_F(InferenceFixture, BM_LIBTORCH_INFERENCE)(benchmark::State& st)
{
for (auto _ : st)
{
BENCHMARK_DEFINE_F(InferenceFixture, BM_LIBTORCH_INFERENCE)(benchmark::State& st) {
for (auto _ : st) {
plugin->getInferenceThread().testInference(LIBTORCH);
}
}

BENCHMARK_DEFINE_F(InferenceFixture, BM_TFLITE_INFERENCE)(benchmark::State& st) {
for (auto _ : st) {
plugin->getInferenceThread().testInference(TFLITE);
}
}

// Register the function as a benchmark
BENCHMARK(BM_PROCESSOR)->Unit(benchmark::kMillisecond);
BENCHMARK(BM_EDITOR)->Unit(benchmark::kMillisecond);
BENCHMARK_REGISTER_F(InferenceFixture, BM_ONNX_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1);
BENCHMARK_REGISTER_F(InferenceFixture, BM_LIBTORCH_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1);
BENCHMARK_REGISTER_F(InferenceFixture, BM_ONNX_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1)->Threads(1);
BENCHMARK_REGISTER_F(InferenceFixture, BM_LIBTORCH_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1)->Threads(1);
// BENCHMARK_REGISTER_F(InferenceFixture, BM_TFLITE_INFERENCE)->Unit(benchmark::kMillisecond)->Iterations(10)->Repetitions(1)->Threads(1);

0 comments on commit ae732d6

Please sign in to comment.