Skip to content

Commit

Permalink
Merge pull request #12683 from ggouaillardet/topic/op_aarch64_sve_ref…
Browse files Browse the repository at this point in the history
…actor

op/aarch64: refactor SVE functions
  • Loading branch information
bosilca authored Jul 17, 2024
2 parents 486ef1b + ba59533 commit 069a8c4
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions ompi/mca/op/aarch64/op_aarch64_functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* reserved.
* Copyright (c) 2019 Arm Ltd. All rights reserved.
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2024 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
*
* $COPYRIGHT$
*
Expand Down Expand Up @@ -140,20 +142,18 @@ _Generic((*(out)), \
struct ompi_datatype_t **dtype, \
struct ompi_op_base_module_1_0_0_t *module) \
{ \
int types_per_step = svcnt(*((type##type_size##_t *) _in)); \
size_t idx = 0, left_over = *count; \
const int types_per_step = svcnt(*((type##type_size##_t *) _in)); \
const int cnt = *count; \
type##type_size##_t *in = (type##type_size##_t *) _in, \
*out = (type##type_size##_t *) _out; \
OP_CONCAT(OMPI_OP_TYPE_PREPEND, type##type_size##_t) vsrc, vdst; \
svbool_t pred = svwhilelt_b##type_size(idx, left_over); \
do { \
for (int idx=0; idx < cnt; idx += types_per_step) { \
svbool_t pred = svwhilelt_b##type_size(idx, cnt); \
vsrc = svld1(pred, &in[idx]); \
vdst = svld1(pred, &out[idx]); \
vdst = OP_CONCAT(OMPI_OP_OP_PREPEND, op##_x)(pred, vdst, vsrc); \
OP_CONCAT(OMPI_OP_OP_PREPEND, st1)(pred, &out[idx], vdst); \
idx += types_per_step; \
pred = svwhilelt_b##type_size(idx, left_over); \
} while (svptest_any(svptrue_b##type_size(), pred)); \
} \
}
#endif

Expand Down Expand Up @@ -308,21 +308,19 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
struct ompi_datatype_t **dtype, \
struct ompi_op_base_module_1_0_0_t *module) \
{ \
int types_per_step = svcnt(*((type##type_size##_t *) _in1)); \
const int types_per_step = svcnt(*((type##type_size##_t *) _in1)); \
type##type_size##_t *in1 = (type##type_size##_t *) _in1, \
*in2 = (type##type_size##_t *) _in2, \
*out = (type##type_size##_t *) _out; \
size_t idx = 0, left_over = *count; \
const int cnt = *count; \
OP_CONCAT(OMPI_OP_TYPE_PREPEND, type##type_size##_t) vsrc, vdst; \
svbool_t pred = svwhilelt_b##type_size(idx, left_over); \
do { \
for (int idx=0; idx < cnt; idx += types_per_step) { \
svbool_t pred = svwhilelt_b##type_size(idx, cnt); \
vsrc = svld1(pred, &in1[idx]); \
vdst = svld1(pred, &in2[idx]); \
vdst = OP_CONCAT(OMPI_OP_OP_PREPEND, op##_x)(pred, vdst, vsrc); \
OP_CONCAT(OMPI_OP_OP_PREPEND, st1)(pred, &out[idx], vdst); \
idx += types_per_step; \
pred = svwhilelt_b##type_size(idx, left_over); \
} while (svptest_any(svptrue_b##type_size(), pred)); \
} \
}
#endif /* defined(GENERATE_SVE_CODE) */

Expand Down

0 comments on commit 069a8c4

Please sign in to comment.