We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
R1-V/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer.py
Lines 700 to 703 in 2e4d05d
这里应该是: reward_kwargs[key].extend([example[key]])
其实这个bug在per_device_batch_size <= num_generation时不会引发太多实际问题。VLLMTrainer设置了一个RepeatSampler,这个sampler已经会把每个dataset item重复num_generation次,不需要我们自己复制。那么按照原来的方案,实际上是item复制了num_generation * num_generation次。 当per_device_batch_size > num_generation时,会导致这个rank上,第二个实际item【每num_generation个item其实对应的是原来的一个item】,也就是num_generation + 1开始的rollout获得错误的reward_kwargs,尤其是solution字段。也就是错位问题。
可能表述的不太清楚,需要结合trl的实现逻辑来看,但简单的说,就是错位问题。
The text was updated successfully, but these errors were encountered:
比如gpu_per_node=8,那么实际上只有前7张卡在训模型,这时候设置超参数num_generation=7,pbs=14(由于RepeatRandomSampler的存在,实际上每张卡的实际item是 pbs / num_generation = 14 / 7 = 2),那么这时候,repeat_kwargs[key]的实际长度为2 * 7(repeat sample重复)* 7([example[key]] * self.num_generations ,第二次重复),但rollout一共是14个,前7个为第一个dataset item的rollout,后7个为第二个item的,而repeat_kwargs[key]中的前14个,其实都来自于第一个dataset item的key,那么就会错位了。
Sorry, something went wrong.
No branches or pull requests
R1-V/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer.py
Lines 700 to 703 in 2e4d05d
这里应该是:
reward_kwargs[key].extend([example[key]])
其实这个bug在per_device_batch_size <= num_generation时不会引发太多实际问题。VLLMTrainer设置了一个RepeatSampler,这个sampler已经会把每个dataset item重复num_generation次,不需要我们自己复制。那么按照原来的方案,实际上是item复制了num_generation * num_generation次。
当per_device_batch_size > num_generation时,会导致这个rank上,第二个实际item【每num_generation个item其实对应的是原来的一个item】,也就是num_generation + 1开始的rollout获得错误的reward_kwargs,尤其是solution字段。也就是错位问题。
可能表述的不太清楚,需要结合trl的实现逻辑来看,但简单的说,就是错位问题。
The text was updated successfully, but these errors were encountered: