Skip to content

Commit

Permalink
Support strided output for GEMM. Support strided input / output for r…
Browse files Browse the repository at this point in the history
…ms / layer norm.
  • Loading branch information
liuliu committed Aug 18, 2024
1 parent ddd3f97 commit 76d2b6c
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 20 deletions.
2 changes: 1 addition & 1 deletion lib/nnc/ccv_nnc_easy.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ static inline CCV_WARN_UNUSED(ccv_nnc_tensor_view_t) ccv_nnc_get_tensor_view(con
{
if (CCV_IS_TENSOR_VIEW(tensor))
return (ccv_nnc_tensor_view_t)*(const ccv_nnc_tensor_view_t*)tensor;
ccv_nnc_tensor_view_t tv;
ccv_nnc_tensor_view_t tv = {0};
memcpy(&tv, tensor, sizeof(ccv_nnc_tensor_t));
return tv;
}
Expand Down
22 changes: 15 additions & 7 deletions lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
assert(a_cols == w_rows);
assert(w_cols == b_cols);
int adim[CCV_NNC_MAX_DIM_ALLOC];
int astride[CCV_NNC_MAX_DIM_ALLOC];
int astride[CCV_NNC_MAX_DIM_ALLOC] = {0};
memcpy(adim, a->info.dim, sizeof(adim));
if (CCV_IS_TENSOR_VIEW(a))
memcpy(astride, a->stride, sizeof(astride));
Expand All @@ -53,6 +53,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
astride[0] = astride[1];
}
}
int bstride[CCV_NNC_MAX_DIM_ALLOC] = {0};
if (CCV_IS_TENSOR_VIEW(b))
memcpy(bstride, b->stride, sizeof(bstride));
const int is_transpose_w = ccv_nnc_is_matrix_transpose(w->info, cmd.info.blas.transpose_b);
int biasdim[CCV_NNC_MAX_DIM_ALLOC] = {0};
int biasstride[CCV_NNC_MAX_DIM_ALLOC] = {0};
Expand Down Expand Up @@ -122,16 +125,21 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
w_batch_inc = 0;
@autoreleasepool {
// Fake the astride at a_nd - 3. For this one, we have flexibility to change fo v2 GEMM kernels.
const int astride_a_nd_3 = astride[a_nd - 3];
const int a_batch_stride = astride[a_nd - 3];
// Only fake it if it is larger than the expected compact stride.
if (astride_a_nd_3 > astride[a_nd - 2] * adim[a_nd - 2])
if (a_batch_stride > astride[a_nd - 2] * adim[a_nd - 2])
astride[a_nd - 3] = astride[a_nd - 2] * adim[a_nd - 2];
const int b_batch_stride = bstride[b_nd - 3];
// Only fake it if it is larger than the expected compact stride.
if (b_batch_stride > bstride[b_nd - 2] * b->info.dim[b_nd - 2])
bstride[b_nd - 3] = bstride[b_nd - 2] * b->info.dim[b_nd - 2];
const int is_contiguous =
(!CCV_IS_TENSOR_VIEW(a) || ccv_nnc_tensor_view_is_contiguous(adim, astride)) &&
(!CCV_IS_TENSOR_VIEW(w) || ccv_nnc_tensor_view_is_contiguous(w->info.dim, w->stride)) &&
(!CCV_IS_TENSOR_VIEW(b) || ccv_nnc_tensor_view_is_contiguous(b->info.dim, b->stride)) &&
(!CCV_IS_TENSOR_VIEW(b) || ccv_nnc_tensor_view_is_contiguous(b->info.dim, bstride)) &&
(bias ? (!CCV_IS_TENSOR_VIEW(bias) || ccv_nnc_tensor_view_is_contiguous(bias->info.dim, bias->stride)) : 1);
astride[a_nd - 3] = astride_a_nd_3;
astride[a_nd - 3] = a_batch_stride;
bstride[b_nd - 3] = b_batch_stride;

const int a_datatype = CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX ? ((a->info.datatype & 0xff) << 12) : a->info.datatype;
const int w_datatype = CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX ? ((w->info.datatype & 0xff) << 12) : w->info.datatype;
Expand Down Expand Up @@ -380,9 +388,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.register_float = (is_upcast ? 1 : 0),

.batch_dimension = b_batch_size,
.batch_stride_a = a_batch_size > 1 ? ccv_max(astride_a_nd_3, b_rows * w_rows) : 0,
.batch_stride_a = a_batch_size > 1 ? ccv_max(a_batch_stride, b_rows * w_rows) : 0,
.batch_stride_b = w_batch_size > 1 ? b_cols * w_rows : 0,
.batch_stride_c = b_batch_size > 1 ? b_rows * b_cols : 0,
.batch_stride_c = b_batch_size > 1 ? ccv_max(b_batch_stride, b_rows * b_cols) : 0,
.batch_stride_d = bias_batch_size > 1 ? b_cols : 0,
};
ccv_nnc_mfa_prepare_gemm(context, params);
Expand Down
16 changes: 14 additions & 2 deletions lib/nnc/cmd/norm/mps/ccv_nnc_layer_norm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,17 @@ static int _ccv_nnc_layer_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_
}
}

const int a_batch_stride = at.stride[0];
// Only fake it if it is larger than the expected compact stride.
if (a_batch_stride > at.stride[1] * at.info.dim[1])
at.stride[0] = at.stride[1] * at.info.dim[1];
const int b_batch_stride = bt.stride[0];
// Only fake it if it is larger than the expected compact stride.
if (b_batch_stride > bt.stride[1] * bt.info.dim[1])
bt.stride[0] = bt.stride[1] * bt.info.dim[1];
if (use_mfa) {
if (!CCV_IS_TENSOR_CONTIGUOUS(inputs[0]) ||
!CCV_IS_TENSOR_CONTIGUOUS(outputs[0]) ||
if (!(!CCV_IS_TENSOR_VIEW(&at) || ccv_nnc_tensor_view_is_contiguous(at.info.dim, at.stride)) ||
!(!CCV_IS_TENSOR_VIEW(&bt) || ccv_nnc_tensor_view_is_contiguous(bt.info.dim, bt.stride)) ||
!CCV_IS_TENSOR_CONTIGUOUS(outputs[1]) ||
!CCV_IS_TENSOR_CONTIGUOUS(outputs[2]) ||
(elementwise_affine &&
Expand All @@ -85,6 +93,8 @@ static int _ccv_nnc_layer_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_
fallback_reason = "Strided.";
}
}
at.stride[0] = a_batch_stride;
bt.stride[0] = b_batch_stride;

int channel_count;
const int channel_groups = 1;
Expand Down Expand Up @@ -146,6 +156,8 @@ static int _ccv_nnc_layer_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_

.batch_dims_data = { 0 },
.batch_dims_scale_translation = { 0 },
.src_batch_stride = ccv_max(a_batch_stride, channel_count * sequence_count),
.dst_batch_stride = ccv_max(b_batch_stride, channel_count * sequence_count),
};

// Create a null-terminated list of batch dimensions.
Expand Down
16 changes: 14 additions & 2 deletions lib/nnc/cmd/norm/mps/ccv_nnc_rmsnorm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,26 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h
}
}

const int a_batch_stride = at.stride[0];
// Only fake it if it is larger than the expected compact stride.
if (a_batch_stride > at.stride[1] * at.info.dim[1])
at.stride[0] = at.stride[1] * at.info.dim[1];
const int b_batch_stride = bt.stride[0];
// Only fake it if it is larger than the expected compact stride.
if (b_batch_stride > bt.stride[1] * bt.info.dim[1])
bt.stride[0] = bt.stride[1] * bt.info.dim[1];
if (use_mfa) {
if (!CCV_IS_TENSOR_CONTIGUOUS(inputs[0]) ||
!CCV_IS_TENSOR_CONTIGUOUS(outputs[0]) ||
if (!(!CCV_IS_TENSOR_VIEW(&at) || ccv_nnc_tensor_view_is_contiguous(at.info.dim, at.stride)) ||
!(!CCV_IS_TENSOR_VIEW(&bt) || ccv_nnc_tensor_view_is_contiguous(bt.info.dim, bt.stride)) ||
!CCV_IS_TENSOR_CONTIGUOUS(outputs[1]) ||
!CCV_IS_TENSOR_CONTIGUOUS(inputs[1]))
{
use_mfa = false;
fallback_reason = "Strided.";
}
}
at.stride[0] = a_batch_stride;
bt.stride[0] = b_batch_stride;

int channel_count;
const int channel_groups = 1;
Expand Down Expand Up @@ -127,6 +137,8 @@ static int _ccv_nnc_rmsnorm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t h

.batch_dims_data = { 0 },
.batch_dims_scale_translation = { 0 },
.src_batch_stride = ccv_max(a_batch_stride, channel_count * sequence_count),
.dst_batch_stride = ccv_max(b_batch_stride, channel_count * sequence_count),
};

// Create a null-terminated list of batch dimensions.
Expand Down
31 changes: 23 additions & 8 deletions lib/nnc/mfa/ccv_nnc_mfa_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ mfa::normalization::hash::hash(ccv_nnc_mfa_normalization_params_t params) {
scale_translation_batched = params.scale_translation_batched;
normalization_type = (Type)params.normalization_type;
reuse_saved_statistics = params.reuse_saved_statistics;
src_batch_stride = params.src_batch_stride;
dst_batch_stride = params.dst_batch_stride;
}

bool mfa::normalization::hash::operator==(const mfa::normalization::hash& hash) const {
Expand All @@ -176,7 +178,9 @@ bool mfa::normalization::hash::operator==(const mfa::normalization::hash& hash)
(elementwise_affine == hash.elementwise_affine) &&
(scale_translation_batched == hash.scale_translation_batched) &&
(normalization_type == hash.normalization_type) &&
(reuse_saved_statistics == hash.reuse_saved_statistics);
(reuse_saved_statistics == hash.reuse_saved_statistics) &&
(src_batch_stride == hash.src_batch_stride) &&
(dst_batch_stride == hash.dst_batch_stride);
}

std::ostream& operator<<(std::ostream& os, const mfa::normalization::hash& hash) {
Expand All @@ -189,7 +193,9 @@ std::ostream& operator<<(std::ostream& os, const mfa::normalization::hash& hash)
os << " .elementwise_affine = " << bool(hash.elementwise_affine) << ',';
os << " .scale_translation_batched = " << bool(hash.scale_translation_batched) << ',';
os << " .normalization_type = " << hash.normalization_type << ',';
os << " .reuse_saved_statistics = " << bool(hash.reuse_saved_statistics) << " ";
os << " .reuse_saved_statistics = " << bool(hash.reuse_saved_statistics) << ",";
os << " .src_batch_stride = " << hash.src_batch_stride << ',';
os << " .dst_batch_stride = " << hash.dst_batch_stride << ' ';
os << "}";
return os;
}
Expand All @@ -207,6 +213,7 @@ std::size_t std::hash<mfa::normalization::hash>::operator()(const mfa::normaliza
combine_64(seed, pack_64(simd::uint2 { hash.channel_count, hash.channel_groups }));
combine_64(seed, pack_64(simd::uint2 { hash.sequence_count, *reinterpret_cast<const uint32_t*>(&hash.epsilon) }));
combine_32(seed, pack_32(simd::uchar4 { hash.elementwise_affine, hash.scale_translation_batched, hash.normalization_type, hash.reuse_saved_statistics }));
combine_64(seed, pack_64(simd::uint2 { hash.src_batch_stride, hash.dst_batch_stride }));
return seed;
}

Expand Down Expand Up @@ -242,9 +249,9 @@ kernel void normalization(
) {
uint threadgroup_index = tgid.z * sequence_count + tgid.x;
{
uint io_offset = threadgroup_index * channel_count + lid;
source += io_offset;
destination += io_offset;
uint io_offset = tgid.x * channel_count + lid;
source += tgid.z * src_batch_stride + io_offset;
destination += tgid.z * dst_batch_stride + io_offset;
}
#if ELEMENTWISE_AFFINE
channel_scales += lid;
Expand Down Expand Up @@ -355,9 +362,9 @@ kernel void normalization(
) {
uint threadgroup_index = tgid.z * sequence_count + tgid.x;
{
uint io_offset = threadgroup_index * channel_count + lid;
source += io_offset;
destination += io_offset;
uint io_offset = tgid.x * channel_count + lid;
source += tgid.z * src_batch_stride + io_offset;
destination += tgid.z * dst_batch_stride + io_offset;
}
#if ELEMENTWISE_AFFINE
channel_scales += lid;
Expand Down Expand Up @@ -533,6 +540,14 @@ kernel void normalization(
uint x_dim = (hash.sequence_count + sample_count - 1) / sample_count * sample_count;
this->grid_size = MTL::Size(x_dim, hash.channel_groups, 1);
}

defines += "constant uint src_batch_stride = ";
defines += std::to_string(hash.src_batch_stride) + ";";
defines += "\n";

defines += "constant uint dst_batch_stride = ";
defines += std::to_string(hash.dst_batch_stride) + ";";
defines += "\n";

defines += "constant ushort threadgroup_size = ";
defines += std::to_string(threadgroup_size) + ";";
Expand Down
4 changes: 4 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ typedef struct {

uint32_t batch_dims_data[CCV_NNC_MAX_DIM_ALLOC];
uint32_t batch_dims_scale_translation[CCV_NNC_MAX_DIM_ALLOC];
uint32_t src_batch_stride;
uint32_t dst_batch_stride;
} ccv_nnc_mfa_normalization_params_t;

#ifdef __cplusplus
Expand Down Expand Up @@ -43,6 +45,8 @@ class hash {
uint8_t scale_translation_batched;
Type normalization_type;
uint8_t reuse_saved_statistics;
uint32_t src_batch_stride;
uint32_t dst_batch_stride;

hash(ccv_nnc_mfa_normalization_params_t);

Expand Down

0 comments on commit 76d2b6c

Please sign in to comment.