Skip to content

Commit

Permalink
Allow build_CS_3D_BLOCK_RC to also have composite sub-schemes.
Browse files Browse the repository at this point in the history
Don't pad transpose kernels with more than 3 lengths

Co-authored-by: Steve Leung <[email protected]>
  • Loading branch information
malcolmroberts and evetsso authored Jun 24, 2022
1 parent 878f63e commit ba13730
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
3 changes: 3 additions & 0 deletions clients/tests/accuracy_test_adhoc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ std::vector<std::vector<size_t>> adhoc_sizes = {
// SBRC 192 with special param
{192, 192, 192},
{192, 84, 84},

// Failure with build_CS_3D_BLOCK_RC
{680, 128, 128},
};

const static std::vector<std::vector<size_t>> stride_range = {{1}};
Expand Down
8 changes: 8 additions & 0 deletions library/src/assignment_policy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,14 @@ void AssignmentPolicy::PadPlan(ExecPlan& execPlan)
// SBCR plans combine higher dimensions in ways that confuse padding
if(u.node.scheme == CS_KERNEL_STOCKHAM_BLOCK_CR)
return;
// transpose kernels don't handle arbitrary strides,
// and with 4 or more lengths either choice of
// padding dim will trigger incorrect behaviour
if((u.node.scheme == CS_KERNEL_TRANSPOSE
|| u.node.scheme == CS_KERNEL_TRANSPOSE_XY_Z
|| u.node.scheme == CS_KERNEL_TRANSPOSE_Z_XY)
&& u.node.length.size() > 3)
return;
}

// Ensure that if we're forced to pad along one dimension
Expand Down
15 changes: 6 additions & 9 deletions library/src/tree_node_3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,13 +473,6 @@ void BLOCKRC3DNode::AssignParams_internal()
node->oDist = node->outStride[2] * node->length[0];
break;
}
case CS_KERNEL_STOCKHAM:
{
node->outStride = node->inStride;
node->oDist = node->iDist;
node->AssignParams();
break;
}
case CS_KERNEL_TRANSPOSE_XY_Z:
{
node->outStride.push_back(1);
Expand All @@ -497,8 +490,12 @@ void BLOCKRC3DNode::AssignParams_internal()
break;
}
default:
// build_CS_3D_BLOCK_RC should not have created any other node types
throw std::runtime_error("Scheme Assertion Failed, unexpected node scheme.");
{
node->outStride = node->inStride;
node->oDist = node->iDist;
node->AssignParams();
break;
}
}
prev_outStride = node->outStride;
prev_oDist = node->oDist;
Expand Down

0 comments on commit ba13730

Please sign in to comment.