Gradient descent is an effective algorithm for finding the local extrema of functions, and the global extrema of convex functions. It's very useful in machine learning for fitting a model from a family of models by finding the parameters that minimise a loss function. However sometimes you have extra constraints; I recently worked on a problem where the maximum value of the fitted function had to occur at a certain point. It's straightforward to adapt gradient descent with differentiable equality constraints.

The idea is simple, we've got a function loss that we're trying to maximise subject to some constraint function. With gradient descent we will take a step in the direction of greatest decrease of the loss function, along the gradient. The size of the step we take is called the learning rate, lr. In Pytorch:

def gradient_descent_step(x, loss, lr=1e-3):

# Calculate the derivative of the loss function
loss().backward()

# Step in direction of greatest local decrease
x -= lr * x.grad

However this may take us off of our constraint curve. As long as we are at a point x on the constraint curve, constraint(x) = 0, we want to stay on that curve. That means we want to take a step in the direction where the derivative of the constraint is zero (so the value won't change). This happens in the direction orthogonal to the gradient of the constraint, which can be done by removing the component parallel to the constraints gradient. In Pytorch:

def gradient_descent_step(x, loss, constraint, lr=1e-3):

# Calculate gradient of the constraint
constraint(x).backward()

# Calculate gradient of loss function
aloss = loss()
aloss.backward()

# Remove the projection of the loss gradient onto the constraint gradient.
# The resulting vector will be perpendicular to the gradient of the constraint.
perp_proj = x.grad - (x.grad @ direction) / (direction @ direction) * direction

# Step in this direction
x -= lr * perp_proj