Skip to content

Commit

Permalink
fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 5, 2024
1 parent 701bc3a commit ad0b83f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 20 deletions.
2 changes: 1 addition & 1 deletion operators/cuda/add_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct AddOrMulSharedInput {
auto length_c = tensor_c.NumberOfElement();

T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape());
T* output_data_ac = output_ab.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());
T* output_data_ac = output_ac.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());

if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
return {};
Expand Down
2 changes: 1 addition & 1 deletion operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
#incluce "cuda/transpose_cast.h"
#include "cuda/transpose_cast.h"
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
Expand Down
15 changes: 9 additions & 6 deletions operators/cuda/transpose_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace contrib {

template <typename TIN, typename TOUT>
struct TransposeCast2D {
struct Transpose2DCast {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
Expand All @@ -19,16 +19,19 @@ struct TransposeCast2D {
ortc::Tensor<TOUT>& output) const {
const TIN* input_data = input.Data();
auto shape = input.Shape();
if (shape.size() != 2)
if (shape.size() != 2) {
ORTX_CXX_API_THROW("Input must be a 2D tensor", ORT_RUNTIME_EXCEPTION);
}
size_t n_rows = shape[0];
size_t n_cols = shape[1];
TOUT* output_data = output.Allocate(shape);
size_t n_cols = shape[1];

std::vector<int64_t> new_shape{n_cols, n_rows};
TOUT* output_data = output.Allocate(new_shape);
if (0 == n_rows || 0 == n_cols) {
return {};
}
TransposeCast2DKernel<TIN, TOUT>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
n_rows, n_cols, input_data, output_data);
LaunchTranspose2DCastKernel<TIN, TOUT>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
n_rows, n_cols, input_data, output_data);
return {};
}
};
Expand Down
24 changes: 15 additions & 9 deletions operators/cuda/transpose_cast_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

using namespace Ort::Custom;

#define TILE_DIM 32
#define BLOCK_ROWS 8

template <typename TOUT, typename TIN>
__global__ void TransposeCast2DKernel(TOUT *output_data, const TIN *input_data, int n_rows, int n_cols) {
__global__ void Transpose2DCastKernel(TOUT *output_data, const TIN *input_data, int n_rows, int n_cols) {
__shared__ TIN tile[TILE_DIM][TILE_DIM + 1];

int x = blockIdx.x * TILE_DIM + threadIdx.x;
Expand All @@ -29,23 +32,26 @@ __global__ void TransposeCast2DKernel(TOUT *output_data, const TIN *input_data,
}

template <typename TIN, typename TOUT>
cudaError_t _LaunchTransposeCast2DKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, const TIN* input, TOUT* output) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols,
const TIN* input, TOUT* output) {
dim3 dimGrid((n_cols + TILE_DIM - 1) / TILE_DIM, (n_rows + TILE_DIM - 1) / TILE_DIM, 1);
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
using TTIN = typename contrib::CudaT<TIN>::MappedType;
using TTOUT = typename contrib::CudaT<TOUT>::MappedType;
TransposeCast2DKernel<TTOUT, TTIN><<<gridSize, blockSize, 0, stream>>>(
Transpose2DCastKernel<TTOUT, TTIN><<<dimGrid, dimBlock, TILE_DIM * TILE_DIM + TILE_DIM, stream>>>(
reinterpret_cast<TTOUT*>(output), reinterpret_cast<const TTIN*>(input),
static_cast<int>(n_rows), static_cast<int>(n_cols));
return cudaGetLastError();
}

template <>
cudaError_t LaunchTransposeCast2DKernel<float, ortc::MFloat16>(cudaStream_t stream, size_t n_rows, size_t n_cols, const float* input, ortc::MFloat16* output) {
return _LaunchTransposeCast2DKernel(stream, n_rows, n_cols, , input, output);
cudaError_t LaunchTranspose2DCastKernel<float, ortc::MFloat16>(cudaStream_t stream, size_t n_rows, size_t n_cols,
const float* input, ortc::MFloat16* output) {
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
}

template <>
cudaError_t LaunchTransposeCast2DKernel<ortc::MFloat16, float>(cudaStream_t stream, size_t n_rows, size_t n_cols, const ortc::MFloat16* input, float* output) {
return _LaunchTransposeCast2DKernel(stream, n_rows, n_cols, input, output);
cudaError_t LaunchTranspose2DCastKernel<ortc::MFloat16, float>(cudaStream_t stream, size_t n_rows, size_t n_cols,
const ortc::MFloat16* input, float* output) {
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
}
2 changes: 1 addition & 1 deletion operators/cuda/transpose_cast_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
#include <cuda_runtime.h>

template <typename TIN, typename TOUT>
cudaError_t TransposeCast2DKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, const TIN* input, TOUT* output);
cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, const TIN* input, TOUT* output);
4 changes: 2 additions & 2 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _transpose_cast_cuda(self, itype):
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
self.assertEqualArray(expected, got, atol=1e-5)
assert_almost_equal(expected, got, decimal=5)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_transpose_cast_cuda(self):
Expand All @@ -332,4 +332,4 @@ def test_transpose_cast_cuda(self):


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)

0 comments on commit ad0b83f

Please sign in to comment.