Felipe Tavares' Avatar

Simple Symbolic Differentiation

August 7, '22

Hello! Today we have another one of those mathy posts. Fasten your seat belts, here we go! 🔥🚀

Today we are talking about derivatives. In Calculus, derivatives are surprisingly easy to compute (and loved by many) because they follow simple rules which can always be used, so you never find corner cases, just cases where you could have used a more efficient approach. This is quite different from integration, for example, where a big part of getting any result at all is having strong intuitions and heuristics around which paths are the correct ones to take.

We are going to use Rust, the superior language. Get out here Carbon. 🔥🔥🔥

Expressions

The first thing we need for calculating derivatives is to be able to represent expressions. In mathematics, expressions represent a sequence of operations that need to be executed over some constants and variables. Those expressions are what we use as input for calculating a derivative. For example, we could have:

$$f(x) = x^2$$

This would mean that if we were to take the derivative:

$$f^\prime(x) = 2x$$

the input expression would have been x*x and the output (the derivative) would have been 2*x.

Those expressions need to be represented in our code so we can transform them (i.e.: calculate the derivative).

Since expressions represent a sequence of operations and a few constants and variables, the first thing we need to represent in our program is variables and constants, and then the operations we can make with them.

In Rust, we can start out with something like:

enum Expression {
    Constant(f32),
    Variable,
}

here we are limiting our variables to a single one, namely x. Constants are being represented as floating point values.

Now we need to represent some operations. Lets start with operations that take in two different operands, in other words:

enum BinaryOperator {
    Addition,
    Subtraction,
    Multiplication,
    Division,
}
struct BinaryExpression {
    operator: BinaryOperator,
    left_operand: Box<Expression>,
    right_operand: Box<Expression>,
}

and lets also augment our original definition of an expression to include this new kind of expression:

enum Expression {
    Binary(BinaryExpression),
    Constant(f32),
    Variable,
}

Notice we use Box<Epression> as the operands for our binary expressions, (i.e.: the as and bs). The Expression is there to make sure we can write operations that refer to further operations. The Box is there because if we don’t use a Box, the Rust compiler can’t actually build those types for us since they would have infinite size: the expressions would refer to other expressions up to infinity times and go infinitely deep.

Now we can think about another kind of operation we can make. When we write things like -(a*b) we have a unary additive inverse operation, which takes (a*b) and calculates the additive inverse.

In those cases we want to define a whole new type of expression:

enum UnaryOperator {
    Minus,
}

struct UnaryExpression {
    operator: UnaryOperator,
    operand: Box<Expression>,
}

And again, we need to add it to the base expression type we created:

enum Expression {
    Unary(UnaryExpression),
    Binary(BinaryExpression),
    Constant(f32),
    Variable,
}

Derivatives

Like I said at the beginning of this post, derivatives have very simple and clear rules. All we need to do is encode those rules as code. Lets start with a derivative function which takes in a Expression to derive and outputs another Expression which represents the derivative.

This function can be as simple as:

fn derivative(input: Expression) -> Expression {
    match input {
        Expression::Unary(unary_expr) => derivative_unary(unary_expr),
        Expression::Binary(bin_expr) => derivative_binary(bin_expr),
        Expression::Constant(_) => Expression::Constant(0.0),
        Expression::Variable => Expression::Constant(1.0),
    }
}

It offloads the bulk of the work to two another functions: derivative_unary and derivative_binary. But it does have some built in rules:

Lets take a quick look into derivative_unary:

fn derivative_unary(input: UnaryExpression) -> Expression {
    match input.operator {
        UnaryOperator::Minus => Expression::Unary(UnaryExpression {
            operator: UnaryOperator::Minus,
            operand: Box::new(derivative(*input.operand)),
        }),
    }
}

There’s only a single case to handle here, which is -a. In this case the derivative is always:

$$\frac{d}{dx} \left[-a\right] = -\frac{d}{dx}\left[a\right]$$

Notice there’s some unboxing and re-boxing but aside from that not much goes on in this function.

The bulk of the work is done in the derivative_binary function:

fn derivative_binary(input: BinaryExpression) -> Expression {
    match input.operator {
        BinaryOperator::Addition => Expression::Binary(BinaryExpression {
            operator: BinaryOperator::Addition,
            left_operand: Box::new(derivative(*input.left_operand)),
            right_operand: Box::new(derivative(*input.right_operand)),
        }),
        BinaryOperator::Subtraction => Expression::Binary(BinaryExpression {
            // ...
        }),
        BinaryOperator::Multiplication => Expression::Binary(BinaryExpression {
            // ...
        }),
        BinaryOperator::Division => Expression::Binary(BinaryExpression {
            // ...
        }),
    }
}

All this function is doing is implementing four simple differentiation rules (full code linked at the end):

There’s a nice page on Wikipedia detailing those and other differentiation rules.

Testing

Now it’s a good time to test our program. Lets try the function from the start of this post: x*x:

let expression: Expression = Expression::Binary(BinaryExpression {
    operator: BinaryOperator::Multiplication,
    left_operand: Box::new(Expression::Variable),
    right_operand: Box::new(Expression::Variable),
});

And now lets run our program:

❯ cargo run
((1*x)+(x*1))

This is the correct derivative! If we work through a bit of simplification we can see this can be rewritten as x+x, which is 2*x.

Simplification

We can do a little bit better, however, by making the program itself do the simplifications! We can use the same general approach as we used to calculate the derivative: create a function which takes in an Expression as input and then outputs another, modified, Expression.

Thanks to Rust pattern matching, this function could be written as:

fn simplify(input: Expression) -> Expression {
    match input {
        Expression::Binary(bin_expr) => match bin_expr.operator {
            BinaryOperator::Multiplication => {
                match (*bin_expr.left_operand, *bin_expr.right_operand) {
                    // match x*a and 1*b, transform into a or b
                    (Expression::Constant(c), right_operand) if c == 1.0 => right_operand,
                    (left_operand, Expression::Constant(c)) if c == 1.0 => left_operand,
                    (left_operand, right_operand) => Expression::Binary(BinaryExpression {
                        operator: BinaryOperator::Multiplication,
                        left_operand: Box::new(simplify(left_operand)),
                        right_operand: Box::new(simplify(right_operand)),
                    }),
                }
            }
            _ => {
                // ... recursively simplify() ...
            }
        },
        Expression::Unary(unary_expr) => {
            // ... recursively simplify() ...
        },
        Expression::Constant(c) => Expression::Constant(c),
        Expression::Variable => Expression::Variable,
    }
}

Which after running the program again gives us:

❯ cargo run
(x+x)

If we also add new rules to match this form (i.e.: variable plus variable) and convert that into 2*x, and call simplify() twice, we can get our desired simplified output:

❯ cargo run
(2*x)

Yay! There we have it! We built a simple symbolic differentiation system! We have sidestepped a few interesting problems here by using a constrained version of the algebra (namely: no exponential functions), but it is nevertheless a fully working system and it does produce correct results. 😁

This kind of system has many applications, ranging from finding roots of continuous functions to ML systems 🤖

symbolic-differentiation.rs