You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
There is an accuracy issue with the current version of Pytorch's implementation of torch.compile and Deepspeed's use of a separate checkpoint function (which was derived from an older version of this Pytorch function). We have encountered and issue which is expressed in an incorrect preservation of random state for the precomputation of checkpointing for dropout nodes. in our issue the state is correctly by the Deepspeed checkpoint function, but the actual seed values passed down by Pytorch are different, leading to different dropout results between the compute in the fwd step and the re-compute in the bwd step.
we encountered this issue on Bert tiny training running using ZeRO1, where over time there is a divergence in the loss values.
in our example, after 40000 steps we got a loss of 6.09375 vs an expected 5.125 which we get with torch compile disabled for checkpointing.
here is a graph describing the deviation in loss over time
Due to the nature of the issue (see below) we believe this issue will affect other workloads as well. the current Deepspeed Master disables torch compile on the checkpointing function explicitly (as a Workaround for other issues) so this issue will not be visible to users unless they try to enable torch compile on the checkpoint function but will need to be addressed before this workaround can be removed (which is desirable for performance reasons).
investigation and background:
Our internal investigation indicates that the Pytorch implementation for supporting compiling checkpointing using torch compile works differently than before (previously storing the RNG state before the FWD compute and restoring the same state before the BWD recompute). The new implementation will wrap random ops in forward and backward in special ops: run_and_save_rng_state and run_with_rng_state (specifically in the function functionalize_rng_ops https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L615 ) . the issue is that this special treatment is defined in Pytorch to apply to the Pytorch checkpoint function (torch.utils.checkpoint) only and not the Deepspeed checkpoint function.
note: switching to use the Pytorch checkpoint function called instead of the Deepspeed function will make the issue go away.
To Reproduce
Steps to reproduce the behavior:
run Bert tiny using zero1 (we ran on an 8 card Gaudi setup, but issue is reproducible on single card as well) for 40000 steps.
over time the loss values will deviate between the two runs significantly.
notes:
initially the loss values are identical and diverge over time. the time it takes to diverge can be minimized by increasing the learning rate substantially.
setting dropout to 0 will make the issue go away as at least in our case this issue is expressed in dropout node calculations.
Expected behavior
It is expected that when using the torch compile the results will be the same or at least close to each other.
ds_report output
Please run ds_report to give us details about your setup.
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
Ubuntu 22
x8 gaudi2 cards (reproducible with single card as well)
Python 3.10
The text was updated successfully, but these errors were encountered:
Describe the bug
There is an accuracy issue with the current version of Pytorch's implementation of torch.compile and Deepspeed's use of a separate checkpoint function (which was derived from an older version of this Pytorch function). We have encountered and issue which is expressed in an incorrect preservation of random state for the precomputation of checkpointing for dropout nodes. in our issue the state is correctly by the Deepspeed checkpoint function, but the actual seed values passed down by Pytorch are different, leading to different dropout results between the compute in the fwd step and the re-compute in the bwd step.
we encountered this issue on Bert tiny training running using ZeRO1, where over time there is a divergence in the loss values.
in our example, after 40000 steps we got a loss of 6.09375 vs an expected 5.125 which we get with torch compile disabled for checkpointing.
here is a graph describing the deviation in loss over time
Due to the nature of the issue (see below) we believe this issue will affect other workloads as well. the current Deepspeed Master disables torch compile on the checkpointing function explicitly (as a Workaround for other issues) so this issue will not be visible to users unless they try to enable torch compile on the checkpoint function but will need to be addressed before this workaround can be removed (which is desirable for performance reasons).
investigation and background:
Our internal investigation indicates that the Pytorch implementation for supporting compiling checkpointing using torch compile works differently than before (previously storing the RNG state before the FWD compute and restoring the same state before the BWD recompute). The new implementation will wrap random ops in forward and backward in special ops: run_and_save_rng_state and run_with_rng_state (specifically in the function functionalize_rng_ops https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L615 ) . the issue is that this special treatment is defined in Pytorch to apply to the Pytorch checkpoint function (torch.utils.checkpoint) only and not the Deepspeed checkpoint function.
note: switching to use the Pytorch checkpoint function called instead of the Deepspeed function will make the issue go away.
To Reproduce
Steps to reproduce the behavior:
https://github.com/microsoft/DeepSpeed/blob/f743feca033515fdded50a98093da5a48eb41e74/deepspeed/runtime/activation_checkpointing/checkpointing.py#L945C1-L945C78
("@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue")
notes:
Expected behavior
It is expected that when using the torch compile the results will be the same or at least close to each other.
ds_report output
Please run
ds_report
to give us details about your setup.Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
The text was updated successfully, but these errors were encountered: