Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Tensor Dimension Handling in predict_masks Method #581

Open
sushmanthreddy opened this issue Sep 29, 2023 · 0 comments · May be fixed by #580
Open

Improved Tensor Dimension Handling in predict_masks Method #581

sushmanthreddy opened this issue Sep 29, 2023 · 0 comments · May be fixed by #580

Comments

@sushmanthreddy
Copy link

Issue:
In the predict_masks method of the MaskDecoder class, there's an enhancement regarding tensor dimension handling. Here's a detailed breakdown:

  1. Conditional Check:

    • A new check if image_embeddings.shape[0] != tokens.shape[0]: has been added to ascertain tensor dimension consistency before applying torch.repeat_interleave.
  2. Usage of torch.repeat_interleave:

    • Ensures image_embeddings tensor's batch size aligns with tokens by expanding it along the batch dimension.
  3. Ensuring Consistency:

    • This check ensures that torch.repeat_interleave is applied only when necessary, ensuring consistent tensor handling within the predict_masks method, as opposed to the original implementation where torch.repeat_interleave is applied directly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant