Skip to content

Commit

Permalink
Merge pull request #530 from abergeron/fix_fix
Browse files Browse the repository at this point in the history
Fix potential race conditions
  • Loading branch information
nouiz authored Sep 15, 2017
2 parents a18251f + 957301f commit 3d1c382
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
24 changes: 22 additions & 2 deletions src/gpuarray_blas_cuda_cublas.c
Original file line number Diff line number Diff line change
Expand Up @@ -835,14 +835,23 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
return ctx->err->code;
}

GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(Ta, CUDA_WAIT_READ));
if (cuda_wait(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
gpudata_release(Ta);
cuda_exit(ctx);
return ctx->err->code;
}

err = cublasSgemmBatched(h->h,
convT(transA), convT(transB),
M, N, K, &alpha,
(const float **)Aa, lda,
(const float **)Ba, ldb, &beta,
(float **)Ca, ldc, batchCount);
if (cuda_record(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
gpudata_release(Ta);
cuda_exit(ctx);
return ctx->err->code;
}
gpudata_release(Ta);
if (err != CUBLAS_STATUS_SUCCESS) {
cuda_exit(ctx);
Expand Down Expand Up @@ -964,15 +973,26 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
return ctx->err->code;
}

GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(Ta, CUDA_WAIT_READ));
if (cuda_wait(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
gpudata_release(Ta);
cuda_exit(ctx);
return ctx->err->code;
}

err = cublasDgemmBatched(h->h,
convT(transA), convT(transB),
M, N, K, &alpha,
(const double **)Aa, lda,
(const double **)Ba, ldb, &beta,
(double **)Ca, ldc, batchCount);

if (cuda_record(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
gpudata_release(Ta);
cuda_exit(ctx);
return ctx->err->code;
}
gpudata_release(Ta);

if (err != CUBLAS_STATUS_SUCCESS) {
cuda_exit(ctx);
return error_cublas(ctx->err, "cublasDgemmBatched", err);
Expand Down
4 changes: 2 additions & 2 deletions src/gpuarray_buffer_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,8 @@ static void cuda_free(gpudata *d) {
d->ptr + d->sz == next->ptr) {
d->sz = d->sz + next->sz;
d->next = next->next;
cuda_wait(next, CUDA_WAIT_ALL);
cuda_record(d, CUDA_WAIT_ALL);
cuda_waits(next, CUDA_WAIT_ALL, d->ls);
cuda_records(d, CUDA_WAIT_ALL, d->ls);
deallocate(next);
} else {
d->next = next;
Expand Down

0 comments on commit 3d1c382

Please sign in to comment.