Constrained Gradient Descent
Gradient descent is an effective algorithm for finding the local extrema of functions, and the global extrema of convex functions. It’s very useful in machine learning for fitting a model from a family of models by finding the parameters that minimise a loss function. However sometimes you have extra constraints; I recently worked on a problem where the maximum value of the fitted function had to occur at a certain point. It’s straightforward to adapt gradient descent with differentiable equality constraints.
The idea is simple, we’ve got a function loss
that we’re trying to maximise subject to some constraint
function. With gradient descent we will take a step in the direction of greatest decrease of the loss function, along the gradient. The size of the step we take is called the learning rate, lr
. In Pytorch:
def gradient_descent_step(x, loss, lr=1e-3):
# Set gradient to zero
if x.grad is not None:
x.grad.zero_()
# Calculate the derivative of the loss function
loss().backward()
# Step in direction of greatest local decrease
with torch.no_grad():
-= lr * x.grad x
However this may take us off of our constraint curve. As long as we are at a point x
on the constraint curve, constraint(x) = 0
, we want to stay on that curve. That means we want to take a step in the direction where the derivative of the constraint is zero (so the value won’t change). This happens in the direction orthogonal to the gradient of the constraint, which can be done by removing the component parallel to the constraints gradient. In Pytorch:
def gradient_descent_step(x, loss, constraint, lr=1e-3):
# Set gradient to zero
if x.grad is not None:
x.grad.zero_()
# Calculate gradient of the constraint
constraint(x).backward()= x.grad.clone()
direction
# Calculate gradient of loss function
x.grad.zero_()= loss()
aloss
aloss.backward()
# Remove the projection of the loss gradient onto the constraint gradient.
# The resulting vector will be perpendicular to the gradient of the constraint.
= x.grad - (x.grad @ direction) / (direction @ direction) * direction
perp_proj
# Step in this direction
with torch.no_grad():
-= lr * perp_proj x
We can put this into highfalutin differential geometry terminology. By the implicit function theorem the equality constraint defines a submanifold of the overall space (except in pathological regions, but there’s often an area where this is true). We want to optimise the loss function on this submanifold. This is done by projecting the derivative of the loss function on the manifold to the tangent space of the submanifold defined by the constraints.
In fact the method of Lagrange Multipliers solves this exact problem, however on my problem I had difficulty getting the point back to the constraint curve. However this method of projection worked really well. We could potentially use a similar approach for an inequality constraint; by first searching in the interior of the region and applying the gradient projection along the boundary (Boyd and Vandenberghe’s Convex Optimisation has the details).