diff --git a/src/math/matrix.c b/src/math/matrix.c index 7ff418178650..e1a160b51516 100644 --- a/src/math/matrix.c +++ b/src/math/matrix.c @@ -24,58 +24,41 @@ * -EINVAL if input dimensions do not allow for multiplication. * -ERANGE if the shift operation might cause integer overflow. */ -int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) +int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, + struct mat_matrix_16b *c) { /* Validate matrix dimensions are compatible for multiplication */ if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) return -EINVAL; - int64_t s; - int16_t *x; - int16_t *y; - int16_t *z = c->data; - int i, j, k; - int y_inc = b->columns; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; - - + int32_t acc; /* Accumulator for dot product calculation */ + int16_t *x, *y, *z = c->data; /* Pointers for matrices a, b, and c */ + int i, j, k; /* Loop counters */ + int y_inc = b->columns; /* Column increment for matrix b elements */ + /* Calculate shift amount for adjusting fractional bits in the result */ + const int shift = a->fractions + b->fractions - c->fractions - 1; /* Check shift to ensure no integer overflow occurs during shifting */ if (shift < -1 || shift > 31) return -ERANGE; - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows; i++) { - for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; - for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; - } - *z = (int16_t)s; /* For Q16.0 */ - z++; - } - } - - return 0; - } - + /* Matrix multiplication loop */ for (i = 0; i < a->rows; i++) { for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; + acc = 0; /* Initialize accumulator for each element */ + x = a->data + a->columns * i; /* Set x at the start of ith row of a */ + y = b->data + j; /* Set y at the top of jth column of b */ + /* Dot product loop */ for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; + acc += (int32_t)(*x++) * (*y); /* Multiply & accumulate */ + y += y_inc; /* Move to next row in the current column of b */ } - *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - z++; + /* Assign computed value to c matrix, adjusting for fractional bits */ + if (shift == -1) + *z = (int16_t)acc; + else + *z = (int16_t)(((acc >> shift) + 1) >> 1); + z++; /* Move to the next element in the output matrix */ } } return 0; @@ -98,15 +81,7 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_ */ int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) -{ - /* Validate matrix dimensions and non-null pointers */ - if (!a || !b || !c || - a->columns != b->columns || a->rows != b->rows || - c->columns != a->columns || c->rows != a->rows) { - return -EINVAL; - } - - int64_t p; +{ int64_t p; int16_t *x = a->data; int16_t *y = b->data; int16_t *z = c->data;