diff options
| -rw-r--r-- | src/lib/include/neuralnet/matrix.h | 6 | ||||
| -rw-r--r-- | src/lib/src/matrix.c | 29 | ||||
| -rw-r--r-- | src/lib/src/train.c | 18 | 
3 files changed, 45 insertions, 8 deletions
| diff --git a/src/lib/include/neuralnet/matrix.h b/src/lib/include/neuralnet/matrix.h index 9816b81..0cb40cf 100644 --- a/src/lib/include/neuralnet/matrix.h +++ b/src/lib/include/neuralnet/matrix.h | |||
| @@ -52,6 +52,12 @@ void nnMatrixInitConstant(nnMatrix*, R value); | |||
| 52 | /// Multiply two matrices. | 52 | /// Multiply two matrices. | 
| 53 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | 53 | void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | 
| 54 | 54 | ||
| 55 | /// Multiply two matrices, row variant. | ||
| 56 | /// | ||
| 57 | /// This function multiples two matrices row-by-row instead of row-by-column. | ||
| 58 | /// nnMatrixMul(A, B, O) == nnMatrixMulRows(A, B^T, O). | ||
| 59 | void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out); | ||
| 60 | |||
| 55 | /// Matrix multiply-add. | 61 | /// Matrix multiply-add. | 
| 56 | /// | 62 | /// | 
| 57 | /// out = left + (right * scale) | 63 | /// out = left + (right * scale) | 
| diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index a7a4ce6..29cdec5 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c | |||
| @@ -150,6 +150,35 @@ void nnMatrixMul(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | |||
| 150 | } | 150 | } | 
| 151 | } | 151 | } | 
| 152 | 152 | ||
| 153 | void nnMatrixMulRows(const nnMatrix* left, const nnMatrix* right, nnMatrix* out) { | ||
| 154 | assert(left != 0); | ||
| 155 | assert(right != 0); | ||
| 156 | assert(out != 0); | ||
| 157 | assert(out != left); | ||
| 158 | assert(out != right); | ||
| 159 | assert(left->cols == right->cols); | ||
| 160 | assert(out->rows == left->rows); | ||
| 161 | assert(out->cols == right->rows); | ||
| 162 | |||
| 163 | R* out_value = out->values; | ||
| 164 | |||
| 165 | for (int i = 0; i < left->rows; ++i) { | ||
| 166 | const R* left_row = &left->values[i * left->cols]; | ||
| 167 | const R* right_value = right->values; | ||
| 168 | |||
| 169 | for (int j = 0; j < right->rows; ++j) { | ||
| 170 | *out_value = 0; | ||
| 171 | |||
| 172 | // Vector dot product. | ||
| 173 | for (int k = 0; k < left->cols; ++k) { | ||
| 174 | *out_value += left_row[k] * *right_value++; | ||
| 175 | } | ||
| 176 | |||
| 177 | out_value++; | ||
| 178 | } | ||
| 179 | } | ||
| 180 | } | ||
| 181 | |||
| 153 | void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { | 182 | void nnMatrixMulAdd(const nnMatrix* left, const nnMatrix* right, R scale, nnMatrix* out) { | 
| 154 | assert(left); | 183 | assert(left); | 
| 155 | assert(right); | 184 | assert(right); | 
| diff --git a/src/lib/src/train.c b/src/lib/src/train.c index 027de66..3061a99 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c | |||
| @@ -129,7 +129,7 @@ void nnTrain( | |||
| 129 | nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); | 129 | nnMatrix* errors = calloc(net->num_layers, sizeof(nnMatrix)); | 
| 130 | 130 | ||
| 131 | // Allocate the weight transpose matrices up front for backpropagation. | 131 | // Allocate the weight transpose matrices up front for backpropagation. | 
| 132 | nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); | 132 | //nnMatrix* weights_T = calloc(net->num_layers, sizeof(nnMatrix)); | 
| 133 | 133 | ||
| 134 | // Allocate the weight delta matrices. | 134 | // Allocate the weight delta matrices. | 
| 135 | nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); | 135 | nnMatrix* weight_deltas = calloc(net->num_layers, sizeof(nnMatrix)); | 
| @@ -143,7 +143,7 @@ void nnTrain( | |||
| 143 | nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); | 143 | nnMatrix* outputs_T = calloc(net->num_layers, sizeof(nnMatrix)); | 
| 144 | 144 | ||
| 145 | assert(errors != 0); | 145 | assert(errors != 0); | 
| 146 | assert(weights_T != 0); | 146 | //assert(weights_T != 0); | 
| 147 | assert(weight_deltas != 0); | 147 | assert(weight_deltas != 0); | 
| 148 | assert(gradient_elems); | 148 | assert(gradient_elems); | 
| 149 | assert(outputs_T); | 149 | assert(outputs_T); | 
| @@ -155,8 +155,8 @@ void nnTrain( | |||
| 155 | 155 | ||
| 156 | errors[l] = nnMatrixMake(1, layer_weights->cols); | 156 | errors[l] = nnMatrixMake(1, layer_weights->cols); | 
| 157 | 157 | ||
| 158 | weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); | 158 | //weights_T[l] = nnMatrixMake(layer_weights->cols, layer_weights->rows); | 
| 159 | nnMatrixTranspose(layer_weights, &weights_T[l]); | 159 | //nnMatrixTranspose(layer_weights, &weights_T[l]); | 
| 160 | 160 | ||
| 161 | weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); | 161 | weight_deltas[l] = nnMatrixMake(layer_weights->rows, layer_weights->cols); | 
| 162 | 162 | ||
| @@ -267,7 +267,9 @@ void nnTrain( | |||
| 267 | 267 | ||
| 268 | // Backpropagate the error before updating weights. | 268 | // Backpropagate the error before updating weights. | 
| 269 | if (l > 0) { | 269 | if (l > 0) { | 
| 270 | nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); | 270 | // G * W^T == G *^T W. | 
| 271 | //nnMatrixMul(gradient, &weights_T[l], &errors[l-1]); | ||
| 272 | nnMatrixMulRows(gradient, layer_weights, &errors[l-1]); | ||
| 271 | } | 273 | } | 
| 272 | 274 | ||
| 273 | // Update weights. | 275 | // Update weights. | 
| @@ -278,7 +280,7 @@ void nnTrain( | |||
| 278 | nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); | 280 | nnMatrixSub(layer_weights, &weight_deltas[l], layer_weights); | 
| 279 | 281 | ||
| 280 | // Update weight transpose matrix for the next training iteration. | 282 | // Update weight transpose matrix for the next training iteration. | 
| 281 | nnMatrixTranspose(layer_weights, &weights_T[l]); | 283 | //nnMatrixTranspose(layer_weights, &weights_T[l]); | 
| 282 | 284 | ||
| 283 | // Update biases. | 285 | // Update biases. | 
| 284 | // This is the same formula as for weights, except that the o_j term is | 286 | // This is the same formula as for weights, except that the o_j term is | 
| @@ -319,7 +321,7 @@ void nnTrain( | |||
| 319 | for (int l = 0; l < net->num_layers; ++l) { | 321 | for (int l = 0; l < net->num_layers; ++l) { | 
| 320 | nnMatrixDel(&errors[l]); | 322 | nnMatrixDel(&errors[l]); | 
| 321 | nnMatrixDel(&outputs_T[l]); | 323 | nnMatrixDel(&outputs_T[l]); | 
| 322 | nnMatrixDel(&weights_T[l]); | 324 | //nnMatrixDel(&weights_T[l]); | 
| 323 | nnMatrixDel(&weight_deltas[l]); | 325 | nnMatrixDel(&weight_deltas[l]); | 
| 324 | 326 | ||
| 325 | nnGradientElements* elems = &gradient_elems[l]; | 327 | nnGradientElements* elems = &gradient_elems[l]; | 
| @@ -340,7 +342,7 @@ void nnTrain( | |||
| 340 | nnMatrixDel(&training_inputs_T); | 342 | nnMatrixDel(&training_inputs_T); | 
| 341 | free(errors); | 343 | free(errors); | 
| 342 | free(outputs_T); | 344 | free(outputs_T); | 
| 343 | free(weights_T); | 345 | //free(weights_T); | 
| 344 | free(weight_deltas); | 346 | free(weight_deltas); | 
| 345 | free(gradient_elems); | 347 | free(gradient_elems); | 
| 346 | } | 348 | } | 
