This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8729f6b  [MetaSchedule] Update scripts for subgraph tuning (#10501)
8729f6b is described below

commit 8729f6b67a1e1d286b536162481014a3bca7633e
Author: Junru Shao <[email protected]>
AuthorDate: Sat Mar 5 23:26:40 2022 -0800

    [MetaSchedule] Update scripts for subgraph tuning (#10501)
---
 python/tvm/auto_scheduler/workload_registry.py     |  10 +-
 python/tvm/meta_schedule/runner/config.py          |   2 +-
 python/tvm/meta_schedule/runner/rpc_runner.py      |   9 +-
 .../testing/run_subgraph_auto_scheduler.py         | 137 +++++++++++++++++++++
 .../testing/run_subgraph_meta_schedule.py          | 120 ++++++++++++++++++
 src/meta_schedule/postproc/verify_gpu_code.cc      |   3 +-
 src/target/tag.cc                                  |  29 ++++-
 src/target/target_kind.cc                          |   6 +-
 .../test_meta_schedule_postproc_verify_gpu_code.py |   3 +-
 tests/python/unittest/test_target_target.py        |  18 +--
 10 files changed, 310 insertions(+), 27 deletions(-)

diff --git a/python/tvm/auto_scheduler/workload_registry.py 
b/python/tvm/auto_scheduler/workload_registry.py
index 885eb0d..17d2001 100644
--- a/python/tvm/auto_scheduler/workload_registry.py
+++ b/python/tvm/auto_scheduler/workload_registry.py
@@ -30,13 +30,14 @@ These strings are efficient for serialization/matching and 
won't be too long.
 When we need the dag, we decode the string and call the function, which will 
return the dag.
 """
 
+import json
 import logging
 import pickle
-import json
 
 import tvm._ffi
 from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
-from .utils import serialize_args, deserialize_args, get_func_name
+
+from .utils import deserialize_args, get_func_name, serialize_args
 
 logger = logging.getLogger("auto_scheduler")
 
@@ -194,7 +195,10 @@ def workload_key_to_tensors(workload_key):
     assert callable(value)
 
     args = deserialize_args(workload[1:])
-    return value(*args)
+    result = value(*args)
+    if isinstance(result, tuple):
+        result = list(result)
+    return result
 
 
 def serialize_workload_registry_entry(workload_key):
diff --git a/python/tvm/meta_schedule/runner/config.py 
b/python/tvm/meta_schedule/runner/config.py
index 712766d..585b88e 100644
--- a/python/tvm/meta_schedule/runner/config.py
+++ b/python/tvm/meta_schedule/runner/config.py
@@ -45,7 +45,7 @@ class EvaluatorConfig(NamedTuple):
 
     number: int = 3
     repeat: int = 1
-    min_repeat_ms: int = 40
+    min_repeat_ms: int = 100
     enable_cpu_cache_flush: bool = False
 
     @staticmethod
diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py 
b/python/tvm/meta_schedule/runner/rpc_runner.py
index 66dec30..5697f85 100644
--- a/python/tvm/meta_schedule/runner/rpc_runner.py
+++ b/python/tvm/meta_schedule/runner/rpc_runner.py
@@ -16,9 +16,9 @@
 # under the License.
 """RPC Runner"""
 import concurrent.futures
-from contextlib import contextmanager
 import logging
 import os.path as osp
+from contextlib import contextmanager
 from typing import Callable, List, Optional, Union
 
 from tvm.contrib.popen_pool import PopenPoolExecutor
@@ -31,15 +31,14 @@ from ..utils import (
     get_global_func_with_default_on_worker,
 )
 from .config import EvaluatorConfig, RPCConfig
-from .runner import PyRunner, RunnerFuture, PyRunnerFuture, RunnerInput, 
RunnerResult
+from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, 
RunnerResult
 from .utils import (
-    T_ARGUMENT_LIST,
     T_ARG_INFO_JSON_OBJ_LIST,
+    T_ARGUMENT_LIST,
     alloc_argument_common,
     run_evaluator_common,
 )
 
-
 logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
 
 
@@ -118,7 +117,7 @@ class RPCRunnerFuture(PyRunnerFuture):
     def result(self) -> RunnerResult:
         try:
             run_secs: List[float] = self.future.result()
-        except TimeoutError as exception:
+        except TimeoutError:
             return RunnerResult(
                 None,
                 error_msg=f"RPCRunner: Timeout, killed after 
{self.timeout_sec} seconds",
diff --git a/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py 
b/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py
new file mode 100644
index 0000000..b52f88a
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitatios
+# under the License.
+# pylint: disable=missing-docstring
+import argparse
+import os
+
+import tvm
+from tvm import auto_scheduler
+from tvm.meta_schedule.runner import RPCConfig
+from tvm.meta_schedule.testing.te_workload import CONFIGS
+
+
+def _parse_args():
+    args = argparse.ArgumentParser()
+    args.add_argument(
+        "--workload",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--target",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--num-trials",
+        type=int,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-host",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-port",
+        type=int,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-key",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--log-dir",
+        type=str,
+        required=True,
+    )
+    parsed = args.parse_args()
+    parsed.target = tvm.target.Target(parsed.target)
+    parsed.rpc_workers = RPCConfig(
+        tracker_host=parsed.rpc_host,
+        tracker_port=parsed.rpc_port,
+        tracker_key=parsed.rpc_key,
+        session_timeout_sec=30,
+    ).count_num_servers(allow_missing=True)
+    return parsed
+
+
+ARGS = _parse_args()
+
+
+def main():
+    log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json")
+    workload_func, params = CONFIGS[ARGS.workload]
+    params = params[0]  # type: ignore
+    workload_func = auto_scheduler.register_workload(workload_func)
+
+    if ARGS.target.kind.name == "llvm":
+        hardware_params = auto_scheduler.HardwareParams(
+            num_cores=int(ARGS.target.attrs["num-cores"]),
+            target=ARGS.target,
+        )
+    elif ARGS.target.kind.name == "cuda":
+        hardware_params = auto_scheduler.HardwareParams(
+            num_cores=-1,
+            vector_unit_bytes=16,
+            cache_line_bytes=64,
+            
max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
+            
max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
+            max_vthread_extent=8,
+            warp_size=32,
+        )
+    else:
+        raise NotImplementedError(f"Unsupported target {ARGS.target}")
+    task = auto_scheduler.SearchTask(
+        func=workload_func,
+        args=params,
+        target=ARGS.target,
+        hardware_params=hardware_params,
+    )
+    runner = auto_scheduler.RPCRunner(
+        key=ARGS.rpc_key,
+        host=ARGS.rpc_host,
+        port=ARGS.rpc_port,
+        n_parallel=ARGS.rpc_workers,
+        number=3,
+        repeat=1,
+        min_repeat_ms=100,
+        enable_cpu_cache_flush=False,
+    )
+
+    # Inspect the computational graph
+    print("Computational DAG:")
+    print(task.compute_dag)
+    tune_option = auto_scheduler.TuningOptions(
+        num_measure_trials=ARGS.num_trials,
+        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+        verbose=2,
+        runner=runner,
+    )
+    print("Running AutoTuning:")
+    task.tune(tune_option)
+    print("History Best:")
+    print(task.print_best(log_file))
+    sch, args = task.apply_best(log_file)
+    print("Lowered TIR:")
+    print(tvm.lower(sch, args, simple_mode=True))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py 
b/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py
new file mode 100644
index 0000000..d4166b1
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import argparse
+import logging
+from os import cpu_count
+from typing import Optional
+
+import tvm
+from tvm import meta_schedule as ms
+from tvm import tir
+from tvm.meta_schedule.testing.te_workload import create_te_workload
+
+
+def _parse_args():
+    args = argparse.ArgumentParser()
+    args.add_argument(
+        "--workload",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--target",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--num-trials",
+        type=int,
+        required=True,
+    )
+    args.add_argument(
+        "--work-dir",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-host",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-port",
+        type=int,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-key",
+        type=str,
+        required=True,
+    )
+    parsed = args.parse_args()
+    parsed.target = tvm.target.Target(parsed.target)
+    if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
+        parsed.alloc_repeat = 3
+    else:
+        parsed.alloc_repeat = 1
+    parsed.rpc_config = ms.runner.RPCConfig(
+        tracker_host=parsed.rpc_host,
+        tracker_port=parsed.rpc_port,
+        tracker_key=parsed.rpc_key,
+        session_timeout_sec=30,
+    )
+    parsed.rpc_workers = 
parsed.rpc_config.count_num_servers(allow_missing=False)
+    return parsed
+
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+ARGS = _parse_args()
+
+
+def main():
+    runner = ms.runner.RPCRunner(
+        rpc_config=ARGS.rpc_config,
+        evaluator_config=ms.runner.EvaluatorConfig(
+            number=3,
+            repeat=1,
+            min_repeat_ms=100,
+            enable_cpu_cache_flush=False,
+        ),
+        alloc_repeat=ARGS.alloc_repeat,
+        max_workers=ARGS.rpc_workers,
+    )
+    sch: Optional[tir.Schedule] = ms.tune_tir(
+        mod=create_te_workload(ARGS.workload, 0),
+        target=ARGS.target,
+        config=ms.EvolutionarySearchConfig(
+            num_trials_per_iter=64,
+            num_trials_total=ARGS.num_trials,
+            init_min_unmeasured=50,
+        ),
+        runner=runner,  # type: ignore
+        task_name=ARGS.workload,
+        work_dir=ARGS.work_dir,
+        num_threads=cpu_count(),
+    )
+    if sch is None:
+        print("No valid schedule found!")
+    else:
+        print(sch.mod.script())
+        print(sch.trace)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc 
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 6b34f69..e2c71b7 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -104,8 +104,7 @@ class VerifyGPUCodeNode : public PostprocNode {
     ICHECK(context->target.defined());
     Target target = context->target.value();
     this->target_constraints_ = Map<String, PrimExpr>{
-        {"max_shared_memory_per_block", Extract(target, 
"shared_memory_per_block")},
-        {"max_local_memory_per_block", Extract(target, "registers_per_block")},
+        {"max_shared_memory_per_block", Extract(target, 
"max_shared_memory_per_block")},
         {"max_threads_per_block", Extract(target, "max_threads_per_block")},
         {"max_vthread", Integer(8)},
         {"max_vector_bytes", Integer(16)}};
diff --git a/src/target/tag.cc b/src/target/tag.cc
index a931a28..07a5a5f 100644
--- a/src/target/tag.cc
+++ b/src/target/tag.cc
@@ -70,14 +70,38 @@ Target TargetTag::AddTag(String name, Map<String, 
ObjectRef> config, bool overri
 
 /**********  Register Target tags  **********/
 
+TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64")
+    .set_config({{"kind", String("llvm")},
+                 {"mtriple", String("aarch64-linux-gnu")},
+                 {"mcpu", String("cortex-a72")},
+                 {"mattr", Array<String>{"+neon"}},
+                 {"num-cores", Integer(4)},
+                 {"host", Map<String, ObjectRef>{{"kind", String("llvm")},
+                                                 {"mtriple", 
String("aarch64-linux-gnu")},
+                                                 {"mcpu", 
String("cortex-a72")},
+                                                 {"mattr", 
Array<String>{"+neon"}},
+                                                 {"num-cores", Integer(4)}}}});
+
+TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier")
+    .set_config({{"kind", String("cuda")},
+                 {"arch", String("sm_72")},
+                 {"max_shared_memory_per_block", Integer(49152)},
+                 {"max_threads_per_block", Integer(1024)},
+                 {"thread_warp_size", Integer(32)},
+                 {"registers_per_block", Integer(65536)},
+                 {"host", Map<String, ObjectRef>{{"kind", String("llvm")},
+                                                 {"mtriple", 
String("aarch64-linux-gnu")},
+                                                 {"mcpu", String("carmel")},
+                                                 {"num-cores", Integer(4)}}}});
+
 #define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \
   TVM_REGISTER_TARGET_TAG(Name).set_config({                      \
       {"kind", String("cuda")},                                   \
       {"arch", String(Arch)},                                     \
-      {"shared_memory_per_block", Integer(SharedMem)},            \
-      {"registers_per_block", Integer(RegPerBlock)},              \
+      {"max_shared_memory_per_block", Integer(SharedMem)},        \
       {"max_threads_per_block", Integer(1024)},                   \
       {"thread_warp_size", Integer(32)},                          \
+      {"registers_per_block", Integer(RegPerBlock)},              \
   });
 
 TVM_REGISTER_CUDA_TAG("nvidia/tesla-k80", "sm_37", 49152, 65536);
@@ -318,7 +342,6 @@ TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-415m", "sm_21", 
49152, 32768);
 TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-480m", "sm_20", 49152, 32768);
 TVM_REGISTER_CUDA_TAG("nvidia/geforce-710m", "sm_21", 49152, 32768);
 TVM_REGISTER_CUDA_TAG("nvidia/geforce-410m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/jetson-agx-xavier", "sm_72", 49152, 65536);
 TVM_REGISTER_CUDA_TAG("nvidia/jetson-nano", "sm_53", 49152, 32768);
 TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx2", "sm_62", 49152, 32768);
 TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx1", "sm_53", 49152, 32768);
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index c562c78..1131e6e 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -286,11 +286,11 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
     .add_attr_option<String>("mcpu")
     .add_attr_option<String>("arch")
     .add_attr_option<Bool>("system-lib")
-    .add_attr_option<Integer>("max_num_threads", Integer(1024))
+    .add_attr_option<Integer>("max_shared_memory_per_block")
+    .add_attr_option<Integer>("max_threads_per_block")
     .add_attr_option<Integer>("thread_warp_size", Integer(32))
-    .add_attr_option<Integer>("shared_memory_per_block")
     .add_attr_option<Integer>("registers_per_block")
-    .add_attr_option<Integer>("max_threads_per_block")
+    .add_attr_option<Integer>("max_num_threads", Integer(1024))  // 
TODO(@zxybazh): deprecate it
     .set_default_keys({"cuda", "gpu"})
     .set_attrs_preprocessor(UpdateCUDAAttrs);
 
diff --git 
a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py 
b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
index 2c37731..db302f4 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
@@ -17,6 +17,7 @@
 # pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 
 import sys
+
 import pytest
 import tvm
 from tvm import tir
@@ -380,7 +381,7 @@ def test_postproc_verify_gpu_1():
     mod = Conv2dCuda1
     ctx = _create_context(mod, target=_target())
     sch = tir.Schedule(mod, debug_mask="all")
-    assert not ctx.postprocs[0].apply(sch)
+    assert ctx.postprocs[0].apply(sch)
 
 
 def test_postproc_verify_gpu_2():
diff --git a/tests/python/unittest/test_target_target.py 
b/tests/python/unittest/test_target_target.py
index 33f9a96..99cdb86 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -216,7 +216,7 @@ def test_target_tag_0():
     tgt = tvm.target.Target("nvidia/geforce-rtx-2080-ti")
     assert tgt.kind.name == "cuda"
     assert tgt.attrs["arch"] == "sm_75"
-    assert tgt.attrs["shared_memory_per_block"] == 49152
+    assert tgt.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 65536
@@ -226,7 +226,7 @@ def test_target_tag_1():
     tgt = tvm.target.Target("nvidia/jetson-nano")
     assert tgt.kind.name == "cuda"
     assert tgt.attrs["arch"] == "sm_53"
-    assert tgt.attrs["shared_memory_per_block"] == 49152
+    assert tgt.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 32768
@@ -243,13 +243,13 @@ def test_target_host_tags():
     tgt = tvm.target.Target("nvidia/jetson-nano", "nvidia/geforce-rtx-2080-ti")
     assert tgt.kind.name == "cuda"
     assert tgt.attrs["arch"] == "sm_53"
-    assert tgt.attrs["shared_memory_per_block"] == 49152
+    assert tgt.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 32768
     assert tgt.host.kind.name == "cuda"
     assert tgt.host.attrs["arch"] == "sm_75"
-    assert tgt.host.attrs["shared_memory_per_block"] == 49152
+    assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 65536
@@ -259,7 +259,7 @@ def test_target_host_tag_dict():
     tgt = tvm.target.Target("nvidia/jetson-nano", {"kind": "llvm"})
     assert tgt.kind.name == "cuda"
     assert tgt.attrs["arch"] == "sm_53"
-    assert tgt.attrs["shared_memory_per_block"] == 49152
+    assert tgt.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 32768
@@ -271,7 +271,7 @@ def test_target_host_single_dict():
     assert tgt.kind.name == "llvm"
     assert tgt.host.kind.name == "cuda"
     assert tgt.host.attrs["arch"] == "sm_53"
-    assert tgt.host.attrs["shared_memory_per_block"] == 49152
+    assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768
@@ -288,7 +288,7 @@ def test_target_host_single_string_with_tag():
     assert tgt.kind.name == "cuda"
     assert tgt.host.kind.name == "cuda"
     assert tgt.host.attrs["arch"] == "sm_53"
-    assert tgt.host.attrs["shared_memory_per_block"] == 49152
+    assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768
@@ -299,7 +299,7 @@ def test_target_host_merge_0():
     assert tgt.kind.name == "cuda"
     assert tgt.host.kind.name == "cuda"
     assert tgt.host.attrs["arch"] == "sm_53"
-    assert tgt.host.attrs["shared_memory_per_block"] == 49152
+    assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768
@@ -346,7 +346,7 @@ def test_target_with_host():
     tgt = tgt.with_host(cuda_host)
     assert tgt.host.kind.name == "cuda"
     assert tgt.host.attrs["arch"] == "sm_53"
-    assert tgt.host.attrs["shared_memory_per_block"] == 49152
+    assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768

Reply via email to