diff --git a/cirkit/backend/torch/layers/optimized.py b/cirkit/backend/torch/layers/optimized.py index b823b8e9..34d2ed08 100644 --- a/cirkit/backend/torch/layers/optimized.py +++ b/cirkit/backend/torch/layers/optimized.py @@ -196,9 +196,11 @@ class TorchTensorDotLayer(TorchInnerLayer): and $K_i$ is the number of input uits. The tensor dot layer firstly reshapes as the tensor $\mathcal{Z}$ having shape $(B, K_j, K_q)$, where $K_i = K_jK_q$. Then, it computes the tensor $\mathcal{S}$ of shape $(B, K_q, K_k)$ as follows: + $$ \mathcal{S}_{bqk} = \sum_{j=1}^{K_j} w_{kj} z_{bjq} $$ + in element-wise notation, where $\mathbf{W}$ is a tensor of shape $(K_k, K_j)$, where we have that $K_o = K_qK_k$ is the number of output units. Finally, it returns the output tensor of shape $(B, K_o)$ obtained by flattening the