Linear Regression in several ways

The plan is to do a simple linear regression in julia, in several different ways. We'll use plain least squares Base.:\, the genereral linear model package JuliaStats/GLM.jl and the probabilistic programming package Turing.jl.

using Distributions, GLM, Turing, StatsPlots

The test data

This is a simple test set. We just generate a synthetic sample with a bunch of points approximating a straight line. We actually create two tests, an unperturbed straight line and a perturbed one.

num_points = 20
xx = range(0.0, 1.0, length=num_points)

intercept = 1.0
slope = 2.0
ε = 0.1

yy = intercept .+ slope * xx

yy_perturbed = yy .+ ε * randn(num_points)

plt = plot(title="Synthetic data", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
Example block output

Straightforward least squares

The least square solution $\hat x$ to a linear problem $Ax = b$ is the vector $\hat x$ that minimizes the sum of the squares of the residuals $b - Ax$, i.e.

\[ \hat x = \argmin_{x} \|Ax - b\|^2\]

The solution is obtained by solving the normal equation

\[ (A^t A)\hat x = A^t b\]

In julia, $\hat x$ can be found by using the function Base.:\, which is actually a polyalgorithm, meaning it uses different algorithms depending on the shape and type of $A$.

  1. When $A$ is an invertible square matrix, then $\hat x = A^{-1}b$ is the unique solution of $Ax = b$.
  2. When $A$ has more rows than columns and the columns are linearly independent, then $\hat x = (A^tA)^{-1}A^tb$ is the unique least square approximation solution of the overdetermined system $Ax = b$.
  3. When $A$ has more columns than rows and the rows are linearly independent, then $\hat x = A^t(AA^t)^{-1}b$ is the unique least norm solution of the underdetermined system $Ax = b$.
  4. In all other cases, attempting to solve $A\b$ throws an error.

First we build the Vandermonde matrix.

A = [ones(length(xx)) xx]
size(A)
(20, 2)

Now we solve the least square problem with the unperturbed data, solving it explicitly with $\hat x = (A^tA)^{-1}A^tb$ and via $A \setminus b$, and checking both against the original slope and intercept.

betahat = inv(transpose(A) * A) * transpose(A) * yy

betahat ≈ A \ yy ≈ [intercept, slope]
true

Now we solve it with the perturbed data

betahat = inv(transpose(A) * A) * transpose(A) * yy_perturbed

betahat ≈ A \ yy_perturbed
true

We extract the intercept and slope of the fitted line

intercepthat, slopehat = betahat
2-element Vector{Float64}:
 1.0107601468755272
 2.0339119675408757

and use that to build the fitted line

yy_hat = intercepthat .+ slopehat * xx
1.0107601468755272:0.10704799829162503:3.044672114416403

which looks as follows

plt = plot(title="Synthetic data and least square fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy_hat, label="fitted line")
Example block output

Bayesian linear regression with Turing.jl

Now we use Turing.jl to fit via Bayesian inference. We start by defining a model, as a compound distribution.

@model function regfit(x, y)
    σ² ~ InverseGamma()
    σ = sqrt(σ²)
    intercept ~ Normal(0.0, 10.0)
    slope ~ Normal(0.0, 10.0)

    for i in eachindex(x)
        v = intercept + slope * x[i]
        y[i] ~ Normal(v, σ)
    end
end
regfit (generic function with 2 methods)

Let's use the Hamiltonian Monte Carlo method to infer the parameters of the model.

model = regfit(xx, yy_perturbed)

chain = sample(model, HMC(0.05, 10), 4_000)
Chains MCMC chain (4000×13×1 Array{Float64, 3}):

Iterations        = 1:1:4000
Number of chains  = 1
Samples per chain = 4000
Wall duration     = 8.46 seconds
Compute duration  = 8.46 seconds
parameters        = σ², intercept, slope
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean       std      mcse   ess_bulk   ess_tail      rhat    ⋯
      Symbol    Float64   Float64   Float64    Float64    Float64   Float64    ⋯

          σ²     0.2883    0.0000    0.0000        NaN        NaN       NaN    ⋯
   intercept    20.2894    0.0000    0.0000        NaN        NaN       NaN    ⋯
       slope   -10.3999    0.0000    0.0000        NaN        NaN       NaN    ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5%
      Symbol    Float64    Float64    Float64    Float64    Float64

          σ²     0.2883     0.2883     0.2883     0.2883     0.2883
   intercept    20.2894    20.2894    20.2894    20.2894    20.2894
       slope   -10.3999   -10.3999   -10.3999   -10.3999   -10.3999
plot(chain)
Example block output

We can access again the summary statistics for the Markov Chain generated by the sampler as follows.

summarize(chain)

  parameters       mean       std      mcse   ess_bulk   ess_tail      rhat    ⋯
      Symbol    Float64   Float64   Float64    Float64    Float64   Float64    ⋯

          σ²     0.2883    0.0000    0.0000        NaN        NaN       NaN    ⋯
   intercept    20.2894    0.0000    0.0000        NaN        NaN       NaN    ⋯
       slope   -10.3999    0.0000    0.0000        NaN        NaN       NaN    ⋯
                                                                1 column omitted

The whole values of the parameters computed along the chain are accessible via chain[[:intercept]].value, chain[[:slope]].value, and chain[[:σ²]].value

We can also directly compute the mean of each of them,

mean(chain, [:intercept, :slope, :σ²])
Mean
  parameters       mean
      Symbol    Float64

          σ²     0.2883
   intercept    20.2894
       slope   -10.3999

and use them to find the Bayesian fitted regression line

yy_bayes = mean(chain, :intercept) .+ mean(chain, :slope) * xx
20.289370833314997:-0.5473653537203736:9.889429112627898
plt = plot(title="Synthetic data and LSQ and Turing fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy, label="unperturbed line")
plot!(plt, xx, yy_hat, label="LSQ fitted line")
plot!(plt, xx, yy_bayes, label="Bayesian fitted line")
Example block output

The least square and Bayesian fits are pretty close to each other:

extrema(yy_hat - yy_bayes)
(-19.27861068643947, -6.844756998211495)

Now it remains to compute and plot the credible intervals. First, we plot the ensemble of lines generated with the chain.

plt = plot(title="Synthetic data and Turing posterior", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, label="Bayesian fitted line", color=2)
for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))
    plot!(plt, xx, c .+ m * xx, alpha=0.01, color=2, label=false)
end
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
Example block output

Now we use again the values computed along the chain to find the credible interval at each point $x$.

quantiles = reduce(hcat, quantile([c + m * x for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))], [0.025, 0.975]) for x in xx)
2×20 Matrix{Float64}:
 20.2894  19.742  19.1946  18.6473  …  11.5315  10.9842  10.4368  9.88943
 20.2894  19.742  19.1946  18.6473     11.5315  10.9842  10.4368  9.88943

With the computed quantiles, we are ready to plot the Bayesian fit with the credible interval.

plt = plot(title="Synthetic data and Turing fit with 95% credible interval", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, ribbon=(yy_bayes .- view(quantiles, 1, :), view(quantiles, 2, :) .- yy_bayes), label="Bayesian fitted line", color=2)
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
plot!(plt, xx, yy, label="unperturbed line", color=1)
Example block output