From 0306ffd839fb64bedd4657ba0976e63f802c229c Mon Sep 17 00:00:00 2001 From: Roy Oursler Date: Tue, 24 Sep 2024 13:13:35 -0700 Subject: [PATCH] common: fix undefined use of memory descriptor strides --- src/common/gemm_types.hpp | 14 ++++++++++---- src/common/gemm_utils.hpp | 9 ++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/common/gemm_types.hpp b/src/common/gemm_types.hpp index c23a7c9f58d..d1f18f91d46 100644 --- a/src/common/gemm_types.hpp +++ b/src/common/gemm_types.hpp @@ -73,7 +73,8 @@ struct gemm_desc_t { // Simplified accessors that comply to GEMM API static transpose_t get_trans(const memory_desc_t &md) { if (!md.ndims) return transpose::notrans; // arbitrary - return md.format_desc.blocking.strides[md.ndims - 1] != 1 + return md.dims[md.ndims - 1] != 1 + && md.format_desc.blocking.strides[md.ndims - 1] != 1 ? transpose::trans : transpose::notrans; } @@ -116,9 +117,14 @@ struct gemm_desc_t { // This assumes that one of the dimensions has strides 1 static dnnl_dim_t get_ld(const memory_desc_t &md) { auto strides = md.format_desc.blocking.strides; - assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1); - return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1] - : strides[md.ndims - 2]; + assert(md.dims[md.ndims - 1] == 1 || strides[md.ndims - 1] == 1 + || md.dims[md.ndims - 2] == 1 || strides[md.ndims - 2] == 1); + switch (get_trans(md)) { + case transpose::trans: + return md.dims[md.ndims - 1] > 1 ? strides[md.ndims - 1] : 1; + default: + return md.dims[md.ndims - 2] > 1 ? strides[md.ndims - 2] : 1; + } } // Leading dimension of A. dnnl_dim_t lda() const { return get_ld(b_desc); } diff --git a/src/common/gemm_utils.hpp b/src/common/gemm_utils.hpp index 65045a7d911..93ec571e325 100644 --- a/src/common/gemm_utils.hpp +++ b/src/common/gemm_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -156,8 +156,11 @@ static inline bool is_md_gemm_compatible_plain_format( if (blk_desc.inner_nblks != 0) return false; - return (blk_desc.strides[md->ndims - 1] == 1) - || (!is_dst && blk_desc.strides[md->ndims - 2] == 1); + return (md->dims[md->ndims - 1] == 1 + || blk_desc.strides[md->ndims - 1] == 1) + || (!is_dst + && (md->dims[md->ndims - 2] == 1 + || blk_desc.strides[md->ndims - 2] == 1)); } } // namespace impl