-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #62 #70
base: master
Are you sure you want to change the base?
Fix #62 #70
Conversation
…`s of offsets (also simplifying `_aux_children`); fix broken test for issue FluxML#62
src/destructure.jl
Outdated
y = _trainmap(f, ch, _trainable(x), au) | ||
y isa Tuple{} && return NoT | ||
p = ProjectTo(x) | ||
if p isa ProjectTo # e.g. Array, NamedTuple | ||
p(y) | ||
else # p === identity for unknown structs | ||
y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self, this I need to think about. Some of this complication was working around things that are now fixed in CRC.jl, if I remember right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, admittedly this line took some trial and error and is a little bit above my pay-grade. I managed to convince myself, but perhaps there's something cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think I finally understand what's going on. Sorry it took a while.
re
constructs another Skip containing the gradient, and backing
turns that into a NamedTuple with the same field names, which is what Tangent wants.
The only way I can see this failing is this: If the primal type's constructor is fussy about what types it can accept, then it may not be happy to accept something which is valid as its gradient. E.g. if there is only Skip(::AbstractLayer)
, and re
tries to make one with a Tangent
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries! Yes, I struggled with that edge case too. Unfortunately I think it's quite tricky to work around. For example, suppose you have a user-defined functor(m::MyModel) = (m.w,), w -> ...
. Then:
- In general there's no way to reconstruct
MyModel
(or even aNamedTuple
of fields/values) withoutre
, as you do not know the corresponding field name given only(m.w,)
, but - As you say, if the primal constructor isn't sufficiently generic then it won't be able to store
Tangent
/Nothing
/etc. values in it's fields and will error beforebacking
can unpack it again
Avoiding re
would be ideal, but I think that would require functor
to always return NamedTuple
s on custom structs. I noticed that this is the default in @functor
, though, so maybe it's not such a painful requirement? In the mean time I can at least add a branch that would avoid re
for structs that are functor
ed to NamedTuple
s.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact there's another problem I didn't spot before, what a mess:
julia> ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]); # from tests: a,c are functor-ed, and only a is trainable
julia> v2, re2 = destructure(ac)
([1.0, 2.0], Restructure(TwoThirds, ..., 2))
julia> gradient(ac) do x # with Tangent{typeof(x), typeof(y)}(y)
w2, _ = destructure(x)
w2[2]^2
end
((a = [0.0, 4.0], b = nothing, c = [4.0, 5.0]),)
# Same, with z = backing(re(y)) :
julia> gradient(ac) do x
w2, _ = destructure(x)
w2[2]^2
end
┌ Info: last case
│ x = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
│ y = (a = [0.0, 4.0], c = [4.0, 5.0])
└ z = NamedTuple{(:a, :b, :c), Tuple{Any, Any, Any}}(([0.0, 4.0], [3.0], [4.0, 5.0]))
((a = [0.0, 4.0], b = [3.0], c = [4.0, 5.0]),)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yikes. That's a good example, hits all the pain points at once. If I'm understanding correctly, the gradient should be ((a = [0.0, 4.0], b = nothing, c = nothing),)
, right?
I think the problem is the _trainmap
above; it populates the nothing
values from _trainable
(non-trainable fields) with the primal values, when they should be NoT
. That's how the b
and/or c
values get back in there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think _trainmap
needs to do something isnothing(t) ? NoT : f(t, a)
here. That's where c = [4.0, 5.0]
is coming from.
But b = [3.0]
is coming from this PR's trick of calling the reconstructor made by @functor
:
julia> ch, re = Functors.functor(ac)
((a = [1.0, 2.0], c = [4.0, 5.0]), var"#1#2"{TwoThirds}(TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])))
julia> re((a = [10, 20], c = nothing))
TwoThirds([10, 20], [3.0], nothing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. So on top of the modified _trainmap
to fix c
, one would still have to filter backing(re(y))
to replace repopulated primal values which aren't functor
-ed with NoT
in order to fix b
.
EDIT: But, based on the output of Tangent{typeof(x), typeof(y)}(y)
, maybe the modified _trainmap
alone would be enough and backing(re(y))
isn't needed after all, as Tangent
will assign NoT
to omitted fields in y
automatically.
EDIT 2: Never mind, that would still fail for children which aren't fields, like Skip
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright pushed something that works for both Skip
and your TwoThirds
example (modified _trainmap
+ filtering backing(re(y))
). But since it uses re
it would still fail for fussy constructors.
…h are not `trainable`; filter primal values from `backing(re(y))`
This adds a couple small changes on top of this draft PR in order to fix #62:
Offset
to fix the issue mentioned in Attempt to fix #62 #63 for array of arrays. For example, the offset structure forx = [[1.0, 2.0]]
is now something likeo = [Offset(4)]
which is not leaflike, compared too = [4]
previously. This also opens the door to storing more information in this wrapper struct (original array size? eltype?), but that doesn't seem necessary at this timey = backing(re(y))
allows forfunctor(x)
to return children which aren't its own fields:y
is first restructured to match the structure ofx
, and then theNamedTuple
backing forre(y)
is extracted and passed toTangent
. It has the added benefit of adding some symmetry with_trainable_biwalk
which naturally restructures the output of_trainmap
, whereas_Tangent_biwalk
previously did notCloses #63 (replaces).