Skip to content

Commit

Permalink
implemented batching to inference
Browse files Browse the repository at this point in the history
  • Loading branch information
faressc committed Oct 24, 2023
1 parent 9958160 commit ae99bcf
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 34 deletions.
15 changes: 8 additions & 7 deletions source/dsp/inference/InferenceConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@ enum InferenceBackend {
TFLite
};

#define MODEL_TFLITE "model_0/model_0.tflite"
#define MODEL_LIBTORCH "model_0/model_0.pt"
#define MODEL_TFLITE "model_0/model_0-streaming.tflite"
#define MODEL_LIBTORCH "model_0/model_0-streaming.pt"
#define MODELS_PATH_ONNX MODELS_PATH_TENSORFLOW
#define MODEL_ONNX "model_0/model_0-tflite.onnx"
#define MODEL_ONNX "model_0/model_0-tflite-streaming.onnx"


#define BATCH_SIZE 128
#define MODEL_INPUT_SIZE 1
#define MODEL_INPUT_SIZE_BACKEND 150 // Same as MODEL_INPUT_SIZE, but for streamable models
#define MODEL_INPUT_SHAPE_ONNX {1, MODEL_INPUT_SIZE_BACKEND, 1}
#define MODEL_INPUT_SHAPE_TFLITE {1, MODEL_INPUT_SIZE_BACKEND, 1}
#define MODEL_INPUT_SHAPE_LIBTORCH {1, 1, MODEL_INPUT_SIZE_BACKEND}
#define MODEL_INPUT_SHAPE_ONNX {BATCH_SIZE, MODEL_INPUT_SIZE_BACKEND, 1}
#define MODEL_INPUT_SHAPE_TFLITE {BATCH_SIZE, MODEL_INPUT_SIZE_BACKEND, 1}
#define MODEL_INPUT_SHAPE_LIBTORCH {BATCH_SIZE, 1, MODEL_INPUT_SIZE_BACKEND}


#define MODEL_OUTPUT_SIZE_BACKEND 1

#define MAX_INFERENCE_TIME 1024
#define MAX_INFERENCE_TIME 2048
#define MODEL_LATENCY 0

#endif //NN_INFERENCE_TEMPLATE_INFERENCECONFIG_H
17 changes: 12 additions & 5 deletions source/dsp/inference/InferenceManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ void InferenceManager::prepareToPlay(const juce::dsp::ProcessSpec &newSpec) {
inferenceCounter = 0;

init = true;
init_samples = 0;
bufferCount = 0;
if ((int) spec.maximumBlockSize % (BATCH_SIZE * MODEL_INPUT_SIZE) != 0 && (int) spec.maximumBlockSize > (BATCH_SIZE * MODEL_INPUT_SIZE)) {
initSamples = (BATCH_SIZE * MODEL_INPUT_SIZE) + (2 * (int) spec.maximumBlockSize) + MAX_INFERENCE_TIME + MODEL_LATENCY;
} else {
initSamples = (BATCH_SIZE * MODEL_INPUT_SIZE) + (int) spec.maximumBlockSize + MAX_INFERENCE_TIME + MODEL_LATENCY;
}

calculateLatency((int) spec.maximumBlockSize);
}

void InferenceManager::processBlock(juce::AudioBuffer<float> &buffer) {
if (init) {
init_samples += buffer.getNumSamples();
if (init && init_samples >= MODEL_INPUT_SIZE + (int) spec.maximumBlockSize + MAX_INFERENCE_TIME + MODEL_LATENCY) init = false;
bufferCount += buffer.getNumSamples();
if (bufferCount >= initSamples) init = false;
}
for (int sample = 0; sample < buffer.getNumSamples(); ++sample) {
sendRingBuffer.pushSample(buffer.getSample(0, sample), 0);
Expand All @@ -45,8 +51,9 @@ void InferenceManager::processBlock(juce::AudioBuffer<float> &buffer) {
}
auto &sendBuffer = inferenceThread.getModelInputBuffer();
// add the available samples from the sendBuffer otherwise with if MODEL_INPUT_SIZE % spec.maximumBlockSize != 0 samples get stuck there
if (sendRingBuffer.getAvailableSamples(0) + sendBuffer.getAvailableSamples(0) >= MODEL_INPUT_SIZE) {
while (sendRingBuffer.getAvailableSamples(0) > 0) {
if (sendRingBuffer.getAvailableSamples(0) + sendBuffer.getAvailableSamples(0) >= (BATCH_SIZE * MODEL_INPUT_SIZE)) {
int rest = (sendRingBuffer.getAvailableSamples(0) + sendBuffer.getAvailableSamples(0)) % (BATCH_SIZE * MODEL_INPUT_SIZE);
while (sendRingBuffer.getAvailableSamples(0) > rest) {
sendBuffer.pushSample(sendRingBuffer.popSample(0), 0);
}
if (!inferenceThread.startThread(juce::Thread::Priority::highest)) {
Expand Down
3 changes: 2 additions & 1 deletion source/dsp/inference/InferenceManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class InferenceManager {

private:
bool init = true;
int init_samples = 0;
int bufferCount = 0;
int initSamples = 0;

InferenceThread inferenceThread;
RingBuffer sendRingBuffer;
Expand Down
25 changes: 17 additions & 8 deletions source/dsp/inference/InferenceThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ InferenceThread::~InferenceThread() {
}

void InferenceThread::prepareToPlay(const juce::dsp::ProcessSpec &spec) {
for (size_t i = 0; i < MODEL_INPUT_SIZE_BACKEND; i++) {
for (size_t i = 0; i < processedModelInput.size(); i++) {
processedModelInput[i] = 0.f;
}
for (size_t i = 0; i < MODEL_OUTPUT_SIZE_BACKEND; i++) {
for (size_t i = 0; i < rawModelOutputBuffer.size(); i++) {
rawModelOutputBuffer[i] = 0.f;
}

rawModelInput.initialise(1, (int) spec.sampleRate * 6); // TODO how big does the ringbuffer need to be?
processedModelOutput.initialise(1, (int) spec.sampleRate * 6); // TODO how big does the ringbuffer need to be?

maxInferencesPerBlock = std::max((unsigned int) 1, (spec.maximumBlockSize / MODEL_INPUT_SIZE) + 1);
// TODO think about calculate maxInferencesPerBlock
maxInferencesPerBlock = std::max((unsigned int) 1, (spec.maximumBlockSize / (BATCH_SIZE * MODEL_INPUT_SIZE)) + 1);

onnxProcessor.prepareToPlay();
torchProcessor.prepareToPlay();
Expand All @@ -38,22 +39,30 @@ void InferenceThread::run() {

void InferenceThread::inference() {

size_t numInferences = (size_t) (rawModelInput.getAvailableSamples(0) / MODEL_INPUT_SIZE);
size_t numInferences = (size_t) (rawModelInput.getAvailableSamples(0) / (MODEL_INPUT_SIZE * BATCH_SIZE));
numInferences = std::min(numInferences, maxInferencesPerBlock);

for (size_t i = 0; i < numInferences; i++) {

// pre-processing
for (size_t j = 1; j < MODEL_INPUT_SIZE_BACKEND; j++) {
processedModelInput[j-1] = processedModelInput[j];
for (size_t batch = 0; batch < BATCH_SIZE; batch++) {
if (batch == 0) {
for (size_t j = 1; j < MODEL_INPUT_SIZE_BACKEND; j++) {
processedModelInput[j-1] = processedModelInput[((BATCH_SIZE-1) * MODEL_INPUT_SIZE_BACKEND) + j];
}
} else {
for (size_t j = 1; j < MODEL_INPUT_SIZE_BACKEND; j++) {
processedModelInput[(batch*MODEL_INPUT_SIZE_BACKEND) + (j-1)] = processedModelInput[((batch-1)*MODEL_INPUT_SIZE_BACKEND) + j];
}
}
processedModelInput[(batch*MODEL_INPUT_SIZE_BACKEND) + (MODEL_INPUT_SIZE_BACKEND-1)] = rawModelInput.popSample(0);
}
processedModelInput[MODEL_INPUT_SIZE_BACKEND-1] = rawModelInput.popSample(0);

// actual inference
processModel();

// post-processing
for (size_t j = 0; j < MODEL_OUTPUT_SIZE_BACKEND; j++) {
for (size_t j = 0; j < BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND; j++) {
processedModelOutput.pushSample(rawModelOutputBuffer[j], 0);
}
}
Expand Down
4 changes: 2 additions & 2 deletions source/dsp/inference/InferenceThread.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class InferenceThread : public juce::Thread {

RingBuffer rawModelInput;
RingBuffer processedModelOutput;
std::array<float, MODEL_OUTPUT_SIZE_BACKEND> rawModelOutputBuffer;
std::array<float, MODEL_INPUT_SIZE_BACKEND> processedModelInput;
std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND> rawModelOutputBuffer;
std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND> processedModelInput;

std::atomic<InferenceBackend> currentBackend {LIBTORCH};

Expand Down
13 changes: 7 additions & 6 deletions source/dsp/inference/backends/LibtorchProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@ LibtorchProcessor::~LibtorchProcessor() {
}

void LibtorchProcessor::prepareToPlay() {
inputs.clear();
inputs.push_back(torch::zeros(MODEL_INPUT_SHAPE_LIBTORCH));
}

void LibtorchProcessor::processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output) {

void LibtorchProcessor::processBlock(std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND>& output) {
// Create input tensor object from input data values and shape
frame = torch::from_blob(input.data(), MODEL_INPUT_SIZE_BACKEND).reshape(MODEL_INPUT_SHAPE_LIBTORCH);
inputs[0] = frame;
inputTensor = torch::from_blob(input.data(), (const long long) input.size()).reshape(MODEL_INPUT_SHAPE_LIBTORCH);

inputs[0] = inputTensor;

// Run inference
at::Tensor outputTensor = module.forward(inputs).toTensor();
outputTensor = module.forward(inputs).toTensor();

// Extract the output tensor data
for (size_t i = 0; i < MODEL_OUTPUT_SIZE_BACKEND; i++) {
for (size_t i = 0; i < BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND; i++) {
output[i] = outputTensor[(int64_t) i].item<float>();
}
}
5 changes: 3 additions & 2 deletions source/dsp/inference/backends/LibtorchProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ class LibtorchProcessor {
~LibtorchProcessor();

void prepareToPlay();
void processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output);
void processBlock(std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND>& output);

private:
std::string filepath = MODELS_PATH_PYTORCH;
std::string modelname = MODEL_LIBTORCH;

torch::jit::script::Module module;

at::Tensor frame;
at::Tensor inputTensor;
at::Tensor outputTensor;
std::vector<torch::jit::IValue> inputs;
};

Expand Down
4 changes: 2 additions & 2 deletions source/dsp/inference/backends/OnnxRuntimeProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void OnnxRuntimeProcessor::prepareToPlay() {
inputShape = MODEL_INPUT_SHAPE_ONNX;
}

void OnnxRuntimeProcessor::processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output) {
void OnnxRuntimeProcessor::processBlock(std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND>& output) {

// Create input tensor object from input data values and shape
const Ort::Value inputTensor = Ort::Value::CreateTensor<float> (memory_info,
Expand All @@ -39,7 +39,7 @@ void OnnxRuntimeProcessor::processBlock(std::array<float, MODEL_INPUT_SIZE_BACKE
}

// Extract the output tensor dat
for (size_t i = 0; i < MODEL_OUTPUT_SIZE_BACKEND; i++) {
for (size_t i = 0; i < BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND; i++) {
output[i] = outputTensors[0].GetTensorMutableData<float>()[i];
}
}
2 changes: 1 addition & 1 deletion source/dsp/inference/backends/OnnxRuntimeProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class OnnxRuntimeProcessor {
~OnnxRuntimeProcessor();

void prepareToPlay();
void processBlock(std::array<float, MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, MODEL_OUTPUT_SIZE_BACKEND>& output);
void processBlock(std::array<float, BATCH_SIZE * MODEL_INPUT_SIZE_BACKEND>& input, std::array<float, BATCH_SIZE * MODEL_OUTPUT_SIZE_BACKEND>& output);

private:
std::string filepath = MODELS_PATH_ONNX;
Expand Down

0 comments on commit ae99bcf

Please sign in to comment.