diff options
| -rw-r--r-- | src/lib/src/matrix.c | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index 29cdec5..f937c01 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c | |||
| @@ -131,21 +131,23 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | |||
| 131 | assert(out->cols == right->cols); | 131 | assert(out->cols == right->cols); |
| 132 | 132 | ||
| 133 | R* out_value = out->values; | 133 | R* out_value = out->values; |
| 134 | for (int i = 0; i < out->rows * out->cols; ++i) { | ||
| 135 | *out_value++ = 0; | ||
| 136 | } | ||
| 134 | 137 | ||
| 135 | for (int i = 0; i < left->rows; ++i) { | 138 | for (int i = 0; i < left->rows; ++i) { |
| 136 | const R* left_row = &left->values[i * left->cols]; | 139 | const R* p_left_value = &left->values[i * left->cols]; |
| 137 | 140 | ||
| 138 | for (int j = 0; j < right->cols; ++j) { | 141 | for (int j = 0; j < left->cols; ++j) { |
| 139 | const R* right_col = &right->values[j]; | 142 | const R left_value = *p_left_value; |
| 140 | *out_value = 0; | 143 | const R* right_value = &right->values[j * right->cols]; |
| 144 | R* out_value = &out->values[i * out->cols]; | ||
| 141 | 145 | ||
| 142 | // Vector dot product. | 146 | for (int k = 0; k < right->cols; ++k) { |
| 143 | for (int k = 0; k < left->cols; ++k) { | 147 | *out_value++ += left_value * *right_value++; |
| 144 | *out_value += left_row[k] * right_col[0]; | ||
| 145 | right_col += right->cols; // Next row in the column. | ||
| 146 | } | 148 | } |
| 147 | 149 | ||
| 148 | out_value++; | 150 | p_left_value++; |
| 149 | } | 151 | } |
| 150 | } | 152 | } |
| 151 | } | 153 | } |
