Linear Regression in several ways

The plan is to do a simple linear regression in julia, in several different ways. We'll use plain least squares Base.:\, the genereral linear model package JuliaStats/GLM.jl and the probabilistic programming package Turing.jl.

using Distributions, GLM, Turing, StatsPlots

The test data

This is a simple test set. We just generate a synthetic sample with a bunch of points approximating a straight line. We actually create two tests, an unperturbed straight line and a perturbed one.

num_points = 20
xx = range(0.0, 1.0, length=num_points)

intercept = 1.0
slope = 2.0
ε = 0.1

yy = intercept .+ slope * xx

yy_perturbed = yy .+ ε * randn(num_points)

plt = plot(title="Synthetic data", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
Example block output

Straightforward least squares

The least square solution $\hat x$ to a linear problem $Ax = b$ is the vector $\hat x$ that minimizes the sum of the squares of the residuals $b - Ax$, i.e.

\[ \hat x = \argmin_{x} \|Ax - b\|^2\]

The solution is obtained by solving the normal equation

\[ (A^t A)\hat x = A^t b\]

In julia, $\hat x$ can be found by using the function Base.:\, which is actually a polyalgorithm, meaning it uses different algorithms depending on the shape and type of $A$.

  1. When $A$ is an invertible square matrix, then $\hat x = A^{-1}b$ is the unique solution of $Ax = b$.
  2. When $A$ has more rows than columns and the columns are linearly independent, then $\hat x = (A^tA)^{-1}A^tb$ is the unique least square approximation solution of the overdetermined system $Ax = b$.
  3. When $A$ has more columns than rows and the rows are linearly independent, then $\hat x = A^t(AA^t)^{-1}b$ is the unique least norm solution of the underdetermined system $Ax = b$.
  4. In all other cases, attempting to solve $A\b$ throws an error.

First we build the Vandermonde matrix.

A = [ones(length(xx)) xx]
size(A)
(20, 2)

Now we solve the least square problem with the unperturbed data, solving it explicitly with $\hat x = (A^tA)^{-1}A^tb$ and via $A \setminus b$, and checking both against the original slope and intercept.

betahat = inv(transpose(A) * A) * transpose(A) * yy

betahat ≈ A \ yy ≈ [intercept, slope]
true

Now we solve it with the perturbed data

betahat = inv(transpose(A) * A) * transpose(A) * yy_perturbed

betahat ≈ A \ yy_perturbed
true

We extract the intercept and slope of the fitted line

intercepthat, slopehat = betahat
2-element Vector{Float64}:
 1.0107601468755272
 2.0339119675408757

and use that to build the fitted line

yy_hat = intercepthat .+ slopehat * xx
1.0107601468755272:0.10704799829162503:3.044672114416403

which looks as follows

plt = plot(title="Synthetic data and least square fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy_hat, label="fitted line")
Example block output

Bayesian linear regression with Turing.jl

Now we use Turing.jl to fit via Bayesian inference. We start by defining a model, as a compound distribution.

@model function regfit(x, y)
    σ² ~ InverseGamma()
    σ = sqrt(σ²)
    intercept ~ Normal(0.0, 10.0)
    slope ~ Normal(0.0, 10.0)

    for i in eachindex(x)
        v = intercept + slope * x[i]
        y[i] ~ Normal(v, σ)
    end
end
regfit (generic function with 2 methods)

Let's use the Hamiltonian Monte Carlo method to infer the parameters of the model.

model = regfit(xx, yy_perturbed)

chain = sample(model, HMC(0.05, 10), 4_000)
Chains MCMC chain (4000×13×1 Array{Float64, 3}):

Iterations        = 1:1:4000
Number of chains  = 1
Samples per chain = 4000
Wall duration     = 9.54 seconds
Compute duration  = 9.54 seconds
parameters        = σ², intercept, slope
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Use `describe(chains)` for summary statistics and quantiles.
plot(chain)
Example block output

We can access again the summary statistics for the Markov Chain generated by the sampler as follows.

summarize(chain)

  parameters       mean       std      mcse   ess_bulk   ess_tail      rhat    ⋯
      Symbol    Float64   Float64   Float64    Float64    Float64   Float64    ⋯

          σ²     0.2883    0.0000    0.0000        NaN        NaN       NaN    ⋯
   intercept    20.2894    0.0000    0.0000        NaN        NaN       NaN    ⋯
       slope   -10.3999    0.0000    0.0000        NaN        NaN       NaN    ⋯
                                                                1 column omitted

The whole values of the parameters computed along the chain are accessible via chain[[:intercept]].value, chain[[:slope]].value, and chain[[:σ²]].value

We can also directly compute the mean of each of them,

mean(chain, [:intercept, :slope, :σ²])
Mean
  parameters       mean
      Symbol    Float64

          σ²     0.2883
   intercept    20.2894
       slope   -10.3999

and use them to find the Bayesian fitted regression line

yy_bayes = mean(chain, :intercept) .+ mean(chain, :slope) * xx
20.289370833314997:-0.5473653537203736:9.889429112627898
plt = plot(title="Synthetic data and LSQ and Turing fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy, label="unperturbed line")
plot!(plt, xx, yy_hat, label="LSQ fitted line")
plot!(plt, xx, yy_bayes, label="Bayesian fitted line")
Example block output

The least square and Bayesian fits are pretty close to each other:

extrema(yy_hat - yy_bayes)
(-19.27861068643947, -6.844756998211495)

Now it remains to compute and plot the credible intervals. First, we plot the ensemble of lines generated with the chain.

plt = plot(title="Synthetic data and Turing posterior", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, label="Bayesian fitted line", color=2)
for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))
    plot!(plt, xx, c .+ m * xx, alpha=0.01, color=2, label=false)
end
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
Example block output

Now we use again the values computed along the chain to find the credible interval at each point $x$.

quantiles = reduce(hcat, quantile([c + m * x for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))], [0.025, 0.975]) for x in xx)
2×20 Matrix{Float64}:
 20.2894  19.742  19.1946  18.6473  …  11.5315  10.9842  10.4368  9.88943
 20.2894  19.742  19.1946  18.6473     11.5315  10.9842  10.4368  9.88943

With the computed quantiles, we are ready to plot the Bayesian fit with the credible interval.

plt = plot(title="Synthetic data and Turing fit with 95% credible interval", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, ribbon=(yy_bayes .- view(quantiles, 1, :), view(quantiles, 2, :) .- yy_bayes), label="Bayesian fitted line", color=2)
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
plot!(plt, xx, yy, label="unperturbed line", color=1)
Example block output