Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap indices to be within max range for as_strided + index_select #3902

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
};
} // namespace

static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
Value input, int64_t dim) {
// performs the operation : index = index % maxIndex to wrap index around
// maxIndex
Value maxIndexValue = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim));
Value isBeyondMaxIndices = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
Value wrappedIndices = b.create<arith::RemSIOp>(loc, index, maxIndexValue);
return b.create<arith::SelectOp>(loc, isBeyondMaxIndices, wrappedIndices,
index);
}

namespace {
// Let's say we have an input tensor: initialized with some random values of
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
Expand Down Expand Up @@ -478,16 +490,17 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {

auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr},
rewriter.getContext());

Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), ValueRange{indices}, initTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
Value index =
wrapIndicesAroundMax(b, loc, args[0], input, dimInt);
index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), index);
SmallVector<Value> indexTarget;
for (unsigned i = 0; i < inputRank; i++)
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));
Expand Down
41 changes: 41 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3988,6 +3988,41 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}

Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
ConversionPatternRewriter &rewriter) {
// performs the operation : index = index % maxIndex to wrap index around
// maxIndex

auto maxIndexValue =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
auto maxIndexValueMinusOne =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();

auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto boolType = indexType.clone(rewriter.getIntegerType(1));

auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
auto wrappedBeyondMaxIndicesQuotient =
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
index, maxIndexValue)
.getResult();
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), indexType,
wrappedBeyondMaxIndicesQuotient,
maxIndexValue, /*shift=*/0)
.getResult();
auto wrappedBeyondMaxIndices =
tosa::CreateOpAndInfer<tosa::SubOp>(
rewriter, op->getLoc(), indexType, index,
wrappedBeyondMaxIndicesQuotientTimesIndices)
.getResult();

return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
indexType, isBeyondMaxIndices,
wrappedBeyondMaxIndices, index);
}

template <>
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
AtenIndexSelectOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4031,6 +4066,10 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}

int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(),
1, std::multiplies<int64_t>());
index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter);

// Get positive dim
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
Expand Down Expand Up @@ -7237,10 +7276,12 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
// coord_i_n * stride[n]
int32_t index = offset;
int64_t coordFinder = i;

for (int64_t dim = 0; dim < outputRank; dim++) {
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
index += indexCoord * stride[outputRank - dim - 1];
coordFinder /= outputSize[outputRank - dim - 1];
index = (index % selfNumElems);
}
targetIndicesVec.push_back(index);
}
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"AsStridedWithOffsetModule_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
Expand Down Expand Up @@ -905,6 +906,7 @@
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"AsStridedWithOffsetModule_basic",
"Unfold_Module_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
Expand Down Expand Up @@ -1765,6 +1767,7 @@
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AsStridedWithOffsetModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down
29 changes: 29 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,3 +1144,32 @@ def forward(self, x):
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5))


# ==============================================================================


class AsStridedWithOffsetModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 6, 60], torch.float32, True),
]
)
def forward(self, x):
output_size = [6, 20]
stride = [60, 1]
slice = torch.ops.aten.slice.Tensor(x, 0, 1, 2)
squeeze = torch.ops.aten.squeeze.dim(slice, 0)
return torch.ops.aten.as_strided(
squeeze, size=output_size, stride=stride, storage_offset=360
)


@register_test_case(module_factory=lambda: AsStridedWithOffsetModule())
def AsStridedWithOffsetModule_basic(module, tu: TestUtils):
module.forward(torch.rand(2, 6, 60))
2 changes: 2 additions & 0 deletions python/torch_mlir/extras/fx_decomp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
torch.ops.aten.nan_to_num.default,
torch.ops.aten.unbind,
torch.ops.aten.diag,
torch.ops.aten.lstm.input,
torch.ops.aten.gru.input,
]
if hasattr(torch.ops.aten, "_scaled_dot_product_flash_attention_for_cpu"):
DEFAULT_DECOMPOSITIONS.append(
Expand Down
75 changes: 59 additions & 16 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1893,22 +1893,29 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<120> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<119> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_8:.*]] = tosa.greater %[[VAL_5]], %[[VAL_7]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi1>
// CHECK: %[[VAL_9:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_9]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_5]], %[[VAL_10]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_8]], %[[VAL_11]], %[[VAL_5]] : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
// CHECK: %[[VAL_14:.*]] = tosa.tile %[[VAL_13]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_18:.*]] = tosa.concat %[[VAL_16]], %[[VAL_17]], %[[VAL_15]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_20]], %[[VAL_21]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_23:.*]] = tosa.reduce_sum %[[VAL_22]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_23]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_19]], %[[VAL_24]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_25]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
// CHECK: %[[VAL_27:.*]] = torch_c.from_builtin_tensor %[[VAL_26]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
// CHECK: return %[[VAL_27]] : !torch.vtensor<[4,5,2],f32>
// CHECK: }
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -2306,6 +2313,42 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor
return %2 : !torch.vtensor<[3,3],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.as_strided$offset(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 30
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 25>} : (tensor<5x5xf32>) -> tensor<25xf32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[5, 6, 7, 7, 8, 9, 9, 10, 11]> : tensor<9xi32>}> : () -> tensor<9xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 9, 1>} : (tensor<9xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 25, 1>} : (tensor<25xf32>) -> tensor<1x25x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 9, 1>} : (tensor<9x1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 9>} : (tensor<9x1xi32>) -> tensor<1x9xi32>
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 9>} : (tensor<1x9x1xf32>) -> tensor<9xf32>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 3, 3>} : (tensor<9xf32>) -> tensor<3x3xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32>
func.func @torch.aten.as_strided$offset(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
%int30 = torch.constant.int 30
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.as_strided %arg0, %0, %1, %int30 : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[3,3],f32>
return %2 : !torch.vtensor<[3,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(
Expand Down
Loading