@@ -2316,3 +2315,37 @@ def backward(self, dy):
def max(a,b):
return Max()(a,b)[0]

+class GEMM(Operation):
+    def __init__(self, alpha, beta, transA, transB):
+        self.alpha = alpha
+        self.beta = beta
+        self.transA = transA
+        self.transB = transB
+        super(GEMM, self).__init__()
+
+    def forward(self, A, B, C):
+        if self.transA:
+            A = singa.DefaultTranspose(A)
+        if self.transB:
+            B = singa.DefaultTranspose(B)
+        if training:
+            self.A = A
+            self.B = B
+
+        singa.MultWithScale(self.alpha, A, B, self.beta, C)
+        return C
+
+    def backward(self, dY):
+        dC = singa.MultFloat(dY, self.beta)

Review comment:
Is the backward the same when transA and transB =1 or 0?
Or it need to add transpose at the end, something like this?
```python
def backward(self, dY):
dC = singa.MultFloat(dY, self.beta)
tB = singa.DefaultTranspose(self.B)
tA = singa.DefaultTranspose(self.A)
dA = singa.Mult(dY, tB)
dB = singa.Mult(tA, dY)
dA = singa.MultFloat(dA, self.alpha)
dB = singa.MultFloat(dB, self.alpha)
if self.transA:
dA = singa.DefaultTranspose(dA)
if self.transB:
dB = singa.DefaultTranspose(dB)
del self.A
del self.B
return dA, dB, dC
```

