MNIST: digit recognition with kNN (no training)

Implementation in pure PHP

In the previous case we recognized digits using Naive Bayes – a probabilistic model that trains on data and builds a global representation of classes. Now we will look at the same task from the opposite side. There will be no training in the usual sense: we will not search for parameters, optimize a loss function, or build a model. Instead, we do something much more direct: we compare a new image to already known ones and search for similar examples. This is exactly the k-nearest neighbors algorithm.

 
<?php

namespace app\classes;

/**
 * Simple K-Nearest Neighbors (KNN) classifier.
 */
class KNearestNeighbors {
    
/**
     * Training feature vectors.
     */
    
private array $trainSamples = [];

    
/**
     * Training labels aligned by index with `$trainSamples`.
     */
    
private array $trainLabels = [];

    public function 
__construct(array $samples, array $labels) {
        
$this->trainSamples $samples;
        
$this->trainLabels $labels;
    }

    
/**
     * Compute Euclidean distance between two vectors.
     */
    
public function euclideanDistance(array $a, array $b): float {
        
$sum 0.0;

        foreach (
$a as $i => $value) {
            
$diff $value $b[$i];
            
$sum += $diff $diff;
        }

        return 
sqrt($sum);
    }

    
/**
     * Predict a label for a single query vector.
     *
     * @param array $query Feature vector to classify
     * @param int $k Number of neighbors to vote (top-K)
     * @param int|null $trainLimit Optional cap on how many training rows are used
     *
     * @return string|int|null The winning label (depends on how labels were provided)
     */
    
public function predict(array $queryint $k 3, ?int $trainLimit null): string|int|null {
        
$trainLimit min($trainLimitcount($this->trainSamples), count($this->trainLabels));

        
$k max(1min($k$trainLimit));

        
$distances = [];

        for (
$i 0$i $trainLimit$i++) {
            
$distances[] = [
                
'distance' => $this->euclideanDistance($this->trainSamples[$i], $query),
                
'label' => $this->trainLabels[$i],
            ];
        }

        
usort($distances, static fn ($a$b) => $a['distance'] <=> $b['distance']);
        
$neighbors array_slice($distances0$k);

        
$votes array_count_values(array_column($neighbors'label'));

        
arsort($votes);

        return 
array_key_first($votes);
    }

    public function 
predictBatch(array $Xint $k 3, ?int $trainLimit null): array {
        
$predictions = [];

        foreach (
$X as $x) {
            
$predictions[] = $this->predict($x$k$trainLimit);
        }

        return 
$predictions;
    }

    
/**
     * Compute simple accuracy score for a labeled dataset.
     */
    
public function score(array $X, array $yint $k 3, ?int $trainLimit null): float {
        
$predictions $this->predictBatch($X$k$trainLimit);
        
$correct 0;

        foreach (
$predictions as $i => $prediction) {
            if (isset(
$y[$i]) && $prediction === $y[$i]) {
                
$correct++;
            }
        }

        return 
count($y) > ? ($correct count($y)) : 0.0;
    }
}
 
<?php

namespace app\classes;

use 
Exception;

/**
 * MnistLoader - Utility class for loading and preprocessing MNIST digit dataset
 *
 * This class provides methods to load MNIST data from CSV files with various
 * preprocessing options including normalization, digit filtering, and label formatting.
 * It supports both array-based loading (for custom implementations) and
 * iterable loading (for Rubix ML compatibility).
 */
class MnistLoader {
    
/**
     * Load MNIST data as arrays for custom ML implementations
     */
    
public static function load(string $filestring $directory ''bool $categoricalLabels falsebool $normalize true, array $digits = [01]): array {
        
$features = [];  // 2D array: each element is an array of 784 pixel values
        
$labels = [];   // 1D array: each element is the corresponding digit label

        
$handle self::openFile($file$directory);

        
// Process each row in the CSV file
        // Each row contains: [label, pixel1, pixel2, ..., pixel784]
        
while (($row fgetcsv($handle)) !== false) {
            if (
$processed self::processRow($row$digits$normalize$categoricalLabels)) {
                
$features[] = $processed[0];
                
$labels[] = $processed[1];
            }
        }

        
// Return features and labels in the format expected by callers.
        
return [$features$labels];
    }

    
/**
     * Load MNIST data as an iterator for Rubix ML compatibility
     */
    
public static function loadIterable(string $filestring $directory ''bool $categoricalLabels falsebool $normalize true, array $digits = [01]): iterable {

        
$handle self::openFile($file$directory);

        
// Read the CSV file row by row and keep only valid samples.
        
while (($row fgetcsv($handle)) !== false) {
            if (
$processed self::processRow($row$digits$normalize$categoricalLabels)) {
                yield 
array_merge($processed[0], [$processed[1]]);
            }
        }
    }

    private static function 
openFile(string $filestring $directory) {
        
$handle = @fopen($directory $file'r');

        if (
$handle === false) {
            throw new 
Exception('Dataset file not found: ' $directory $file);
        }

        return 
$handle;
    }

    private static function 
processRow(array &$row, array &$digitsbool $normalizebool $categoricalLabels): ?array {
        if (
$row === [] || $row[0] === null || $row[0] === '') {
            return 
null;
        }

        
$label = (int)$row[0];

        if (!
in_array($label$digits)) {
            return 
null;
        }

        
$pixels array_slice($row1);

        if (
$normalize) {
            
$pixels array_map(static fn ($v): float => ((float)$v) / 255.0$pixels);
        }

        
$formattedLabel $categoricalLabels ? ($label === 'one' 'zero') : $label;

        return [
$pixels$formattedLabel];
    }
}

Implementation in RubixML

Now the same approach, but using a library.

 
<?php

use app\classes\MnistLoader;
use 
Rubix\ML\Classifiers\KNearestNeighbors;
use 
Rubix\ML\Datasets\Labeled;
use 
Rubix\ML\Kernels\Distance\Euclidean;

try {
    
// Build the training and test datasets from the filtered CSV rows.
    
$trainRows MnistLoader::loadIterable('train.csv'categoricalLabelstruenormalizetruedigits: [01]);
    
$testRows MnistLoader::loadIterable('test.csv'categoricalLabelstruenormalizetruedigits: [01]);

    
$dataset Labeled::fromIterator($trainRows);
    
$testDataset Labeled::fromIterator($testRows);
} catch (
Exception $e) {
    echo 
'<div class="alert alert-danger" role="alert">' htmlspecialchars($e->getMessage(), ENT_QUOTES'UTF-8') . '</div>';
    exit;
}

$model = new KNearestNeighbors(3false, new Euclidean());
$model->train($dataset);