Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cell output is not clearly distinguishable from the state #2548

Open
CarloLucibello opened this issue Dec 11, 2024 · 3 comments
Open

cell output is not clearly distinguishable from the state #2548

CarloLucibello opened this issue Dec 11, 2024 · 3 comments
Labels

Comments

@CarloLucibello
Copy link
Member

As of Flux v0.15, after the redesign in #2500, recurrent cells (RNNCell, GRUCell, LSTMCell) behave like this

$$state_t = cell(x_t, state_{t-1})$$

and the output at each timestep is either y_t = state_t = h_t (RNNCell and GRUCell) or y_t = state_t[1] = h_t (LSTMCell), that is

h = rnncell(x_t, h)
h = grucell(x_t, h)
h, c = lstmcell(x_t, (h, c))

In this, we follow the pytorch style, but maybe we should have followed flax and lux instead and do

y_t, state = cell(x_t, state) 

The Problem

The problem is that the current return from the cell's forward doesn't clearly distinguish the output and the state. So if we want to expose a wrapper layer around a cell, let's call it Recurrent, that processes an entire input sequence at once and return the stacked outputs (so what RNN, GRU and LSTM do, but now for arbitrary cells), it is not clear which interface should we ask the Cell types to comply to in order to be Recurrent-compatible.

What would be the rule for a Flux-compliant cell definition? Something like

  • It should return the state. The returned state will be passed as input in the next timestep.
  • If the state is a tuple, the first element of the tuple is considered as output.
  • Otherwise, the entire state is considered as output.

This seems a little odd though.

Proposals

Proposal 1

We do nothing, since we have already churned the recurrent layers recently and we don't want to do it again. Pytorch survived fine with an interface similar to the one we currently have. It doesn't have a de facto interface that allow people to define a cell and immediately extend it with recursion, as flax does.
We can have this interface but with the slightly odd rules described above.

Proposal 2

We do a Flux v0.16 as soon as possible changing the return of cells to

y_t, state = cell(x_t, state) 

so that we have a clear cut interface that people can adopt to define custom cells.
Not happy of having another breaking release so soon, but since v0.15 has been around only a few days probably most people would skip it entirely and move directly to v0.16.

What Other Frameworks Do

As a reference, below I report the interface exposed by the different frameworks for cells and recurrent layers.

Flux v0.14

# cell types note exported and not fully documented
for xt in seq
     state, yt = cell(state, xt)
end

rnn = LSTM() # = Recur(LSTMCell())
# doesn't handle entire sequences
for xt in seq
     yt = rnn(xt) # state updated internally
end

Flux v0.15

for xt in seq
     state = cell(xt, state)
     # `yt = state` for RNNCell and GRUCell, `yt = first(state)` for LSTMCell
end

input = batch_time(seq) # size: (in_dim, len, batch_size)

rnn = LSTM() 
out = rnn(input, state0)
size(out) == (out_dim, len, batch_size)

rnn = LSTM() 
out = rnn(input, state0)
size(out) == (out_dim, len, batch_size)

rnn = LSTM(return_state=true) # option yet to be implemented 
out, stateT = rnn(input, state0) # also return the last state

Lux

# simplified lux interface where we don't have to juggle params
for xt in seq
     yt, state = cell(xt, state)   # omitting ps and st
end

input = batch_time(seq) # size: (in_dim, len, batch_size)

rnn = Recurrence(cell)
yT, stateT = rnn(input)
size(yT) == (out_dim, batch_size)

rnn = Recurrence(cell, return_sequence=true)
out, stateT = rnn(input)
size(out) == (out_dim, len, batch_size)

Flax

for xt in seq
     state, yt = cell(state, xt)
end

input = batch_time(seq)
input.shape == (batch_size, len, out_dim)

rnn = LSTM(cell)
out = rnn(state0, input)
out.shape == (batch_size, len, in_dim)

rnn = RNN(cell, return_carry=True)
stateT, out = rnn(state0, input)

Pytorch

for xt in seq
     state = cell(xt, state)
     # `yt = state` for RNNCell and GRUCell, `yt = state[0]` for LSTMCell
end

input = batch_time(seq)

rnn = LSTM() # or GRU() or RNN()
out, stateT = rnn(input, state0)
out.shape == (batch_size, len, out_dim)
@CarloLucibello
Copy link
Member Author

@MartinuzziFrancesco

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Dec 13, 2024

I'm in favor of proposal 2 (implementation in #2551): provide the interface that makes sense now, tag v0.16, and hopefully be done with breaking changes in the recurrent layers forever.

Having a nice Recurrent layer able to wrap any cell is very convenient (as I just realized having to port many recurrent layers in GraphNeuralNetworks.jl), and the current cell interface makes the implementation odd.

@MartinuzziFrancesco
Copy link
Contributor

I do think proposal 2 would make things more straightforward for downstream implementations. The Pytorch approach makes sense for them since they provide high level RNN, LSTM etc with more options but less flexibility.

Yeha I also realize that having the Recurrent top level layer makes life easier in a lot of circumstances and it should probably be the way to go about recurrent layers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants