DataLoader

 
from torch.utils.data import DataLoader 

```
Init signature:
DataLoader(
    dataset: torch.utils.data.dataset.Dataset[+T_co],
    batch_size: Optional[int] = 1,
    shuffle: Optional[bool] = None,
    sampler: Union[torch.utils.data.sampler.Sampler, Iterable, NoneType] = None,
    batch_sampler: Union[torch.utils.data.sampler.Sampler[List], Iterable[List], NoneType] = None,
    num_workers: int = 0,
    collate_fn: Optional[Callable[[List[~T]], Any]] = None)
    ....
```


 

    

 

    

 
如何将多个样本数据(从 dataset 中获取的)合并成一个批次(batch)
    

默认行为

 
如果不指定 collate_fn,DataLoader 会使用默认的合并函数,其逻辑如下:

如果数据是数值张量(如 torch.Tensor)、NumPy 数组或类似结构,会直接按第 0 维度堆叠(torch.stack)。

如果数据是可变长度的(如列表、字符串或非对齐的张量),会以列表形式返回(不堆叠)。

如果数据是字典或元组,会对每个字段/元素递归应用上述规则。
    

假设 dataset 返回的是张量 (x, y)

 
# 单个样本:x.shape=(3, 4), y.shape=(5,)
# 默认 collate_fn 会将 batch_size=2 的数据合并为:
#   x_batch.shape=(2, 3, 4)  # 自动堆叠
#   y_batch.shape=(2, 5)      # 自动堆叠

自定义 collate_fn

 
当默认行为不满足需求时,可以通过 collate_fn 自定义合并逻辑。常见场景包括:

处理可变长度数据(如文本、音频):
    
    用填充(padding)将所有样本对齐到相同长度。
    
    例如:将不同长度的句子填充到相同长度,并生成注意力掩码(attention mask)。
    
过滤异常数据:
    
    在合并前检查数据有效性(如跳过 None 或损坏的样本)。
    
复杂数据结构:
    
    处理嵌套的字典、元组,或非张量数据(如自定义对象)

 


 


 


 


 


 

  

 


参考