Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] DeepSpeed accuracy issue for torch.compile if activation checkpoint function not compiler disabled #6811

Open
NirSonnenschein opened this issue Dec 1, 2024 · 0 comments
Labels
bug Something isn't working training

Comments

@NirSonnenschein
Copy link
Contributor

Describe the bug
There is an accuracy issue with the current version of Pytorch's implementation of torch.compile and Deepspeed's use of a separate checkpoint function (which was derived from an older version of this Pytorch function). We have encountered and issue which is expressed in an incorrect preservation of random state for the precomputation of checkpointing for dropout nodes. in our issue the state is correctly by the Deepspeed checkpoint function, but the actual seed values passed down by Pytorch are different, leading to different dropout results between the compute in the fwd step and the re-compute in the bwd step.

we encountered this issue on Bert tiny training running using ZeRO1, where over time there is a divergence in the loss values.
in our example, after 40000 steps we got a loss of 6.09375 vs an expected 5.125 which we get with torch compile disabled for checkpointing.

here is a graph describing the deviation in loss over time
Image

Due to the nature of the issue (see below) we believe this issue will affect other workloads as well. the current Deepspeed Master disables torch compile on the checkpointing function explicitly (as a Workaround for other issues) so this issue will not be visible to users unless they try to enable torch compile on the checkpoint function but will need to be addressed before this workaround can be removed (which is desirable for performance reasons).

investigation and background:
Our internal investigation indicates that the Pytorch implementation for supporting compiling checkpointing using torch compile works differently than before (previously storing the RNG state before the FWD compute and restoring the same state before the BWD recompute). The new implementation will wrap random ops in forward and backward in special ops: run_and_save_rng_state and run_with_rng_state (specifically in the function functionalize_rng_ops https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L615 ) . the issue is that this special treatment is defined in Pytorch to apply to the Pytorch checkpoint function (torch.utils.checkpoint) only and not the Deepspeed checkpoint function.
note: switching to use the Pytorch checkpoint function called instead of the Deepspeed function will make the issue go away.

To Reproduce
Steps to reproduce the behavior:

  1. run Bert tiny using zero1 (we ran on an 8 card Gaudi setup, but issue is reproducible on single card as well) for 40000 steps.
  2. run the same workload with torch compile enabled on the checkpointing function. To do this comment out the following line:
    https://github.com/microsoft/DeepSpeed/blob/f743feca033515fdded50a98093da5a48eb41e74/deepspeed/runtime/activation_checkpointing/checkpointing.py#L945C1-L945C78
    ("@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue")
  3. over time the loss values will deviate between the two runs significantly.
    notes:
  4. initially the loss values are identical and diverge over time. the time it takes to diverge can be minimized by increasing the learning rate substantially.
  5. setting dropout to 0 will make the issue go away as at least in our case this issue is expressed in dropout node calculations.

Expected behavior
It is expected that when using the torch compile the results will be the same or at least close to each other.

ds_report output
Please run ds_report to give us details about your setup.

Screenshots
If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • Ubuntu 22
  • x8 gaudi2 cards (reproducible with single card as well)
  • Python 3.10
@NirSonnenschein NirSonnenschein added bug Something isn't working training labels Dec 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant