-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ortvalue features support for MGX EP #81
Adding ortvalue features support for MGX EP #81
Conversation
if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { | ||
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); | ||
} | ||
allocator = GetRocmAllocator(device.Id()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might be in an odd situation here as our offering has both MIGraphx and ROCm EPs include, thus we should we get both allocators? Did you test this when we build both MIGraphX and ROCm EPs? How does the allocator work for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it with both included USE_ROCM and USE_MIGRAPHX and it's working well. In that case I think it's the same which allocator is used because both of them are for GPU. If you think it can be written better, I can change it or put additional condition for both of them to use specified ROCm or MGX EP.
// make it stream aware | ||
true, | ||
// enable cross stream sharing? | ||
false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something we want to make controllable from he API later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be. I created it watching the implementation for ROCm EP in rocm_execution_provider.cc
. It's quite same, if you think this is the situation in ROCm EP implementation then the answer is yes.
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!
Few questions about this. Overall looks good.
I've added questions/comments. One detail about combined ROCm/MIGraphX EP builds and if you've tested this with both.
also if you can, download and use lintrunner in your env to solve the lint issue. It'll make upstreaming easier
|
Thanks for the review. I changed what was needed and hope gave you the answers. I'll use lintrunner to solve the issues in next days. |
Lint Python format issue has been fixed. |
Have you done a build with both ROCm and MIGraphX EP? Seeing this during baseline unit tests from a clean build (rocm6.3_internal_testing + this branch). Can you resolve this for the linux side and both EPs? ` ` |
I have used onnxruntime_USE_COMPOSABLE_KERNEL=OFF. I suppose that's the difference causing the error because I'm not seeing any errors; all tests have been passed, and I've tested functionality. I'm rebuilding it with onnxruntime_USE_COMPOSABLE_KERNEL=ON now to see what's happening. |
f661e4f
into
rocm6.3_internal_testing
Confirmed build on a stock ROCm 6.3 Onnxruntime image from upstream. Will also cherry pick this into ROCm 6.4 internal_testing so this is part of QA's cycle |
* Adding ourtvalue support for MGX EP --------- authored-by: Uros Petkovic <[email protected]>
* Adding ourtvalue support for MGX EP --------- authored-by: Uros Petkovic <[email protected]>
* Adding ourtvalue support for MGX EP --------- authored-by: Uros Petkovic <[email protected]>
Created PR request with implementation of
ortvalue_from_numpy()
andortvalue_from_shape_and_type()
features for MGX EP on Windows and Linux in order of getting better performance forllama2 int 4
model execution. Some methods have been overridden and some of them implemented similar like it was done in ROCm EP. Implementing these features we significantly decreased amount of time needed for creating and copying tensors, almost whole time is dedicated to GPU now, which caused much better performance in tok/s for our GPUs. Similar option added for ROCM EP.