Why Naive Bayes works

MNIST: probabilistic digit classification (Naive Bayes)


Implementation in pure PHP

Here we solve a full MNIST-based binary classification task with Gaussian Naive Bayes. Instead of learning a linear boundary, the model estimates per-pixel distributions for each digit class and compares log-probabilities. The example below uses the MNIST CSV dataset loaded through MnistLoader. It keeps only digits 0 and 1, normalizes all pixels, groups training samples by class, computes mean and variance for each pixel, and scores each test image with Gaussian likelihoods.

 
<?php

namespace app\classes;

class 
GaussianNB {
    private array 
$stats = [];
    private array 
$grouped = [];
    private 
int $totalSamples 0;

    
// Calculate arithmetic mean of values
    
private function mean(array $values): float {
        return 
array_sum($values) / count($values);
    }

    
// Calculate variance of values given their mean
    
private function variance(array $valuesfloat $mean): float {
        
$sum 0;

        foreach (
$values as $v) {
            
$sum += pow($v $mean2);
        }

        return 
$sum count($values);
    }

    
// Calculate Gaussian probability density function
    
private function gaussian(float $xfloat $meanfloat $variance): float {
        return (
sqrt(pi() * $variance)) * exp(-pow($x $mean2) / ($variance));
    }

    
// Train the classifier by calculating mean and variance for each class/feature
    
public function train(array $X, array $y): void {
        
$this->totalSamples count($X);

        
// Group samples by class
        
$this->grouped = [];

        foreach (
$X as $i => $sample) {
            
$this->grouped[$y[$i]][] = $sample;
        }

        
// Calculate statistics for each class and feature
        
$this->stats = [];

        foreach (
$this->grouped as $class => $rows) {
            
$features array_map(null, ...$rows);

            foreach (
$features as $i => $values) {
                
$m $this->mean($values);
                
$v $this->variance($values$m);

                
$this->stats[$class][$i] = [
                    
'mean' => $m,
                    
// protection against zero variance
                    
'variance' => $v ?: 1e-6,
                ];
            }
        }
    }

    
// Predict class labels for multiple samples
    
public function predict(array $X): array {
        
$predictions = [];

        foreach (
$X as $sample) {
            
$predictions[] = $this->predictSingle($sample);
        }

        return 
$predictions;
    }

    
// Predict class label for a single sample using Bayes' theorem
    
public function predictSingle(array $input): int {
        
$scores = [];

        foreach (
$this->stats as $class => $features) {
            
// Start with prior probability (class frequency)
            
$logProb log(count($this->grouped[$class]) / $this->totalSamples);

            
// Add likelihood for each feature (naive independence assumption)
            
foreach ($features as $i => $params) {
                
$prob $this->gaussian($input[$i], $params['mean'], $params['variance']);
                
$logProb += log($prob);
            }

            
$scores[$class] = $logProb;
        }

        
// Return class with highest score
        
arsort($scores);

        return (int) 
array_key_first($scores);
    }

    
// Calculate score a set of predictions
    
public function score(array $X, array $y): float {
        
$predictions $this->predict($X);
        
$correct 0;

        foreach (
$predictions as $i => $prediction) {
            if (
$prediction === $y[$i]) {
                
$correct++;
            }
        }

        return 
count($y) > ? ($correct count($y)) : 0.0;
    }

    
// Get calculated statistics (mean and variance) for each class/feature
    
public function getStats(): array {
        return 
$this->stats;
    }

    
// Get grouped training data by class
    
public function getGrouped(): array {
        return 
$this->grouped;
    }
}
 
<?php

namespace app\classes;

use 
Exception;

/**
 * MnistLoader - Utility class for loading and preprocessing MNIST digit dataset
 *
 * This class provides methods to load MNIST data from CSV files with various
 * preprocessing options including normalization, digit filtering, and label formatting.
 * It supports both array-based loading (for custom implementations) and
 * iterable loading (for Rubix ML compatibility).
 */
class MnistLoader {
    
/**
     * Load MNIST data as arrays for custom ML implementations
     */
    
public static function load(string $filestring $directory ''bool $categoricalLabels falsebool $normalize true, array $digits = [01]): array {
        
$features = [];  // 2D array: each element is an array of 784 pixel values
        
$labels = [];   // 1D array: each element is the corresponding digit label

        
$handle self::openFile($file$directory);

        
// Process each row in the CSV file
        // Each row contains: [label, pixel1, pixel2, ..., pixel784]
        
while (($row fgetcsv($handle)) !== false) {
            if (
$processed self::processRow($row$digits$normalize$categoricalLabels)) {
                
$features[] = $processed[0];
                
$labels[] = $processed[1];
            }
        }

        
// Return features and labels in the format expected by callers.
        
return [$features$labels];
    }

    
/**
     * Load MNIST data as an iterator for Rubix ML compatibility
     */
    
public static function loadIterable(string $filestring $directory ''bool $categoricalLabels falsebool $normalize true, array $digits = [01]): iterable {

        
$handle self::openFile($file$directory);

        
// Read the CSV file row by row and keep only valid samples.
        
while (($row fgetcsv($handle)) !== false) {
            if (
$processed self::processRow($row$digits$normalize$categoricalLabels)) {
                yield 
array_merge($processed[0], [$processed[1]]);
            }
        }
    }

    private static function 
openFile(string $filestring $directory) {
        
$handle = @fopen($directory $file'r');

        if (
$handle === false) {
            throw new 
Exception('Dataset file not found: ' $directory $file);
        }

        return 
$handle;
    }

    private static function 
processRow(array &$row, array &$digitsbool $normalizebool $categoricalLabels): ?array {
        if (
$row === [] || $row[0] === null || $row[0] === '') {
            return 
null;
        }

        
$label = (int)$row[0];

        if (!
in_array($label$digits)) {
            return 
null;
        }

        
$pixels array_slice($row1);

        if (
$normalize) {
            
$pixels array_map(static fn ($v): float => ((float)$v) / 255.0$pixels);
        }

        
$formattedLabel $categoricalLabels ? ($label === 'one' 'zero') : $label;

        return [
$pixels$formattedLabel];
    }
}

MNIST case with RubixML GaussianNB

The same problem can be solved with RubixML using GaussianNB. The library handles class priors and Gaussian feature likelihoods for us. This keeps the example short while preserving the exact probabilistic interpretation of the handwritten version.

 
<?php

use app\classes\MnistLoader;
use 
Rubix\ML\Classifiers\GaussianNB;
use 
Rubix\ML\Datasets\Labeled;

// Build the training and test datasets from the filtered CSV rows.
$trainRows MnistLoader::loadIterable('train.csv'categoricalLabelstruenormalizetruedigits: [01]);
$testRows MnistLoader::loadIterable('test.csv'categoricalLabelstruenormalizetruedigits: [01]);

$dataset Labeled::fromIterator($trainRows);
$testDataset Labeled::fromIterator($testRows);

$model = new GaussianNB();
$model->train($dataset);