You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As of Flux v0.15, after the redesign in #2500, recurrent cells (RNNCell, GRUCell, LSTMCell) behave like this
$$state_t = cell(x_t, state_{t-1})$$
and the output at each timestep is either y_t = state_t = h_t (RNNCell and GRUCell) or y_t = state_t[1] = h_t (LSTMCell), that is
h =rnncell(x_t, h)
h =grucell(x_t, h)
h, c =lstmcell(x_t, (h, c))
In this, we follow the pytorch style, but maybe we should have followed flax and lux instead and do
y_t, state =cell(x_t, state)
The Problem
The problem is that the current return from the cell's forward doesn't clearly distinguish the output and the state. So if we want to expose a wrapper layer around a cell, let's call it Recurrent, that processes an entire input sequence at once and return the stacked outputs (so what RNN, GRU and LSTM do, but now for arbitrary cells), it is not clear which interface should we ask the Cell types to comply to in order to be Recurrent-compatible.
What would be the rule for a Flux-compliant cell definition? Something like
It should return the state. The returned state will be passed as input in the next timestep.
If the state is a tuple, the first element of the tuple is considered as output.
Otherwise, the entire state is considered as output.
This seems a little odd though.
Proposals
Proposal 1
We do nothing, since we have already churned the recurrent layers recently and we don't want to do it again. Pytorch survived fine with an interface similar to the one we currently have. It doesn't have a de facto interface that allow people to define a cell and immediately extend it with recursion, as flax does.
We can have this interface but with the slightly odd rules described above.
Proposal 2
We do a Flux v0.16 as soon as possible changing the return of cells to
y_t, state =cell(x_t, state)
so that we have a clear cut interface that people can adopt to define custom cells.
Not happy of having another breaking release so soon, but since v0.15 has been around only a few days probably most people would skip it entirely and move directly to v0.16.
What Other Frameworks Do
As a reference, below I report the interface exposed by the different frameworks for cells and recurrent layers.
Flux v0.14
# cell types note exported and not fully documented
for xt in seq
state, yt = cell(state, xt)
end
rnn = LSTM() # = Recur(LSTMCell())
# doesn't handle entire sequences
for xt in seq
yt = rnn(xt) # state updated internally
end
Flux v0.15
for xt in seq
state = cell(xt, state)
# `yt = state` for RNNCell and GRUCell, `yt = first(state)` for LSTMCell
end
input = batch_time(seq) # size: (in_dim, len, batch_size)
rnn = LSTM()
out = rnn(input, state0)
size(out) == (out_dim, len, batch_size)
rnn = LSTM()
out = rnn(input, state0)
size(out) == (out_dim, len, batch_size)
rnn = LSTM(return_state=true) # option yet to be implemented
out, stateT = rnn(input, state0) # also return the last state
Lux
# simplified lux interface where we don't have to juggle params
for xt in seq
yt, state = cell(xt, state) # omitting ps and st
end
input = batch_time(seq) # size: (in_dim, len, batch_size)
rnn = Recurrence(cell)
yT, stateT = rnn(input)
size(yT) == (out_dim, batch_size)
rnn = Recurrence(cell, return_sequence=true)
out, stateT = rnn(input)
size(out) == (out_dim, len, batch_size)
Flax
for xt in seq
state, yt = cell(state, xt)
end
input = batch_time(seq)
input.shape == (batch_size, len, out_dim)
rnn = LSTM(cell)
out = rnn(state0, input)
out.shape == (batch_size, len, in_dim)
rnn = RNN(cell, return_carry=True)
stateT, out = rnn(state0, input)
Pytorch
for xt in seq
state = cell(xt, state)
# `yt = state` for RNNCell and GRUCell, `yt = state[0]` for LSTMCell
end
input = batch_time(seq)
rnn = LSTM() # or GRU() or RNN()
out, stateT = rnn(input, state0)
out.shape == (batch_size, len, out_dim)
The text was updated successfully, but these errors were encountered:
I'm in favor of proposal 2 (implementation in #2551): provide the interface that makes sense now, tag v0.16, and hopefully be done with breaking changes in the recurrent layers forever.
Having a nice Recurrent layer able to wrap any cell is very convenient (as I just realized having to port many recurrent layers in GraphNeuralNetworks.jl), and the current cell interface makes the implementation odd.
I do think proposal 2 would make things more straightforward for downstream implementations. The Pytorch approach makes sense for them since they provide high level RNN, LSTM etc with more options but less flexibility.
Yeha I also realize that having the Recurrent top level layer makes life easier in a lot of circumstances and it should probably be the way to go about recurrent layers
As of Flux v0.15, after the redesign in #2500, recurrent cells (RNNCell, GRUCell, LSTMCell) behave like this
and the output at each timestep is either
y_t = state_t = h_t
(RNNCell and GRUCell) ory_t = state_t[1] = h_t
(LSTMCell), that isIn this, we follow the pytorch style, but maybe we should have followed flax and lux instead and do
The Problem
The problem is that the current return from the cell's forward doesn't clearly distinguish the output and the state. So if we want to expose a wrapper layer around a cell, let's call it
Recurrent
, that processes an entire input sequence at once and return the stacked outputs (so whatRNN
,GRU
andLSTM
do, but now for arbitrary cells), it is not clear which interface should we ask the Cell types to comply to in order to beRecurrent
-compatible.What would be the rule for a Flux-compliant cell definition? Something like
This seems a little odd though.
Proposals
Proposal 1
We do nothing, since we have already churned the recurrent layers recently and we don't want to do it again. Pytorch survived fine with an interface similar to the one we currently have. It doesn't have a de facto interface that allow people to define a
cell
and immediately extend it with recursion, as flax does.We can have this interface but with the slightly odd rules described above.
Proposal 2
We do a Flux v0.16 as soon as possible changing the return of cells to
so that we have a clear cut interface that people can adopt to define custom cells.
Not happy of having another breaking release so soon, but since v0.15 has been around only a few days probably most people would skip it entirely and move directly to v0.16.
What Other Frameworks Do
As a reference, below I report the interface exposed by the different frameworks for cells and recurrent layers.
Flux v0.14
Flux v0.15
Lux
Flax
Pytorch
The text was updated successfully, but these errors were encountered: