- 收集每个GPU上的输出
在分布式训练时,每个GPU都会有一部分数据,当我们需要使用全部的数据进行计算时,我们需要收集所有GPU的tensor。
比如两个GPU,第一个GPU有16组数据,第二个GPU有16组数据, 在进行对比学习计算时,我们需要收集所有的输出来增加负样本的数量。
我们可以使用tensors_from_all = self.all_gather(my_tensor)
比如:
def training_step(self, batch, batch_idx):
outputs = self(batch)
...
all_outputs = self.all_gather(outputs, sync_grads=True)
loss = contrastive_loss_fn(all_outputs, ...)
return loss