diff --git a/loki/transformations/single_column/tests/test_scc_vector.py b/loki/transformations/single_column/tests/test_scc_vector.py index 66238d29b..96378d869 100644 --- a/loki/transformations/single_column/tests/test_scc_vector.py +++ b/loki/transformations/single_column/tests/test_scc_vector.py @@ -573,20 +573,30 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block """ fcode_kernel = """ - subroutine some_kernel(start, end, nlon, flag0, flag1) + subroutine some_kernel(start, end, nlon, flag0, flag1, flag2) implicit none integer, intent(in) :: nlon, start, end - logical, intent(in) :: flag0, flag1 + logical, intent(in) :: flag0, flag1, flag2 real, dimension(nlon) :: work integer :: jl - if(flag0)then + if (flag0) then call some_other_kernel() - elseif(flag1)then + elseif (flag1) then + do jl=start,end + work(jl) = 1. + enddo + elseif (flag2) then do jl=start,end work(jl) = 1. + work(jl) = 2. + enddo + else + do jl=start,end + work(jl) = 41. + work(jl) = 42. enddo endif @@ -595,7 +605,7 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block routine = Subroutine.from_source(fcode_kernel, frontend=frontend) - # check whether pipeline can be applied and works as expected + # check whether pipeline can be applied and works as expected scc_pipeline = SCCVectorPipeline( horizontal=horizontal, vertical=vertical, block_dim=blocking, directive='openacc', trim_vector_sections=trim_vector_sections @@ -611,3 +621,11 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block assert isinstance(conditional.else_body[0].body[0], ir.Comment) assert isinstance(conditional.else_body[0].body[1], ir.Loop) assert conditional.else_body[0].body[1].pragma[0].content.lower() == 'loop vector' + + # Check that all else-bodies have been wrapped + else_bodies = conditional.else_bodies + assert(len(else_bodies) == 3) + for body in else_bodies: + assert isinstance(body[0], ir.Comment) + assert isinstance(body[1], ir.Loop) + assert body[1].pragma[0].content.lower() == 'loop vector' diff --git a/loki/transformations/single_column/vector.py b/loki/transformations/single_column/vector.py index 2764c019f..64782aaf0 100644 --- a/loki/transformations/single_column/vector.py +++ b/loki/transformations/single_column/vector.py @@ -141,7 +141,9 @@ def extract_vector_sections(cls, section, horizontal): # as 'Conditional's rely on the fact that the first element of the 'else_body' # (if 'has_elseif') is a Conditional itself if separator.has_elseif and separator.else_body: - subsec_else = cls.extract_vector_sections(separator.else_body[0].body, horizontal) + subsec_else = [] + for ebody in separator.else_bodies: + subsec_else += cls.extract_vector_sections(ebody, horizontal) else: subsec_else = cls.extract_vector_sections(separator.else_body, horizontal) if subsec_else: