-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable pointer-generator T5 models in BeamSearch #23134
base: main
Are you sure you want to change the base?
Conversation
onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Dismissed
Show dismissed
Hide dismissed
|
||
ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3, | ||
"kFirstPastInputIndex currently only supports 2 or 3"); | ||
ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3 && first_past_input_index_ != 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From SetPastInputIndex implementation, this assertion of first_past_input_index_ seems always True so we can remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will do
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline |
/azp run Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-linux-gpu-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline,Android CI Pipeline |
/azp run iOS CI Pipeline,ONNX Runtime React Native CI Pipeline,CoreML CI Pipeline,Linux DNNL CI Pipeline,Linux MIGraphX CI Pipeline,Linux ROCm CI Pipeline |
Azure Pipelines successfully started running 6 pipeline(s). |
Azure Pipelines successfully started running 9 pipeline(s). |
Azure Pipelines successfully started running 10 pipeline(s). |
@tianleiwu I don't think I got which is the problem on the iOS failure. All the involved tests seems passing there. Do you have some insights? |
subgraph_inputs[2]->Name()); | ||
const int enc_attn_mask_index = 1 + has_encoder_input_ids_; | ||
const int enc_hidden_state_index = enc_attn_mask_index + 1; | ||
if (has_encoder_input_ids_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this check since the definition has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids"
so this is not necessary.
@@ -49,11 +49,12 @@ namespace transformers { | |||
|
|||
Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs, | |||
const std::vector<const NodeArg*>& subgraph_outputs) { | |||
bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false; | |||
SetPastInputIndex(has_hidden_state); | |||
bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend to add a comment about example inputs:
input_ids, encoder_input_ids (optional), encoder_attention_mask, encoder_hidden_states (optional),
past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0,
...
@@ -238,7 +268,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( | |||
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output | |||
// of encoder. | |||
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder. | |||
for (size_t j = static_cast<size_t>(4) - first_past_input_index_; j < encoder_fetches.size(); j++) { | |||
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decoder inputs input_ids, encoder_input_ids (optional), encoder_attention_mask, encoder_hidden_states (optional), past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0, ...
.
This loop is used to add feeds for encoder_hidden_states (optional), past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0, ...
from encoder output.
The encoder output is like logits, encoder_hidden_states (optional), past_self_value_0, past_cross_key_0, past_cross_value_0, ...
so j shall start from 1 (the second output). Here we assume that, if encoder hidden state is not used in decoder, we shall not output it in encoder for best performance.
I understand that we might also need change some code in encoder output validation to make sure all outputs are used by decoder. That means, if encoder_hidden_states is not used by decoder, it shall not exist in encoder output.
Another possible implementation is to use name to match then construct a mapping from encoder input/output index to decoder input index. That could be more flexible.
Suggest to update the comment before this loop.
Description
Introduces a new optional input (encoder_ibnput_ids) in the decoder graph of the T5 implementation for BeamSearch. This allows usage of pointer generator networks in decoder graph.
Motivation and Context