Linear Regression in several ways
The plan is to do a simple linear regression in julia, in several different ways. We'll use plain least squares Base.:\, the genereral linear model package JuliaStats/GLM.jl and the probabilistic programming package Turing.jl.
using Distributions, GLM, Turing, StatsPlotsThe test data
This is a simple test set. We just generate a synthetic sample with a bunch of points approximating a straight line. We actually create two tests, an unperturbed straight line and a perturbed one.
num_points = 20
xx = range(0.0, 1.0, length=num_points)
intercept = 1.0
slope = 2.0
ε = 0.1
yy = intercept .+ slope * xx
yy_perturbed = yy .+ ε * randn(num_points)
plt = plot(title="Synthetic data", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")Straightforward least squares
The least square solution $\hat x$ to a linear problem $Ax = b$ is the vector $\hat x$ that minimizes the sum of the squares of the residuals $b - Ax$, i.e.
\[ \hat x = \argmin_{x} \|Ax - b\|^2\]
The solution is obtained by solving the normal equation
\[ (A^t A)\hat x = A^t b\]
In julia, $\hat x$ can be found by using the function Base.:\, which is actually a polyalgorithm, meaning it uses different algorithms depending on the shape and type of $A$.
- When $A$ is an invertible square matrix, then $\hat x = A^{-1}b$ is the unique solution of $Ax = b$.
- When $A$ has more rows than columns and the columns are linearly independent, then $\hat x = (A^tA)^{-1}A^tb$ is the unique least square approximation solution of the overdetermined system $Ax = b$.
- When $A$ has more columns than rows and the rows are linearly independent, then $\hat x = A^t(AA^t)^{-1}b$ is the unique least norm solution of the underdetermined system $Ax = b$.
- In all other cases, attempting to solve $A\b$ throws an error.
First we build the Vandermonde matrix.
A = [ones(length(xx)) xx]
size(A)(20, 2)Now we solve the least square problem with the unperturbed data, solving it explicitly with $\hat x = (A^tA)^{-1}A^tb$ and via $A \setminus b$, and checking both against the original slope and intercept.
betahat = inv(transpose(A) * A) * transpose(A) * yy
betahat ≈ A \ yy ≈ [intercept, slope]trueNow we solve it with the perturbed data
betahat = inv(transpose(A) * A) * transpose(A) * yy_perturbed
betahat ≈ A \ yy_perturbedtrueWe extract the intercept and slope of the fitted line
intercepthat, slopehat = betahat2-element Vector{Float64}:
1.0107601468755272
2.0339119675408757and use that to build the fitted line
yy_hat = intercepthat .+ slopehat * xx1.0107601468755272:0.10704799829162503:3.044672114416403which looks as follows
plt = plot(title="Synthetic data and least square fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy_hat, label="fitted line")Bayesian linear regression with Turing.jl
Now we use Turing.jl to fit via Bayesian inference. We start by defining a model, as a compound distribution.
@model function regfit(x, y)
σ² ~ InverseGamma()
σ = sqrt(σ²)
intercept ~ Normal(0.0, 10.0)
slope ~ Normal(0.0, 10.0)
for i in eachindex(x)
v = intercept + slope * x[i]
y[i] ~ Normal(v, σ)
end
endregfit (generic function with 2 methods)Let's use the Hamiltonian Monte Carlo method to infer the parameters of the model.
model = regfit(xx, yy_perturbed)
chain = sample(model, HMC(0.05, 10), 4_000)Chains MCMC chain (4000×15×1 Array{Union{Missing, Float64}, 3}):
Iterations = 1:1:4000
Number of chains = 1
Samples per chain = 4000
Wall duration = 9.34 seconds
Compute duration = 9.34 seconds
parameters = σ², intercept, slope
internals = hamiltonian_energy, n_steps, numerical_error, loglikelihood, hamiltonian_energy_error, is_accept, logprior, log_density, step_size, acceptance_rate, lp, nom_step_size
Use `describe(chains)` for summary statistics and quantiles.
plot(chain)We can access again the summary statistics for the Markov Chain generated by the sampler as follows.
summarize(chain)
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64? Float64 Float64 ⋯
σ² 0.2031 4.4787 0.0856 3126.0594 2521.2373 0.9998 ⋯
intercept 1.0121 0.1643 0.0027 3690.9735 3002.3033 1.0003 ⋯
slope 2.0378 0.2637 0.0045 3779.3300 2846.3180 0.9998 ⋯
1 column omitted
The whole values of the parameters computed along the chain are accessible via chain[[:intercept]].value, chain[[:slope]].value, and chain[[:σ²]].value
We can also directly compute the mean of each of them,
mean(chain, [:intercept, :slope, :σ²])Mean
parameters mean
Symbol Float64
σ² 0.2031
intercept 1.0121
slope 2.0378
and use them to find the Bayesian fitted regression line
yy_bayes = mean(chain, :intercept) .+ mean(chain, :slope) * xx1.0120701488710526:0.10725067367606578:3.0498329487163023plt = plot(title="Synthetic data and LSQ and Turing fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy, label="unperturbed line")
plot!(plt, xx, yy_hat, label="LSQ fitted line")
plot!(plt, xx, yy_bayes, label="Bayesian fitted line")The least square and Bayesian fits are pretty close to each other:
extrema(yy_hat - yy_bayes)(-0.005160834299899664, -0.001310001995525445)Now it remains to compute and plot the credible intervals. First, we plot the ensemble of lines generated with the chain.
plt = plot(title="Synthetic data and Turing posterior", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, label="Bayesian fitted line", color=2)
for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))
plot!(plt, xx, c .+ m * xx, alpha=0.01, color=2, label=false)
end
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)Now we use again the values computed along the chain to find the credible interval at each point $x$.
quantiles = reduce(hcat, quantile([c + m * x for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))], [0.025, 0.975]) for x in xx)2×20 Matrix{Float64}:
0.710318 0.840986 0.968195 1.09584 … 2.49992 2.5866 2.67254 2.7595
1.30502 1.39072 1.47936 1.56468 2.95358 3.08212 3.2113 3.34646With the computed quantiles, we are ready to plot the Bayesian fit with the credible interval.
plt = plot(title="Synthetic data and Turing fit with 95% credible interval", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, ribbon=(yy_bayes .- view(quantiles, 1, :), view(quantiles, 2, :) .- yy_bayes), label="Bayesian fitted line", color=2)
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
plot!(plt, xx, yy, label="unperturbed line", color=1)