Skip to content

Commit

Permalink
refined inference
Browse files Browse the repository at this point in the history
  • Loading branch information
faressc committed Oct 23, 2023
1 parent 0698cab commit 826be33
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 58 deletions.
34 changes: 16 additions & 18 deletions source/dsp/inference/InferenceManager.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "InferenceManager.h"

InferenceManager::InferenceManager() {
inferenceThread.addInferenceListener(this);
// inferenceThread.addInferenceListener(this);
}

InferenceManager::~InferenceManager() {
inferenceThread.removeInferenceListener(this);
// inferenceThread.removeInferenceListener(this);
}

void InferenceManager::parameterChanged(const juce::String &parameterID, float newValue) {
Expand All @@ -15,13 +15,14 @@ void InferenceManager::parameterChanged(const juce::String &parameterID, float n
}
}

void InferenceManager::prepareToPlay(const juce::dsp::ProcessSpec &spec) {
void InferenceManager::prepareToPlay(const juce::dsp::ProcessSpec &newSpec) {
spec = const_cast<juce::dsp::ProcessSpec &>(newSpec);
numInferencedBufferAvailable.store(0);

sendRingBuffer.initialise(1, (int) spec.sampleRate * 6); // TODO how big does the ringbuffer need to be?
receiveRingBuffer.initialise(1, (int) spec.sampleRate * 6); // TODO how big does the ringbuffer need to be?

inferenceThread.prepareToPlay();
inferenceThread.prepareToPlay(spec);
inferenceCounter = 0;

init = true;
Expand All @@ -32,23 +33,20 @@ void InferenceManager::prepareToPlay(const juce::dsp::ProcessSpec &spec) {
void InferenceManager::processBlock(juce::AudioBuffer<float> &buffer) {
if (init) {
init_samples += buffer.getNumSamples();
if (init && init_samples >= MODEL_INPUT_SIZE + MAX_INFERENCE_TIME + MODEL_LATENCY) init = false;
if (init && init_samples >= MODEL_INPUT_SIZE + (int) spec.maximumBlockSize + MAX_INFERENCE_TIME + MODEL_LATENCY) init = false;
}
for (int sample = 0; sample < buffer.getNumSamples(); ++sample) {
sendRingBuffer.pushSample(buffer.getSample(0, sample), 0);
}
if (!inferenceThread.isThreadRunning()) {
if (!inferenceThread.isThreadRunning()) { // TODO fix if Thread runs to long because of multiple iterations
auto &receiveBuffer = inferenceThread.getModelOutputBuffer();
while (numInferencedBufferAvailable.load() > 0) {
for (size_t sample = 0; sample < MODEL_INPUT_SIZE; ++sample) {
receiveRingBuffer.pushSample(receiveBuffer[sample], 0);
}
numInferencedBufferAvailable.store(numInferencedBufferAvailable.load() - 1);
while (receiveBuffer.getAvailableSamples(0) > 0) {
receiveRingBuffer.pushSample(receiveBuffer.popSample(0), 0);
}
auto &sendBuffer = inferenceThread.getModelInputBuffer();
if (sendRingBuffer.getAvailableSamples(0) >= MODEL_INPUT_SIZE) { // TODO: refine this dynamic modelinputsize
for (size_t sample = 0; sample < MODEL_INPUT_SIZE; ++sample) {
sendBuffer[sample] = sendRingBuffer.popSample(0);
if (sendRingBuffer.getAvailableSamples(0) >= MODEL_INPUT_SIZE) {
while (sendRingBuffer.getAvailableSamples(0) > 0) {
sendBuffer.pushSample(sendRingBuffer.popSample(0), 0);
}
if (!inferenceThread.startThread(juce::Thread::Priority::highest)) {
std::cout << "Inference thread could not be started" << std::endl;
Expand Down Expand Up @@ -87,13 +85,13 @@ void InferenceManager::processOutput(juce::AudioBuffer<float> &buffer) {
}

void InferenceManager::calculateLatency(int maxSamplesPerBuffer) {
latencyInSamples = MODEL_INPUT_SIZE + MAX_INFERENCE_TIME + MODEL_LATENCY - maxSamplesPerBuffer;
// latencyInSamples = MODEL_INPUT_SIZE + MAX_INFERENCE_TIME + MODEL_LATENCY - maxSamplesPerBuffer;
}

int InferenceManager::getLatency() const {
return latencyInSamples;
}

void InferenceManager::inferenceThreadFinished() {
numInferencedBufferAvailable.store(numInferencedBufferAvailable.load() + 1);
}
// void InferenceManager::inferenceThreadFinished() {
// numInferencedBufferAvailable.store(numInferencedBufferAvailable.load() + 1);
// }
8 changes: 5 additions & 3 deletions source/dsp/inference/InferenceManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
#include "InferenceThread.h"
#include "../utils/RingBuffer.h"

class InferenceManager : private InferenceThread::Listener {
// class InferenceManager : private InferenceThread::Listener {
class InferenceManager {
public:
InferenceManager();
~InferenceManager() override;
~InferenceManager();

void prepareToPlay(const juce::dsp::ProcessSpec& spec);
void processBlock(juce::AudioBuffer<float>& buffer);
Expand All @@ -21,7 +22,7 @@ class InferenceManager : private InferenceThread::Listener {
private:
void processOutput(juce::AudioBuffer<float>& buffer);
void calculateLatency(int maxSamplesPerBuffer);
void inferenceThreadFinished() override;
// void inferenceThreadFinished() override;

private:
bool init = true;
Expand All @@ -32,6 +33,7 @@ class InferenceManager : private InferenceThread::Listener {
RingBuffer receiveRingBuffer;

std::atomic<int> numInferencedBufferAvailable;
juce::dsp::ProcessSpec spec;

int latencyInSamples = 0;
int inferenceCounter = 0;
Expand Down
32 changes: 18 additions & 14 deletions source/dsp/inference/InferenceThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ InferenceThread::~InferenceThread() {
stopThread(100);
}

void InferenceThread::prepareToPlay() {
for (size_t i = 0; i < MODEL_INPUT_SIZE; i++) {
void InferenceThread::prepareToPlay(const juce::dsp::ProcessSpec &spec) {
for (size_t i = 0; i < MODEL_INPUT_SIZE_BACKEND; i++) {
processedModelInput[i] = 0.f;
rawModelInputBuffer[i] = 0.f;
processedModelOutput[i] = 0.f;
}
for (size_t i = 0; i < MODEL_OUTPUT_SIZE; i++) {
for (size_t i = 0; i < MODEL_OUTPUT_SIZE_BACKEND; i++) {
rawModelOutputBuffer[i] = 0.f;
}

rawModelInput.initialise(1, (int) spec.sampleRate * 6); // TODO how big does the ringbuffer need to be?
processedModelOutput.initialise(1, (int) spec.sampleRate * 6); // TODO how big does the ringbuffer need to be?

onnxProcessor.prepareToPlay();
}

Expand All @@ -28,24 +29,27 @@ void InferenceThread::run() {
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);

processingTime.store(duration.count());

listeners.call(&Listener::inferenceThreadFinished);
}

void InferenceThread::inference() {

for (size_t i = 0; i < MODEL_INPUT_SIZE; i++) {
size_t numInferences = (size_t) (rawModelInput.getAvailableSamples(0) / MODEL_INPUT_SIZE);

for (size_t i = 0; i < numInferences; i++) {

// pre-processing
for (size_t j = 1; j < MODEL_INPUT_SIZE; j++) {
for (size_t j = 1; j < MODEL_INPUT_SIZE_BACKEND; j++) {
processedModelInput[j-1] = processedModelInput[j];
}
processedModelInput[MODEL_INPUT_SIZE-1] = rawModelInputBuffer[i];
processedModelInput[MODEL_INPUT_SIZE_BACKEND-1] = rawModelInput.popSample(0);

// actual inference
processModel();

// post-processing
processedModelOutput[i] = rawModelOutputBuffer[0];
for (size_t j = 0; j < MODEL_OUTPUT_SIZE_BACKEND; j++) {
processedModelOutput.pushSample(rawModelOutputBuffer[j], 0);
}
}
}

Expand All @@ -57,11 +61,11 @@ void InferenceThread::processModel() {
}
}

std::array<float, MODEL_INPUT_SIZE>& InferenceThread::getModelInputBuffer() {
return rawModelInputBuffer;
RingBuffer& InferenceThread::getModelInputBuffer() {
return rawModelInput;
}

std::array<float, MODEL_INPUT_SIZE>& InferenceThread::getModelOutputBuffer() {
RingBuffer& InferenceThread::getModelOutputBuffer() {
return processedModelOutput;
}

Expand Down
22 changes: 7 additions & 15 deletions source/dsp/inference/InferenceThread.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,10 @@ class InferenceThread : public juce::Thread {
InferenceThread();
~InferenceThread() override;

class Listener {
public:
virtual ~Listener() = default;
virtual void inferenceThreadFinished() = 0;
};
void addInferenceListener(Listener* listenerToAdd) {listeners.add(listenerToAdd);}
void removeInferenceListener(Listener* listenerToRemove) {listeners.remove(listenerToRemove);}

void prepareToPlay();
void prepareToPlay(const juce::dsp::ProcessSpec &spec);
void setBackend(InferenceBackend backend);
std::array<float, MODEL_INPUT_SIZE>& getModelInputBuffer();
std::array<float, MODEL_INPUT_SIZE>& getModelOutputBuffer();
RingBuffer& getModelInputBuffer();
RingBuffer& getModelOutputBuffer();

private:
void run() override;
Expand All @@ -39,10 +31,10 @@ class InferenceThread : public juce::Thread {
// LibtorchProcessor torchProcessor;


std::array<float, MODEL_INPUT_SIZE> rawModelInputBuffer;
std::array<float, MODEL_OUTPUT_SIZE> rawModelOutputBuffer;
std::array<float, MODEL_INPUT_SIZE> processedModelInput;
std::array<float, MODEL_INPUT_SIZE> processedModelOutput;
RingBuffer rawModelInput;
RingBuffer processedModelOutput;
std::array<float, MODEL_OUTPUT_SIZE_BACKEND> rawModelOutputBuffer;
std::array<float, MODEL_INPUT_SIZE_BACKEND> processedModelInput;

std::atomic<InferenceBackend> currentBackend {ONNX};

Expand Down
6 changes: 3 additions & 3 deletions source/dsp/inference/backends/OnnxRuntimeProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ OnnxRuntimeProcessor::~OnnxRuntimeProcessor()

void OnnxRuntimeProcessor::prepareToPlay() {
// Define the shape of input tensor
inputShape = {1, MODEL_INPUT_SIZE, 1};
inputShape = {1, MODEL_INPUT_SIZE_BACKEND, 1};
}

void OnnxRuntimeProcessor::processBlock(std::array<float, MODEL_INPUT_SIZE>& input, std::array<float, MODEL_OUTPUT_SIZE>& output) {
void OnnxRuntimeProcessor::processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output) {

// Create input tensor object from input data values and shape
const Ort::Value inputTensor = Ort::Value::CreateTensor<float> (memory_info,
Expand All @@ -39,7 +39,7 @@ void OnnxRuntimeProcessor::processBlock(std::array<float, MODEL_INPUT_SIZE>& inp
}

// Extract the output tensor dat
for (size_t i = 0; i < MODEL_OUTPUT_SIZE; i++) {
for (size_t i = 0; i < MODEL_OUTPUT_SIZE_BACKEND; i++) {
output[i] = outputTensors[0].GetTensorMutableData<float>()[i];
}
}
2 changes: 1 addition & 1 deletion source/dsp/inference/backends/OnnxRuntimeProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class OnnxRuntimeProcessor {
~OnnxRuntimeProcessor();

void prepareToPlay();
void processBlock(std::array<float, MODEL_INPUT_SIZE>& input, std::array<float, MODEL_OUTPUT_SIZE>& output);
void processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output);

private:
std::string filepath = MODELS_PATH_TENSORFLOW;
Expand Down
9 changes: 5 additions & 4 deletions source/dsp/inference/utils/InferenceConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ enum InferenceBackend {
TFLite
};

#define MODEL_INPUT_SIZE 150
#define MAX_INFERENCE_TIME 4096
#define MODEL_LATENCY 0
#define MODEL_INPUT_SIZE 1
#define MODEL_INPUT_SIZE_BACKEND 150 // Same as MODEL_INPUT_SIZE, but for streamable models
#define MODEL_OUTPUT_SIZE_BACKEND 1

#define MODEL_OUTPUT_SIZE 1
#define MAX_INFERENCE_TIME 1024
#define MODEL_LATENCY 0

#endif //NN_INFERENCE_TEMPLATE_INFERENCECONFIG_H

0 comments on commit 826be33

Please sign in to comment.