diff --git a/opal/mca/common/cuda/common_cuda.c b/opal/mca/common/cuda/common_cuda.c index 8d4c7033b32..b8ce5a7bea6 100644 --- a/opal/mca/common/cuda/common_cuda.c +++ b/opal/mca/common/cuda/common_cuda.c @@ -108,6 +108,10 @@ struct cudaFunctionTable { #if OPAL_CUDA_GET_ATTRIBUTES int (*cuPointerGetAttributes)(unsigned int, CUpointer_attribute *, void **, CUdeviceptr); #if OPAL_CUDA_VMM_SUPPORT + int (*cuDevicePrimaryCtxRetain)(CUcontext*, CUdevice); + int (*cuDevicePrimaryCtxGetState)(CUdevice, unsigned int*, int*); + int (*cuMemPoolGetAccess)(CUmemAccess_flags*, CUmemoryPool, CUmemLocation*); + int (*cuDeviceGetAttribute)(int*, CUdevice_attribute, CUdevice); int (*cuDeviceGetCount)(int*); int (*cuMemRelease)(CUmemGenericAllocationHandle); int (*cuMemRetainAllocationHandle)(CUmemGenericAllocationHandle*, void*); @@ -488,6 +492,10 @@ int mca_common_cuda_stage_one_init(void) OPAL_CUDA_DLSYM(libcuda_handle, cuPointerGetAttributes); #endif /* OPAL_CUDA_GET_ATTRIBUTES */ #if OPAL_CUDA_VMM_SUPPORT + OPAL_CUDA_DLSYM(libcuda_handle, cuDevicePrimaryCtxRetain); + OPAL_CUDA_DLSYM(libcuda_handle, cuDevicePrimaryCtxGetState); + OPAL_CUDA_DLSYM(libcuda_handle, cuMemPoolGetAccess); + OPAL_CUDA_DLSYM(libcuda_handle, cuDeviceGetAttribute); OPAL_CUDA_DLSYM(libcuda_handle, cuDeviceGetCount); OPAL_CUDA_DLSYM(libcuda_handle, cuMemRelease); OPAL_CUDA_DLSYM(libcuda_handle, cuMemRetainAllocationHandle); @@ -1745,7 +1753,90 @@ static float mydifftime(opal_timer_t ts_start, opal_timer_t ts_end) { } #endif /* OPAL_ENABLE_DEBUG */ -static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type) +static int mca_common_cuda_check_mpool(CUdeviceptr dbuf, CUmemorytype *mem_type, + int *dev_id) +{ +#if OPAL_CUDA_VMM_SUPPORT + static int device_count = -1; + static int mpool_supported = -1; + CUresult result; + CUmemoryPool mpool; + CUmemAccess_flags flags; + CUmemLocation location; + + if (mpool_supported <= 0) { + if (mpool_supported == -1) { + if (device_count == -1) { + result = cuFunc.cuDeviceGetCount(&device_count); + if (result != CUDA_SUCCESS || (0 == device_count)) { + mpool_supported = 0; /* never check again */ + device_count = 0; + return 0; + } + } + + /* assume uniformity of devices */ + result = cuFunc.cuDeviceGetAttribute(&mpool_supported, + CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, 0); + if (result != CUDA_SUCCESS) { + mpool_supported = 0; + } + } + if (0 == mpool_supported) { + return 0; + } + } + + result = cuFunc.cuPointerGetAttribute(&mpool, + CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, + dbuf); + if (CUDA_SUCCESS != result) { + return 0; + } + + /* check if device has access */ + for (int i = 0; i < device_count; i++) { + location.type = CU_MEM_LOCATION_TYPE_DEVICE; + location.id = i; + result = cuFunc.cuMemPoolGetAccess(&flags, mpool, &location); + if ((CUDA_SUCCESS == result) && + (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags)) { + *mem_type = CU_MEMORYTYPE_DEVICE; + *dev_id = i; + return 1; + } + } + + /* host must have access as device access possibility is exhausted */ + *mem_type = CU_MEMORYTYPE_HOST; + *dev_id = -1; + return 0; +#endif + + return 0; +} + +static int mca_common_cuda_get_primary_context(CUdevice dev_id, CUcontext *pctx) +{ + CUresult result; + unsigned int flags; + int active; + + result = cuFunc.cuDevicePrimaryCtxGetState(dev_id, &flags, &active); + if (CUDA_SUCCESS != result) { + return OPAL_ERROR; + } + + if (active) { + result = cuFunc.cuDevicePrimaryCtxRetain(pctx, dev_id); + return OPAL_SUCCESS; + } + + return OPAL_ERROR; +} + +static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type, + int *dev_id) { #if OPAL_CUDA_VMM_SUPPORT static int device_count = -1; @@ -1775,6 +1866,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type) if (prop.location.type == CU_MEM_LOCATION_TYPE_DEVICE) { *mem_type = CU_MEMORYTYPE_DEVICE; + *dev_id = prop.location.id; cuFunc.cuMemRelease(alloc_handle); return 1; } @@ -1788,6 +1880,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type) if ((CUDA_SUCCESS == result) && (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags)) { *mem_type = CU_MEMORYTYPE_DEVICE; + *dev_id = i; cuFunc.cuMemRelease(alloc_handle); return 1; } @@ -1796,6 +1889,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type) /* host must have access as device access possibility is exhausted */ *mem_type = CU_MEMORYTYPE_HOST; + *dev_id = -1; cuFunc.cuMemRelease(alloc_handle); return 1; @@ -1809,12 +1903,17 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t { int res; int is_vmm = 0; + int is_mpool = 0; CUmemorytype vmm_mem_type = 0; + CUmemorytype mpool_mem_type = 0; CUmemorytype memType = 0; + int vmm_dev_id = -1; + int mpool_dev_id = -1; CUdeviceptr dbuf = (CUdeviceptr)pUserBuf; CUcontext ctx = NULL, memCtx = NULL; - is_vmm = mca_common_cuda_check_vmm(dbuf, &vmm_mem_type); + is_vmm = mca_common_cuda_check_vmm(dbuf, &vmm_mem_type, &vmm_dev_id); + is_mpool = mca_common_cuda_check_mpool(dbuf, &mpool_mem_type, &mpool_dev_id); #if OPAL_CUDA_GET_ATTRIBUTES uint32_t isManaged = 0; @@ -1844,6 +1943,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t } else if (memType == CU_MEMORYTYPE_HOST) { if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE)) { memType = CU_MEMORYTYPE_DEVICE; + } else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE)) { + memType = CU_MEMORYTYPE_DEVICE; } else { /* Host memory, nothing to do here */ return 0; @@ -1864,6 +1965,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t } else if (memType == CU_MEMORYTYPE_HOST) { if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE)) { memType = CU_MEMORYTYPE_DEVICE; + } else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE)) { + memType = CU_MEMORYTYPE_DEVICE; } else { /* Host memory, nothing to do here */ return 0; @@ -1893,14 +1996,18 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t return OPAL_ERROR; } #endif /* OPAL_CUDA_GET_ATTRIBUTES */ - if (is_vmm) { - /* This function is expected to set context if pointer is device - * accessible but VMM allocations have NULL context associated - * which cannot be set against the calling thread */ - opal_output(0, - "CUDA: unable to set context with the given pointer" - "ptr=%p aborting...", dbuf); - return OPAL_ERROR; + if (is_vmm || is_mpool) { + if (OPAL_SUCCESS == + mca_common_cuda_get_primary_context( + is_vmm ? vmm_dev_id : mpool_dev_id, &memCtx)) { + /* As VMM/mempool allocations have no context associated + * with them, check if device primary context can be set */ + } else { + opal_output(0, + "CUDA: unable to set ctx with the given pointer" + "ptr=%p aborting...", pUserBuf); + return OPAL_ERROR; + } } res = cuFunc.cuCtxSetCurrent(memCtx);