diff options
| author | 3gg <3gg@shellblade.net> | 2023-11-23 10:02:33 -0800 |
|---|---|---|
| committer | 3gg <3gg@shellblade.net> | 2023-11-23 10:02:33 -0800 |
| commit | 3df7b6fb0c65295eed4590e6f166d60e89b3c68e (patch) | |
| tree | 51c53d0b55e4fdff0facc5c4624b1102a40a13f0 | |
| parent | 6ca8a31143f087f3bc470d39eb3c00156443802a (diff) | |
Documentation.
| -rw-r--r-- | src/lib/src/matrix.c | 2 | ||||
| -rw-r--r-- | src/lib/src/neuralnet_impl.h | 2 | ||||
| -rw-r--r-- | src/lib/src/train.c | 15 |
3 files changed, 11 insertions, 8 deletions
diff --git a/src/lib/src/matrix.c b/src/lib/src/matrix.c index 174504f..d98c8bb 100644 --- a/src/lib/src/matrix.c +++ b/src/lib/src/matrix.c | |||
| @@ -6,7 +6,7 @@ | |||
| 6 | 6 | ||
| 7 | nnMatrix nnMatrixMake(int rows, int cols) { | 7 | nnMatrix nnMatrixMake(int rows, int cols) { |
| 8 | R* values = calloc(rows * cols, sizeof(R)); | 8 | R* values = calloc(rows * cols, sizeof(R)); |
| 9 | assert(values != 0); | 9 | assert(values != 0); // TODO: Make it a hard assert. |
| 10 | 10 | ||
| 11 | return (nnMatrix){ | 11 | return (nnMatrix){ |
| 12 | .rows = rows, | 12 | .rows = rows, |
diff --git a/src/lib/src/neuralnet_impl.h b/src/lib/src/neuralnet_impl.h index 18694f4..f5a9c63 100644 --- a/src/lib/src/neuralnet_impl.h +++ b/src/lib/src/neuralnet_impl.h | |||
| @@ -30,7 +30,7 @@ typedef struct nnNeuralNetwork { | |||
| 30 | /// |network_outputs| points to the last output matrix in |layer_outputs| for | 30 | /// |network_outputs| points to the last output matrix in |layer_outputs| for |
| 31 | /// convenience. | 31 | /// convenience. |
| 32 | typedef struct nnQueryObject { | 32 | typedef struct nnQueryObject { |
| 33 | int num_layers; | 33 | int num_layers; // Same as nnNeuralNetwork::num_layers. |
| 34 | nnMatrix* layer_outputs; // Output matrices, one output per layer. | 34 | nnMatrix* layer_outputs; // Output matrices, one output per layer. |
| 35 | nnMatrix* network_outputs; // Points to the last output matrix. | 35 | nnMatrix* network_outputs; // Points to the last output matrix. |
| 36 | } nnTrainingQueryObject; | 36 | } nnTrainingQueryObject; |
diff --git a/src/lib/src/train.c b/src/lib/src/train.c index 9244907..dc93f0f 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c | |||
| @@ -219,13 +219,15 @@ void nnTrain( | |||
| 219 | // Assuming one training input per iteration for now. | 219 | // Assuming one training input per iteration for now. |
| 220 | nnMatrixTranspose(&training_inputs, &training_inputs_T); | 220 | nnMatrixTranspose(&training_inputs, &training_inputs_T); |
| 221 | 221 | ||
| 222 | // Run a forward pass and compute the output layer error. | 222 | // Run a forward pass and compute the output layer error relevant to the |
| 223 | // We don't square the error here; instead, we just compute t-o, which is | 223 | // derivative: o-t. |
| 224 | // part of the derivative, -2(t-o). Also, we compute o-t instead to | 224 | // Error: (t-o)^2 |
| 225 | // remove that outer negative sign. | 225 | // dE/do = -2(t-o) |
| 226 | // = +2(o-t) | ||
| 227 | // Note that we compute o-t instead to remove that outer negative sign. | ||
| 228 | // The 2 is dropped because we are only interested in the direction of the | ||
| 229 | // gradient. The learning rate controls the magnitude. | ||
| 226 | nnQuery(net, query, &training_inputs); | 230 | nnQuery(net, query, &training_inputs); |
| 227 | // nnMatrixSub(&training_targets, training_outputs, | ||
| 228 | // &errors[net->num_layers - 1]); | ||
| 229 | nnMatrixSub( | 231 | nnMatrixSub( |
| 230 | training_outputs, &training_targets, &errors[net->num_layers - 1]); | 232 | training_outputs, &training_targets, &errors[net->num_layers - 1]); |
| 231 | 233 | ||
| @@ -328,6 +330,7 @@ void nnTrain( | |||
| 328 | params->max_iterations, ComputeMSE(&errors[net->num_layers - 1])); | 330 | params->max_iterations, ComputeMSE(&errors[net->num_layers - 1])); |
| 329 | } | 331 | } |
| 330 | 332 | ||
| 333 | // Clean up. | ||
| 331 | for (int l = 0; l < net->num_layers; ++l) { | 334 | for (int l = 0; l < net->num_layers; ++l) { |
| 332 | nnMatrixDel(&errors[l]); | 335 | nnMatrixDel(&errors[l]); |
| 333 | nnMatrixDel(&outputs_T[l]); | 336 | nnMatrixDel(&outputs_T[l]); |
