diff options
| author | 3gg <3gg@shellblade.net> | 2023-12-16 11:29:33 -0800 |
|---|---|---|
| committer | 3gg <3gg@shellblade.net> | 2023-12-16 11:29:33 -0800 |
| commit | 57bf2b46b4b277952d722f6439b72f9e40db129c (patch) | |
| tree | 2f979007fea1fd77b5ccfc2e7b530dd1842c1503 | |
| parent | dc538733da8d49e7240d00fb05517053076fe261 (diff) | |
Clarify some terminology.
| -rw-r--r-- | src/lib/src/train.c | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/src/lib/src/train.c b/src/lib/src/train.c index fe9f598..7559ece 100644 --- a/src/lib/src/train.c +++ b/src/lib/src/train.c | |||
| @@ -239,17 +239,16 @@ void nnTrain( | |||
| 239 | 239 | ||
| 240 | // Compute this layer's gradient. | 240 | // Compute this layer's gradient. |
| 241 | // | 241 | // |
| 242 | // By "gradient" we mean the expression common to the weights and bias | 242 | // By 'gradient' we mean the subexpression common to all the gradients |
| 243 | // gradients. This is the part of the expression that does not contain | 243 | // for this layer. |
| 244 | // this layer's input. | 244 | // For linear layers, this is the subexpression common to both the |
| 245 | // weights and bias gradients. | ||
| 245 | // | 246 | // |
| 246 | // Linear: G = id | 247 | // Linear: G = id |
| 247 | // Relu: G = (output_k > 0 ? 1 : 0) | 248 | // Relu: G = (output_k > 0 ? 1 : 0) |
| 248 | // Sigmoid: G = output_k * (1 - output_k) | 249 | // Sigmoid: G = output_k * (1 - output_k) |
| 249 | switch (layer->type) { | 250 | switch (layer->type) { |
| 250 | case nnLinear: { | 251 | case nnLinear: { |
| 251 | // TODO: Just copy the pointer? | ||
| 252 | *gradient = nnMatrixBorrow(&errors[l]); | ||
| 253 | break; | 252 | break; |
| 254 | } | 253 | } |
| 255 | case nnRelu: | 254 | case nnRelu: |
| @@ -294,7 +293,7 @@ void nnTrain( | |||
| 294 | nnMatrix* layer_biases = &linear->biases; | 293 | nnMatrix* layer_biases = &linear->biases; |
| 295 | 294 | ||
| 296 | // Outer product to compute the weight deltas. | 295 | // Outer product to compute the weight deltas. |
| 297 | nnMatrixMulOuter(layer_input, gradient, &weight_deltas[l]); | 296 | nnMatrixMulOuter(layer_input, &errors[l], &weight_deltas[l]); |
| 298 | 297 | ||
| 299 | // Update weights. | 298 | // Update weights. |
| 300 | nnMatrixScale(&weight_deltas[l], params->learning_rate); | 299 | nnMatrixScale(&weight_deltas[l], params->learning_rate); |
| @@ -304,7 +303,7 @@ void nnTrain( | |||
| 304 | // This is the same formula as for weights, except that the o_j term | 303 | // This is the same formula as for weights, except that the o_j term |
| 305 | // is just 1. | 304 | // is just 1. |
| 306 | nnMatrixMulSub( | 305 | nnMatrixMulSub( |
| 307 | layer_biases, gradient, params->learning_rate, layer_biases); | 306 | layer_biases, &errors[l], params->learning_rate, layer_biases); |
| 308 | } | 307 | } |
| 309 | } | 308 | } |
| 310 | 309 | ||
