Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rule for mixed precision training #152

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open

Conversation

CarloLucibello
Copy link
Member

Implements a fundamental strategy for training big models.
In this implementation, the high precision weights are part of the optimiser's state.

The optimiser introduced here should be coupled to a loss scaling strategy (not in this PR, and probably not ever in this repo) to obtain robust mixed precision training.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@CarloLucibello CarloLucibello requested a review from mcabbott July 26, 2023 12:57
@darsnack
Copy link
Member

This will need a custom adjust too

@CarloLucibello
Copy link
Member Author

Good to go?

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me

src/rules.jl Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more minor thing

src/rules.jl Show resolved Hide resolved
@mcabbott
Copy link
Member

Would like to ask for some time to read this closely.

Haven't understood why it doesn't use OptimiserChain. Not sure that exposing MixedPrecision{Float32} as the official way to specify the type is so nice.

@darsnack
Copy link
Member

darsnack commented Jul 27, 2023

It’s unfortunate that the API here doesn’t use OptimiserChain like AccumGrad, but I think it is unavoidable. We need to invoke the inner optimizer’s apply! with higher precision which we could do by promoting x and dx as the first thing in the chain (along with #151). But then we still need subtract! to happen at higher precision and the result synced back to the high precision copy of x in MixedPrecision‘s state. So doing it would require a big change to update/subtract!.

Maybe I’m not seeing it, and indeed we should think about this feature carefully before releasing it. It’s an important one to get right.

It makes me think that OptimiserChain is not the best API in general. Yes, it makes writing the rule easier. But it doesn’t work well for cases like this, and it has a potential foot gun where you can place the rule in the wrong spot in a chain. Wrapping seems more intuitive as an API (“whatever this rule that I’m wrapping does, do that in precision T” vs. grokking how the update transforms in a chain).

@darsnack
Copy link
Member

Apologies for the train of comments.

I guess one distinction here is a gradient transformation vs. an optimizer transformation. Something like AccumGrad is a gradient transformation. But something like a scheduler or MixedPrecision is an optimizer transformation—where we transform/augment an optimizer's state, parameters, or invocation before running the rule. The arrow goes in one direction. You can write a gradient transformation rule as an optimizer transformation rule, but not vice-versa.

CarloLucibello and others added 2 commits July 30, 2023 07:26
Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
@CarloLucibello
Copy link
Member Author

Not sure that exposing MixedPrecision{Float32} as the official way to specify the type is so nice.

Can be MixedPrecision(Float32, opt) instead

@CarloLucibello
Copy link
Member Author

@mcabbott good to go?

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2023

I have questions but haven't had time to read up more.

First, this setup starts with the low-precision model, and temporarily stores a higher-precision copy for the purpose of accumulation without overflow. Is this standard? There's no very easy way to get back the high-precision model.

Somehow I thought the default was the reverse, to start and end with the high-precision one, and treat the low-precision model as a temporary step for cheaper gradients. This could not, I think, be implemented as a rule here.

Second, I thought it was conventional to scale the loss used for low-precision steps, and then scale the gradients. There's no sign of that here. Am I mis-remembering? Scaling the gradients could be done by composing with Descent but I'm not entirely sure that's the right place. And perhaps if it's standard it should be made easy.

Comment on lines +816 to +817
adjust(o::MixedPrecision, eta::Real) = MixedPrecision(adjust(o.opt, eta))
adjust(o::MixedPrecision; kw...) = MixedPrecision(adjust(o.opt; kw...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't these forget T?

```
"""
struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule
opt::O
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be rule?

Suggested change
opt::O
rule::O

In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`,
then `x` is updated with the value of `xT`.

[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 .
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 .
# Reference
[1] Micikevicius et al., "Mixed Precision Training", ICLR 2018, https://arxiv.org/abs/1710.03740 .

xT = subtract!(xT, dx′)
if maywrite(x)
x .= xT
dx′ = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is correct.

But perhaps weird things will happen if you try to compose it, e.g. OptimiserChain(MixedPrecision(...), ClipGrad()). If so then we should make sure such things give an error.

x .= xT
dx′ = nothing
else
dx′ = x .- eltype(x).(xT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this path, should the subtraction happen in high or low precision, does it matter?

This is the sort of place that I worry about scaling & the range of Float16. But haven't thought hard.

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2023

Just to sketch another possibility, Flux could instead wrap pairs of low+high precision copies of the same model:

bimodel = MixedPrec(Float16, model32)  # makes a Float16 copy, stores a scale for loss.
bimodel(x)  # calls variant matching eltype(x)? Default Float16?

opt_state = setup(Adam(), bimodel)  # special method?
gs = gradient(loss, bimodel)  # this could be a special method which scales the loss?
update!(opt_state, bimodel, gs[1]) # this knows about the scale, updates both halves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants