Skip to content

Commit

Permalink
[DAP/Whisper] Connecting the preprocessor and model.
Browse files Browse the repository at this point in the history
  • Loading branch information
taiqzheng committed Jun 28, 2024
1 parent 8a30be9 commit 5d08f61
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 109 deletions.
8 changes: 7 additions & 1 deletion examples/BuddyWhisper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,17 @@ SET_TARGET_PROPERTIES(
PROPERTIES
LINKER_LANGUAGE C)

add_executable(buddy-whisper-run whisper-main.cpp)
set(BUDDY_WHISPER_FILES
whisper-main.h
whisper-main.cpp
)

add_executable(buddy-whisper-run ${BUDDY_WHISPER_FILES})
target_link_directories(buddy-whisper-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR})

set(BUDDY_WHISPER_LIBS
WHISPER
BuddyLibDAP
mlir_c_runner_utils
omp
)
Expand Down
7 changes: 2 additions & 5 deletions examples/BuddyWhisper/import-whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,14 @@
# Retrieve the Whisper model path from environment variables.
model_path = os.environ.get("WHISPER_MODEL_PATH")
if model_path is None:
raise EnvironmentError(
"The environment variable 'WHISPER_MODEL_PATH' is not set or is invalid."
)
model_path = "openai/whisper-base"

# Initialize the tokenizer and model from the specified model path.
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
model.config.use_cache = False

dataset_path = os.environ.get("AUDIO_DATASET_PATH")
ds = load_dataset(dataset_path, "clean", split="validation")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[1]["audio"]
input_features = processor(
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
Expand Down
107 changes: 4 additions & 103 deletions examples/BuddyWhisper/whisper-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,107 +14,7 @@
//
//===----------------------------------------------------------------------===//

#include <buddy/Core/Container.h>
#include <buddy/LLM/TextContainer.h>
#include <chrono>
#include <cmath>
#include <cstddef>
#include <filesystem>
#include <fstream>
#include <iostream>
using namespace buddy;

constexpr size_t ParamsSize = 99148800;
constexpr size_t MaxVocabSize = 51865;
constexpr size_t MaxTokenLength = 448;
constexpr size_t HiddenSize = 512;

/// Declare Whisper forward function.
extern "C" void _mlir_ciface_forward(MemRef<float, 3> *, MemRef<float, 1> *,
MemRef<float, 3> *, MemRef<size_t, 2> *);

// -----------------------------------------------------------------------------
// Helper Functions
// -----------------------------------------------------------------------------

/// Capture input message.
void getUserInput(std::string &inputStr) {
std::cout << "\nPlease send a message:" << std::endl;
std::cout << ">>> ";
getline(std::cin, inputStr);
std::cout << std::endl;
}

/// Print [Log] label in bold blue format.
void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; }

/// Print information for each iteration.
void printIterInfo(size_t iterIdx, std::string str, double time) {
std::cout << "\033[32;1m[Iteration " << iterIdx << "] \033[0m";
std::cout << "Token: " << str << " | "
<< "Time: " << time << "s" << std::endl;
}

/// Load parameters into data container.
void loadParameters(const std::string &paramFilePath,
MemRef<float, 1> &params) {
const auto loadStart = std::chrono::high_resolution_clock::now();
std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary);
if (!paramFile.is_open()) {
throw std::runtime_error("[Error] Failed to open params file!");
}
printLogLabel();
std::cout << "Loading params..." << std::endl;
printLogLabel();
std::cout << "Params file: " << std::filesystem::canonical(paramFilePath)
<< std::endl;
paramFile.read(reinterpret_cast<char *>(params.getData()),
sizeof(float) * (params.getSize()));
if (paramFile.fail()) {
throw std::runtime_error("Error occurred while reading params file!");
}
paramFile.close();
const auto loadEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> loadTime =
loadEnd - loadStart;
printLogLabel();
std::cout << "Params load time: " << (double)(loadTime.count()) / 1000
<< "s\n"
<< std::endl;
}

void loadAudio(const std::string &paramFilePath, MemRef<float, 3> &params) {
const auto loadStart = std::chrono::high_resolution_clock::now();
std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary);
if (!paramFile.is_open()) {
throw std::runtime_error("[Error] Failed to open input_features file!");
}
printLogLabel();
std::cout << "Loading input_features..." << std::endl;
printLogLabel();
std::cout << "input_features file: "
<< std::filesystem::canonical(paramFilePath) << std::endl;

paramFile.read(reinterpret_cast<char *>(params.getData()),
sizeof(float) * (params.getSize()));

if (paramFile.fail()) {
throw std::runtime_error("Error occurred while reading params file!");
}
paramFile.close();
const auto loadEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> loadTime =
loadEnd - loadStart;
printLogLabel();
std::cout << "input_features load time: " << (double)(loadTime.count()) / 1000
<< "s\n"
<< std::endl;
}

/// Find the index of the max value.
int findMaxIndex(const float *start, const float *end) {
return std::distance(start, std::max_element(start, end));
}
#include "whisper-main.h"

// -----------------------------------------------------------------------------
// Whisper Inference Main Entry
Expand Down Expand Up @@ -147,10 +47,11 @@ int main() {

/// Fill data into containers
// - Output: register vocabulary.
// - Parameters: load parameters from the `arg0` file into the container.
// - Parameters: generate audioInput from rawAudioData.
outputContainer.loadVocab(vocabDir);
loadParameters(paramsDir, paramsContainer);
loadAudio(input_featuresDir, audioInput);
rawAudioData = std::move(MemRef<double, 1>(rawSpeech, inputShape));
dap::WhisperPreprocess(&rawAudioData, &audioInput);

/// Run Whisper Inference
// - Perform the forward function.
Expand Down
139 changes: 139 additions & 0 deletions examples/BuddyWhisper/whisper-main.h

Large diffs are not rendered by default.

0 comments on commit 5d08f61

Please sign in to comment.