对于多模态异构的数据读取,存在很多痛点:
- 对于图像等非定长数据,通常采取连续存储,流式读取的方法,例如webdataset;同时对于文本、或者SFT阶段需要经常可视化和编辑的小数据,通常使用随机读取(
__getitem__
)的方法。 - 不同数据集来源复杂,需要配置文件方便管理和设定比例。
- 流式、随机读取均需要引入随机性;且需要恢复进度。
- 需要对不同数据集灵活设置处理函数,并且在配置文件中可以更改,部分处理参数,例如序列长度、图像大小等,在运行时传入才知道,不能在配置文件中写死。
Streaming API采取yaml格式来管理配置不同数据集,使用稍加修改的Omegaconf读取和初始化对象。提供MergedDataset
类来融合不同类型的数据,并产生统一的读取、加载、切分等接口。
from cogdata.streaming import instantiate_from_yaml, to_state, mixed_collate
本模块提供了几个关键函数,用于处理配置文件和数据处理:
-
instantiate_from_yaml(config_path: str, variables: dict = {}) -> Any:
此函数用于从YAML配置文件加载并实例化对象。它接受一个配置文件路径和一个可选的变量字典,这些变量将在配置文件中被解析。返回的是根据YAML文件中定义的指令创建的Python对象。详见创建配置文件章节。 -
to_state(obj: Any) -> dict:
将给定的Python对象转换为状态字典,便于保存当前迭代状态。详见状态保存和恢复章节。 -
mixed_collate(batch: list) -> Any:
用于Dataloader混合成batch数据,得到数据输出格式中的返回格式。
配置文件使用yaml来管理,通过instantiate_from_yaml
来读取,需要注意的是:
- 如果某个yaml的字典对象包含
target
和params
两个键,将会在最终被转化为一个target
对应的类的对象,即target(**params)
. 其他的键值对将变成这个对象的属性。target
可以为任何能import到的类路径,例如cogdata.streaming.MetaDistributedWebDataset
. - 如果某个值为
${variables:vname}
,则它会在最终被使用时解析为instantiate_from_yaml(config_path, variables)
函数调用的variables
字典中对应的值。这个功能是为了处理读数据时依赖训练超参的情况,例如图像大小和序列长度。 - 如果某个值为
${dynamic_objs:vname}
,则它会在最终被使用时解析为instantiate_from_yaml(config_path, variables, dynamic_objs)
函数调用的dynamic_objs
字典中对应的值。与variables不同的是这里可以是任意的object,但是尽量不要使用以防止滥用,导致数据集的读取依赖于难以找到源码的动态传入的函数。 - 如果某个yaml的字典对象包含
include
键,则值为一个yaml路径。最终会将该yaml instantiate的结果作为此处的对象,并添加/覆盖其他的键值对。这个一般用在混合多个数据集的时候。
- 任何Pytorch的Dataset的子类,支持
__len__()
和__getitem__()
. cogdata.streaming.MetaDistributedWebDataset
. webdataset加强版,每个tar带有一个同名的jsonl文件记录附属信息的格式。cogdata.streaming.JsonlIterableDataset
. 类似webdataset,但是每次读一行jsonl。cogdata.streaming.MergedDataset
. 融合多个数据集的IterableDataset
,支持嵌套和按权重采样,也是本库唯一支持的顶层数据集(即使只使用一个数据集,也要包裹一层MergedDataset)。支持创建时参数customized_yield_fn
来处理混合samples为新的sample的情况(例子)
一个混合多种不同类型数据的配置样例。
MergedDataset
返回的数据,期望的每个batch的输出格式一个字典列表,每个字典是一个样本。
例如:
[
# for iterable webdataset
{
'img': tensor1,
'caption': str1,
...,
'__datasetname__': 'my_img_txt_pair_dataset',
'__dprank__': 0,
'__workerid__': 1,
'__seed__', 1234,
'__url__': '/path/t001.tar',
'__key__': '0000001'
},
# for itemized Pytorch Dataset
{
'tokens': tensor1,
'position_ids': tensor2,
...,
'__datasetname__': 'my_txt_binary_dataset',
'__dprank__': 0,
'__workerid__': 1,
'__seed__', 1234,
'__index__': 55
},
...
]
后续读取之后在机器学习程序中灵活处理。
可以使用一个dict来更新训练过数据中记录的状态来持续跟踪数据读取状态。
to_state(sample)
函数会返回一个字典,用来更新状态。
dataloader_states = {}
iterator = iter(dataset)
sample = next(iterator)
dataloader_states.update(to_state(sample))
# ...
# save & reload dataloader_states
iterator = dataset.__iter_from__(dataloader_states)
# iterator will be reloaded to the previous progress
-
注意对于状态,每个数据并行rank和Dataloader worker_id,都会变成不同的key分开存储。因此reload的时候不能改变dp_size 和num_workers。
-
当改变序列长度时(这意味着对于流式数据,特别是有在线随机的情况,分sample的情况会被改变;对于itemize数据暂时不支持),对于__iter_from__方法,需要传入
reload_from_url_level=True
从而在url级别重新加载。 -
重新加载后,对于每个数据集的进度可以保持,但是初始sample各个数据集的顺序可能会发生变化。
在MergedDataset
的构造函数中,有shuffle_unstreaming_subsets
和shuffle_unstreaming_subsets
选项。
shuffle_unstreaming_subsets
对于存在__getitem__
的数据集的shuffle方法,目前是每2000000个都进行某种相同的permute。shuffle_streaming_subsets
对于IterableDataset
的每个url内部储存buffer进行online shuffling。- 选择数据集的时候,根据percent作为权重,每次吐最低累积权重的数据集,保证均匀性。
- 每个epoch后刷新该数据集的random seed。
TODO
TODO