Skip to content

Commit

Permalink
[NVIDIA] Refine API difference of all_gather_object (#6892)
Browse files Browse the repository at this point in the history
* Refine API difference of `all_gather_object`

In `paddle`, no need to pre-allocate `object_list` to equal size of group size

* Add spaces between EN and CN chars in doc

fix doc hook `insert-whitespace-between-cn-and-en-char`
  • Loading branch information
anderson101866 authored Oct 20, 2024
1 parent 6e749a9 commit db52460
Showing 1 changed file with 25 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,35 @@ torch.distributed.all_gather_object(object_list, obj, group=None)
paddle.distributed.all_gather_object(object_list, obj, group=None)
```

功能一致,参数完全一致,具体如下:
功能一致,参数几乎完全一致。但`object_list`的初始化方式不同。具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| -------- | ------------ | --------------------------------------------- |
| object_list | object_list | 表示用于保存聚合结果的列表。 |
| object_list | | 表示用于保存聚合结果的列表。需初始化成与 `group` 等长的列表 |
| | object_list | 表示用于保存聚合结果的列表。需初始化成空列表 |
| obj | obj | 表示待聚合的对象。 |
| group | group | 表示执行该操作的进程组实例。 |

### 转写示例

```python
# PyTorch 写法
import torch.distributed as dist
object_list = [{}, {}] # NOTE: world size is 2
if dist.get_rank() == 0:
obj = {"foo": [1, 2, 3]}
else:
obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)

# Paddle 写法
import paddle.distributed as dist
object_list = [] # No need to pre-allocate
if dist.get_rank() == 0:
obj = {"foo": [1, 2, 3]}
else:
obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)
```

0 comments on commit db52460

Please sign in to comment.