shingjan commented on code in PR #12485:
URL: https://github.com/apache/tvm/pull/12485#discussion_r959842525


##########
python/tvm/relay/frontend/pytorch.py:
##########
@@ -2247,6 +2247,47 @@ def embedding(self, inputs, input_types):
 
         return _op.take(weight, indices.astype("int32"), axis=0)
 
+    def embedding_bag(self, inputs, _):
+        assert len(inputs) == 9, "embedding_bag needs 9 arguments"
+        (
+            weights,
+            indices,
+            offsets_1d,
+            scale_grad_by_freq,
+            mode,
+            sparse,
+            per_sample_weights,
+            include_last_offset,
+            padding_idx,
+        ) = inputs
+
+        assert len(_infer_shape(indices)) == 1, "Expects 1D indices for 
aten::embedding_bag."
+
+        assert (

Review Comment:
   actually if `scale_grad_by_freq` and `sparse` arguments won't effect the 
results of our embedding_bag impl here, the expected behavior I have in mind 
will be that we can have some tests confirming this fact and raise a warning, 
instead of assertion here, to user if any of those two arguments are passed 
with non-default values. Does that make sense?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to