-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BuddyWhisper] Add BuddyWhisper and Conv1d to buddy-mlir. #321
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to update this PR. I have modified and pushed to upstream.
Thanks!
|
||
from buddy.compiler.frontend import DynamoCompiler | ||
|
||
# ===- import-whisper.py -------------------------------------------------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong format.
model_path = "/home/liweijia/whisper-base" | ||
if model_path is None: | ||
raise EnvironmentError( | ||
"The environment variable 'LLAMA_MODEL_PATH' is not set or is invalid." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLAMA -> WHISPER
from buddy.compiler.graph.transform import simply_fuse | ||
|
||
# Retrieve the LLaMA model path from environment variables. | ||
model_path = "/home/liweijia/whisper-base" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do not use absolute path.
model.config.use_cache = False | ||
|
||
ds = load_dataset( | ||
"/home/liweijia/librispeech_asr_dummy", "clean", split="validation" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do not use absolute path.
@@ -0,0 +1,198 @@ | |||
//===- whisper-main.cpp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong format.
from buddy.compiler.graph import GraphDriver | ||
from buddy.compiler.graph.transform import simply_fuse | ||
|
||
# Retrieve the LLaMA model path from environment variables. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLaMA -> Whisper
|
||
for (size_t i = 0; i < this->tokenCnt; i++) { | ||
int id = this->aligned[i]; | ||
if (id == PAD_ID || id == CLS_ID || id == TRAN_ID || id == NOTIMESTAMPS_ID || (id >= 50259 && id <= 50357)) //pad,start,type timestamps and language |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong format, please use clang-format.
constexpr size_t MaxTokenLength = 448; | ||
constexpr size_t HiddenSize = 512; | ||
|
||
/// Declare LLaMA forward function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLaMA -> Whisper
@@ -9,3 +9,6 @@ protobuf | |||
pybind11 == 2.11.1 | |||
torchvision | |||
tabulate | |||
datasets | |||
soundfile | |||
librosa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an empty line here.
No description provided.