This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 54a62c1b53 [Fix][TIR] SampleCategorical apply-to-schedule (#14133)
54a62c1b53 is described below

commit 54a62c1b5381a3ab630b143648c6a52b2e761027
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Feb 26 23:17:44 2023 -0500

    [Fix][TIR] SampleCategorical apply-to-schedule (#14133)
    
    This PR is another way to fix the issue described in #14118.
    
    Since we do not have a standard for json file on the format of float
    numbers (for example, we cannot require a json file producer to print
    the "integer" float numbers with at least one decimal), and the json
    parser is not responsible for determining if an integer in a json file
    should be parsed to a float or an int, the most convenient way of fixing
    the SampleCategorical issue will be allowing both FloatImms and IntImms
    as input, and converting all IntImms to FloatImms accordingly.
    
    This PR fixes the issue in this way.
---
 src/tir/schedule/primitive/sampling.cc           | 17 +++++++++++++--
 tests/python/unittest/test_tir_schedule_trace.py | 27 ++++++++++++++++++++++--
 2 files changed, 40 insertions(+), 4 deletions(-)

diff --git a/src/tir/schedule/primitive/sampling.cc 
b/src/tir/schedule/primitive/sampling.cc
index ec12b045d3..e84e171811 100644
--- a/src/tir/schedule/primitive/sampling.cc
+++ b/src/tir/schedule/primitive/sampling.cc
@@ -391,9 +391,22 @@ struct SampleCategoricalTraits : public 
UnpackedInstTraits<SampleCategoricalTrai
 
   static ExprRV UnpackedApplyToSchedule(Schedule sch,               //
                                         Array<Integer> candidates,  //
-                                        Array<FloatImm> probs,      //
+                                        Array<ObjectRef> probs,     //
                                         Optional<Integer> decision) {
-    return sch->SampleCategorical(candidates, probs, decision);
+    Array<FloatImm> probs_float = probs.Map([](const ObjectRef& prob) {
+      const auto* prob_float = prob.as<FloatImmNode>();
+      if (prob_float != nullptr) {
+        return GetRef<FloatImm>(prob_float);
+      }
+      const auto* prob_int = prob.as<IntImmNode>();
+      if (prob_int != nullptr) {
+        return FloatImm(DataType::Float(32), 
static_cast<double>(prob_int->value));
+      }
+      LOG(FATAL)
+          << "SampleCategorical does not accept probability with type other 
than float or int.";
+      throw;
+    });
+    return sch->SampleCategorical(candidates, probs_float, decision);
   }
 
   static String UnpackedAsPython(Array<String> outputs,      //
diff --git a/tests/python/unittest/test_tir_schedule_trace.py 
b/tests/python/unittest/test_tir_schedule_trace.py
index 916db184e0..a87fd4ed5b 100644
--- a/tests/python/unittest/test_tir_schedule_trace.py
+++ b/tests/python/unittest/test_tir_schedule_trace.py
@@ -316,6 +316,30 @@ def test_apply_json_to_schedule_1():
     tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])
 
 
+def test_apply_json_to_schedule_sample_categorical():
+    var = tir.Var("v", "int32")
+    trace1 = Trace(
+        insts=[
+            Instruction(
+                kind=InstructionKind.get("SampleCategorical"),
+                inputs=[],
+                attrs=[[tvm.tir.IntImm("int32", 3)], 
[tvm.tir.FloatImm("float32", 1.0)]],
+                outputs=[var],
+            )
+        ],
+        decisions={},
+    )
+    json = trace1.as_json()
+    assert json == [[["SampleCategorical", [], [[3], [1]], ["v0"]]], []]
+
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    # As long as the application does not fail, it is fine.
+    Trace.apply_json_to_schedule(json, sch)
+    python_str = sch.trace.as_python()
+    assert len(python_str) == 1
+    assert python_str[0] == "v0 = sch.sample_categorical(candidates=[3], 
probs=[1], decision=0)"
+
+
 def _test_apply_annotation_trace_from_json(annotation: str):
     """Test applying an annotation works without crashing.
 
@@ -367,5 +391,4 @@ def test_apply_annotation_from_json():
 
 
 if __name__ == "__main__":
-    test_trace_simplified_2()
-    # tvm.testing.main()
+    tvm.testing.main()

Reply via email to