Linear Regression in several ways
The plan is to do a simple linear regression in julia, in several different ways. We'll use plain least squares Base.:\
, the genereral linear model package JuliaStats/GLM.jl and the probabilistic programming package Turing.jl.
using Distributions, GLM, Turing, StatsPlots
The test data
This is a simple test set. We just generate a synthetic sample with a bunch of points approximating a straight line. We actually create two tests, an unperturbed straight line and a perturbed one.
num_points = 20
xx = range(0.0, 1.0, length=num_points)
intercept = 1.0
slope = 2.0
ε = 0.1
yy = intercept .+ slope * xx
yy_perturbed = yy .+ ε * randn(num_points)
plt = plot(title="Synthetic data", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
Straightforward least squares
The least square solution $\hat x$ to a linear problem $Ax = b$ is the vector $\hat x$ that minimizes the sum of the squares of the residuals $b - Ax$, i.e.
\[ \hat x = \argmin_{x} \|Ax - b\|^2\]
The solution is obtained by solving the normal equation
\[ (A^t A)\hat x = A^t b\]
In julia, $\hat x$ can be found by using the function Base.:\
, which is actually a polyalgorithm, meaning it uses different algorithms depending on the shape and type of $A$.
- When $A$ is an invertible square matrix, then $\hat x = A^{-1}b$ is the unique solution of $Ax = b$.
- When $A$ has more rows than columns and the columns are linearly independent, then $\hat x = (A^tA)^{-1}A^tb$ is the unique least square approximation solution of the overdetermined system $Ax = b$.
- When $A$ has more columns than rows and the rows are linearly independent, then $\hat x = A^t(AA^t)^{-1}b$ is the unique least norm solution of the underdetermined system $Ax = b$.
- In all other cases, attempting to solve $A\b$ throws an error.
First we build the Vandermonde matrix.
A = [ones(length(xx)) xx]
size(A)
(20, 2)
Now we solve the least square problem with the unperturbed data, solving it explicitly with $\hat x = (A^tA)^{-1}A^tb$ and via $A \setminus b$, and checking both against the original slope and intercept.
betahat = inv(transpose(A) * A) * transpose(A) * yy
betahat ≈ A \ yy ≈ [intercept, slope]
true
Now we solve it with the perturbed data
betahat = inv(transpose(A) * A) * transpose(A) * yy_perturbed
betahat ≈ A \ yy_perturbed
true
We extract the intercept and slope of the fitted line
intercepthat, slopehat = betahat
2-element Vector{Float64}:
1.0107601468755272
2.0339119675408757
and use that to build the fitted line
yy_hat = intercepthat .+ slopehat * xx
1.0107601468755272:0.10704799829162503:3.044672114416403
which looks as follows
plt = plot(title="Synthetic data and least square fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy, label="unperturbed line")
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy_hat, label="fitted line")
Bayesian linear regression with Turing.jl
Now we use Turing.jl
to fit via Bayesian inference. We start by defining a model, as a compound distribution.
@model function regfit(x, y)
σ² ~ InverseGamma()
σ = sqrt(σ²)
intercept ~ Normal(0.0, 10.0)
slope ~ Normal(0.0, 10.0)
for i in eachindex(x)
v = intercept + slope * x[i]
y[i] ~ Normal(v, σ)
end
end
regfit (generic function with 2 methods)
Let's use the Hamiltonian Monte Carlo method to infer the parameters of the model.
model = regfit(xx, yy_perturbed)
chain = sample(model, HMC(0.05, 10), 4_000)
Chains MCMC chain (4000×13×1 Array{Float64, 3}):
Iterations = 1:1:4000
Number of chains = 1
Samples per chain = 4000
Wall duration = 8.46 seconds
Compute duration = 8.46 seconds
parameters = σ², intercept, slope
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
σ² 0.2883 0.0000 0.0000 NaN NaN NaN ⋯
intercept 20.2894 0.0000 0.0000 NaN NaN NaN ⋯
slope -10.3999 0.0000 0.0000 NaN NaN NaN ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
σ² 0.2883 0.2883 0.2883 0.2883 0.2883
intercept 20.2894 20.2894 20.2894 20.2894 20.2894
slope -10.3999 -10.3999 -10.3999 -10.3999 -10.3999
plot(chain)
We can access again the summary statistics for the Markov Chain generated by the sampler as follows.
summarize(chain)
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
σ² 0.2883 0.0000 0.0000 NaN NaN NaN ⋯
intercept 20.2894 0.0000 0.0000 NaN NaN NaN ⋯
slope -10.3999 0.0000 0.0000 NaN NaN NaN ⋯
1 column omitted
The whole values of the parameters computed along the chain are accessible via chain[[:intercept]].value
, chain[[:slope]].value
, and chain[[:σ²]].value
We can also directly compute the mean of each of them,
mean(chain, [:intercept, :slope, :σ²])
Mean
parameters mean
Symbol Float64
σ² 0.2883
intercept 20.2894
slope -10.3999
and use them to find the Bayesian fitted regression line
yy_bayes = mean(chain, :intercept) .+ mean(chain, :slope) * xx
20.289370833314997:-0.5473653537203736:9.889429112627898
plt = plot(title="Synthetic data and LSQ and Turing fit", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
scatter!(plt, xx, yy_perturbed, label="perturbed sample")
plot!(plt, xx, yy, label="unperturbed line")
plot!(plt, xx, yy_hat, label="LSQ fitted line")
plot!(plt, xx, yy_bayes, label="Bayesian fitted line")
The least square and Bayesian fits are pretty close to each other:
extrema(yy_hat - yy_bayes)
(-19.27861068643947, -6.844756998211495)
Now it remains to compute and plot the credible intervals. First, we plot the ensemble of lines generated with the chain.
plt = plot(title="Synthetic data and Turing posterior", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, label="Bayesian fitted line", color=2)
for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))
plot!(plt, xx, c .+ m * xx, alpha=0.01, color=2, label=false)
end
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
Now we use again the values computed along the chain to find the credible interval at each point $x$.
quantiles = reduce(hcat, quantile([c + m * x for (c, m) in eachrow(view(chain.value.data, :, 2:3, 1))], [0.025, 0.975]) for x in xx)
2×20 Matrix{Float64}:
20.2894 19.742 19.1946 18.6473 … 11.5315 10.9842 10.4368 9.88943
20.2894 19.742 19.1946 18.6473 11.5315 10.9842 10.4368 9.88943
With the computed quantiles, we are ready to plot the Bayesian fit with the credible interval.
plt = plot(title="Synthetic data and Turing fit with 95% credible interval", titlefont=10, ylims=(0.0, 1.1 * (intercept + slope)))
plot!(plt, xx, yy_bayes, ribbon=(yy_bayes .- view(quantiles, 1, :), view(quantiles, 2, :) .- yy_bayes), label="Bayesian fitted line", color=2)
scatter!(plt, xx, yy_perturbed, label="perturbed sample", color=1)
plot!(plt, xx, yy, label="unperturbed line", color=1)