diff options
| author | jeanne <jeanne@localhost.localdomain> | 2022-05-11 09:54:38 -0700 |
|---|---|---|
| committer | jeanne <jeanne@localhost.localdomain> | 2022-05-11 09:54:38 -0700 |
| commit | 411f66a2540fa17c736116d865e0ceb0cfe5623b (patch) | |
| tree | fa92c69ec627642c8452f928798ff6eccd24ddd6 /src/bin | |
| parent | 7705b07456dfd4b89c272613e98eda36cc787254 (diff) | |
Initial commit.
Diffstat (limited to 'src/bin')
| -rw-r--r-- | src/bin/CMakeLists.txt | 3 | ||||
| -rw-r--r-- | src/bin/mnist/CMakeLists.txt | 11 | ||||
| -rw-r--r-- | src/bin/mnist/src/main.c | 473 |
3 files changed, 487 insertions, 0 deletions
diff --git a/src/bin/CMakeLists.txt b/src/bin/CMakeLists.txt new file mode 100644 index 0000000..051a56f --- /dev/null +++ b/src/bin/CMakeLists.txt | |||
| @@ -0,0 +1,3 @@ | |||
| 1 | cmake_minimum_required(VERSION 3.0) | ||
| 2 | |||
| 3 | add_subdirectory(mnist) | ||
diff --git a/src/bin/mnist/CMakeLists.txt b/src/bin/mnist/CMakeLists.txt new file mode 100644 index 0000000..a6c54f2 --- /dev/null +++ b/src/bin/mnist/CMakeLists.txt | |||
| @@ -0,0 +1,11 @@ | |||
| 1 | cmake_minimum_required(VERSION 3.0) | ||
| 2 | |||
| 3 | add_executable(mnist | ||
| 4 | src/main.c) | ||
| 5 | |||
| 6 | target_link_libraries(mnist PRIVATE | ||
| 7 | neuralnet | ||
| 8 | bsd | ||
| 9 | z) | ||
| 10 | |||
| 11 | target_compile_options(mnist PRIVATE -Wall -Wextra) | ||
diff --git a/src/bin/mnist/src/main.c b/src/bin/mnist/src/main.c new file mode 100644 index 0000000..4d268ac --- /dev/null +++ b/src/bin/mnist/src/main.c | |||
| @@ -0,0 +1,473 @@ | |||
| 1 | #include <neuralnet/matrix.h> | ||
| 2 | #include <neuralnet/neuralnet.h> | ||
| 3 | #include <neuralnet/train.h> | ||
| 4 | |||
| 5 | #include <zlib.h> | ||
| 6 | |||
| 7 | #include <assert.h> | ||
| 8 | #include <bsd/string.h> | ||
| 9 | #include <linux/limits.h> | ||
| 10 | #include <math.h> | ||
| 11 | #include <stdbool.h> | ||
| 12 | #include <stdint.h> | ||
| 13 | #include <stdio.h> | ||
| 14 | #include <stdlib.h> | ||
| 15 | |||
| 16 | static const int TRAIN_ITERATIONS = 100; | ||
| 17 | |||
| 18 | static const int32_t IMAGE_FILE_MAGIC = 0x00000803; | ||
| 19 | static const int32_t LABEL_FILE_MAGIC = 0x00000801; | ||
| 20 | |||
| 21 | // Inputs of 0 cancel weights during training. This value is used to rescale the | ||
| 22 | // input pixels from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. | ||
| 23 | static const double PIXEL_LOWER_BOUND = 0.01; | ||
| 24 | |||
| 25 | // Scale the outputs to (0,1) since the sigmoid cannot produce 0 or 1. | ||
| 26 | static const double LABEL_LOWER_BOUND = 0.01; | ||
| 27 | static const double LABEL_UPPER_BOUND = 0.99; | ||
| 28 | |||
| 29 | // Epsilon used to compare R values. | ||
| 30 | static const double EPS = 1e-10; | ||
| 31 | |||
| 32 | #define min(a,b) ((a) < (b) ? (a) : (b)) | ||
| 33 | |||
| 34 | typedef struct ImageSet { | ||
| 35 | nnMatrix images; // Images flattened into row vectors of the matrix. | ||
| 36 | nnMatrix labels; // One-hot-encoded labels. | ||
| 37 | int count; // Number of images and labels. | ||
| 38 | int rows; // Rows in an image. | ||
| 39 | int cols; // Columns in an image. | ||
| 40 | } ImageSet; | ||
| 41 | |||
| 42 | static void usage(const char* argv0) { | ||
| 43 | fprintf(stderr, "Usage: %s <path to mnist files directory> [num images]\n", argv0); | ||
| 44 | fprintf(stderr, "\n"); | ||
| 45 | fprintf(stderr, " Use -1 for [num images] to use all the images in the data set\n"); | ||
| 46 | } | ||
| 47 | |||
| 48 | static bool R_eq(R a, R b) { | ||
| 49 | return fabs(a-b) <= EPS; | ||
| 50 | } | ||
| 51 | |||
| 52 | static void PrintImage(const nnMatrix* images, int rows, int cols, int image_index) { | ||
| 53 | assert(images); | ||
| 54 | assert((0 <= image_index) && (image_index < images->rows)); | ||
| 55 | |||
| 56 | // Top line. | ||
| 57 | for (int j = 0; j < cols/2; ++j) { | ||
| 58 | printf(" -"); | ||
| 59 | } | ||
| 60 | printf("\n"); | ||
| 61 | |||
| 62 | // Image. | ||
| 63 | const R* value = nnMatrixRow(images, image_index); | ||
| 64 | for (int i = 0; i < rows; ++i) { | ||
| 65 | printf("|"); | ||
| 66 | for (int j = 0; j < cols; ++j) { | ||
| 67 | if (*value > 0.8) { | ||
| 68 | printf("#"); | ||
| 69 | } else if (*value > 0.5) { | ||
| 70 | printf("*"); | ||
| 71 | } | ||
| 72 | else if (*value > PIXEL_LOWER_BOUND) { | ||
| 73 | printf(":"); | ||
| 74 | } else if (*value == 0.0) { | ||
| 75 | // Values should not be exactly 0, otherwise they cancel out weights | ||
| 76 | // during training. | ||
| 77 | printf("X"); | ||
| 78 | } else { | ||
| 79 | printf(" "); | ||
| 80 | } | ||
| 81 | value++; | ||
| 82 | } | ||
| 83 | printf("|\n"); | ||
| 84 | } | ||
| 85 | |||
| 86 | // Bottom line. | ||
| 87 | for (int j = 0; j < cols/2; ++j) { | ||
| 88 | printf(" -"); | ||
| 89 | } | ||
| 90 | printf("\n"); | ||
| 91 | } | ||
| 92 | |||
| 93 | static void PrintLabel(const nnMatrix* labels, int label_index) { | ||
| 94 | assert(labels); | ||
| 95 | assert((0 <= label_index) && (label_index < labels->rows)); | ||
| 96 | |||
| 97 | // Compute the label from the one-hot encoding. | ||
| 98 | const R* value = nnMatrixRow(labels, label_index); | ||
| 99 | int label = -1; | ||
| 100 | for (int i = 0; i < 10; ++i) { | ||
| 101 | if (R_eq(*value++, LABEL_UPPER_BOUND)) { | ||
| 102 | label = i; | ||
| 103 | break; | ||
| 104 | } | ||
| 105 | } | ||
| 106 | assert((0 <= label) && (label <= 9)); | ||
| 107 | |||
| 108 | printf("Label: %d ( ", label); | ||
| 109 | value = nnMatrixRow(labels, label_index); | ||
| 110 | for (int i = 0; i < 10; ++i) { | ||
| 111 | printf("%.3f ", *value++); | ||
| 112 | } | ||
| 113 | printf(")\n"); | ||
| 114 | } | ||
| 115 | |||
| 116 | static R lerp(R a, R b, R t) { | ||
| 117 | return a + t*(b-a); | ||
| 118 | } | ||
| 119 | |||
| 120 | /// Rescales a pixel from [0,255] to [PIXEL_LOWER_BOUND, 1.0]. | ||
| 121 | static R FormatPixel(uint8_t pixel) { | ||
| 122 | const R value = (R)(pixel) / 255.0 * (1.0 - PIXEL_LOWER_BOUND) + PIXEL_LOWER_BOUND; | ||
| 123 | assert(value >= PIXEL_LOWER_BOUND); | ||
| 124 | assert(value <= 1.0); | ||
| 125 | return value; | ||
| 126 | } | ||
| 127 | |||
| 128 | /// Rescales a one-hot-encoded label value to (0,1). | ||
| 129 | static R FormatLabel(R label) { | ||
| 130 | const R value = lerp(LABEL_LOWER_BOUND, LABEL_UPPER_BOUND, label); | ||
| 131 | assert(value > 0.0); | ||
| 132 | assert(value < 1.0); | ||
| 133 | return value; | ||
| 134 | } | ||
| 135 | |||
| 136 | static int32_t ReverseEndian32(int32_t x) { | ||
| 137 | const int32_t x0 = x & 0xff; | ||
| 138 | const int32_t x1 = (x >> 8) & 0xff; | ||
| 139 | const int32_t x2 = (x >> 16) & 0xff; | ||
| 140 | const int32_t x3 = (x >> 24) & 0xff; | ||
| 141 | return (x0 << 24) | (x1 << 16) | (x2 << 8) | x3; | ||
| 142 | } | ||
| 143 | |||
| 144 | static void ImageToMatrix( | ||
| 145 | const uint8_t* pixels, int num_pixels, int row, nnMatrix* images) { | ||
| 146 | assert(pixels); | ||
| 147 | assert(images); | ||
| 148 | |||
| 149 | for (int i = 0; i < num_pixels; ++i) { | ||
| 150 | const R pixel = FormatPixel(pixels[i]); | ||
| 151 | nnMatrixSet(images, row, i, pixel); | ||
| 152 | } | ||
| 153 | } | ||
| 154 | |||
| 155 | static bool ReadImages(gzFile images_file, int max_num_images, ImageSet* image_set) { | ||
| 156 | assert(images_file != Z_NULL); | ||
| 157 | assert(image_set); | ||
| 158 | |||
| 159 | bool success = false; | ||
| 160 | |||
| 161 | uint8_t* pixels = 0; | ||
| 162 | |||
| 163 | int32_t magic, total_images, rows, cols; | ||
| 164 | if ( (gzread(images_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || | ||
| 165 | (gzread(images_file, (char*)&total_images, sizeof(int32_t)) != sizeof(int32_t)) || | ||
| 166 | (gzread(images_file, (char*)&rows, sizeof(int32_t)) != sizeof(int32_t)) || | ||
| 167 | (gzread(images_file, (char*)&cols, sizeof(int32_t)) != sizeof(int32_t)) ) { | ||
| 168 | fprintf(stderr, "Failed to read header\n"); | ||
| 169 | goto cleanup; | ||
| 170 | } | ||
| 171 | |||
| 172 | magic = ReverseEndian32(magic); | ||
| 173 | total_images = ReverseEndian32(total_images); | ||
| 174 | rows = ReverseEndian32(rows); | ||
| 175 | cols = ReverseEndian32(cols); | ||
| 176 | |||
| 177 | if (magic != IMAGE_FILE_MAGIC) { | ||
| 178 | fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", | ||
| 179 | magic, IMAGE_FILE_MAGIC); | ||
| 180 | goto cleanup; | ||
| 181 | } | ||
| 182 | |||
| 183 | printf("Magic: %.8x\nTotal images: %d\nRows: %d\nCols: %d\n", | ||
| 184 | magic, total_images, rows, cols); | ||
| 185 | |||
| 186 | total_images = max_num_images >= 0 ? min(total_images, max_num_images) : total_images; | ||
| 187 | |||
| 188 | // Images are flattened into single row vectors. | ||
| 189 | const int num_pixels = rows * cols; | ||
| 190 | image_set->images = nnMatrixMake(total_images, num_pixels); | ||
| 191 | image_set->count = total_images; | ||
| 192 | image_set->rows = rows; | ||
| 193 | image_set->cols = cols; | ||
| 194 | |||
| 195 | pixels = calloc(1, num_pixels); | ||
| 196 | if (!pixels) { | ||
| 197 | fprintf(stderr, "Failed to allocate image buffer\n"); | ||
| 198 | goto cleanup; | ||
| 199 | } | ||
| 200 | |||
| 201 | for (int i = 0; i < total_images; ++i) { | ||
| 202 | const int bytes_read = gzread(images_file, pixels, num_pixels); | ||
| 203 | if (bytes_read < num_pixels) { | ||
| 204 | fprintf(stderr, "Failed to read image %d\n", i); | ||
| 205 | goto cleanup; | ||
| 206 | } | ||
| 207 | ImageToMatrix(pixels, num_pixels, i, &image_set->images); | ||
| 208 | } | ||
| 209 | |||
| 210 | success = true; | ||
| 211 | |||
| 212 | cleanup: | ||
| 213 | if (pixels) { | ||
| 214 | free(pixels); | ||
| 215 | } | ||
| 216 | if (!success) { | ||
| 217 | nnMatrixDel(&image_set->images); | ||
| 218 | } | ||
| 219 | return success; | ||
| 220 | } | ||
| 221 | |||
| 222 | static void OneHotEncode(const uint8_t* labels_bytes, int num_labels, nnMatrix* labels) { | ||
| 223 | assert(labels_bytes); | ||
| 224 | assert(labels); | ||
| 225 | assert(labels->rows == num_labels); | ||
| 226 | assert(labels->cols == 10); | ||
| 227 | |||
| 228 | static const R one_hot[10][10] = { | ||
| 229 | { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | ||
| 230 | { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, | ||
| 231 | { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, | ||
| 232 | { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, | ||
| 233 | { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, | ||
| 234 | { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, | ||
| 235 | { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, | ||
| 236 | { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, | ||
| 237 | { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, | ||
| 238 | { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, | ||
| 239 | }; | ||
| 240 | |||
| 241 | R* value = labels->values; | ||
| 242 | |||
| 243 | for (int i = 0; i < num_labels; ++i) { | ||
| 244 | const uint8_t label = labels_bytes[i]; | ||
| 245 | const R* one_hot_value = one_hot[label]; | ||
| 246 | |||
| 247 | for (int j = 0; j < 10; ++j) { | ||
| 248 | *value++ = FormatLabel(*one_hot_value++); | ||
| 249 | } | ||
| 250 | } | ||
| 251 | } | ||
| 252 | |||
| 253 | static int OneHotDecode(const nnMatrix* label_matrix) { | ||
| 254 | assert(label_matrix); | ||
| 255 | assert(label_matrix->cols == 1); | ||
| 256 | assert(label_matrix->rows == 10); | ||
| 257 | |||
| 258 | R max_value = 0; | ||
| 259 | int pos_max = 0; | ||
| 260 | for (int i = 0; i < 10; ++i) { | ||
| 261 | const R value = nnMatrixAt(label_matrix, 0, i); | ||
| 262 | if (value > max_value) { | ||
| 263 | max_value = value; | ||
| 264 | pos_max = i; | ||
| 265 | } | ||
| 266 | } | ||
| 267 | assert(pos_max >= 0); | ||
| 268 | assert(pos_max <= 10); | ||
| 269 | return pos_max; | ||
| 270 | } | ||
| 271 | |||
| 272 | static bool ReadLabels(gzFile labels_file, int max_num_labels, ImageSet* image_set) { | ||
| 273 | assert(labels_file != Z_NULL); | ||
| 274 | assert(image_set != 0); | ||
| 275 | |||
| 276 | bool success = false; | ||
| 277 | |||
| 278 | uint8_t* labels = 0; | ||
| 279 | |||
| 280 | int32_t magic, total_labels; | ||
| 281 | if ( (gzread(labels_file, (char*)&magic, sizeof(int32_t)) != sizeof(int32_t)) || | ||
| 282 | (gzread(labels_file, (char*)&total_labels, sizeof(int32_t)) != sizeof(int32_t)) ) { | ||
| 283 | fprintf(stderr, "Failed to read header\n"); | ||
| 284 | goto cleanup; | ||
| 285 | } | ||
| 286 | |||
| 287 | magic = ReverseEndian32(magic); | ||
| 288 | total_labels = ReverseEndian32(total_labels); | ||
| 289 | |||
| 290 | if (magic != LABEL_FILE_MAGIC) { | ||
| 291 | fprintf(stderr, "Magic number mismatch. Got %x, expected: %x\n", | ||
| 292 | magic, LABEL_FILE_MAGIC); | ||
| 293 | goto cleanup; | ||
| 294 | } | ||
| 295 | |||
| 296 | printf("Magic: %.8x\nTotal labels: %d\n", magic, total_labels); | ||
| 297 | |||
| 298 | total_labels = max_num_labels >= 0 ? min(total_labels, max_num_labels) : total_labels; | ||
| 299 | |||
| 300 | assert(image_set->count == total_labels); | ||
| 301 | |||
| 302 | // One-hot encoding of labels, 10 values (digits) per label. | ||
| 303 | image_set->labels = nnMatrixMake(total_labels, 10); | ||
| 304 | |||
| 305 | labels = calloc(total_labels, sizeof(uint8_t)); | ||
| 306 | if (!labels) { | ||
| 307 | fprintf(stderr, "Failed to allocate labels buffer\n"); | ||
| 308 | goto cleanup; | ||
| 309 | } | ||
| 310 | |||
| 311 | if (gzread(labels_file, labels, total_labels * sizeof(uint8_t)) != total_labels) { | ||
| 312 | fprintf(stderr, "Failed to read labels\n"); | ||
| 313 | goto cleanup; | ||
| 314 | } | ||
| 315 | |||
| 316 | OneHotEncode(labels, total_labels, &image_set->labels); | ||
| 317 | |||
| 318 | success = true; | ||
| 319 | |||
| 320 | cleanup: | ||
| 321 | if (labels) { | ||
| 322 | free(labels); | ||
| 323 | } | ||
| 324 | if (!success) { | ||
| 325 | nnMatrixDel(&image_set->labels); | ||
| 326 | } | ||
| 327 | return success; | ||
| 328 | } | ||
| 329 | |||
| 330 | int main(int argc, const char** argv) { | ||
| 331 | if (argc < 2) { | ||
| 332 | usage(argv[0]); | ||
| 333 | return 1; | ||
| 334 | } | ||
| 335 | |||
| 336 | bool success = false; | ||
| 337 | |||
| 338 | gzFile train_images_file = Z_NULL; | ||
| 339 | gzFile train_labels_file = Z_NULL; | ||
| 340 | gzFile test_images_file = Z_NULL; | ||
| 341 | gzFile test_labels_file = Z_NULL; | ||
| 342 | ImageSet train_set = { 0 }; | ||
| 343 | ImageSet test_set = { 0 }; | ||
| 344 | nnNeuralNetwork* net = 0; | ||
| 345 | nnQueryObject* query = 0; | ||
| 346 | |||
| 347 | const char* mnist_files_dir = argv[1]; | ||
| 348 | const int max_num_images = argc > 2 ? atoi(argv[2]) : -1; | ||
| 349 | |||
| 350 | char train_labels_path[PATH_MAX]; | ||
| 351 | char train_images_path[PATH_MAX]; | ||
| 352 | char test_labels_path[PATH_MAX]; | ||
| 353 | char test_images_path[PATH_MAX]; | ||
| 354 | strlcpy(train_labels_path, mnist_files_dir, PATH_MAX); | ||
| 355 | strlcpy(train_images_path, mnist_files_dir, PATH_MAX); | ||
| 356 | strlcpy(test_labels_path, mnist_files_dir, PATH_MAX); | ||
| 357 | strlcpy(test_images_path, mnist_files_dir, PATH_MAX); | ||
| 358 | strlcat(train_labels_path, "/train-labels-idx1-ubyte.gz", PATH_MAX); | ||
| 359 | strlcat(train_images_path, "/train-images-idx3-ubyte.gz", PATH_MAX); | ||
| 360 | strlcat(test_labels_path, "/t10k-labels-idx1-ubyte.gz", PATH_MAX); | ||
| 361 | strlcat(test_images_path, "/t10k-images-idx3-ubyte.gz", PATH_MAX); | ||
| 362 | |||
| 363 | train_images_file = gzopen(train_images_path, "r"); | ||
| 364 | if (train_images_file == Z_NULL) { | ||
| 365 | fprintf(stderr, "Failed to open file: %s\n", train_images_path); | ||
| 366 | goto cleanup; | ||
| 367 | } | ||
| 368 | |||
| 369 | train_labels_file = gzopen(train_labels_path, "r"); | ||
| 370 | if (train_labels_file == Z_NULL) { | ||
| 371 | fprintf(stderr, "Failed to open file: %s\n", train_labels_path); | ||
| 372 | goto cleanup; | ||
| 373 | } | ||
| 374 | |||
| 375 | test_images_file = gzopen(test_images_path, "r"); | ||
| 376 | if (test_images_file == Z_NULL) { | ||
| 377 | fprintf(stderr, "Failed to open file: %s\n", test_images_path); | ||
| 378 | goto cleanup; | ||
| 379 | } | ||
| 380 | |||
| 381 | test_labels_file = gzopen(test_labels_path, "r"); | ||
| 382 | if (test_labels_file == Z_NULL) { | ||
| 383 | fprintf(stderr, "Failed to open file: %s\n", test_labels_path); | ||
| 384 | goto cleanup; | ||
| 385 | } | ||
| 386 | |||
| 387 | if (!ReadImages(train_images_file, max_num_images, &train_set)) { | ||
| 388 | goto cleanup; | ||
| 389 | } | ||
| 390 | if (!ReadLabels(train_labels_file, max_num_images, &train_set)) { | ||
| 391 | goto cleanup; | ||
| 392 | } | ||
| 393 | |||
| 394 | if (!ReadImages(test_images_file, max_num_images, &test_set)) { | ||
| 395 | goto cleanup; | ||
| 396 | } | ||
| 397 | if (!ReadLabels(test_labels_file, max_num_images, &test_set)) { | ||
| 398 | goto cleanup; | ||
| 399 | } | ||
| 400 | |||
| 401 | printf("\nTraining image/label pair examples:\n"); | ||
| 402 | for (int i = 0; i < min(3, train_set.images.rows); ++i) { | ||
| 403 | PrintImage(&train_set.images, train_set.rows, train_set.cols, i); | ||
| 404 | PrintLabel(&train_set.labels, i); | ||
| 405 | printf("\n"); | ||
| 406 | } | ||
| 407 | |||
| 408 | // Network definition. | ||
| 409 | const int image_size_pixels = train_set.rows * train_set.cols; | ||
| 410 | const int num_layers = 2; | ||
| 411 | const int layer_sizes[3] = { image_size_pixels, 100, 10 }; | ||
| 412 | const nnActivation layer_activations[2] = { nnSigmoid, nnSigmoid }; | ||
| 413 | if (!(net = nnMakeNet(num_layers, layer_sizes, layer_activations))) { | ||
| 414 | fprintf(stderr, "Failed to create neural network\n"); | ||
| 415 | goto cleanup; | ||
| 416 | } | ||
| 417 | |||
| 418 | // Train. | ||
| 419 | printf("Training with up to %d images from the data set\n\n", max_num_images); | ||
| 420 | const nnTrainingParams training_params = { | ||
| 421 | .learning_rate = 0.1, | ||
| 422 | .max_iterations = TRAIN_ITERATIONS, | ||
| 423 | .seed = 0, | ||
| 424 | .weight_init = nnWeightInitNormal, | ||
| 425 | .debug = true, | ||
| 426 | }; | ||
| 427 | nnTrain(net, &train_set.images, &train_set.labels, &training_params); | ||
| 428 | |||
| 429 | // Test. | ||
| 430 | int hits = 0; | ||
| 431 | query = nnMakeQueryObject(net, /*num_inputs=*/1); | ||
| 432 | for (int i = 0; i < test_set.count; ++i) { | ||
| 433 | const nnMatrix test_image = nnMatrixBorrowRows(&test_set.images, i, 1); | ||
| 434 | const nnMatrix test_label = nnMatrixBorrowRows(&test_set.labels, i, 1); | ||
| 435 | |||
| 436 | nnQuery(net, query, &test_image); | ||
| 437 | |||
| 438 | const int test_label_expected = OneHotDecode(&test_label); | ||
| 439 | const int test_label_actual = OneHotDecode(nnNetOutputs(query)); | ||
| 440 | |||
| 441 | if (test_label_actual == test_label_expected) { | ||
| 442 | ++hits; | ||
| 443 | } | ||
| 444 | } | ||
| 445 | const R hit_ratio = (R)hits / (R)test_set.count; | ||
| 446 | printf("Test images: %d\n", test_set.count); | ||
| 447 | printf("Hits: %d/%d (%.3f%%)\n", hits, test_set.count, hit_ratio*100); | ||
| 448 | |||
| 449 | success = true; | ||
| 450 | |||
| 451 | cleanup: | ||
| 452 | if (query) { | ||
| 453 | nnDeleteQueryObject(&query); | ||
| 454 | } | ||
| 455 | if (net) { | ||
| 456 | nnDeleteNet(&net); | ||
| 457 | } | ||
| 458 | nnMatrixDel(&train_set.images); | ||
| 459 | nnMatrixDel(&test_set.images); | ||
| 460 | if (train_images_file != Z_NULL) { | ||
| 461 | gzclose(train_images_file); | ||
| 462 | } | ||
| 463 | if (train_labels_file != Z_NULL) { | ||
| 464 | gzclose(train_labels_file); | ||
| 465 | } | ||
| 466 | if (test_images_file != Z_NULL) { | ||
| 467 | gzclose(test_images_file); | ||
| 468 | } | ||
| 469 | if (test_labels_file != Z_NULL) { | ||
| 470 | gzclose(test_labels_file); | ||
| 471 | } | ||
| 472 | return success ? 0 : 1; | ||
| 473 | } | ||
