let's build an autograd engine (in Rust)
Building reverse-mode automatic differentiation from scratch in Rust, following Andrej Karpathy's micrograd. Covers computational graphs, Rc/RefCell, operator overloading, and backprop.
Andrej Karpathy's micrograd is a 150-line Python autograd engine, and the clearest possible explanation of how backpropagation actually works. Not as a matrix operation, but as a walk over a directed acyclic graph of scalar values. I wanted to build the same thing in Rust.
The result is CrabGrad. The interesting parts are all the places where Rust pushes back.
You'll need basic familiarity with Rust ownership and some understanding of derivatives. The concepts follow Karpathy's guide closely, but the challenge is mostly translating Python's implicit shared mutable state into explicit Rust.
what autograd actually is
When you compute loss = (pred - target)^2, you're building a graph. Each operation (-, ^2) creates a new node that remembers its inputs. Backpropagation is just a reverse traversal of that graph, applying the chain rule at each node to accumulate gradients.
In Python this is trivial: objects are reference-counted by default, mutation is unchecked, and you can store references wherever you want. In Rust, shared mutable state is the thing the language is designed to prevent. So every ownership and borrowing decision has to be explicit.
the core data structure
Every number in the computation is a Value. Let's define it:
pub struct Value {
pub data: RefCell<f64>,
pub grad: RefCell<f64>,
pub op: Option<Op>,
pub inputs: Inputs,
pub label: String,
}data and grad are wrapped in RefCell, Rust's escape hatch for interior mutability. The borrow checker enforces its rules at compile time, but RefCell defers them to runtime. This is necessary because during the backward pass, multiple nodes need to accumulate into the same gradient (grad += upstream) and the compiler can't statically verify that's safe.
The Inputs enum makes the graph structure explicit:
pub enum Inputs {
Leaf,
Unary(Rc<Value>),
Binary(Rc<Value>, Rc<Value>),
}A leaf node is a weight, bias, or input with no inputs of its own. Everything else is a unary op (relu, tanh, exp, neg) or a binary op (add, mul, sub, div, pow). Nodes are shared via Rc<Value>, reference-counted single-threaded pointers. When you write a + b, both a and b might be inputs to multiple other nodes. Rc handles that shared ownership.
the operator overloading problem
In Python, a + b just works because you override __add__. In Rust, implementing Add for Rc<Value> isn't allowed: you don't own Rc. This is the orphan rule: you can only implement foreign traits for your own types.
The fix is a newtype wrapper:
pub struct GradValue(pub Rc<Value>);Now you own GradValue, so you can implement Add, Mul, etc. for it. The ergonomics cost: you need four implementations per binary op for all ownership combinations (GradValue + GradValue, &GradValue + GradValue, etc.). They all delegate to the && case.
Each operator creates a new node and wires up the graph:
// in ops/add.rs
pub fn forward(lhs: &Rc<Value>, rhs: &Rc<Value>) -> GradValue {
let result = *lhs.data.borrow() + *rhs.data.borrow();
GradValue(Rc::new(Value {
data: RefCell::new(result),
grad: RefCell::new(0.0),
op: Some(Op::Add),
inputs: Inputs::Binary(Rc::clone(lhs), Rc::clone(rhs)),
label: String::new(),
}))
}The backward function for add is trivial, the upstream gradient flows unchanged to both inputs:
pub fn backward(lhs: &Rc<Value>, rhs: &Rc<Value>, upstream: f64) {
*lhs.grad.borrow_mut() += upstream;
*rhs.grad.borrow_mut() += upstream;
}Note the +=. Gradients accumulate. If the same node is used in multiple parts of the computation, its gradient is the sum of all upstream contributions. That's the chain rule.
Try it:
let a = GradValue::new(2.0);
let b = GradValue::new(3.0);
let c = &a + &b; // c.data = 5.0The graph is building itself as you do arithmetic.
a few gradient formulas
tanh is a common activation function. Its derivative is 1 - tanh(x)². Since we already computed tanh(x) as the forward output, we reuse it:
pub fn backward(input: &Rc<Value>, upstream: f64, output: f64) {
*input.grad.borrow_mut() += upstream * (1.0 - output * output);
}pow needs to propagate to both base and exponent. d/dx (x^n) = n * x^(n-1), and d/dn (x^n) = x^n * ln(x). We guard against base <= 0 to avoid ln(0):
pub fn backward(base: &Rc<Value>, exp: &Rc<Value>, upstream: f64, output: f64) {
let b = *base.data.borrow();
let e = *exp.data.borrow();
if b > 0.0 {
*base.grad.borrow_mut() += upstream * e * b.powf(e - 1.0);
*exp.grad.borrow_mut() += upstream * output * b.ln();
}
}Each operation lives in its own file under ops/. The pattern is always the same: a forward function that builds a new node, and a backward function that accumulates gradients into the inputs.
the backward pass
Backpropagation requires processing nodes in reverse topological order, outputs before inputs. The topo sort is a standard DFS. The tricky part is the visited set:
fn _build_topo(node: &Rc<Value>, visited: &mut HashSet<*const Value>, topo: &mut Vec<Rc<Value>>) {
let ptr = Rc::as_ptr(node);
if visited.contains(&ptr) { return; }
visited.insert(ptr);
match &node.inputs {
Inputs::Binary(lhs, rhs) => {
_build_topo(lhs, visited, topo);
_build_topo(rhs, visited, topo);
}
Inputs::Unary(input) => _build_topo(input, visited, topo),
Inputs::Leaf => {}
}
topo.push(Rc::clone(node));
}The visited set uses *const Value (raw pointer) as the key, not Rc<Value> itself. Rc doesn't implement Hash because hashing by pointer value would be surprising in most cases, but here we explicitly want pointer identity to detect shared nodes. Rc::as_ptr() gives the raw pointer without transferring ownership.
Once the topological order is built, backward is straightforward:
pub fn backward(root: &Rc<Value>) {
*root.grad.borrow_mut() = 1.0; // dL/dL = 1
let topo = build_topo(root);
for node in topo.iter().rev() {
if let Some(op) = &node.op {
let upstream = *node.grad.borrow();
let output = *node.data.borrow();
op.backward(&node.inputs, upstream, output);
}
}
}Let's verify it works on a simple example:
let a = GradValue::new(2.0);
let b = GradValue::new(3.0);
let c = &a * &b; // c = a * b = 6
let loss = c.tanh();
loss.backward();
// dc/da = b = 3.0, but scaled by d(tanh(c))/dc
println!("a.grad = {}", a.grad()); // something nonzero
println!("b.grad = {}", b.grad()); // something nonzerobuilding a neuron
On top of the graph engine, we can build a small MLP. A Neuron holds weights and a bias as GradValues, just leaf nodes:
pub struct Neuron {
pub weights: Vec<GradValue>,
pub bias: GradValue,
pub non_lin: bool,
}Forward pass is just arithmetic, and the graph builds itself:
pub fn forward(&self, inputs: &[GradValue]) -> GradValue {
let mut out = self.bias.clone();
for (w, x) in self.weights.iter().zip(inputs.iter()) {
out = out + w * x;
}
if self.non_lin { out.relu() } else { out }
}Training is vanilla SGD:
pub fn update_parameters(&self, lr: f64) {
for p in self.parameters() {
let grad = *p.grad.borrow();
*p.data.borrow_mut() -= lr * grad;
}
}Layer stacks neurons; MLP stacks layers.
training
main.rs fits a 1 → [10, 10] → 1 MLP to y = 0.5x² + 0.3x + 0.1. 50 training points, 100 epochs, MSE loss, lr = 0.01:
for epoch in 0..100 {
let preds: Vec<GradValue> = xs.iter().map(|x| mlp.forward(&[x.clone()])).collect();
let loss = mse(&preds, &ys);
mlp.zero_grad();
loss.backward();
mlp.update_parameters(0.01);
if epoch % 10 == 0 {
println!("epoch {} loss: {:.6}", epoch, loss.data());
}
}zero_grad() resets all gradients to 0.0 before each backward pass, necessary because they accumulate with +=. Running it, loss drops from ~0.3 to ~0.001 over 100 epochs. It works.
next steps
This covers the core of reverse-mode autodiff. From here, the natural next step is tensors: instead of scalar Value nodes, each node holds an n-dimensional array, and the graph operates on batches. That's what PyTorch and JAX do under the hood.
Karpathy extended micrograd into neural networks: zero to hero, a lecture series that's a good place to continue if you want to see how this scales up.
The main friction points in Rust were all about shared mutable state, which is exactly what a computational graph is. Multiple nodes share the same inputs; the backward pass mutates gradients on nodes referenced from many places. The tools: Rc for shared ownership, RefCell for interior mutability, raw pointers for identity-based hashing. None of it is exotic, but writing it forces you to name the decisions that Python makes implicitly.