diff --git a/include/dali/core/tensor_shape.h b/include/dali/core/tensor_shape.h index b2399370e3..1921af36a7 100644 --- a/include/dali/core/tensor_shape.h +++ b/include/dali/core/tensor_shape.h @@ -641,12 +641,17 @@ struct TensorListShapeBase { */ template void set_tensor_shape(int64_t sample, const SampleShape &sample_shape) { - detail::check_compatible_ndim::value>(); - assert(static_cast(dali::size(sample_shape)) == static_cast(sample_dim())); + constexpr int rhs_sample_ndim = compile_time_size::value; + detail::check_compatible_ndim(); + constexpr bool is_scalar = sample_ndim == 0 || rhs_sample_ndim == 0; assert(sample >= 0 && sample < nsamples && "Sample index out of range"); - int64_t base = sample_dim() * sample; - for (int i = 0; i < sample_dim(); i++) { - shapes[base + i] = sample_shape[i]; + assert(static_cast(dali::size(sample_shape)) == static_cast(sample_dim())); + assert(static_cast(shapes.size()) == nsamples * sample_dim() && "shapes size mismatch"); + if constexpr (!is_scalar) { + int64_t base = sample_dim() * sample; + for (int i = 0; i < sample_dim(); i++) { + shapes[base + i] = sample_shape[i]; + } } }