Readings
To perform matrix-matrix multiplication between two tensors we can use @
operator in PyTorch. For example:
Matrix-Vector & Matrix-Matrix Multiplication
a = torch.randn(3, 4)
b = torch.randn(4)
print((a @ b).shape) # torch.Size([3])
print((torch.matmul(a, b)).shape) # torch.Size([3])
a = torch.randn(3, 4)
b = torch.randn(4, 5)
print((a @ b).shape) # torch.Size([3, 5])
print((torch.matmul(a, b)).shape) # torch.Size([3, 5])
This also aligns with the mathematics. If we have two matrices and , then the matrix-matrix product is .
However, when the participating tensors are with higher dimensions, it becomes ambiguous that how the multiplication should be performed.
Batched Matrix-Matrix Multiplication
Let's say now we work with 3D tensors and . This is like the batched version of the previous example. The batch dimension doesn't participate in the multiplication. The result is a 3D tensor .
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
print((a @ b).shape) # torch.Size([10, 3, 5])
print((torch.bmm(a, b)).shape) # torch.Size([10, 3, 5])
print((torch.matmul(a, b)).shape) # torch.Size([10, 3, 5])
assert (a @ b == torch.bmm(a, b)).all() # True
assert (a @ b == torch.matmul(a, b)).all() # True
torch.bmm
is specifically for batched matrix-matrix multiplication. It expects the input tensors to be 3D. @
and torch.matmul
are more flexible. They can handle tensors with arbitrary dimensions but are also more confusing.
Higher Dimensional Matrix-Matrix Multiplication
When the participating tensors are with higher dimensions, only the last two dimensions participate in the multiplication.
a = torch.randn(64, 10, 3, 4)
b = torch.randn(64, 10, 4, 5)
print((a @ b).shape) # torch.Size([64, 10, 3, 5])
print((torch.matmul(a, b)).shape) # torch.Size([64, 10, 3, 5])
assert (a @ b == torch.matmul(a, b)).all() # True