diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index c3e65dc4..e5976a3d 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -33,7 +33,7 @@ def bucket_tensor_generator(tensor_generator, bucket_size_mb): """ size_limit = bucket_size_mb * 1024 * 1024 buf_dict = defaultdict(lambda: [[], 0]) - for tensor in tensor_generator: + for tensor in tensor_generator(): if tensor.is_sparse: yield tensor, False continue