MNIST: digit recognition with kNN (no training)
Implementation in RubixML
Running the kNN example with RubixML.
Example of use
<?php
use app\classes\MnistLoader;
use Rubix\ML\Classifiers\KNearestNeighbors;
use Rubix\ML\CrossValidation\Metrics\Accuracy;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Kernels\Distance\Euclidean;
try {
// Build the training and test datasets from the filtered CSV rows.
$trainRows = MnistLoader::loadIterable('train.csv', categoricalLabels: true, normalize: true, digits: [0, 1]);
$testRows = MnistLoader::loadIterable('test.csv', categoricalLabels: true, normalize: true, digits: [0, 1]);
$dataset = Labeled::fromIterator($trainRows);
$testDataset = Labeled::fromIterator($testRows);
} catch (Exception $e) {
echo '<div class="alert alert-danger" role="alert">' . htmlspecialchars($e->getMessage(), ENT_QUOTES, 'UTF-8') . '</div>';
exit;
}
$model = new KNearestNeighbors(3, false, new Euclidean());
$model->train($dataset);
$predictions = [];
$testingLabels = $testDataset->labels();
foreach ($testDataset->samples() as $i => $x) {
$prediction = $model->predict(new Unlabeled([$x]))[0];
$predictions[] = $prediction;
}
$metric = new Accuracy();
$score = $metric->score($predictions, $testingLabels);
echo 'Train samples handled: ' . number_format($dataset->numSamples()) . PHP_EOL;
echo 'Test samples handled: ' . number_format($testDataset->numSamples()) . PHP_EOL . PHP_EOL;
echo 'Accuracy: ' . round($score * 100, 2) . '%';
Samples of digit: 0
Predicted digit: 0
Predicted digit: 0
Predicted digit: 0
Samples of digit: 1
Predicted digit: 1
Predicted digit: 1
Predicted digit: 1
Result:
Memory: 0 Mb
Time running: < 0.001 sec.
Train samples handled: 12,666
Test samples handled: 2,116
Accuracy: 99.92%