junrushao1994 commented on a change in pull request #10066: URL: https://github.com/apache/tvm/pull/10066#discussion_r793144276
########## File path: tests/python/unittest/test_tir_transform_inject_software_pipeline.py ########## @@ -0,0 +1,824 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm import tir, te, TVMError +from tvm.script import tir as T + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_error(func): + mod = tvm.IRModule.from_expr(func) + with pytest.raises(ValueError): + tvm.tir.transform.InjectSoftwarePipeline()(mod) + + [email protected]_func +def trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]} + ): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + [email protected]_func +def transformed_trivial_pipeline( + A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"] +) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0]) + T.writes(C[tx, 0]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0]) + T.writes(B[0, tx, 0]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads() + T.writes() + T.evaluate(0) + with T.block(): + T.reads(B[0, tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[0, tx, 0] + T.float32(1) + + [email protected]_func +def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}, + ): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + [email protected]_func +def transformed_simple_compute( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16]]) + T.writes([C[tx, 0:16]]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0]]) + T.writes([B[0, tx, 0]]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16], B[0:2, tx, 0]]) + T.writes([B[0:2, tx, 0], C[tx, 0:15]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1]]) + T.writes([B[(i + 1) % 2, tx, 0]]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.reads([B[i % 2, tx, 0]]) + T.writes([C[tx, i]]) + C[tx, i] = B[i % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 0]]) + T.writes([C[tx, 15]]) + C[tx, 15] = B[1, tx, 0] + T.float32(1) + + [email protected]_func +def nested_pipeline_simple( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1, 1, 1], + "software_pipeline_order": [0, 1, 2, 3], + }, + ): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + [email protected]_func +def transformed_nested_pipeline_simple( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:15, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:15, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 0]]) + T.writes([B[0, tx, i, 0]]) + B[0, tx, i, 0] = A_shared[i % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 0:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 0]]) + T.writes([B[0, tx, 15, 0]]) + B[0, tx, 15, 0] = A_shared[1, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + [email protected]_func +def nested_pipeline_prefetch_inner( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 1, 1], + "software_pipeline_order": [0, 2, 1, 3], + }, + ): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + [email protected]_func +def transformed_nested_pipeline_prefetch_inner( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[0, tx, 0, 0]]) + T.writes([A_shared[0, tx, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_shared[0, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + [email protected]_func +def nested_pipeline_interleaving( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 0, 1, 1], + "software_pipeline_order": [0, 2, 3, 1, 4], + }, + ): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + [email protected]_func +def transformed_nested_pipeline_interleaving( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + [email protected]_func +def nested_pipeline_double_buffer( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 0, 1, 1], + "software_pipeline_order": [0, 2, 3, 1, 4], + }, + ): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.block_attr({"double_buffer_scope": 0}) + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + [email protected]_func +def transformed_nested_pipeline_double_buffer( + A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[0, 0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[0:2, tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0:2, 0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32( + 2 + ) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[(i + 1) % 2, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] Review comment: Sounds great! Let’s work together to make our doc really slick :-) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
