Illustrated using Julia

Published

April 21, 2023

An example of estimating linear regression beta coefficients via gradient descent, using Julia.

``````using Random
using ForwardDiff
using Distributions
using Statistics``````

# Generate Data

First, we generate some fake data

``````Random.seed!(0408)

#x data
𝐗 = hcat(ones(1000), randn(1000, 3))

#ground truth betas
𝚩 = [.5, 1, 2, 3]

#multiply data by betas
f₁(X) = X*𝚩

#make some error
ϵ = rand(Normal(0, .5), size(𝐗))

#generate y
y = f₁(𝐗) + ϵ;``````

# Define a Loss Function

Mean squared error is the most straightforward

``````function mse_loss(X, y, b)
ŷ = X*b

l = mean((y .- ŷ).^2)

return l
end``````
``mse_loss (generic function with 1 method)``

# Define a training function

This implements the gradient descent algorithm:

• initialize some random beta values
• initialize error as some very large number (the init value doesn’t really matter as long as it’s greater than the function’s `tol` parameter)
• initialize the number of iterations (`iter`) at 0
• define a function `d()` to get the gradient of the loss function at a given set of betas
• define a loop that updates the beta values by the learning rate * the gradients until convergence
``````function grad_descent(X, y; lr = .01, tol = .01, max_iter = 1_000, noisy = false)
#randomly initialize betas
β = rand(size(X))

#init error to something large
err = 1e10

#initialize iterations at 0
iter = 0

#define a function to get the gradient of the loss function at a given set of betas
d(b) = ForwardDiff.gradient(params -> mse_loss(X, y, params), b)

while err > tol && iter < max_iter
β -= lr*d(β)
err = mse_loss(X, y, β)
if (noisy == true)
println("Iteration \$(iter): current error is \$(err)")
end
iter += 1
end
return β
end``````
``grad_descent (generic function with 1 method)``

# Estimate βs

To estimate the betas, we just run the function

``b = grad_descent(𝐗, y)``
``````4-element Vector{Float64}:
0.5220524143318362
0.992503536801155
1.9951668587882012
2.997961983119764``````

# Check Solution Against Base Julia Solver

``𝐗 \ y .≈ b``
``````4-element BitVector:
1
1
1
1``````

huzzah!