Hello! Today we have another one of those mathy posts. Fasten your seat belts, here we go! 🔥🚀

Today we are talking about derivatives. In Calculus, derivatives are
surprisingly easy to compute (and loved by many) because they follow simple
rules which can *always* be used, so you never find corner cases, just cases
where you could have used a more efficient approach. This is quite different
from integration, for example, where a big part of getting any result at all is
having strong intuitions and heuristics around which paths are the correct ones
to take.

We are going to use Rust, the **superior language**. Get out here
Carbon. 🔥🔥🔥

# Expressions

The first thing we need for calculating derivatives is to be able to represent
*expressions*. In mathematics, expressions represent a sequence of operations
that need to be executed over some constants and variables. Those expressions
are what we use as input for calculating a derivative. For example, we could have:

$$f(x) = x^2$$

This would mean that if we were to take the derivative:

$$f^\prime(x) = 2x$$

the input expression would have been `x*x`

and the output (the derivative)
would have been `2*x`

.

Those expressions need to be represented in our code so we can transform them (i.e.: calculate the derivative).

Since expressions represent a sequence of operations and a few constants and variables, the first thing we need to represent in our program is variables and constants, and then the operations we can make with them.

In *Rust*, we can start out with something like:

```
enum Expression {
Constant(f32),
Variable,
}
```

here we are limiting our variables to a single one, namely `x`

. Constants are
being represented as floating point values.

Now we need to represent some operations. Lets start with operations that take in two different operands, in other words:

- a + b
- a - b
- a * b
- a / b

```
enum BinaryOperator {
Addition,
Subtraction,
Multiplication,
Division,
}
struct BinaryExpression {
operator: BinaryOperator,
left_operand: Box<Expression>,
right_operand: Box<Expression>,
}
```

and lets also augment our original definition of an expression to include this new kind of expression:

```
enum Expression {
Binary(BinaryExpression),
Constant(f32),
Variable,
}
```

Notice we use `Box<Epression>`

as the operands for our binary expressions,
(i.e.: the `a`

s and `b`

s). The `Expression`

is there to make sure we can write
operations that refer to further operations. The `Box`

is there because if we
don’t use a `Box`

, the *Rust* compiler can’t actually build those types for us
since they would have infinite size: the expressions would refer to other
expressions up to infinity times and go infinitely deep.

Now we can think about another kind of operation we can make. When we write
things like `-(a*b)`

we have a unary additive inverse operation, which takes
`(a*b)`

and calculates the additive inverse.

In those cases we want to define a whole new type of expression:

```
enum UnaryOperator {
Minus,
}
struct UnaryExpression {
operator: UnaryOperator,
operand: Box<Expression>,
}
```

And again, we need to add it to the base expression type we created:

```
enum Expression {
Unary(UnaryExpression),
Binary(BinaryExpression),
Constant(f32),
Variable,
}
```

# Derivatives

Like I said at the beginning of this post, derivatives have very simple and
clear rules. All we need to do is encode those rules as code. Lets start with a
`derivative`

function which takes in a `Expression`

to derive and outputs
another `Expression`

which represents the derivative.

This function can be as simple as:

```
fn derivative(input: Expression) -> Expression {
match input {
Expression::Unary(unary_expr) => derivative_unary(unary_expr),
Expression::Binary(bin_expr) => derivative_binary(bin_expr),
Expression::Constant(_) => Expression::Constant(0.0),
Expression::Variable => Expression::Constant(1.0),
}
}
```

It offloads the bulk of the work to two another functions: `derivative_unary`

and `derivative_binary`

. But it does have some built in rules:

- for constants, the derivative is always zero;
- for the variable (always
`x`

), the derivative is always`1`

.

Lets take a quick look into `derivative_unary`

:

```
fn derivative_unary(input: UnaryExpression) -> Expression {
match input.operator {
UnaryOperator::Minus => Expression::Unary(UnaryExpression {
operator: UnaryOperator::Minus,
operand: Box::new(derivative(*input.operand)),
}),
}
}
```

There’s only a single case to handle here, which is `-a`

. In this case the derivative is always:

$$\frac{d}{dx} \left[-a\right] = -\frac{d}{dx}\left[a\right]$$

Notice there’s some unboxing and re-boxing but aside from that not much goes on in this function.

The bulk of the work is done in the `derivative_binary`

function:

```
fn derivative_binary(input: BinaryExpression) -> Expression {
match input.operator {
BinaryOperator::Addition => Expression::Binary(BinaryExpression {
operator: BinaryOperator::Addition,
left_operand: Box::new(derivative(*input.left_operand)),
right_operand: Box::new(derivative(*input.right_operand)),
}),
BinaryOperator::Subtraction => Expression::Binary(BinaryExpression {
// ...
}),
BinaryOperator::Multiplication => Expression::Binary(BinaryExpression {
// ...
}),
BinaryOperator::Division => Expression::Binary(BinaryExpression {
// ...
}),
}
}
```

All this function is doing is implementing four simple differentiation rules (full code linked at the end):

- addition rule;
- subtraction rule;
- product rule;
- quotient rule.

There’s a nice page on Wikipedia detailing those and other differentiation rules.

# Testing

Now it’s a good time to test our program. Lets try the function from the start of this post: `x*x`

:

```
let expression: Expression = Expression::Binary(BinaryExpression {
operator: BinaryOperator::Multiplication,
left_operand: Box::new(Expression::Variable),
right_operand: Box::new(Expression::Variable),
});
```

And now lets run our program:

```
❯ cargo run
((1*x)+(x*1))
```

This is the correct derivative! If we work through a bit of simplification we can see this can be rewritten as `x+x`

, which is `2*x`

.

# Simplification

We can do a little bit better, however, by making the program itself do the
simplifications! We can use the same general approach as we used to calculate
the derivative: create a function which takes in an `Expression`

as input and
then outputs another, modified, `Expression`

.

Thanks to *Rust* pattern matching, this function could be written as:

```
fn simplify(input: Expression) -> Expression {
match input {
Expression::Binary(bin_expr) => match bin_expr.operator {
BinaryOperator::Multiplication => {
match (*bin_expr.left_operand, *bin_expr.right_operand) {
// match x*a and 1*b, transform into a or b
(Expression::Constant(c), right_operand) if c == 1.0 => right_operand,
(left_operand, Expression::Constant(c)) if c == 1.0 => left_operand,
(left_operand, right_operand) => Expression::Binary(BinaryExpression {
operator: BinaryOperator::Multiplication,
left_operand: Box::new(simplify(left_operand)),
right_operand: Box::new(simplify(right_operand)),
}),
}
}
_ => {
// ... recursively simplify() ...
}
},
Expression::Unary(unary_expr) => {
// ... recursively simplify() ...
},
Expression::Constant(c) => Expression::Constant(c),
Expression::Variable => Expression::Variable,
}
}
```

Which after running the program again gives us:

```
❯ cargo run
(x+x)
```

If we also add new rules to match this form (i.e.: variable plus variable) and
convert that into `2*x`

, and call `simplify()`

twice, we can get our desired
simplified output:

```
❯ cargo run
(2*x)
```

Yay! There we have it! We built a simple symbolic differentiation system! We have sidestepped a few interesting problems here by using a constrained version of the algebra (namely: no exponential functions), but it is nevertheless a fully working system and it does produce correct results. 😁

This kind of system has many applications, ranging from finding roots of continuous functions to ML systems 🤖