Gradient descent on fingers
Example 2. Effect of learning rate
Below you can see how the parameter $w$ behaves for three learning rate values on the same dataset.
Example of use
<?php
function train(float $lr, array $x, array $y): array {
// Start with an initial weight and track how it changes each epoch.
$w = 0.0;
$n = count($x);
$trajectory = [];
echo PHP_EOL . "Learning rate = $lr" . PHP_EOL;
for ($epoch = 1; $epoch <= 10; $epoch++) {
$gradient = 0.0;
for ($i = 0; $i < $n; $i++) {
// Compute prediction error for one training example.
$error = ($w * $x[$i]) - $y[$i];
$gradient += $x[$i] * $error;
}
// Convert the accumulated value into the mean squared error gradient.
$gradient = (2 / $n) * $gradient;
// Update the weight in the direction that reduces the loss.
$w -= $lr * $gradient;
// Store the rounded weight so the caller can visualize the path.
$trajectory[] = [
'epoch' => $epoch,
'w' => round($w, 4),
];
echo "Epoch $epoch: w = " . round($w, 4) . PHP_EOL;
}
return $trajectory;
}
// Simple training data where the target relationship is y = 2x.
$x = [1, 2, 3, 4];
$y = [2, 4, 6, 8];
$trajectories = [
'0.01' => train(0.01, $x, $y),
'0.1' => train(0.1, $x, $y),
'1.0' => train(1.0, $x, $y),
];
0.01 – too small a step: training progresses slowly.
0.1 – a normal step: the parameter quickly approaches the correct value.
1.0 – too large a step: the updates become unstable and diverge.
Result:
Memory: 0.017 Mb
Time running: < 0.001 sec.
Learning rate = 0.01
Epoch 1: w = 0.3
Epoch 2: w = 0.555
Epoch 3: w = 0.7718
Epoch 4: w = 0.956
Epoch 5: w = 1.1126
Epoch 6: w = 1.2457
Epoch 7: w = 1.3588
Epoch 8: w = 1.455
Epoch 9: w = 1.5368
Epoch 10: w = 1.6063
Learning rate = 0.1
Epoch 1: w = 3
Epoch 2: w = 1.5
Epoch 3: w = 2.25
Epoch 4: w = 1.875
Epoch 5: w = 2.0625
Epoch 6: w = 1.9688
Epoch 7: w = 2.0156
Epoch 8: w = 1.9922
Epoch 9: w = 2.0039
Epoch 10: w = 1.998
Learning rate = 1
Epoch 1: w = 30
Epoch 2: w = -390
Epoch 3: w = 5490
Epoch 4: w = -76830
Epoch 5: w = 1075650
Epoch 6: w = -15059070
Epoch 7: w = 210827010
Epoch 8: w = -2951578110
Epoch 9: w = 41322093570
Epoch 10: w = -578509309950