tf.data 最佳实践摘要
具体理解参考:https://tensorflow.juejin.im/performance/datasets_performance.html
以下是设计输入管道的最佳实践总结:
- 使用 prefetch 转换来合并训练和开销的工作。 特别是,我们建议在输入管道的末端添加 prefetch(n)(其中 n 是训练步骤消耗的元素/批次数),以将 CPU 上执行的转换与加速器上的训练合并。
- 通过设置 num_parallel_calls 参数来并行化 map 转换。我们建议使用可用的 CPU 内核数量作为其值。
- 如果要使用 batch 转换将预处理元素组合到批处理中,我们建议使用融合的 map_and_batch 转换;特别是在你使用大批量数据的情况下。
- 如果您正在处理云端存储的数据和/或需要反序列化的数据,我们建议使用 parallel_interleave 转换来重叠读取(和反序列化)来自不同文件的数据。
- 向传递给 map 转换的轻量用户定义函数进行矢量化,以分摊与调度和执行函数相关的开销。
- 如果你的数据可以放入内存,在第一个迭代次数期间使用 cache 转换将其缓存在内存中,这样后续的迭代次数可以避免产生与读取,解析和转换相关的开销。
- 如果预处理增加了数据的大小,我们建议首先应用 interleave,prefetch,和 shuffle 如果可以的话)以减少内存占用量。
- 我们建议在 repeat 转换之前应用 shuffle 转换,理想情况下使用融合的 shuffle_and_repeat 转换。
tf.data.Dataset.shuffle用于打乱数据
tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解
参考https://juejin.im/post/5b855d016fb9a01a1a27d035
首先,Dataset会取所有数据的前buffer_size数据项,填充 buffer,如下图
然后,从buffer中随机选择一条数据输出,比如这里随机选中了item 7,那么buffer中item 7对应的位置就空出来了
然后,从Dataset中顺序选择最新的一条数据填充到buffer中,这里是item 10
然后在从Buffer中随机选择下一条数据输出。
需要说明的是,这里的数据项item,并不只是单单一条真实数据,如果有batch size
,则一条数据项item包含了batch size
条真实数据。
shuffle是防止数据过拟合的重要手段,然而不当的buffer size,会导致shuffle无意义,具体可以参考这篇Importance of buffer_size in shuffle()