Skip to content

Commit

Permalink
common: fix undefined use of memory descriptor strides
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler authored and karturov committed Oct 2, 2024
1 parent 05303ea commit 0306ffd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
14 changes: 10 additions & 4 deletions src/common/gemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ struct gemm_desc_t {
// Simplified accessors that comply to GEMM API
static transpose_t get_trans(const memory_desc_t &md) {
if (!md.ndims) return transpose::notrans; // arbitrary
return md.format_desc.blocking.strides[md.ndims - 1] != 1
return md.dims[md.ndims - 1] != 1
&& md.format_desc.blocking.strides[md.ndims - 1] != 1
? transpose::trans
: transpose::notrans;
}
Expand Down Expand Up @@ -116,9 +117,14 @@ struct gemm_desc_t {
// This assumes that one of the dimensions has strides 1
static dnnl_dim_t get_ld(const memory_desc_t &md) {
auto strides = md.format_desc.blocking.strides;
assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1);
return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1]
: strides[md.ndims - 2];
assert(md.dims[md.ndims - 1] == 1 || strides[md.ndims - 1] == 1
|| md.dims[md.ndims - 2] == 1 || strides[md.ndims - 2] == 1);
switch (get_trans(md)) {
case transpose::trans:
return md.dims[md.ndims - 1] > 1 ? strides[md.ndims - 1] : 1;
default:
return md.dims[md.ndims - 2] > 1 ? strides[md.ndims - 2] : 1;
}
}
// Leading dimension of A.
dnnl_dim_t lda() const { return get_ld(b_desc); }
Expand Down
9 changes: 6 additions & 3 deletions src/common/gemm_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2023 Intel Corporation
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -156,8 +156,11 @@ static inline bool is_md_gemm_compatible_plain_format(

if (blk_desc.inner_nblks != 0) return false;

return (blk_desc.strides[md->ndims - 1] == 1)
|| (!is_dst && blk_desc.strides[md->ndims - 2] == 1);
return (md->dims[md->ndims - 1] == 1
|| blk_desc.strides[md->ndims - 1] == 1)
|| (!is_dst
&& (md->dims[md->ndims - 2] == 1
|| blk_desc.strides[md->ndims - 2] == 1));
}

} // namespace impl
Expand Down

0 comments on commit 0306ffd

Please sign in to comment.