Skip to content

Commit

Permalink
Merge branch 'main' of github.com:DeepLink-org/DIOPI.dev into zmz/fix…
Browse files Browse the repository at this point in the history
…_memory_leak
  • Loading branch information
hellozmz committed Aug 19, 2024
2 parents 5ad58a6 + b86d7ef commit eda003b
Show file tree
Hide file tree
Showing 24 changed files with 954 additions and 32 deletions.
15 changes: 8 additions & 7 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@
"shape": ((), (16,), (72,),
(2, 11856), (2, 741, 80), (4, 4, 16, 20),
(0,), (4, 0), (9, 0, 16)),
"gen_fn": dict(fn='Genfunc.uniform', low=0, high=1),
},
{
"ins": ['weight'],
Expand Down Expand Up @@ -1119,7 +1120,7 @@
),

'pointwise_op_abs_input': dict(
name=['log', 'log2', 'log10', 'sqrt', 'rsqrt'],
name=['log', 'log2', 'log10', 'log1p', 'sqrt', 'rsqrt'],
interface=['torch'],
is_inplace=True,
dtype=[np.float16, np.float32, np.float64],
Expand All @@ -1137,7 +1138,7 @@
),

'log_integer_input': dict(
name=['log', 'log2', 'log10'],
name=['log', 'log2', 'log10', 'log1p'],
interface=['torch'],
dtype=[np.int16, np.int32, np.int64, np.uint8, np.int8],
tensor_para=dict(
Expand All @@ -1154,7 +1155,7 @@
),

'log_zero_input': dict(
name=['log', 'log2', 'log10'],
name=['log', 'log2', 'log10', 'log1p'],
interface=['torch'],
dtype=[np.float16, np.float32, np.float64,
np.int16, np.int32, np.int64,
Expand All @@ -1173,7 +1174,7 @@
),

'log_neg_input': dict(
name=['log', 'log2', 'log10'],
name=['log', 'log2', 'log10', 'log1p'],
interface=['torch'],
dtype=[np.float16, np.float32, np.float64,
np.int16, np.int32, np.int64,
Expand Down Expand Up @@ -5236,7 +5237,7 @@
},
{
"ins": ['max_exp_avg_sq'],
"shape": [None, None, (4, 8), (12, 4, 8)],
"shape": [(), (16,), (4, 8), (12, 4, 8)],
"gen_fn": 'Genfunc.rand',
},
]
Expand Down Expand Up @@ -6020,13 +6021,13 @@
{
"ins": ['input'],
"shape": ((8, 0), (0, 128), (256, 8)),
"dtype": [np.float32, np.float16, np.float64],
"dtype": [np.float16, np.float32, np.float64],
"gen_fn": 'Genfunc.randn',
},
{
"ins": ['mat2'],
"shape": ((0, 128), (128, 128), (8, 0)),
"dtype": [np.float16, np.float64, np.float32],
"dtype": [np.float16, np.float32, np.float64],
"gen_fn": 'Genfunc.randn',
},
],
Expand Down
4 changes: 4 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ def log10(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiLog10", promote_type(input, Dtype.float32))


def log1p(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiLog1p", promote_type(input, Dtype.float32))


def erf(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiErf", promote_type(input, Dtype.float32))

Expand Down
10 changes: 9 additions & 1 deletion impl/ascend/aclnn/adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ struct IsBoolStdArray<std::array<bool, N>> : std::true_type {};

inline aclIntArray* createAclIntArrayFromIntVector(const std::vector<int64_t>& vec) { return ::aclCreateIntArray(vec.data(), vec.size()); }

inline aclTensorList* createAclTensorListFromAclTensorVector(const std::vector<aclTensor*>& tensorsVec) {
return ::aclCreateTensorList(tensorsVec.data(), tensorsVec.size());
}

inline aclTensorList* createAclTensorListFromAscendTensorVector(const std::vector<AscendTensor>& tensorsVec) {
std::vector<const aclTensor*> tList(tensorsVec.size());
for (size_t i = 0; i < tensorsVec.size(); i++) {
Expand All @@ -175,7 +179,11 @@ inline aclTensorList* createAclTensorListFromConstDiopiTensorVector(const std::v

template <class T, class U = std::remove_cv_t<std::remove_reference_t<T>>>
decltype(auto) convertType(T&& param) {
if constexpr (std::is_same_v<U, AscendTensor>) {
if constexpr (std::is_same_v<U, aclTensor*>) {
return std::forward<T>(param);
} else if constexpr (std::is_same_v<U, std::vector<aclTensor*>>) {
return createAclTensorListFromAclTensorVector(std::forward<T>(param));
} else if constexpr (std::is_same_v<U, AscendTensor>) {
return createAclTensorFromAscendTensor(std::forward<T>(param));
} else if constexpr (std::is_same_v<U, diopiTensorHandle_t> || std::is_same_v<U, diopiConstTensorHandle_t>) {
return createAclTensorFromDiopiTensor(std::forward<T>(param));
Expand Down
102 changes: 102 additions & 0 deletions impl/ascend/ascend_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

#include "ascend_tensor.hpp"

// #include <algorithm>
#include <array>
#include <cstdint>
#include <mutex>
#include <numeric>
#include <utility>

#include "common/debug.hpp"
Expand Down Expand Up @@ -82,6 +84,106 @@ AscendTensor& AscendTensor::asStrided(const std::vector<int64_t>& shape, const s
return *this;
}

AscendTensor& AscendTensor::permute(std::vector<int64_t> dims) {
ASCEND_CHECK_ABORT(this->dim() == dims.size(), "permute dims does not match the tensor dims.");

std::vector<int64_t> newShape(dims.size(), 0);
std::vector<int64_t> newStride(dims.size(), 0);

for (size_t i = 0; i < dims.size(); i++) {
newShape[i] = this->shape(dims[i]);
newStride[i] = this->stride(dims[i]);
}

this->shape_ = newShape;
this->stride_ = newStride;

return *this;
}

AscendTensor& AscendTensor::expand(std::vector<int64_t> shape) {
ASCEND_CHECK_ABORT(shape.size() >= this->dim(),
"the number of sizes provided[% ld] must be greater or eaqual to the number of dimensions of the tensor[% ld].",
shape.size(),
this->dim());

// todo: dim() == 0
int64_t expandDims = shape.size() - this->shape().size();
std::vector<int64_t> tShapeExp(expandDims, 0);
auto tShape = this->shape();
tShapeExp.insert(tShapeExp.end(), tShape.begin(), tShape.end());
std::vector<int64_t> newShape = shape;

for (int64_t i = 0; i < newShape.size(); i++) {
if (newShape[i] < 0 && i < expandDims) {
ASCEND_CHECK_ABORT(false, "The expanded size of the tensor (%ld) isn't allowed in a leading, non-existing dimension %ld", newShape[i], i);
}

if (i >= expandDims) {
if (newShape[i] == -1) {
newShape[i] = tShapeExp[i];
} else {
ASCEND_CHECK_ABORT(tShapeExp[i] == 1 || newShape[i] == tShapeExp[i],
"The expanded size of the tensor (%ld) must match the existing size (%ld) at non-singleton dimension %ld.",
newShape[i],
tShapeExp[i],
i);
}
}
}

int64_t numElem = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies<>());
std::vector<int64_t> newStride(expandDims, 0);
auto tStride = this->stride();
newStride.insert(newStride.end(), tStride.begin(), tStride.end());
for (int64_t i = expandDims; i < shape.size(); i++) {
if (shape[i] == -1 || shape[i] == tShapeExp[i]) {
continue;
} else {
newStride[i] = 0;
}
}

this->numel_ = numElem;
this->shape_ = newShape;
this->stride_ = newStride;

return *this;
}

AscendTensor& AscendTensor::resize(const std::vector<int64_t>& shape) {
int64_t numElem = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
std::vector<int64_t> stride(shape.size(), 1);
for (int64_t j = shape.size() - 2; j >= 0; j--) {
stride[j] = stride[j + 1] * shape[j + 1];
}

this->numel_ = numElem;
this->shape_ = shape;
this->stride_ = stride;

return *this;
}
AscendTensor& AscendTensor::select(int64_t dim, int64_t index) {
auto shape = this->shape();
auto stride = this->stride();

ASCEND_CHECK_ABORT(dim >= 0 && dim < shape.size(), "selected dim [%ld] execeed the tensor dims [%ld].", dim, shape.size());

if (dim < shape.size() - 1) {
int64_t offset = dim * shape[dim] * stride[dim];
this->storageOffset_ = offset;
}
this->numel_ /= shape[dim];

shape.erase(shape.begin() + dim);
stride.erase(stride.begin() + dim);
this->shape_ = shape;
this->stride_ = stride;

return *this;
}

AscendTensor& AscendTensor::unsqueeze(int dim) {
// Note: `channels_last` tensor uses this will become uncontiguous
// which is same with pytorch
Expand Down
4 changes: 4 additions & 0 deletions impl/ascend/ascend_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ class AscendTensor final {
AscendTensor& asStrided(const std::vector<int64_t>& shape, const std::vector<int64_t>& stride);
AscendTensor& unsqueeze(int dim);
AscendTensor& view(const std::vector<int64_t>& shape);
AscendTensor& resize(const std::vector<int64_t>& shape);
AscendTensor& select(int64_t dim, int64_t index);
AscendTensor& permute(std::vector<int64_t> dims);
AscendTensor& expand(std::vector<int64_t> shape);

private:
// diopi origin tensor
Expand Down
4 changes: 4 additions & 0 deletions impl/ascend/common/acloprunner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ diopiError_t makeOnesLike(diopiContextHandle_t ctx, diopiTensorHandle_t* out, di

diopiTensorHandle_t hostToDevice(diopiContextHandle_t ctx, diopiConstTensorHandle_t src);

AscendTensor hostToDeviceAsync(diopiContextHandle_t ctx, const AscendTensor& hostTensor);

AscendTensor deviceToHostSync(diopiContextHandle_t ctx, const AscendTensor& deviceTensor);

inline std::vector<int64_t> calcStrides(int ndims, diopiSize_t size, diopiMemoryFormat_t format = diopiMemoryFormat_t::Contiguous) {
std::vector<int64_t> strides;
strides.resize(ndims);
Expand Down
43 changes: 43 additions & 0 deletions impl/ascend/common/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,49 @@ diopiTensorHandle_t hostToDevice(diopiContextHandle_t ctx, diopiConstTensorHandl
}
}

AscendTensor hostToDeviceAsync(diopiContextHandle_t ctx, const AscendTensor& hostTensor) {
diopiDevice_t device = hostTensor.device();

if (device == diopi_host) {
diopiTensorHandle_t dst;
diopiSize_t size{hostTensor.shape().data(), hostTensor.dim()};
diopiSize_t stride{hostTensor.stride().data(), (int64_t)hostTensor.stride().size()};
diopiDtype_t dtype = hostTensor.dtype();
diopiRequireTensor(ctx, &dst, &size, &stride, dtype, diopi_device);
const void* srcPtr = hostTensor.data();
void* dstPtr;
diopiGetTensorData(dst, &dstPtr);
diopiStreamHandle_t stream;
diopiGetStream(ctx, &stream);
int64_t elemsize = hostTensor.numel() * hostTensor.elemsize();
CALL_ACLRT(aclrtMemcpyAsync(dstPtr, elemsize, const_cast<void*>(srcPtr), elemsize, ACL_MEMCPY_HOST_TO_DEVICE, stream));
return AscendTensor(dst);
} else {
return hostTensor;
}
}

AscendTensor deviceToHostSync(diopiContextHandle_t ctx, const AscendTensor& deviceTensor) {
if (deviceTensor.device() == diopi_device) {
diopiTensorHandle_t dst;
diopiSize_t size{deviceTensor.shape().data(), deviceTensor.dim()};
diopiSize_t stride{deviceTensor.stride().data(), (int64_t)deviceTensor.stride().size()};
diopiDtype_t dtype = deviceTensor.dtype();
diopiRequireTensor(ctx, &dst, &size, &stride, dtype, diopi_host);
const void* srcPtr = deviceTensor.data();
void* dstPtr;
diopiGetTensorData(dst, &dstPtr);
diopiStreamHandle_t stream;
diopiGetStream(ctx, &stream);
int64_t elemsize = deviceTensor.numel() * deviceTensor.elemsize();
CALL_ACLRT(aclrtMemcpyAsync(dstPtr, elemsize, const_cast<void*>(srcPtr), elemsize, ACL_MEMCPY_DEVICE_TO_HOST, stream));
CALL_ACLRT(aclrtSynchronizeStream(stream));
return AscendTensor(dst);
} else {
return deviceTensor;
}
}

static diopiError_t choiceDtype(const std::set<diopiDtype_t>& opSupportedDtypes, diopiDtype_t* dtype) {
if (opSupportedDtypes.find(diopi_dtype_float32) != opSupportedDtypes.end()) {
*dtype = diopi_dtype_float32;
Expand Down
9 changes: 9 additions & 0 deletions impl/ascend/convert_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,12 @@
- diopiMaxPool2dBackward:
tensor_dtype:
indices: (int64)->int32

- diopiBatchNormStats:
dtype: (float64)->float32

- diopiBatchNormGatherStatsWithCounts:
dtype: (float64)->float32

- diopiBatchNormBackwardReduce:
dtype: (float64)->float32
Loading

0 comments on commit eda003b

Please sign in to comment.