MNIST: digit recognition with kNN (no training)
Implementation in pure PHP
In the previous case we recognized digits using Naive Bayes – a probabilistic model that trains on data and builds a global representation of classes. Now we will look at the same task from the opposite side. There will be no training in the usual sense: we will not search for parameters, optimize a loss function, or build a model. Instead, we do something much more direct: we compare a new image to already known ones and search for similar examples. This is exactly the k-nearest neighbors algorithm.
Example of code: Class KNearestNeighbors
<?php
namespace app\classes;
/**
* Simple K-Nearest Neighbors (KNN) classifier.
*/
class KNearestNeighbors {
/**
* Training feature vectors.
*/
private array $trainSamples = [];
/**
* Training labels aligned by index with `$trainSamples`.
*/
private array $trainLabels = [];
public function __construct(array $samples, array $labels) {
$this->trainSamples = $samples;
$this->trainLabels = $labels;
}
/**
* Compute Euclidean distance between two vectors.
*/
public function euclideanDistance(array $a, array $b): float {
$sum = 0.0;
foreach ($a as $i => $value) {
$diff = $value - $b[$i];
$sum += $diff * $diff;
}
return sqrt($sum);
}
/**
* Predict a label for a single query vector.
*
* @param array $query Feature vector to classify
* @param int $k Number of neighbors to vote (top-K)
* @param int|null $trainLimit Optional cap on how many training rows are used
*
* @return string|int|null The winning label (depends on how labels were provided)
*/
public function predict(array $query, int $k = 3, ?int $trainLimit = null): string|int|null {
$trainLimit = min($trainLimit, count($this->trainSamples), count($this->trainLabels));
$k = max(1, min($k, $trainLimit));
$distances = [];
for ($i = 0; $i < $trainLimit; $i++) {
$distances[] = [
'distance' => $this->euclideanDistance($this->trainSamples[$i], $query),
'label' => $this->trainLabels[$i],
];
}
usort($distances, static fn ($a, $b) => $a['distance'] <=> $b['distance']);
$neighbors = array_slice($distances, 0, $k);
$votes = array_count_values(array_column($neighbors, 'label'));
arsort($votes);
return array_key_first($votes);
}
public function predictBatch(array $X, int $k = 3, ?int $trainLimit = null): array {
$predictions = [];
foreach ($X as $x) {
$predictions[] = $this->predict($x, $k, $trainLimit);
}
return $predictions;
}
/**
* Compute simple accuracy score for a labeled dataset.
*/
public function score(array $X, array $y, int $k = 3, ?int $trainLimit = null): float {
$predictions = $this->predictBatch($X, $k, $trainLimit);
$correct = 0;
foreach ($predictions as $i => $prediction) {
if (isset($y[$i]) && $prediction === $y[$i]) {
$correct++;
}
}
return count($y) > 0 ? ($correct / count($y)) : 0.0;
}
}
Example of code: Class MnistLoader
<?php
namespace app\classes;
use Exception;
/**
* MnistLoader - Utility class for loading and preprocessing MNIST digit dataset
*
* This class provides methods to load MNIST data from CSV files with various
* preprocessing options including normalization, digit filtering, and label formatting.
* It supports both array-based loading (for custom implementations) and
* iterable loading (for Rubix ML compatibility).
*/
class MnistLoader {
/**
* Load MNIST data as arrays for custom ML implementations
*/
public static function load(string $file, string $directory = '', bool $categoricalLabels = false, bool $normalize = true, array $digits = [0, 1]): array {
$features = []; // 2D array: each element is an array of 784 pixel values
$labels = []; // 1D array: each element is the corresponding digit label
$handle = self::openFile($file, $directory);
// Process each row in the CSV file
// Each row contains: [label, pixel1, pixel2, ..., pixel784]
while (($row = fgetcsv($handle)) !== false) {
if ($processed = self::processRow($row, $digits, $normalize, $categoricalLabels)) {
$features[] = $processed[0];
$labels[] = $processed[1];
}
}
// Return features and labels in the format expected by callers.
return [$features, $labels];
}
/**
* Load MNIST data as an iterator for Rubix ML compatibility
*/
public static function loadIterable(string $file, string $directory = '', bool $categoricalLabels = false, bool $normalize = true, array $digits = [0, 1]): iterable {
$handle = self::openFile($file, $directory);
// Read the CSV file row by row and keep only valid samples.
while (($row = fgetcsv($handle)) !== false) {
if ($processed = self::processRow($row, $digits, $normalize, $categoricalLabels)) {
yield array_merge($processed[0], [$processed[1]]);
}
}
}
private static function openFile(string $file, string $directory) {
$handle = @fopen($directory . $file, 'r');
if ($handle === false) {
throw new Exception('Dataset file not found: ' . $directory . $file);
}
return $handle;
}
private static function processRow(array &$row, array &$digits, bool $normalize, bool $categoricalLabels): ?array {
if ($row === [] || $row[0] === null || $row[0] === '') {
return null;
}
$label = (int)$row[0];
if (!in_array($label, $digits)) {
return null;
}
$pixels = array_slice($row, 1);
if ($normalize) {
$pixels = array_map(static fn ($v): float => ((float)$v) / 255.0, $pixels);
}
$formattedLabel = $categoricalLabels ? ($label === 1 ? 'one' : 'zero') : $label;
return [$pixels, $formattedLabel];
}
}
Implementation in RubixML
Now the same approach, but using a library.
Example of code: RubixML KNearestNeighbors
<?php
use app\classes\MnistLoader;
use Rubix\ML\Classifiers\KNearestNeighbors;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Kernels\Distance\Euclidean;
try {
// Build the training and test datasets from the filtered CSV rows.
$trainRows = MnistLoader::loadIterable('train.csv', categoricalLabels: true, normalize: true, digits: [0, 1]);
$testRows = MnistLoader::loadIterable('test.csv', categoricalLabels: true, normalize: true, digits: [0, 1]);
$dataset = Labeled::fromIterator($trainRows);
$testDataset = Labeled::fromIterator($testRows);
} catch (Exception $e) {
echo '<div class="alert alert-danger" role="alert">' . htmlspecialchars($e->getMessage(), ENT_QUOTES, 'UTF-8') . '</div>';
exit;
}
$model = new KNearestNeighbors(3, false, new Euclidean());
$model->train($dataset);