merrymercy commented on a change in pull request #6568:
URL: https://github.com/apache/incubator-tvm/pull/6568#discussion_r495524624



##########
File path: tests/python/unittest/test_auto_scheduler_evolutionary_search.py
##########
@@ -22,56 +22,102 @@
 from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel
 
 
-class MockCostModel(PythonBasedModel):
-    """A mock cost model that rates 1 only for the states with tile_k=2."""
+def test_mutate_tile_size():
+    """
+    The test case initializes evo search with a batch of "bad" states and 
check whether
+    the search algorithm can find "good" states by mutating the "bad" states.
+
+    This unit test has been tested with 1,000 runs with no failures, meaning 
that
+    the failure rate is less than 0.1%.
+    """
 
-    def predict(self, task, states):
-        scores = []
-        found = False
-        for state in states:
+    class MockCostModel(PythonBasedModel):
+        """A mock cost model that rates 1 only for the states with tile_k=2."""
+
+        @staticmethod
+        def is_good_state(state):
             for line in str(state).split("\n"):
                 if line.find("k.1") != -1 and line.find("(0,2)") != -1:
-                    found = True
-                    break
-            scores.append(1 if found else 0)
-        return scores
+                    return True
+            return False
 
+        def predict(self, task, states):
+            scores = []
+            found = False
+            for state in states:
+                scores.append(1 if self.is_good_state(state) else 0)
+            return scores
 
-def test_evo_search():
-    """Test evolutionary search. Since we cannot mock random number generator,
-    we mocked the cost model to manually guide the evo search. If evo search 
works
-    as expected, it should find the target state after a sufficient number of 
iterations.
-    This unit test has been tested with 1,000 runs with no failures, meaning 
that
-    the failure rate is less than 0.1%.
-    """
     workload_key = 
auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4))
     dag = auto_scheduler.ComputeDAG(workload_key)
     task = auto_scheduler.SearchTask(dag, workload_key, 
tvm.target.Target("llvm"))
     policy = auto_scheduler.SketchPolicy(task, 
program_cost_model=MockCostModel(), verbose=0)
     states = policy.sample_initial_population(50)
-    pruned_states = []
+
+    bad_states = []
     for state in states:
-        found = False
-        for line in str(state).split("\n"):
-            # Remove all tile_k=2 states and expect evo search will fine them.
-            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
-                found = True
-                break
-        if not found:
-            pruned_states.append(state)
+        if not MockCostModel.is_good_state(state):
+            bad_states.append(state)
 
-    new_states = policy.evolutionary_search(pruned_states, 50)
+    new_states = policy.evolutionary_search(bad_states, 50)
     found = False
     for state in new_states:
-        for line in str(state).split("\n"):
-            # Check if evo search found at least one state with tile_k=2.
-            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
+        if MockCostModel.is_good_state(state):
+            found = True
+            break
+    assert found
+
+
+def test_mutate_parallel():
+    """
+    The test case initializes evo search with a batch of "bad" states and 
check whether
+    the search algorithm can find "good" states by mutating the "bad" states.
+
+    This unit test has been tested with 1,000 runs with no failures, meaning 
that
+    the failure rate is less than 0.1%.
+    """
+
+    class MockCostModel(PythonBasedModel):
+        @staticmethod
+        def is_good_state(state):
+            for line in str(state).split("\n"):
+                if (
+                    line.find("parallel i.0@ (0") != -1
+                    or line.find("parallel [email protected]@ (0") != -1
+                    or line.find("parallel [email protected]@i.1@ (0") != -1
+                ):
+                    return True
+            return False
+
+        def predict(self, task, states):
+            scores = []
+            found = False
+            for state in states:
+                scores.append(1 if self.is_good_state(state) else 0)
+            return scores
+
+    workload_key = 
auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (1024, 1024, 1024))
+    dag = auto_scheduler.ComputeDAG(workload_key)
+    task = auto_scheduler.SearchTask(dag, workload_key, 
tvm.target.Target("llvm"))
+    policy = auto_scheduler.SketchPolicy(task, 
program_cost_model=MockCostModel(), verbose=0)
+    states = policy.sample_initial_population(100)
+
+    bad_states = []
+    for state in states:
+        if not MockCostModel.is_good_state(state):
+            bad_states.append(state)
+
+    found = False
+    retry_ct = 0

Review comment:
       Good catch. Fixed.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to