Hello! Today we have another one of those mathy posts. Fasten your seat belts, here we go! 🔥🚀
Today we are talking about derivatives. In Calculus, derivatives are surprisingly easy to compute (and loved by many) because they follow simple rules which can always be used, so you never find corner cases, just cases where you could have used a more efficient approach. This is quite different from integration, for example, where a big part of getting any result at all is having strong intuitions and heuristics around which paths are the correct ones to take.
We are going to use Rust, the superior language. Get out here Carbon. 🔥🔥🔥
Expressions
The first thing we need for calculating derivatives is to be able to represent expressions. In mathematics, expressions represent a sequence of operations that need to be executed over some constants and variables. Those expressions are what we use as input for calculating a derivative. For example, we could have:
$$f(x) = x^2$$
This would mean that if we were to take the derivative:
$$f^\prime(x) = 2x$$
the input expression would have been x*x
and the output (the derivative)
would have been 2*x
.
Those expressions need to be represented in our code so we can transform them (i.e.: calculate the derivative).
Since expressions represent a sequence of operations and a few constants and variables, the first thing we need to represent in our program is variables and constants, and then the operations we can make with them.
In Rust, we can start out with something like:
enum Expression {
Constant(f32),
Variable,
}
here we are limiting our variables to a single one, namely x
. Constants are
being represented as floating point values.
Now we need to represent some operations. Lets start with operations that take in two different operands, in other words:
- a + b
- a - b
- a * b
- a / b
enum BinaryOperator {
Addition,
Subtraction,
Multiplication,
Division,
}
struct BinaryExpression {
operator: BinaryOperator,
left_operand: Box<Expression>,
right_operand: Box<Expression>,
}
and lets also augment our original definition of an expression to include this new kind of expression:
enum Expression {
Binary(BinaryExpression),
Constant(f32),
Variable,
}
Notice we use Box<Epression>
as the operands for our binary expressions,
(i.e.: the a
s and b
s). The Expression
is there to make sure we can write
operations that refer to further operations. The Box
is there because if we
don’t use a Box
, the Rust compiler can’t actually build those types for us
since they would have infinite size: the expressions would refer to other
expressions up to infinity times and go infinitely deep.
Now we can think about another kind of operation we can make. When we write
things like -(a*b)
we have a unary additive inverse operation, which takes
(a*b)
and calculates the additive inverse.
In those cases we want to define a whole new type of expression:
enum UnaryOperator {
Minus,
}
struct UnaryExpression {
operator: UnaryOperator,
operand: Box<Expression>,
}
And again, we need to add it to the base expression type we created:
enum Expression {
Unary(UnaryExpression),
Binary(BinaryExpression),
Constant(f32),
Variable,
}
Derivatives
Like I said at the beginning of this post, derivatives have very simple and
clear rules. All we need to do is encode those rules as code. Lets start with a
derivative
function which takes in a Expression
to derive and outputs
another Expression
which represents the derivative.
This function can be as simple as:
fn derivative(input: Expression) -> Expression {
match input {
Expression::Unary(unary_expr) => derivative_unary(unary_expr),
Expression::Binary(bin_expr) => derivative_binary(bin_expr),
Expression::Constant(_) => Expression::Constant(0.0),
Expression::Variable => Expression::Constant(1.0),
}
}
It offloads the bulk of the work to two another functions: derivative_unary
and derivative_binary
. But it does have some built in rules:
- for constants, the derivative is always zero;
- for the variable (always
x
), the derivative is always1
.
Lets take a quick look into derivative_unary
:
fn derivative_unary(input: UnaryExpression) -> Expression {
match input.operator {
UnaryOperator::Minus => Expression::Unary(UnaryExpression {
operator: UnaryOperator::Minus,
operand: Box::new(derivative(*input.operand)),
}),
}
}
There’s only a single case to handle here, which is -a
. In this case the derivative is always:
$$\frac{d}{dx} \left[-a\right] = -\frac{d}{dx}\left[a\right]$$
Notice there’s some unboxing and re-boxing but aside from that not much goes on in this function.
The bulk of the work is done in the derivative_binary
function:
fn derivative_binary(input: BinaryExpression) -> Expression {
match input.operator {
BinaryOperator::Addition => Expression::Binary(BinaryExpression {
operator: BinaryOperator::Addition,
left_operand: Box::new(derivative(*input.left_operand)),
right_operand: Box::new(derivative(*input.right_operand)),
}),
BinaryOperator::Subtraction => Expression::Binary(BinaryExpression {
// ...
}),
BinaryOperator::Multiplication => Expression::Binary(BinaryExpression {
// ...
}),
BinaryOperator::Division => Expression::Binary(BinaryExpression {
// ...
}),
}
}
All this function is doing is implementing four simple differentiation rules (full code linked at the end):
- addition rule;
- subtraction rule;
- product rule;
- quotient rule.
There’s a nice page on Wikipedia detailing those and other differentiation rules.
Testing
Now it’s a good time to test our program. Lets try the function from the start of this post: x*x
:
let expression: Expression = Expression::Binary(BinaryExpression {
operator: BinaryOperator::Multiplication,
left_operand: Box::new(Expression::Variable),
right_operand: Box::new(Expression::Variable),
});
And now lets run our program:
❯ cargo run
((1*x)+(x*1))
This is the correct derivative! If we work through a bit of simplification we can see this can be rewritten as x+x
, which is 2*x
.
Simplification
We can do a little bit better, however, by making the program itself do the
simplifications! We can use the same general approach as we used to calculate
the derivative: create a function which takes in an Expression
as input and
then outputs another, modified, Expression
.
Thanks to Rust pattern matching, this function could be written as:
fn simplify(input: Expression) -> Expression {
match input {
Expression::Binary(bin_expr) => match bin_expr.operator {
BinaryOperator::Multiplication => {
match (*bin_expr.left_operand, *bin_expr.right_operand) {
// match x*a and 1*b, transform into a or b
(Expression::Constant(c), right_operand) if c == 1.0 => right_operand,
(left_operand, Expression::Constant(c)) if c == 1.0 => left_operand,
(left_operand, right_operand) => Expression::Binary(BinaryExpression {
operator: BinaryOperator::Multiplication,
left_operand: Box::new(simplify(left_operand)),
right_operand: Box::new(simplify(right_operand)),
}),
}
}
_ => {
// ... recursively simplify() ...
}
},
Expression::Unary(unary_expr) => {
// ... recursively simplify() ...
},
Expression::Constant(c) => Expression::Constant(c),
Expression::Variable => Expression::Variable,
}
}
Which after running the program again gives us:
❯ cargo run
(x+x)
If we also add new rules to match this form (i.e.: variable plus variable) and
convert that into 2*x
, and call simplify()
twice, we can get our desired
simplified output:
❯ cargo run
(2*x)
Yay! There we have it! We built a simple symbolic differentiation system! We have sidestepped a few interesting problems here by using a constrained version of the algebra (namely: no exponential functions), but it is nevertheless a fully working system and it does produce correct results. 😁
This kind of system has many applications, ranging from finding roots of continuous functions to ML systems 🤖