from torch.utils.data import DataLoader
```
Init signature:
DataLoader(
dataset: torch.utils.data.dataset.Dataset[+T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[torch.utils.data.sampler.Sampler, Iterable, NoneType] = None,
batch_sampler: Union[torch.utils.data.sampler.Sampler[List], Iterable[List], NoneType] = None,
num_workers: int = 0,
collate_fn: Optional[Callable[[List[~T]], Any]] = None)
....
```
|
如何将多个样本数据(从 dataset 中获取的)合并成一个批次(batch)
默认行为
如果不指定 collate_fn,DataLoader 会使用默认的合并函数,其逻辑如下:
如果数据是数值张量(如 torch.Tensor)、NumPy 数组或类似结构,会直接按第 0 维度堆叠(torch.stack)。
如果数据是可变长度的(如列表、字符串或非对齐的张量),会以列表形式返回(不堆叠)。
如果数据是字典或元组,会对每个字段/元素递归应用上述规则。
假设 dataset 返回的是张量 (x, y) # 单个样本:x.shape=(3, 4), y.shape=(5,) # 默认 collate_fn 会将 batch_size=2 的数据合并为: # x_batch.shape=(2, 3, 4) # 自动堆叠 # y_batch.shape=(2, 5) # 自动堆叠 自定义 collate_fn
当默认行为不满足需求时,可以通过 collate_fn 自定义合并逻辑。常见场景包括:
处理可变长度数据(如文本、音频):
用填充(padding)将所有样本对齐到相同长度。
例如:将不同长度的句子填充到相同长度,并生成注意力掩码(attention mask)。
过滤异常数据:
在合并前检查数据有效性(如跳过 None 或损坏的样本)。
复杂数据结构:
处理嵌套的字典、元组,或非张量数据(如自定义对象)
|
|
|
|
|
|
|