The original definition of ALiBi
refers to here.
dynamic_batching.ALiBiMask
, uses seqstarts
and kvstarts
to record the sequence begining position of each batch. The mask is similar to a block diagonal matrix. We generate the mask for each batch and fill it in the corresponding position. The following code shows the calculation process.
alibi_mask = torch.zeros([seqstarts[-1], kvstarts[-1]], dtype=data_type)
seqlens = seqstarts[1:] - seqstarts[:-1]
kvlens = kvstarts[1:] - kvstarts[:-1]
for batch_idx, seqlen in enumerate(seqlens):
kvlen = kvlens[batch_idx]
seqbeg = seqstarts[batch_idx]
seqend = seqstarts[batch_idx+1]
kvbeg = kvstarts[batch_idx]
kvend = kvstarts[batch_idx+1]
tmp_alibi_mask = torch.full((seqlen, kvlen), float('inf'), dtype=data_type)
# generate masks for each batch, the process is the same as static batch alibi
for i in range(xx):
for j in range(xx):
...
alibi_mask[seqbeg:seqend, kvbeg:kvend] = tmp_alibi_mask
alibi_mask = alibi_mask.unsqueeze(0).expand(num_heads, -1, -1)
# alibi_mask shape -> (num_heads, sum(seqlens), sum(kvlens))
# slopes_m shape -> (num_heads, 1, 1)
alibi_mask = slopes_m * alibi_mask
Number of heads
Data type of ALiBi mask
seqstarts[:B]
contains the position of the first token in query
for each batch. And seqstarts[B]
contains the total length of query
.
Note that seqstarts[b+1]-seqstarts[b]
can calculate out the sequence length of batch
Shape:
kvstarts[:B]
contains the position of the first token in key
and value
for each batch. And kvstarts[B]
contains the total length of key
and value
.
Note that kvstarts[b+1]-kvstarts[b]
can calculate out the key and value length of batch
Shape:
Optional custom mask.
seqlens=seqstarts[1:]-seqstarts[:B]
is a sequence contains length of query
for each batch.
kvlens=kvstarts[1:]-kvstarts[:B]
is a sequence contains length of key
and value
for each batch.
Note: The last dim of mask could be bigger than
Shape:
Output mask of ALiBi.
Shape: