The Chain Rule¶
The Problem¶
We know derivatives tell us "which way to nudge." But in a neural network, the loss depends on parameters through a long chain of operations:
flowchart LR
A["parameter"] --> B["embedding"] --> C["linear"] --> D["relu"] --> E["linear"] --> F["softmax"] --> G["log"] --> H["loss"] How do we find the derivative of loss with respect to parameter when there are 6 operations between them?
Composition of Functions¶
When you nest functions inside functions, it's called composition:
For \(h(x) = (2x + 1)^2\):
- First compute \(g = 2x + 1\)
- Then compute \(f = g^2\)
The question is: what is the derivative of \(h\) with respect to \(x\)?
The Chain Rule¶
The Chain Rule
The derivative of a composition is the product of the individual derivatives.
Let's verify with numbers. For \(h(x) = (2x + 1)^2\) at \(x = 3\):
Nudge \(x\) from 3 to 3.001:
The Chain Rule Visually¶
Think of it like a pipeline. Each stage multiplies the "sensitivity":
flowchart LR
X["x"] -- "×2" --> G["g"] -- "×2g" --> H["h"] If \(x\) wiggles by 1:
- \(g\) wiggles by 2 (because \(\frac{dg}{dx} = 2\))
- \(h\) wiggles by \(2 \times 2g\) (because \(\frac{df}{dg} = 2g\), and \(g\) already wiggled by 2)
Longer Chains¶
The chain rule extends to any number of steps:
Just multiply all the individual derivatives along the chain.
Info
In a neural network, this chain might be 10 or 100 steps long. But the principle is always the same: multiply the local derivatives along the path.
The Key Insight for Autograd¶
This is why microgpt.py's Value class stores local gradients at each operation:
| Operation | Local gradient of \(z\) w.r.t. \(x\) |
|---|---|
| \(z = x + y\) | \(1\) |
| \(z = x \times y\) | \(y\) |
| \(z = x^2\) | \(2x\) |
Each operation only needs to know its own local derivative. The chain rule takes care of composing them into the full derivative.
A Three-Node Example¶
Let's trace through a tiny computation graph:
We want: \(\frac{d(\text{loss})}{da}\) — how does the loss change if we tweak \(a\)?
Each factor is a local gradient — the derivative of one node with respect to its immediate input. The chain rule multiplies them together.
Terminology
| Term | Meaning |
|---|---|
| Chain rule | \(\frac{d(f \circ g)}{dx} = \frac{df}{dg} \times \frac{dg}{dx}\) — multiply local derivatives |
| Local gradient | The derivative of one operation w.r.t. its immediate input |
| Composition | Nesting functions: \(f(g(x))\) |
| Sensitivity | How much the output changes when an input wiggles |