From 411f66a2540fa17c736116d865e0ceb0cfe5623b Mon Sep 17 00:00:00 2001
From: jeanne <jeanne@localhost.localdomain>
Date: Wed, 11 May 2022 09:54:38 -0700
Subject: Initial commit.

---
 src/lib/test/matrix_test.c                         | 350 +++++++++++++++++++++
 src/lib/test/neuralnet_test.c                      |  92 ++++++
 src/lib/test/test.h                                | 185 +++++++++++
 src/lib/test/test_main.c                           |   3 +
 src/lib/test/test_util.h                           |  22 ++
 .../test/train_linear_perceptron_non_origin_test.c |  67 ++++
 src/lib/test/train_linear_perceptron_test.c        |  62 ++++
 src/lib/test/train_sigmoid_test.c                  |  66 ++++
 src/lib/test/train_xor_test.c                      |  66 ++++
 9 files changed, 913 insertions(+)
 create mode 100644 src/lib/test/matrix_test.c
 create mode 100644 src/lib/test/neuralnet_test.c
 create mode 100644 src/lib/test/test.h
 create mode 100644 src/lib/test/test_main.c
 create mode 100644 src/lib/test/test_util.h
 create mode 100644 src/lib/test/train_linear_perceptron_non_origin_test.c
 create mode 100644 src/lib/test/train_linear_perceptron_test.c
 create mode 100644 src/lib/test/train_sigmoid_test.c
 create mode 100644 src/lib/test/train_xor_test.c

(limited to 'src/lib/test')

diff --git a/src/lib/test/matrix_test.c b/src/lib/test/matrix_test.c
new file mode 100644
index 0000000..8191c97
--- /dev/null
+++ b/src/lib/test/matrix_test.c
@@ -0,0 +1,350 @@
+#include <neuralnet/matrix.h>
+
+#include "test.h"
+#include "test_util.h"
+
+#include <assert.h>
+#include <stdlib.h>
+
+// static void PrintMatrix(const nnMatrix* matrix) {
+//   assert(matrix);
+
+//   for (int i = 0; i < matrix->rows; ++i) {
+//     for (int j = 0; j < matrix->cols; ++j) {
+//       printf("%f ", nnMatrixAt(matrix, i, j));
+//     }
+//     printf("\n");
+//   }
+// }
+
+TEST_CASE(nnMatrixMake_1x1) {
+  nnMatrix A = nnMatrixMake(1, 1);
+  TEST_EQUAL(A.rows, 1);
+  TEST_EQUAL(A.cols, 1);
+}
+
+TEST_CASE(nnMatrixMake_3x1) {
+  nnMatrix A = nnMatrixMake(3, 1);
+  TEST_EQUAL(A.rows, 3);
+  TEST_EQUAL(A.cols, 1);
+}
+
+TEST_CASE(nnMatrixInit_3x1) {
+  nnMatrix A = nnMatrixMake(3, 1);
+  nnMatrixInit(&A, (R[]) { 1, 2, 3 });
+  TEST_EQUAL(A.values[0], 1);
+  TEST_EQUAL(A.values[1], 2);
+  TEST_EQUAL(A.values[2], 3);
+}
+
+TEST_CASE(nnMatrixCopyCol_test) {
+  nnMatrix A = nnMatrixMake(3, 2);
+  nnMatrix B = nnMatrixMake(3, 1);
+
+  nnMatrixInit(&A, (R[]) {
+    1, 2,
+    3, 4,
+    5, 6,
+  });
+
+  nnMatrixCopyCol(&A, &B, 1, 0);
+
+  TEST_EQUAL(nnMatrixAt(&B, 0, 0), 2);
+  TEST_EQUAL(nnMatrixAt(&B, 1, 0), 4);
+  TEST_EQUAL(nnMatrixAt(&B, 2, 0), 6);
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+}
+
+TEST_CASE(nnMatrixMul_square_3x3) {
+  nnMatrix A = nnMatrixMake(3, 3);
+  nnMatrix B = nnMatrixMake(3, 3);
+  nnMatrix O = nnMatrixMake(3, 3);
+
+  nnMatrixInit(&A, (const R[]){
+    1, 2, 3,
+    4, 5, 6,
+    7, 8, 9,
+  });
+  nnMatrixInit(&B, (const R[]){
+    2, 4, 3,
+    6, 8, 5,
+    1, 7, 9,
+  });
+  nnMatrixMul(&A, &B, &O);
+
+  const R expected[3][3] = {
+    { 17, 41, 40 },
+    { 44, 98, 91 },
+    { 71, 155, 142 },
+  };
+  for (int i = 0; i < O.rows; ++i) {
+    for (int j = 0; j < O.cols; ++j) {
+      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
+    }
+  }
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&O);
+}
+
+TEST_CASE(nnMatrixMul_non_square_2x3_3x1) {
+  nnMatrix A = nnMatrixMake(2, 3);
+  nnMatrix B = nnMatrixMake(3, 1);
+  nnMatrix O = nnMatrixMake(2, 1);
+
+  nnMatrixInit(&A, (const R[]){
+    1, 2, 3,
+    4, 5, 6,
+  });
+  nnMatrixInit(&B, (const R[]){
+    2,
+    6,
+    1,
+  });
+  nnMatrixMul(&A, &B, &O);
+
+  const R expected[2][1] = {
+    { 17 },
+    { 44 },
+  };
+  for (int i = 0; i < O.rows; ++i) {
+    for (int j = 0; j < O.cols; ++j) {
+      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
+    }
+  }
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&O);
+}
+
+TEST_CASE(nnMatrixMulAdd_test) {
+  nnMatrix A = nnMatrixMake(2, 3);
+  nnMatrix B = nnMatrixMake(2, 3);
+  nnMatrix O = nnMatrixMake(2, 3);
+  const R scale = 2;
+
+  nnMatrixInit(&A, (const R[]){
+    1, 2, 3,
+    4, 5, 6,
+  });
+  nnMatrixInit(&B, (const R[]){
+    2, 3, 1,
+    7, 4, 3
+  });
+  nnMatrixMulAdd(&A, &B, scale, &O);  // O = A + B * scale
+
+  const R expected[2][3] = {
+    { 5, 8, 5 },
+    { 18, 13, 12 },
+  };
+  for (int i = 0; i < O.rows; ++i) {
+    for (int j = 0; j < O.cols; ++j) {
+      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
+    }
+  }
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&O);
+}
+
+TEST_CASE(nnMatrixMulSub_test) {
+  nnMatrix A = nnMatrixMake(2, 3);
+  nnMatrix B = nnMatrixMake(2, 3);
+  nnMatrix O = nnMatrixMake(2, 3);
+  const R scale = 2;
+
+  nnMatrixInit(&A, (const R[]){
+    1, 2, 3,
+    4, 5, 6,
+  });
+  nnMatrixInit(&B, (const R[]){
+    2, 3, 1,
+    7, 4, 3
+  });
+  nnMatrixMulSub(&A, &B, scale, &O);  // O = A - B * scale
+
+  const R expected[2][3] = {
+    { -3, -4, 1 },
+    { -10, -3, 0 },
+  };
+  for (int i = 0; i < O.rows; ++i) {
+    for (int j = 0; j < O.cols; ++j) {
+      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
+    }
+  }
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&O);
+}
+
+TEST_CASE(nnMatrixMulPairs_2x3) {
+  nnMatrix A = nnMatrixMake(2, 3);
+  nnMatrix B = nnMatrixMake(2, 3);
+  nnMatrix O = nnMatrixMake(2, 3);
+
+  nnMatrixInit(&A, (const R[]){
+    1, 2, 3,
+    4, 5, 6,
+  });
+  nnMatrixInit(&B, (const R[]){
+    2, 3, 1,
+    7, 4, 3
+  });
+  nnMatrixMulPairs(&A, &B, &O);
+
+  const R expected[2][3] = {
+    { 2, 6, 3 },
+    { 28, 20, 18 },
+  };
+  for (int i = 0; i < O.rows; ++i) {
+    for (int j = 0; j < O.cols; ++j) {
+      TEST_TRUE(double_eq(nnMatrixAt(&O, i, j), expected[i][j], EPS));
+    }
+  }
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&O);
+}
+
+TEST_CASE(nnMatrixAdd_square_2x2) {
+  nnMatrix A = nnMatrixMake(2, 2);
+  nnMatrix B = nnMatrixMake(2, 2);
+  nnMatrix C = nnMatrixMake(2, 2);
+
+  nnMatrixInit(&A, (R[]) {
+    1, 2,
+    3, 4,
+  });
+  nnMatrixInit(&B, (R[]) {
+    2, 1,
+    5, 3,
+  });
+
+  nnMatrixAdd(&A, &B, &C);
+
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), 3, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), 3, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), 8, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), 7, EPS));
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&C);
+}
+
+TEST_CASE(nnMatrixSub_square_2x2) {
+  nnMatrix A = nnMatrixMake(2, 2);
+  nnMatrix B = nnMatrixMake(2, 2);
+  nnMatrix C = nnMatrixMake(2, 2);
+
+  nnMatrixInit(&A, (R[]) {
+    1, 2,
+    3, 4,
+  });
+  nnMatrixInit(&B, (R[]) {
+    2, 1,
+    5, 3,
+  });
+
+  nnMatrixSub(&A, &B, &C);
+
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), -1, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), +1, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), -2, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), +1, EPS));
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&C);
+}
+
+TEST_CASE(nnMatrixAddRow_test) {
+  nnMatrix A = nnMatrixMake(2, 3);
+  nnMatrix B = nnMatrixMake(1, 3);
+  nnMatrix C = nnMatrixMake(2, 3);
+
+  nnMatrixInit(&A, (R[]) {
+    1, 2, 3,
+    4, 5, 6,
+  });
+  nnMatrixInit(&B, (R[]) {
+    2, 1, 3,
+  });
+
+  nnMatrixAddRow(&A, &B, &C);
+
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 0), 3, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 1), 3, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 0, 2), 6, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 0), 6, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 1), 6, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&C, 1, 2), 9, EPS));
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+  nnMatrixDel(&C);
+}
+
+TEST_CASE(nnMatrixTranspose_square_2x2) {
+  nnMatrix A = nnMatrixMake(2, 2);
+  nnMatrix B = nnMatrixMake(2, 2);
+
+  nnMatrixInit(&A, (R[]) {
+    1, 2,
+    3, 4
+  });
+
+  nnMatrixTranspose(&A, &B);
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 1, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 3, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 0), 2, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 1), 4, EPS));
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+}
+
+TEST_CASE(nnMatrixTranspose_non_square_2x1) {
+  nnMatrix A = nnMatrixMake(2, 1);
+  nnMatrix B = nnMatrixMake(1, 2);
+
+  nnMatrixInit(&A, (R[]) {
+    1,
+    3,
+  });
+
+  nnMatrixTranspose(&A, &B);
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 1, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 3, EPS));
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+}
+
+TEST_CASE(nnMatrixGt_test) {
+  nnMatrix A = nnMatrixMake(2, 3);
+  nnMatrix B = nnMatrixMake(2, 3);
+
+  nnMatrixInit(&A, (R[]) {
+    -3, 2, 0,
+    4, -1, 5
+  });
+
+  nnMatrixGt(&A, 0, &B);
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 0), 0, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 1), 1, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 0, 2), 0, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 0), 1, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 1), 0, EPS));
+  TEST_TRUE(double_eq(nnMatrixAt(&B, 1, 2), 1, EPS));
+
+  nnMatrixDel(&A);
+  nnMatrixDel(&B);
+}
diff --git a/src/lib/test/neuralnet_test.c b/src/lib/test/neuralnet_test.c
new file mode 100644
index 0000000..14d9438
--- /dev/null
+++ b/src/lib/test/neuralnet_test.c
@@ -0,0 +1,92 @@
+#include <neuralnet/neuralnet.h>
+
+#include <neuralnet/matrix.h>
+#include "activation.h"
+#include "neuralnet_impl.h"
+
+#include "test.h"
+#include "test_util.h"
+
+#include <assert.h>
+
+TEST_CASE(neuralnet_perceptron_test) {
+  const int num_layers = 1;
+  const int layer_sizes[] = { 1, 1 };
+  const nnActivation layer_activations[] = { nnSigmoid };
+  const R weights[] = { 0.3 };
+
+  nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
+  assert(net);
+  nnSetWeights(net, weights);
+
+  nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1);
+
+  const R input[] = { 0.9 };
+  R output[1];
+  nnQueryArray(net, query, input, output);
+
+  const R expected_output = sigmoid(input[0] * weights[0]);
+  printf("\nOutput: %f, Expected: %f\n", output[0], expected_output);
+  TEST_TRUE(double_eq(output[0], expected_output, EPS));
+
+  nnDeleteQueryObject(&query);
+  nnDeleteNet(&net);
+}
+
+TEST_CASE(neuralnet_xor_test) {
+  const int num_layers = 2;
+  const int layer_sizes[] = { 2, 2, 1 };
+  const nnActivation layer_activations[] = { nnRelu, nnIdentity };
+  const R weights[] = {
+    1, 1, 1, 1,  // First (hidden) layer.
+    1, -2        // Second (output) layer.
+  };
+  const R biases[] = {
+    0, -1,  // First (hidden) layer.
+    0       // Second (output) layer.
+  };
+
+  nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
+  assert(net);
+  nnSetWeights(net, weights);
+  nnSetBiases(net, biases);
+
+  // First layer weights.
+  TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 0), 1);
+  TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 1), 1);
+  TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 2), 1);
+  TEST_EQUAL(nnMatrixAt(&net->weights[0], 0, 3), 1);
+  // Second layer weights.
+  TEST_EQUAL(nnMatrixAt(&net->weights[1], 0, 0), 1);
+  TEST_EQUAL(nnMatrixAt(&net->weights[1], 0, 1), -2);
+  // First layer biases.
+  TEST_EQUAL(nnMatrixAt(&net->biases[0], 0, 0), 0);
+  TEST_EQUAL(nnMatrixAt(&net->biases[0], 0, 1), -1);
+  // Second layer biases.
+  TEST_EQUAL(nnMatrixAt(&net->biases[1], 0, 0), 0);
+
+  // Test.
+
+  #define M 4
+
+  nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M);
+
+  const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } };
+  nnMatrix test_inputs_matrix = nnMatrixMake(M, 2);
+  nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs);
+  nnQuery(net, query, &test_inputs_matrix);
+
+  const R expected_outputs[M] = { 0., 1., 1., 0. };
+  for (int i = 0; i < M; ++i) {
+    const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
+    printf("\nInput: (%f, %f), Output: %f, Expected: %f\n",
+      test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]);
+  }
+  for (int i = 0; i < M; ++i) {
+    const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
+    TEST_TRUE(double_eq(test_output, expected_outputs[i], OUTPUT_EPS));
+  }
+
+  nnDeleteQueryObject(&query);
+  nnDeleteNet(&net);
+}
diff --git a/src/lib/test/test.h b/src/lib/test/test.h
new file mode 100644
index 0000000..fd8dc22
--- /dev/null
+++ b/src/lib/test/test.h
@@ -0,0 +1,185 @@
+// SPDX-License-Identifier: MIT
+#pragma once
+
+#ifdef UNIT_TEST
+
+#include <stdbool.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#if defined(__DragonFly__) || defined(__FreeBSD__) || defined(__FreeBSD_kernel__) ||     \
+    defined(__NetBSD__) || defined(__OpenBSD__)
+#define USE_SYSCTL_FOR_ARGS 1
+// clang-format off
+#include <sys/types.h>
+#include <sys/sysctl.h>
+// clang-format on
+#include <unistd.h>        // getpid
+#endif
+
+struct test_file_metadata;
+
+struct test_failure {
+	bool present;
+	const char *message;
+	const char *file;
+	int line;
+};
+
+struct test_case_metadata {
+	void (*fn)(struct test_case_metadata *, struct test_file_metadata *);
+	struct test_failure failure;
+	const char *name;
+	struct test_case_metadata *next;
+};
+
+struct test_file_metadata {
+	bool registered;
+	const char *name;
+	struct test_file_metadata *next;
+	struct test_case_metadata *tests;
+};
+
+struct test_file_metadata __attribute__((weak)) * test_file_head;
+
+#define SET_FAILURE(_message)                                                             \
+	metadata->failure = (struct test_failure) {                                       \
+		.message = _message, .file = __FILE__, .line = __LINE__, .present = true, \
+	}
+
+#define TEST_EQUAL(a, b)                                                                 \
+	do {                                                                             \
+		if ((a) != (b)) {                                                        \
+			SET_FAILURE(#a " != " #b);                                       \
+			return;                                                          \
+		}                                                                        \
+	} while (0)
+
+#define TEST_TRUE(a)                                                                     \
+	do {                                                                             \
+		if (!(a)) {                                                              \
+			SET_FAILURE(#a " is not true");                                  \
+			return;                                                          \
+		}                                                                        \
+	} while (0)
+
+#define TEST_STREQUAL(a, b)                                                              \
+	do {                                                                             \
+		if (strcmp(a, b) != 0) {                                                 \
+			SET_FAILURE(#a " != " #b);                                       \
+			return;                                                          \
+		}                                                                        \
+	} while (0)
+
+#define TEST_CASE(_name)                                                                  \
+	static void __test_h_##_name(struct test_case_metadata *,                         \
+	                             struct test_file_metadata *);                        \
+	static struct test_file_metadata __test_h_file;                                   \
+	static struct test_case_metadata __test_h_meta_##_name = {                        \
+	    .name = #_name,                                                               \
+	    .fn = __test_h_##_name,                                                       \
+	};                                                                                \
+	static void __attribute__((constructor(101))) __test_h_##_name##_register(void) { \
+		__test_h_meta_##_name.next = __test_h_file.tests;                         \
+		__test_h_file.tests = &__test_h_meta_##_name;                             \
+		if (!__test_h_file.registered) {                                          \
+			__test_h_file.name = __FILE__;                                    \
+			__test_h_file.next = test_file_head;                              \
+			test_file_head = &__test_h_file;                                  \
+			__test_h_file.registered = true;                                  \
+		}                                                                         \
+	}                                                                                 \
+	static void __test_h_##_name(                                                     \
+	    struct test_case_metadata *metadata __attribute__((unused)),                  \
+	    struct test_file_metadata *file_metadata __attribute__((unused)))
+
+extern void __attribute__((weak)) (*test_h_unittest_setup)(void);
+/// Run defined tests, return true if all tests succeeds
+/// @param[out] tests_run if not NULL, set to whether tests were run
+static inline void __attribute__((constructor(102))) run_tests(void) {
+	bool should_run = false;
+#ifdef USE_SYSCTL_FOR_ARGS
+	int mib[] = {
+		CTL_KERN,
+#if defined(__NetBSD__) || defined(__OpenBSD__)
+		KERN_PROC_ARGS,
+		getpid(),
+		KERN_PROC_ARGV,
+#else
+		KERN_PROC,
+		KERN_PROC_ARGS,
+		getpid(),
+#endif
+	};
+	char *arg = NULL;
+	size_t arglen;
+	sysctl(mib, sizeof(mib) / sizeof(mib[0]), NULL, &arglen, NULL, 0);
+	arg = malloc(arglen);
+	sysctl(mib, sizeof(mib) / sizeof(mib[0]), arg, &arglen, NULL, 0);
+#else
+	FILE *cmdlinef = fopen("/proc/self/cmdline", "r");
+	char *arg = NULL;
+	int arglen;
+	fscanf(cmdlinef, "%ms%n", &arg, &arglen);
+	fclose(cmdlinef);
+#endif
+	for (char *pos = arg; pos < arg + arglen; pos += strlen(pos) + 1) {
+		if (strcmp(pos, "--unittest") == 0) {
+			should_run = true;
+			break;
+		}
+	}
+	free(arg);
+
+	if (!should_run) {
+		return;
+	}
+
+	if (&test_h_unittest_setup) {
+		test_h_unittest_setup();
+	}
+
+	struct test_file_metadata *i = test_file_head;
+	int failed = 0, success = 0;
+	while (i) {
+		fprintf(stderr, "Running tests from %s:\n", i->name);
+		struct test_case_metadata *j = i->tests;
+		while (j) {
+			fprintf(stderr, "\t%s ... ", j->name);
+			j->failure.present = false;
+			j->fn(j, i);
+			if (j->failure.present) {
+				fprintf(stderr, "failed (%s at %s:%d)\n", j->failure.message,
+				        j->failure.file, j->failure.line);
+				failed++;
+			} else {
+				fprintf(stderr, "passed\n");
+				success++;
+			}
+			j = j->next;
+		}
+		fprintf(stderr, "\n");
+		i = i->next;
+	}
+	int total = failed + success;
+	fprintf(stderr, "Test results: passed %d/%d, failed %d/%d\n", success, total,
+	        failed, total);
+	exit(failed == 0 ? EXIT_SUCCESS : EXIT_FAILURE);
+}
+
+#else
+
+#include <stdbool.h>
+
+#define TEST_CASE(name) static void __attribute__((unused)) __test_h_##name(void)
+
+#define TEST_EQUAL(a, b)                                                                 \
+	(void)(a);                                                                       \
+	(void)(b)
+#define TEST_TRUE(a) (void)(a)
+#define TEST_STREQUAL(a, b)                                                              \
+	(void)(a);                                                                       \
+	(void)(b)
+
+#endif
diff --git a/src/lib/test/test_main.c b/src/lib/test/test_main.c
new file mode 100644
index 0000000..4cce7f6
--- /dev/null
+++ b/src/lib/test/test_main.c
@@ -0,0 +1,3 @@
+int main() {
+  return 0;
+}
diff --git a/src/lib/test/test_util.h b/src/lib/test/test_util.h
new file mode 100644
index 0000000..8abb99a
--- /dev/null
+++ b/src/lib/test/test_util.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <neuralnet/types.h>
+
+#include <math.h>
+
+// General epsilon for comparing values.
+static const R EPS = 1e-10;
+
+// Epsilon for comparing network weights after training.
+static const R WEIGHT_EPS = 0.01;
+
+// Epsilon for comparing network outputs after training.
+static const R OUTPUT_EPS = 0.01;
+
+static inline bool double_eq(double a, double b, double eps) {
+  return fabs(a - b) <= eps;
+}
+
+static inline R lerp(R a, R b, R t) {
+  return a + t*(b-a);
+}
diff --git a/src/lib/test/train_linear_perceptron_non_origin_test.c b/src/lib/test/train_linear_perceptron_non_origin_test.c
new file mode 100644
index 0000000..5a320ac
--- /dev/null
+++ b/src/lib/test/train_linear_perceptron_non_origin_test.c
@@ -0,0 +1,67 @@
+#include <neuralnet/train.h>
+
+#include <neuralnet/matrix.h>
+#include <neuralnet/neuralnet.h>
+#include "activation.h"
+#include "neuralnet_impl.h"
+
+#include "test.h"
+#include "test_util.h"
+
+#include <assert.h>
+
+TEST_CASE(neuralnet_train_linear_perceptron_non_origin_test) {
+  const int num_layers = 1;
+  const int layer_sizes[] = { 1, 1 };
+  const nnActivation layer_activations[] = { nnIdentity };
+
+  nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
+  assert(net);
+
+  // Train.
+
+  // Try to learn the Y = 2X + 1 line.
+  #define N 2
+  const R inputs[N]  = { 0., 1. };
+  const R targets[N] = { 1., 3. };
+
+  nnMatrix inputs_matrix  = nnMatrixMake(N, 1);
+  nnMatrix targets_matrix = nnMatrixMake(N, 1);
+  nnMatrixInit(&inputs_matrix, inputs);
+  nnMatrixInit(&targets_matrix, targets);
+
+  nnTrainingParams params = {
+    .learning_rate = 0.7,
+    .max_iterations = 20,
+    .seed = 0,
+    .weight_init = nnWeightInit01,
+    .debug = false,
+  };
+
+  nnTrain(net, &inputs_matrix, &targets_matrix, &params);
+
+  const R weight = nnMatrixAt(&net->weights[0], 0, 0);
+  const R expected_weight = 2.0;
+  printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
+  TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
+
+  const R bias = nnMatrixAt(&net->biases[0], 0, 0);
+  const R expected_bias = 1.0;
+  printf("Trained network bias: %f, Expected: %f\n", bias, expected_bias);
+  TEST_TRUE(double_eq(bias, expected_bias, WEIGHT_EPS));
+
+  // Test.
+
+  nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1);
+
+  const R test_input[] = { 2.3 };
+  R test_output[1];
+  nnQueryArray(net, query, test_input, test_output);
+
+  const R expected_output = test_input[0] * expected_weight + expected_bias;
+  printf("Output: %f, Expected: %f\n", test_output[0], expected_output);
+  TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS));
+
+  nnDeleteQueryObject(&query);
+  nnDeleteNet(&net);
+}
diff --git a/src/lib/test/train_linear_perceptron_test.c b/src/lib/test/train_linear_perceptron_test.c
new file mode 100644
index 0000000..2b1336d
--- /dev/null
+++ b/src/lib/test/train_linear_perceptron_test.c
@@ -0,0 +1,62 @@
+#include <neuralnet/train.h>
+
+#include <neuralnet/matrix.h>
+#include <neuralnet/neuralnet.h>
+#include "activation.h"
+#include "neuralnet_impl.h"
+
+#include "test.h"
+#include "test_util.h"
+
+#include <assert.h>
+
+TEST_CASE(neuralnet_train_linear_perceptron_test) {
+  const int num_layers = 1;
+  const int layer_sizes[] = { 1, 1 };
+  const nnActivation layer_activations[] = { nnIdentity };
+
+  nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
+  assert(net);
+
+  // Train.
+
+  // Try to learn the Y=X line.
+  #define N 2
+  const R inputs[N]  = { 0., 1. };
+  const R targets[N] = { 0., 1. };
+
+  nnMatrix inputs_matrix  = nnMatrixMake(N, 1);
+  nnMatrix targets_matrix = nnMatrixMake(N, 1);
+  nnMatrixInit(&inputs_matrix, inputs);
+  nnMatrixInit(&targets_matrix, targets);
+
+  nnTrainingParams params = {
+    .learning_rate = 0.7,
+    .max_iterations = 10,
+    .seed = 0,
+    .weight_init = nnWeightInit01,
+    .debug = false,
+  };
+
+  nnTrain(net, &inputs_matrix, &targets_matrix, &params);
+
+  const R weight = nnMatrixAt(&net->weights[0], 0, 0);
+  const R expected_weight = 1.0;
+  printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
+  TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
+
+  // Test.
+
+  nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1);
+
+  const R test_input[] = { 2.3 };
+  R test_output[1];
+  nnQueryArray(net, query, test_input, test_output);
+
+  const R expected_output = test_input[0];
+  printf("Output: %f, Expected: %f\n", test_output[0], expected_output);
+  TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS));
+
+  nnDeleteQueryObject(&query);
+  nnDeleteNet(&net);
+}
diff --git a/src/lib/test/train_sigmoid_test.c b/src/lib/test/train_sigmoid_test.c
new file mode 100644
index 0000000..588e7ca
--- /dev/null
+++ b/src/lib/test/train_sigmoid_test.c
@@ -0,0 +1,66 @@
+#include <neuralnet/train.h>
+
+#include <neuralnet/matrix.h>
+#include <neuralnet/neuralnet.h>
+#include "activation.h"
+#include "neuralnet_impl.h"
+
+#include "test.h"
+#include "test_util.h"
+
+#include <assert.h>
+
+TEST_CASE(neuralnet_train_sigmoid_test) {
+  const int num_layers = 1;
+  const int layer_sizes[] = { 1, 1 };
+  const nnActivation layer_activations[] = { nnSigmoid };
+
+  nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
+  assert(net);
+
+  // Train.
+
+  // Try to learn the sigmoid function.
+  #define N 3
+  R inputs[N];
+  R targets[N];
+  for (int i = 0; i < N; ++i) {
+    inputs[i] = lerp(-1, +1, (R)i / (R)(N-1));
+    targets[i] = sigmoid(inputs[i]);
+  }
+
+  nnMatrix inputs_matrix  = nnMatrixMake(N, 1);
+  nnMatrix targets_matrix = nnMatrixMake(N, 1);
+  nnMatrixInit(&inputs_matrix, inputs);
+  nnMatrixInit(&targets_matrix, targets);
+
+  nnTrainingParams params = {
+    .learning_rate = 0.9,
+    .max_iterations = 100,
+    .seed = 0,
+    .weight_init = nnWeightInit01,
+    .debug = false,
+  };
+
+  nnTrain(net, &inputs_matrix, &targets_matrix, &params);
+
+  const R weight = nnMatrixAt(&net->weights[0], 0, 0);
+  const R expected_weight = 1.0;
+  printf("\nTrained network weight: %f, Expected: %f\n", weight, expected_weight);
+  TEST_TRUE(double_eq(weight, expected_weight, WEIGHT_EPS));
+
+  // Test.
+
+  nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/1);
+
+  const R test_input[] = { 0.3 };
+  R test_output[1];
+  nnQueryArray(net, query, test_input, test_output);
+
+  const R expected_output = 0.574442516811659;  // sigmoid(0.3)
+  printf("Output: %f, Expected: %f\n", test_output[0], expected_output);
+  TEST_TRUE(double_eq(test_output[0], expected_output, OUTPUT_EPS));
+
+  nnDeleteQueryObject(&query);
+  nnDeleteNet(&net);
+}
diff --git a/src/lib/test/train_xor_test.c b/src/lib/test/train_xor_test.c
new file mode 100644
index 0000000..6ddc6e0
--- /dev/null
+++ b/src/lib/test/train_xor_test.c
@@ -0,0 +1,66 @@
+#include <neuralnet/train.h>
+
+#include <neuralnet/matrix.h>
+#include <neuralnet/neuralnet.h>
+#include "activation.h"
+#include "neuralnet_impl.h"
+
+#include "test.h"
+#include "test_util.h"
+
+#include <assert.h>
+
+TEST_CASE(neuralnet_train_xor_test) {
+  const int num_layers = 2;
+  const int layer_sizes[] = { 2, 2, 1 };
+  const nnActivation layer_activations[] = { nnRelu, nnIdentity };
+
+  nnNeuralNetwork* net = nnMakeNet(num_layers, layer_sizes, layer_activations);
+  assert(net);
+
+  // Train.
+
+  #define N 4
+  const R inputs[N][2]  = { { 0., 0. }, { 0., 1. }, { 1., 0. }, { 1., 1. } };
+  const R targets[N] = { 0., 1., 1., 0. };
+
+  nnMatrix inputs_matrix  = nnMatrixMake(N, 2);
+  nnMatrix targets_matrix = nnMatrixMake(N, 1);
+  nnMatrixInit(&inputs_matrix, (const R*)inputs);
+  nnMatrixInit(&targets_matrix, targets);
+
+  nnTrainingParams params = {
+    .learning_rate = 0.1,
+    .max_iterations = 500,
+    .seed = 0,
+    .weight_init = nnWeightInit01,
+    .debug = false,
+  };
+
+  nnTrain(net, &inputs_matrix, &targets_matrix, &params);
+
+  // Test.
+
+  #define M 4
+
+  nnQueryObject* query = nnMakeQueryObject(net, /*num_inputs=*/M);
+
+  const R test_inputs[M][2] = { { 0., 0. }, { 1., 0. }, { 0., 1. }, { 1., 1. } };
+  nnMatrix test_inputs_matrix = nnMatrixMake(M, 2);
+  nnMatrixInit(&test_inputs_matrix, (const R*)test_inputs);
+  nnQuery(net, query, &test_inputs_matrix);
+
+  const R expected_outputs[M] = { 0., 1., 1., 0. };
+  for (int i = 0; i < M; ++i) {
+    const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
+    printf("\nInput: (%f, %f), Output: %f, Expected: %f\n",
+      test_inputs[i][0], test_inputs[i][1], test_output, expected_outputs[i]);
+  }
+  for (int i = 0; i < M; ++i) {
+    const R test_output = nnMatrixAt(nnNetOutputs(query), i, 0);
+    TEST_TRUE(double_eq(test_output, expected_outputs[i], OUTPUT_EPS));
+  }
+
+  nnDeleteQueryObject(&query);
+  nnDeleteNet(&net);
+}
-- 
cgit v1.2.3