Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot resolve operator 'LSTM' with webgl backend #23083

Open
mrdrprofuroboros opened this issue Dec 11, 2024 · 2 comments
Open

Cannot resolve operator 'LSTM' with webgl backend #23083

mrdrprofuroboros opened this issue Dec 11, 2024 · 2 comments
Labels
ep:WebGPU ort-web webgpu provider ep:WebNN WebNN execution provider platform:mobile issues related to ONNX Runtime mobile; typically submitted using template platform:web issues related to ONNX Runtime web; typically submitted using template

Comments

@mrdrprofuroboros
Copy link

mrdrprofuroboros commented Dec 11, 2024

Describe the issue

I'm exporting the following model from pytorch to onnx

class MyLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.batch_size = None
        self.conv_channels = 32
        self.lstm_layers = 1
        self.lstm_hidden = 64

        self.conv11 = nn.Conv1d(n_barks, self.conv_channels, kernel_size=3, padding=1)
        self.conv21 = nn.Conv2d(1, 1, kernel_size=5, padding=2)
        self.conv12 = nn.Conv1d(self.conv_channels, n_barks, kernel_size=3, padding=1)
        self.lstm = nn.LSTM(
            input_size=n_samples * n_barks,          
            hidden_size=self.lstm_hidden,
            num_layers=self.lstm_layers,
            batch_first=True,
            bidirectional=False     
        )
        self.linear = nn.Linear(self.lstm_hidden, n_visemes)

...

It works fine with WASM backend in my web frontend, but when I try to run it with webgl I get
Image

  • I tried different opsets: 7, 9, 10, 11, 14, 17 -- no luck

I also tried to run it with webgpu / webnn, but even though I have a fresh chromium and I've enabled unsafe webgpu flag, it says:
Image

on both mac and android

To reproduce

import * as ort from 'onnxruntime-web/webgl';

this.session = await ort.InferenceSession.create('./model.onnx', {executionProviders: ['webgl']});

Urgency

we can use WASM for now, but would like to know what type of error is it - our misconfiguration or LSTM not being supported

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

Execution Provider

'webgl' (WebGL)

@mrdrprofuroboros mrdrprofuroboros added the platform:web issues related to ONNX Runtime web; typically submitted using template label Dec 11, 2024
@github-actions github-actions bot added ep:WebGPU ort-web webgpu provider ep:WebNN WebNN execution provider platform:mobile issues related to ONNX Runtime mobile; typically submitted using template labels Dec 11, 2024
@skottmckay
Copy link
Contributor

Lists of supported operators for the web EPs: https://github.com/microsoft/onnxruntime/tree/main/js/web/docs

LSTM is not implemented for webgl.

It is for webnn, but has a number of limitations in what the implementation supports. See https://github.com/microsoft/onnxruntime/blob/main/js/web/docs/webnn-operators.md

@mrdrprofuroboros
Copy link
Author

I'm trying to run it with webnn (had to check the enable webnn flag in chrome so it could find the backend, not the experimental webgl flag)

LSTM itself seems to be supported, but there's something wrong with the LSTM input:
Image

And so apparently I get downgraded to CPU

2024-12-12 12:16:06.272699 [W:onnxruntime:, webnn_execution_provider.cc:207 GetCapability] WebNNExecutionProvider::GetCapability, number of partitions supported by WebNN: 2 number of nodes in the graph: 27 number of nodes supported by WebNN: 26

2024-12-12 12:16:06.486599 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.

I've already simplified my LSTM usage in the exported model to unidirectional unbatched input, here:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_channels = 32
        self.lstm_layers = 1
        self.lstm_hidden = 64

        self.conv11 = nn.Conv1d(n_barks, self.conv_channels, kernel_size=3, padding=1)
        self.conv21 = nn.Conv2d(1, 1, kernel_size=5, padding=2)
        self.conv12 = nn.Conv1d(self.conv_channels, n_barks, kernel_size=3, padding=1)
        self.lstm = nn.LSTM(
            input_size=n_samples * n_barks,
            hidden_size=self.lstm_hidden,
            num_layers=self.lstm_layers,
            batch_first=True,
            bidirectional=False
        )
        self.linear = nn.Linear(self.lstm_hidden, n_visemes)

    def forward(self, x, h, c):
        """
        x: [n_samples, n_barks]
        h: [lstm_hidden]
        c: [lstm_hidden]
        """
        n_samples, n_barks = x.size()

        identity = x.transpose(0, 1).unsqueeze(0)
        x = F.relu(self.conv11(identity))
        x = F.relu(self.conv21(x.unsqueeze(1)).squeeze(1))
        x = F.relu(self.conv12(x) + identity)
        x = x.view(1, n_samples * n_barks)

        x, (h, c) = self.lstm(x, (h.unsqueeze(0), c.unsqueeze(0)))
        x = self.linear(x)
        return x, h.squeeze(0), c.squeeze(0)

What am I missing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider ep:WebNN WebNN execution provider platform:mobile issues related to ONNX Runtime mobile; typically submitted using template platform:web issues related to ONNX Runtime web; typically submitted using template
Projects
None yet
Development

No branches or pull requests

2 participants