interesaaat opened a new issue #6300:
URL: https://github.com/apache/incubator-tvm/issues/6300
I have the following pytorch program:
```
tree_nodes = indexes
feature_nodes = torch.index_select(self.features, 0, tree_nodes).view(-1,
self.num_trees)
feature_values = torch.gather(x, 1, feature_nodes)
thresholds = torch.index_select(self.thresholds, 0, indexes).view(-1,
self.num_trees)
lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)
indexes = torch.where(torch.ge(feature_values, thresholds), rights,
lefts).long()
indexes = indexes + self.nodes_offset
indexes = indexes.view(-1)
```
when I compile it into TVM, I get the following interesting trace:
```
%0 = (%v_operator_map.SklearnLGBMRegressor.nodes_offset,
%v_operator_map.SklearnLGBMRegressor.nodes_offset,
%v_operator_map.SklearnLGBMRegressor.nodes_offset);
%1 = concatenate(%0);
%2 = reshape(%1, newshape=[-1]);
%3 = take(%v_operator_map.SklearnLGBMRegressor.features, %2, axis=0);
%4 = reshape(%3, newshape=[-1, 3]);
%5 = gather(%input, %4, axis=1);
%6 = take(%v_operator_map.SklearnLGBMRegressor.thresholds, %2, axis=0);
%7 = reshape(%6, newshape=[-1, 3]);
%8 = greater_equal(%5, %7);
%9 = take(%v_operator_map.SklearnLGBMRegressor.rights, %2, axis=0);
%10 = reshape(%9, newshape=[-1, 3]);
%11 = take(%v_operator_map.SklearnLGBMRegressor.lefts, %2, axis=0);
%12 = reshape(%11, newshape=[-1, 3]);
%13 = where(%8, %10, %12);
%14 = cast(%13, dtype="float32");
%15 = add(%14, %v_operator_map.SklearnLGBMRegressor.nodes_offset) an
internal invariant was violated while typechecking your program [16:15:28]
/Users/mainterl/Develop/tvm/src/relay/op/type_relations.cc:107: Check failed:
t0->dtype == t1->dtype (float32 vs. int64) :
; ;
```
Apparently in line 14 the cast into long is translated into a cast into
float32.
To reproduce is you can pull
[this](https://github.com/microsoft/hummingbird/tree/mainterl/tvm) branch, `pip
install -e .[extra]` and run
[this](https://github.com/microsoft/hummingbird/blob/mainterl/tvm/tests/test_lightgbm_converter.py)
test file. I can try to generate a minimal running example if it helps.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]