Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial values for the hidden/cell state for LSTM and GRU models in Pytorch #1120

Open
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

JanFSchulte
Copy link
Contributor

@JanFSchulte JanFSchulte commented Nov 11, 2024

This PR addresses #1074 and implements the passing of initial values for the hidden and cell states in GRU and LSTM models, which is supported in pytorch. This first version implements this only for the pytorch parser, but it should be able to be extended it for keras and other parsers.

I have tested this for Vivado, Vitis, and Quartus. Thanks to Jovan, this is also implemented for oneAPI. Nothing is done for Catapult.

Note that this currently only works in io_parallel. In io_stream I was having some conceptual issues and was unsure if I should treat these initial states are streamed inputs or not. Might be good enough for now and I can revisit io_stream if there are any suggestions how to tackle that.

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change which adds functionality)

Tests

Tested in both standalone scripts and also the pytests to ensure that model parsing and evaluation work with and without passing these optional tensors.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Nov 11, 2024
@JanFSchulte JanFSchulte added this to the v1.1.0 milestone Jan 8, 2025
const typename CONFIG_T::weight_t RWO[CONFIG_T::n_out * CONFIG_T::n_out],
const typename CONFIG_T::bias_t BI[CONFIG_T::n_out], const typename CONFIG_T::bias_t BF[CONFIG_T::n_out],
const typename CONFIG_T::bias_t BC[CONFIG_T::n_out], const typename CONFIG_T::bias_t BO[CONFIG_T::n_out]) {
res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't hidden_state be of type data2_T and cell_state of type data3_T?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, fixed.

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Feb 7, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Feb 7, 2025
@jmitrevs
Copy link
Contributor

jmitrevs commented Feb 7, 2025

We need to fix the datatypes for oneAPI before we merge

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Feb 7, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Feb 10, 2025
@JanFSchulte JanFSchulte removed the please test Trigger testing by creating local PR branch label Feb 10, 2025
@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Feb 10, 2025
@JanFSchulte JanFSchulte removed the please test Trigger testing by creating local PR branch label Feb 12, 2025
@JanFSchulte
Copy link
Contributor Author

pre-commit.ci autofix

@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Feb 12, 2025
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Feb 12, 2025
@@ -47,4 +47,6 @@ def parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader):
if layer['return_state']:
raise Exception('"return_state" of {} layer is not yet supported.')

layer['pass_initial_states'] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, right? We find out from number of inputs. Also, would be good to add this as an expected attribute to the RNN layers in IR

input_shapes = [output_shapes[str(node.args[0])]]
input_shapes = []
input_names = []
for i in node.args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have more descriptive names of variables i and y

// SimpleRNN with pytorch biases
//----------------------

struct simpleRNN_pytorch_config {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor annoyance is that this breaks naming convention, I was expecting simple_rnn_...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is never used (and similarly also not for quartus). It seems like the simpleRNN_config is used for all cases. Of course that one also violates the naming convention.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion is that we remove simpleRNN_pytorch_config and rename simpleRNN_config, though we could keep the two if we really wanted to. That would complicate the templates a touch for no real gain. (I truthfully was never much of a fan of the base config templates that get overriden.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JanFSchulte#9 does the simple approach to rename simpleRNN_config as simple_rnn_config and remove simpleRNN_pytorch_config altogether, if you agree with that approach.

@@ -235,6 +234,41 @@ void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CO
}
}

template <class data_T, class data2_T, class data3_T, class res_T, typename CONFIG_T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any more descriptive names than data2_T and data3_T?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to put more descriptive names in the oneAPI version. Maybe we can go more in that direction for the other backends (and potentially standardize).

Copy link
Contributor

@vloncar vloncar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I left some comments on cosmetics

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants