diff --git a/apex/contrib/csrc/xentropy/xentropy_kernel.cu b/apex/contrib/csrc/xentropy/xentropy_kernel.cu index 8ef62a334..13c49bdfc 100644 --- a/apex/contrib/csrc/xentropy/xentropy_kernel.cu +++ b/apex/contrib/csrc/xentropy/xentropy_kernel.cu @@ -574,7 +574,7 @@ std::vector host_softmax_xentropy( const Tensor & labels_, const float smoothing, const bool half_to_float){ - if (half_to_float) TORCH_CHECK(input_.scalar_type() == ScalarType::Half,"conversion is supported for Half type only"); + if (half_to_float) TORCH_CHECK(input_.scalar_type() == ScalarType::Half || input_.scalar_type() == ScalarType::BFloat16,"conversion is supported for Half type only"); TORCH_CHECK(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); auto input = input_.contiguous(); @@ -712,7 +712,7 @@ at::Tensor softmax_xentropy_backward_cuda( const float smoothing) { bool half_to_float = grad_loss.scalar_type() != logits.scalar_type(); if (half_to_float) { - TORCH_CHECK((grad_loss.scalar_type() == ScalarType::Float && logits.scalar_type() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); + TORCH_CHECK((grad_loss.scalar_type() == ScalarType::Float && (logits.scalar_type() == ScalarType::Half || logits.scalar_type() == ScalarType::BFloat16)), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); } return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float); }