-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial values for the hidden/cell state for LSTM and GRU models in Pytorch #1120
base: main
Are you sure you want to change the base?
Conversation
const typename CONFIG_T::weight_t RWO[CONFIG_T::n_out * CONFIG_T::n_out], | ||
const typename CONFIG_T::bias_t BI[CONFIG_T::n_out], const typename CONFIG_T::bias_t BF[CONFIG_T::n_out], | ||
const typename CONFIG_T::bias_t BC[CONFIG_T::n_out], const typename CONFIG_T::bias_t BO[CONFIG_T::n_out]) { | ||
res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't hidden_state
be of type data2_T
and cell_state
of type data3_T
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, fixed.
We need to fix the datatypes for oneAPI before we merge |
…into initialRecurr
update types for lstm init state oneAPI
fix pytorch_order for GRU, recurrent bias for simpleNN, oneAPI
Fix pytorch simple RNN for oneAPI; add initial state version for Quartus and oneAPI
pre-commit.ci autofix |
fix simple-rnn config for Keras; make test names unique
@@ -47,4 +47,6 @@ def parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader): | |||
if layer['return_state']: | |||
raise Exception('"return_state" of {} layer is not yet supported.') | |||
|
|||
layer['pass_initial_states'] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite, right? We find out from number of inputs. Also, would be good to add this as an expected attribute to the RNN layers in IR
input_shapes = [output_shapes[str(node.args[0])]] | ||
input_shapes = [] | ||
input_names = [] | ||
for i in node.args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have more descriptive names of variables i
and y
// SimpleRNN with pytorch biases | ||
//---------------------- | ||
|
||
struct simpleRNN_pytorch_config { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor annoyance is that this breaks naming convention, I was expecting simple_rnn_...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this is never used (and similarly also not for quartus). It seems like the simpleRNN_config
is used for all cases. Of course that one also violates the naming convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion is that we remove simpleRNN_pytorch_config
and rename simpleRNN_config
, though we could keep the two if we really wanted to. That would complicate the templates a touch for no real gain. (I truthfully was never much of a fan of the base config templates that get overriden.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JanFSchulte#9 does the simple approach to rename simpleRNN_config
as simple_rnn_config
and remove simpleRNN_pytorch_config
altogether, if you agree with that approach.
@@ -235,6 +234,41 @@ void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CO | |||
} | |||
} | |||
|
|||
template <class data_T, class data2_T, class data3_T, class res_T, typename CONFIG_T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any more descriptive names than data2_T
and data3_T
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to put more descriptive names in the oneAPI version. Maybe we can go more in that direction for the other backends (and potentially standardize).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I left some comments on cosmetics
This PR addresses #1074 and implements the passing of initial values for the hidden and cell states in GRU and LSTM models, which is supported in pytorch. This first version implements this only for the pytorch parser, but it should be able to be extended it for keras and other parsers.
I have tested this for Vivado, Vitis, and Quartus. Thanks to Jovan, this is also implemented for oneAPI. Nothing is done for Catapult.
Note that this currently only works in
io_parallel
. Inio_stream
I was having some conceptual issues and was unsure if I should treat these initial states are streamed inputs or not. Might be good enough for now and I can revisitio_stream
if there are any suggestions how to tackle that.Type of change
Tests
Tested in both standalone scripts and also the pytests to ensure that model parsing and evaluation work with and without passing these optional tensors.
Checklist
pre-commit
on the files I edited or added.