Gradient Descent

Illustrated using Julia

Published

April 21, 2023

An example of estimating linear regression beta coefficients via gradient descent, using Julia.

using Random
using ForwardDiff
using Distributions
using Statistics

Generate Data

First, we generate some fake data

Random.seed!(0408)

#x data
𝐗 = hcat(ones(1000), randn(1000, 3))

#ground truth betas
𝚩 = [.5, 1, 2, 3]

#multiply data by betas
f₁(X) = X*𝚩

#make some error
ϵ = rand(Normal(0, .5), size(𝐗)[1])

#generate y
y = f₁(𝐗) + ϵ;

Define a Loss Function

Mean squared error is the most straightforward

function mse_loss(X, y, b)
= X*b

    l = mean((y .- ŷ).^2)

    return l
end
mse_loss (generic function with 1 method)

Define a training function

This implements the gradient descent algorithm:

  • initialize some random beta values
  • initialize error as some very large number (the init value doesn’t really matter as long as it’s greater than the function’s tol parameter)
  • initialize the number of iterations (iter) at 0
  • define a function d() to get the gradient of the loss function at a given set of betas
  • define a loop that updates the beta values by the learning rate * the gradients until convergence
function grad_descent(X, y; lr = .01, tol = .01, max_iter = 1_000, noisy = false)
   #randomly initialize betas
   β = rand(size(X)[2])
   
    #init error to something large
    err = 1e10

    #initialize iterations at 0
    iter = 0

    #define a function to get the gradient of the loss function at a given set of betas
    d(b) = ForwardDiff.gradient(params -> mse_loss(X, y, params), b)

    while err > tol && iter < max_iter
        β -= lr*d(β)
        err = mse_loss(X, y, β)
        if (noisy == true)
            println("Iteration $(iter): current error is $(err)")
        end
        iter += 1
    end
    return β
end
grad_descent (generic function with 1 method)

Estimate βs

To estimate the betas, we just run the function

b = grad_descent(𝐗, y)
4-element Vector{Float64}:
 0.5220524143318362
 0.992503536801155
 1.9951668587882012
 2.997961983119764

Check Solution Against Base Julia Solver

𝐗 \ y .≈ b
4-element BitVector:
 1
 1
 1
 1

huzzah!