diff --git a/src/math/matrix.c b/src/math/matrix.c index e4c4c1ecbfff..8b4d072d341a 100644 --- a/src/math/matrix.c +++ b/src/math/matrix.c @@ -3,93 +3,112 @@ // Copyright(c) 2022 Intel Corporation. All rights reserved. // // Author: Seppo Ingalsuo +// Shriram Shastry +// #include #include #include -int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) +/* Performs matrix multiplication of two fixed-point 16-bit integer matrices, + * storing the result in a third matrix. It accounts for fractional bits for + * fixed-point arithmetic, adjusting the result accordingly. + * + * Arguments: + * a: pointer to the first input matrix + * b: pointer to the second input matrix + * c: pointer to the output matrix to store result + * + * Return: + * 0 on successful multiplication. + * -EINVAL if input dimensions do not allow for multiplication. + * -ERANGE if the shift operation might cause integer overflow. + */ +int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, + struct mat_matrix_16b *c) { - int64_t s; - int16_t *x; - int16_t *y; - int16_t *z = c->data; - int i, j, k; - int y_inc = b->columns; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; + int32_t acc; /* Accumulator for dot product calculation */ + int16_t *x, *y, *z = c->data; /* Pointers for matrices a, b, and c */ + int i, j, k; /* Loop counters */ + int y_inc = b->columns; /* Column increment for matrix b elements */ + /* Calculate shift amount for adjusting fractional bits in the result */ + const int shift = a->fractions + b->fractions - c->fractions - 1; + /* Validate matrix dimensions are compatible for multiplication */ if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) return -EINVAL; - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows; i++) { - for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; - for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; - } - *z = (int16_t)s; /* For Q16.0 */ - z++; - } - } - - return 0; - } + /* Check shift to ensure no integer overflow occurs during shifting */ + if (shift < -1 || shift > 31) + return -ERANGE; + /* Matrix multiplication loop */ for (i = 0; i < a->rows; i++) { for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; + acc = 0; /* Initialize accumulator for each element */ + x = a->data + a->columns * i; /* Set x at the start of ith row of a */ + y = b->data + j; /* Set y at the top of jth column of b */ + /* Dot product loop */ for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; + acc += (int32_t)(*x++) * (*y); /* Multiply & accumulate */ + y += y_inc; /* Move to next row in the current column of b */ } - *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - z++; + /* Assign computed value to c matrix, adjusting for fractional bits */ + if (shift == -1) + *z = (int16_t)acc; + else + *z = (int16_t)(((acc >> shift) + 1) >> 1); + z++; /* Move to the next element in the output matrix */ } } return 0; } +/* Description: Performs element-wise multiplication of two matrices with 16-bit integer elements + * and stores the result in a third matrix. Checks that all matrices have the same + * dimensions and adjusts for fractional bits appropriately. This operation handles + * the manipulation of fixed-point precision based on the fractional bits present in + * the matrices. + * + * Arguments: + * a - pointer to the first input matrix + * b - pointer to the second input matrix + * c - pointer to the output matrix where the result will be stored + * + * Returns: + * 0 on successful multiplication, + * -EINVAL if input pointers are NULL or matrix dimensions do not match. + */ int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) -{ int64_t p; +{ int16_t *x = a->data; int16_t *y = b->data; int16_t *z = c->data; - int i; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; + int32_t prod; - if (a->columns != b->columns || b->columns != c->columns || - a->rows != b->rows || b->rows != c->rows) { + /* Validate matrix dimensions and non-null pointers */ + if (!a || !b || !c || a->columns != b->columns || a->rows != b->rows) return -EINVAL; - } - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows * a->columns; i++) { + /* Compute the total number of elements in the matrices */ + const int total_elements = a->rows * a->columns; + /* Compute the required bit shift based on the fractional part of each matrix */ + const int shift = a->fractions + b->fractions - c->fractions - 1; + + /* Perform multiplication with or without adjusting the fractional bits */ + if (shift == -1) { + /* Direct multiplication when no adjustment for fractional bits is needed */ + for (int i = 0; i < total_elements; i++, x++, y++, z++) *z = *x * *y; - x++; - y++; - z++; + } else { + /* Multiplication with rounding to account for the fractional bits */ + for (int i = 0; i < total_elements; i++, x++, y++, z++) { + /* Multiply elements as int32_t */ + prod = (int32_t)(*x) * (*y); + /* Adjust and round the result */ + *z = (int16_t)(((prod >> shift) + 1) >> 1); } - - return 0; - } - - for (i = 0; i < a->rows * a->columns; i++) { - p = (int32_t)(*x) * *y; - *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - x++; - y++; - z++; } return 0;