Linear Regression in several ways

The plan is to do a simple linear regression in julia, in several different ways. We'll use plain least squares Base.:\, the genereral linear model package JuliaStats/GLM.jl and the probabilistic programming package Turing.jl.

using Distributions, GLM, Turing, StatsPlots

The test data

This is a simple test set. We just generate a synthetic sample with a bunch of points approximating a straight line. We actually create two tests, an unperturbed straight line and a perturbed one.

num_points = 20
xx = range(0.0, 1.0, length=num_points)

intercept = 1.0
slope = 2.0
ε = 0.1

yy = intercept .+ slope * xx

yy_perturbed = yy .+ ε * randn(num_points)

plt = plot(title="Synthetic data", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
Example block output

Straightforward least squares

The least square solution $\hat x$ to a linear problem $Ax = b$ is the vector $\hat x$ that minimizes the sum of the squares of the residuals $b - Ax$, i.e.

\[ \hat x = \argmin_{x} \|Ax - b\|^2\]

The solution is obtained by solving the normal equation

\[ (A^t A)\hat x = A^t b\]

In julia, $\hat x$ can be found by using the function Base.:\, which is actually a polyalgorithm, meaning it uses different algorithms depending on the shape and type of $A$.

  1. When $A$ is an invertible square matrix, then $\hat x = A^{-1}b$ is the unique solution of $Ax = b$.
  2. When $A$ has more rows than columns and the columns are linearly independent, then $\hat x = (A^tA)^{-1}A^tb$ is the unique least square approximation solution of the overdetermined system $Ax = b$.
  3. When $A$ has more columns than rows and the rows are linearly independent, then $\hat x = A^t(AA^t)^{-1}b$ is the unique least norm solution of the underdetermined system $Ax = b$.
  4. In all other cases, attempting to solve $A\b$ throws an error.

First we build the Vandermonde matrix.

A = [ones(length(xx)) xx]
size(A)
(20, 2)

Now we solve the least square problem with the unperturbed data, solving it explicitly with $\hat x = (A^tA)^{-1}A^tb$ and via $A \setminus b$, and checking both against the original slope and intercept.

betahat = inv(transpose(A) * A) * transpose(A) * yy

betahat ≈ A \ yy ≈ [intercept, slope]
true

Now we solve it with the perturbed data

betahat = inv(transpose(A) * A) * transpose(A) * yy_perturbed

betahat ≈ A \ yy_perturbed
true

We extract the intercept and slope of the fitted line

intercepthat, slopehat = betahat
2-element Vector{Float64}:
 1.0107601468755272
 2.0339119675408757

and use that to build the fitted line

yy_hat = intercepthat .+ slopehat * xx
1.0107601468755272:0.10704799829162503:3.044672114416403

which looks as follows

plt = plot(title="Synthetic data and least square fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy_hat, label="fitted line")
Example block output

Bayesian linear regression with Turing.jl

Now we use Turing.jl to fit via Bayesian inference. We start by defining a model, as a compound distribution.

@model function regfit(x, y)
    σ² ~ InverseGamma()
    σ = sqrt(σ²)
    intercept ~ Normal(0.0, 10.0)
    slope ~ Normal(0.0, 10.0)

    for i in eachindex(x)
        v = intercept + slope * x[i]
        y[i] ~ Normal(v, σ)
    end
end
regfit (generic function with 2 methods)

Let's use the Hamiltonian Monte Carlo method to infer the parameters of the model.

model = regfit(xx, yy_perturbed)

chain = sample(model, HMC(0.05, 10), 4_000)
Chains MCMC chain (4000×15×1 Array{Union{Missing, Float64}, 3}):

Iterations        = 1:1:4000
Number of chains  = 1
Samples per chain = 4000
Wall duration     = 9.34 seconds
Compute duration  = 9.34 seconds
parameters        = σ², intercept, slope
internals         = hamiltonian_energy, n_steps, numerical_error, loglikelihood, hamiltonian_energy_error, is_accept, logprior, log_density, step_size, acceptance_rate, lp, nom_step_size

Use `describe(chains)` for summary statistics and quantiles.
plot(chain)
Example block output

We can access again the summary statistics for the Markov Chain generated by the sampler as follows.

summarize(chain)


  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64    Float64?     Float64   Float64   ⋯

          σ²    0.2031    4.4787    0.0856   3126.0594   2521.2373    0.9998   ⋯
   intercept    1.0121    0.1643    0.0027   3690.9735   3002.3033    1.0003   ⋯
       slope    2.0378    0.2637    0.0045   3779.3300   2846.3180    0.9998   ⋯

                                                                1 column omitted

The whole values of the parameters computed along the chain are accessible via chain[[:intercept]].value, chain[[:slope]].value, and chain[[:σ²]].value

We can also directly compute the mean of each of them,

mean(chain, [:intercept, :slope, :σ²])
Mean

  parameters      mean
      Symbol   Float64

          σ²    0.2031
   intercept    1.0121
       slope    2.0378

and use them to find the Bayesian fitted regression line

yy_bayes = mean(chain, :intercept) .+ mean(chain, :slope) * xx
1.0120701488710526:0.10725067367606578:3.0498329487163023
plt = plot(title="Synthetic data and LSQ and Turing fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy, label="unperturbed line")
plot!(plt, xx, yy_hat, label="LSQ fitted line")
plot!(plt, xx, yy_bayes, label="Bayesian fitted line")
Example block output

The least square and Bayesian fits are pretty close to each other:

extrema(yy_hat - yy_bayes)
(-0.005160834299899664, -0.001310001995525445)

Now it remains to compute and plot the credible intervals. First, we plot the ensemble of lines generated with the chain.

plt = plot(title="Synthetic data and Turing posterior", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, label="Bayesian fitted line", color=2)
for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))
    plot!(plt, xx, c .+ m * xx, alpha=0.01, color=2, label=false)
end
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
Example block output

Now we use again the values computed along the chain to find the credible interval at each point $x$.

quantiles = reduce(hcat, quantile([c + m * x for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))], [0.025, 0.975]) for x in xx)
2×20 Matrix{Float64}:
 0.710318  0.840986  0.968195  1.09584  …  2.49992  2.5866   2.67254  2.7595
 1.30502   1.39072   1.47936   1.56468     2.95358  3.08212  3.2113   3.34646

With the computed quantiles, we are ready to plot the Bayesian fit with the credible interval.

plt = plot(title="Synthetic data and Turing fit with 95% credible interval", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, ribbon=(yy_bayes .- view(quantiles, 1, :), view(quantiles, 2, :) .- yy_bayes), label="Bayesian fitted line", color=2)
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
plot!(plt, xx, yy, label="unperturbed line", color=1)
Example block output