Skip to content

Commit

Permalink
Math: Optimize 16-bit matrix multiplication function
Browse files Browse the repository at this point in the history
Performed multiple optimization in the 16-bit matrix
multiplication function.

- This check-in changed accumulator data type from
int64_t to int32_t , reducing the instruction cycle
count by ~8.18% gain for matrix multiplication.
- Enhanced pointer arithmetic within for loops
- Eliminated unnecessary conditionals by directly
handling Q0 data within algorithm core logic

These optimization yied a ~36.31% reduction in
memory usage for matrix multplication function

Signed-off-by: Shriram Shastry <[email protected]>
  • Loading branch information
ShriramShastry committed Aug 25, 2024
1 parent 8fe7d36 commit 65a34cd
Showing 1 changed file with 41 additions and 31 deletions.
72 changes: 41 additions & 31 deletions src/math/matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,56 +25,66 @@
* -EINVAL if input dimensions do not allow for multiplication.
* -ERANGE if the shift operation might cause integer overflow.
*/
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
struct mat_matrix_16b *c)
{
/* Validate matrix dimensions are compatible for multiplication */
if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
return -EINVAL;

int64_t s;
int16_t *x;
int16_t *y;
int16_t *z = c->data;
int i, j, k;
int y_inc = b->columns;
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
int32_t acc; /* Accumulator for dot product calculation */
int16_t *x, *y, *z = c->data; /* Pointers for matrices a, b, and c */
int i, j, k; /* Loop counters */
int y_inc = b->columns; /* Column increment for matrix b elements */
/* Calculate shift amount for adjusting fractional bits in the result */
const int shift = a->fractions + b->fractions - c->fractions - 1;

/* Check shift to ensure no integer overflow occurs during shifting */
if (shift_minus_one < -1 || shift_minus_one > 31)
if (shift < -1 || shift > 31)
return -ERANGE;

/* If all data is Q0 */
if (shift_minus_one == -1) {
/* Matrix multiplication loop */
if (shift == -1) {
/* Special case when shift is -1 (Q0 data) */
for (i = 0; i < a->rows; i++) {
for (j = 0; j < b->columns; j++) {
s = 0;
/* Initialize accumulator for each element */
acc = 0;
/* Set x at the start of ith row of a */
x = a->data + a->columns * i;
/* Set y at the top of jth column of b */
y = b->data + j;
/* Dot product loop */
for (k = 0; k < b->rows; k++) {
s += (int32_t)(*x) * (*y);
x++;
/* Multiply & accumulate */
acc += (int32_t)(*x++) * (*y);
/* Move to next row in the current column of b */
y += y_inc;
}
*z = (int16_t)s; /* For Q16.0 */
z++;
*z = (int16_t)acc;
z++; /* Move to the next element in the output matrix */
}
}

return 0;
}

for (i = 0; i < a->rows; i++) {
for (j = 0; j < b->columns; j++) {
s = 0;
x = a->data + a->columns * i;
y = b->data + j;
for (k = 0; k < b->rows; k++) {
s += (int32_t)(*x) * (*y);
x++;
y += y_inc;
} else {
/* General case for other shift values */
for (i = 0; i < a->rows; i++) {
for (j = 0; j < b->columns; j++) {
/* Initialize accumulator for each element */
acc = 0;
/* Set x at the start of ith row of a */
x = a->data + a->columns * i;
/* Set y at the top of jth column of b */
y = b->data + j;
/* Dot product loop */
for (k = 0; k < b->rows; k++) {
/* Multiply & accumulate */
acc += (int32_t)(*x++) * (*y);
/* Move to next row in the current column of b */
y += y_inc;
}
*z = (int16_t)(((acc >> shift) + 1) >> 1);
z++; /* Move to the next element in the output matrix */
}
*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
z++;
}
}
return 0;
Expand Down

0 comments on commit 65a34cd

Please sign in to comment.