diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc index 7a044209677d30..222e918ae540b7 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc @@ -40,9 +40,11 @@ ProcessMesh GetSubProcessMesh(const ProcessMesh& mesh, int64_t axis) { std::vector process_ids; for (int64_t i = 0; i < shape_of_axis; ++i) { coord[axis] = i; - int64_t rank = coord.back(); - for (int64_t j = static_cast(coord.size() - 2); j >= 0; --j) { - rank += coord[j] * mesh.dim_size(j + 1); + int64_t rank = 0; + int64_t degree = 1; + for (int64_t j = static_cast(coord.size() - 1); j >= 0; --j) { + rank += coord[j] * degree; + degree *= mesh.dim_size(j); } process_ids.emplace_back(mesh.process_ids()[rank]); } diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_3d_global_mesh_reshard.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_3d_global_mesh_reshard.py index bdc256a8a6493b..9f15b4c36c234d 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_3d_global_mesh_reshard.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_3d_global_mesh_reshard.py @@ -64,8 +64,18 @@ def test_basic(self): verbose=True, ) + def test_3d_mesh_with_any_status(self): + dense_tensor = paddle.ones(shape=[2, 6], dtype='float32') + dist_tensor = dist.shard_tensor( + dense_tensor, + self._global_mesh, + [dist.Replicate(), dist.Shard(0), dist.Replicate()], + ) + np.testing.assert_equal(dist_tensor._local_shape, [1, 6]) + def run_test_case(self): self.test_basic() + self.test_3d_mesh_with_any_status() if __name__ == '__main__':