-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rule for mixed precision training #152
base: master
Are you sure you want to change the base?
Conversation
This will need a custom |
Good to go? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more minor thing
Would like to ask for some time to read this closely. Haven't understood why it doesn't use OptimiserChain. Not sure that exposing |
It’s unfortunate that the API here doesn’t use Maybe I’m not seeing it, and indeed we should think about this feature carefully before releasing it. It’s an important one to get right. It makes me think that |
Apologies for the train of comments. I guess one distinction here is a gradient transformation vs. an optimizer transformation. Something like |
Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
Can be |
@mcabbott good to go? |
I have questions but haven't had time to read up more. First, this setup starts with the low-precision model, and temporarily stores a higher-precision copy for the purpose of accumulation without overflow. Is this standard? There's no very easy way to get back the high-precision model. Somehow I thought the default was the reverse, to start and end with the high-precision one, and treat the low-precision model as a temporary step for cheaper gradients. This could not, I think, be implemented as a rule here. Second, I thought it was conventional to scale the loss used for low-precision steps, and then scale the gradients. There's no sign of that here. Am I mis-remembering? Scaling the gradients could be done by composing with Descent but I'm not entirely sure that's the right place. And perhaps if it's standard it should be made easy. |
adjust(o::MixedPrecision, eta::Real) = MixedPrecision(adjust(o.opt, eta)) | ||
adjust(o::MixedPrecision; kw...) = MixedPrecision(adjust(o.opt; kw...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't these forget T
?
``` | ||
""" | ||
struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule | ||
opt::O |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be rule?
opt::O | |
rule::O |
In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`, | ||
then `x` is updated with the value of `xT`. | ||
|
||
[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 . | |
# Reference | |
[1] Micikevicius et al., "Mixed Precision Training", ICLR 2018, https://arxiv.org/abs/1710.03740 . |
xT = subtract!(xT, dx′) | ||
if maywrite(x) | ||
x .= xT | ||
dx′ = nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is correct.
But perhaps weird things will happen if you try to compose it, e.g. OptimiserChain(MixedPrecision(...), ClipGrad())
. If so then we should make sure such things give an error.
x .= xT | ||
dx′ = nothing | ||
else | ||
dx′ = x .- eltype(x).(xT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On this path, should the subtraction happen in high or low precision, does it matter?
This is the sort of place that I worry about scaling & the range of Float16. But haven't thought hard.
Just to sketch another possibility, Flux could instead wrap pairs of low+high precision copies of the same model: bimodel = MixedPrec(Float16, model32) # makes a Float16 copy, stores a scale for loss.
bimodel(x) # calls variant matching eltype(x)? Default Float16?
opt_state = setup(Adam(), bimodel) # special method?
gs = gradient(loss, bimodel) # this could be a special method which scales the loss?
update!(opt_state, bimodel, gs[1]) # this knows about the scale, updates both halves. |
Implements a fundamental strategy for training big models.
In this implementation, the high precision weights are part of the optimiser's state.
The optimiser introduced here should be coupled to a loss scaling strategy (not in this PR, and probably not ever in this repo) to obtain robust mixed precision training.
PR Checklist