Finite-difference score-matching of a two-dimensional Gaussian mixture model
Introduction
Here, we modify the previous finite-difference score-matching example to fit a two-dimensional model.
Julia language setup
We use the Julia programming language with suitable packages.
using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
We set the random seed for reproducibility purposes.
rng = Xoshiro(12345)
Data
We build the target model and draw samples from it. This time the target model is a bivariate random variable.
xrange = range(-8, 8, 120)
yrange = range(-8, 8, 120)
dx = Float64(xrange.step)
dy = Float64(yrange.step)
target_prob = MixtureModel([MvNormal([-3, -3], [1 0; 0 1]), MvNormal([3, 3], [1 0; 0 1]), MvNormal([-1, 1], [1 0; 0 1])], [0.4, 0.4, 0.2])
target_pdf = [pdf(target_prob, [x, y]) for y in yrange, x in xrange]
target_score = reduce(hcat, gradlogpdf(target_prob, [x, y]) for y in yrange, x in xrange)
2×14400 Matrix{Float64}:
5.0 5.0 5.0 5.0 … -5.0 -5.0 -5.0 -5.0
5.0 4.86555 4.73109 4.59664 -4.59664 -4.73109 -4.86555 -5.0
sample_points = rand(rng, target_prob, 1024)
2×1024 Matrix{Float64}:
-3.69692 2.55677 -0.317683 -1.16593 … 2.92216 -3.03365 5.16729
-3.64622 3.4103 1.34462 -0.286077 2.96657 0.669646 3.21557
surface(xrange, yrange, target_pdf, title="PDF", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], [pdf(target_prob, [x, y]) for (x, y) in eachcol(sample_points)], markercolor=:lightgreen, markersize=2, alpha=0.5)
heatmap(xrange, yrange, target_pdf, title="PDF", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
surface(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], [logpdf(target_prob, [x, y]) for (x, y) in eachcol(sample_points)], markercolor=:lightgreen, alpha=0.5, markersize=2)
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
xx, yy = meshgrid(xrange[begin:8:end], yrange[begin:8:end])
uu = reduce(hcat, gradlogpdf(target_prob, [x, y]) for (x, y) in zip(xx, yy))
2×225 Matrix{Float64}:
5.0 3.92437 2.84874 1.77311 0.697479 … -1.90756 -2.98319 -4.05882
5.0 5.0 5.0 5.0 5.0 -4.05882 -4.05882 -4.05882
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score function (vector field)", titlefont=10, legend=false, color=:vik)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
The neural network model
The neural network we consider is again a simple feed-forward neural network made of a single hidden layer. For the 2d case, we need to bump it a little bit, doubling the width of the hidden layer.
model = Chain(Dense(2 => 16, relu), Dense(16 => 2))
Chain(
layer_1 = Dense(2 => 16, relu), # 48 parameters
layer_2 = Dense(16 => 2), # 34 parameters
) # Total: 82 parameters,
# plus 0 states.
The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup
function, giving us the parameters and the state of the model.
ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[-1.2633736 1.8181845; 2.3867178 0.5713208; … ; -0.7349689 -1.3411301; -1.167693 -1.8823313], bias = Float32[0.48686042, 0.32835403, 0.23500215, -0.6289243, -0.4877348, -0.08177762, 0.57667726, -0.06088416, 0.20568033, -0.66487825, 0.17736456, -0.20175992, 0.4999213, -0.5164128, 0.5682744, 0.33539677]), layer_2 = (weight = Float32[0.3586208 -0.057023764 … -0.17773739 -0.41380876; 0.12927775 -0.22223856 … 0.06703623 0.18231718], bias = Float32[0.067408144, 0.110440105])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Loss functions for score-matching
The loss function is again based on Aapo Hyvärinen (2005), combined with the work of Pang, Xu, Li, Song, Ermon, and Zhu (2020) using finite differences to approximate the divergence of the modeled score function.
In the multidimensional case, say on $\mathbb{R}^d$, $d\in\mathbb{N}$, the explicit score matching loss function is given by
\[ J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = \frac{1}{2}\int_{\mathbb{R}^d} p_{\mathbf{X}}(\mathbf{x}) \|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x})\|^2\;\mathrm{d}\mathbf{x};\]
where $p_{\mathbf{X}}(\mathbf{x})$ is the PDF of the target distribution.
The integration by parts in the expectation yields $J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C$, where $C$ is constant with respect to the parameters and the implicit score matching loss function $J_{\mathrm{ISM}}({\boldsymbol{\theta}})$ is given by
\[ J_{\mathrm{ISM}}({\boldsymbol{\theta}}) = \int_{\mathbb{R}} p_{\mathbf{X}}(\mathbf{x}) \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) \right)\;\mathrm{d}\mathbf{x},\]
which does not involve the unknown score function of ${\mathbf{X}}$. It does, however, involve the divergence of the modeled score function, which is expensive to compute.
In practice, the loss function is estimated via the empirical distribution, so the unknown $p_{\mathbf{X}}(\mathbf{x})$ is handled implicitly by the sample data $(\mathbf{x}_n)_n$, and we minimize the empirical implicit score matching loss function
\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) \right).\]
Componentwise, with $\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) = (\psi_i(\mathbf{x}; {\boldsymbol{\theta}}))_{i=1}^d$, this is written as
\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \left( \frac{1}{2}\psi_i(\mathbf{x}_n; {\boldsymbol{\theta}})^2 + \frac{\partial}{\partial x_i} \psi_i(\mathbf{x}_n; {\boldsymbol{\theta}}) \right).\]
As mentioned before, computing a derivative to form the loss function becomes expensive when combined with the usual optimization methods to fit a neural network, as they require the gradient of the loss function itself, so we approximate the derivative of the modeled score function by centered finite differences. With the model calculated at the displaced points, we just average them to avoid computing the model at the sample point itself. This leads to the empirical finite-difference (implicit) score matching loss function
\[ {\tilde J}_{\mathrm{FDSM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \Bigg( \frac{1}{2}\left(\frac{1}{d}\sum_{j=1}^d \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{2}\right)^2 \\ \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg).\]
Since this is a synthetic problem and we actually know the target distribution, we implement the empirical explicit score matching loss function
\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) = \frac{1}{2}\frac{1}{N}\sum_{n=1}^N \|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x}_n)\|^2.\]
This is used as a sure check whether the neural network is sufficient to model the score function and for checking the optimization process, since in theory this should be roughly (apart from the approximations by the empirical distribution, the finite-difference approximation, and the round-off errors) a constant different from the loss function for ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}$.
Implementation of ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$
In the two-dimensional case, $d = 2$, this becomes
\[ \begin{align*} {\tilde J}_{\mathrm{FDSM}{\tilde p}_0} & = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \Bigg( \frac{1}{2}\left(\frac{1}{d}\sum_{j=1}^d \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{2}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg) \\ & = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^2 \Bigg( \frac{1}{2}\left(\sum_{j=1}^2 \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{4}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg) \\ & = \frac{1}{N} \frac{1}{2} \sum_{n=1}^N \sum_{i=1}^2 \left(\sum_{j=1}^2 \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{4}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad + \frac{1}{N}\sum_{n=1}^N \frac{\psi_1(\mathbf{x}_n + \delta\mathbf{e}_1; {\boldsymbol{\theta}}) - \psi_1(\mathbf{x}_n - \delta\mathbf{e}_1; {\boldsymbol{\theta}})}{2\delta} \\ & \qquad \qquad \qquad \qquad \qquad \qquad + \frac{1}{N}\sum_{n=1}^N \frac{\psi_2(\mathbf{x}_n + \delta\mathbf{e}_2; {\boldsymbol{\theta}}) - \psi_2(\mathbf{x}_n - \delta\mathbf{e}_2; {\boldsymbol{\theta}})}{2\delta} \end{align*}\]
function loss_function(model, ps, st, data)
sample_points, deltax, deltay = data
s_pred_fwd_x, = Lux.apply(model, sample_points .+ [deltax, 0.0], ps, st)
s_pred_bwd_x, = Lux.apply(model, sample_points .- [deltax, 0.0], ps, st)
s_pred_fwd_y, = Lux.apply(model, sample_points .+ [0.0, deltay], ps, st)
s_pred_bwd_y, = Lux.apply(model, sample_points .- [0.0, deltay], ps, st)
s_pred = ( s_pred_bwd_x .+ s_pred_fwd_x .+ s_pred_bwd_y .+ s_pred_fwd_y) ./ 4
dsdx_pred = (s_pred_fwd_x .- s_pred_bwd_x ) ./ 2deltax
dsdy_pred = (s_pred_fwd_y .- s_pred_bwd_y ) ./ 2deltay
loss = mean(abs2, s_pred) + mean(view(dsdx_pred, 1, :)) + mean(view(dsdy_pred, 2, :))
return loss, st, ()
end
loss_function (generic function with 1 method)
We included the steps for the finite difference computations in the data
passed to training to avoid repeated computations.
xmin, xmax = extrema(sample_points[1, :])
ymin, ymax = extrema(sample_points[2, :])
deltax, deltay = (xmax - xmin) / 2size(sample_points, 2), (ymax - ymin) / 2size(sample_points, 2)
(0.00581614558657555, 0.006593323926327331)
data = sample_points, deltax, deltay
([-3.696922040577957 2.556772674293886 … -3.033651509981303 5.1672947152544895; -3.6462237060140024 3.4102960875997117 … 0.6696459295015311 3.2155718822745527], 0.00581614558657555, 0.006593323926327331)
Implementation of ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}})$
As a sanity check, we also include the empirical explicit score matching loss function, which uses the know score functions of the target model.
In the two-dimensional case, this is simply the mean square value of all the components.
\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) = \frac{1}{2} \frac{1}{N}\sum_{n=1}^N \|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x}_n)\|^2 = \frac{1}{2} \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^2 \left(\psi_i(\mathbf{x}_n; {\boldsymbol{\theta}}) - \psi_{\mathbf{X}, i}(\mathbf{x}_n) \right)^2.\]
function loss_function_cheat(model, ps, st, data)
sample_points, score_cheat = data
score_pred, st = Lux.apply(model, sample_points, ps, st)
loss = mean(abs2, score_pred .- score_cheat)
return loss, st, ()
end
loss_function_cheat (generic function with 1 method)
The data in this case includes information about the target distribution.
score_cheat = reduce(hcat, gradlogpdf(target_prob, u) for u in eachcol(sample_points))
data_cheat = sample_points, score_cheat
([-3.696922040577957 2.556772674293886 … -3.033651509981303 5.1672947152544895; -3.6462237060140024 3.4102960875997117 … 0.6696459295015311 3.2155718822745527], [0.6969228899748048 0.44299201282532824 … 1.9946701191708671 -2.1672947253904584; 0.6462254048076983 -0.41041374404010544 … 0.2523912079807282 -0.2155718873425372])
Computing the constant
The expression ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx {\tilde J}_{\mathrm{ISM}{\tilde p}_0}({\boldsymbol{\theta}}) + C$ can be used to test the implementation of the different loss functions. For that, we need to compute the constant $C$. This can be computed with a fine mesh or with a Monte-Carlo approximation. We do both just for fun.
function compute_constante(target_prob, xrange, yrange)
dx = Float64(xrange.step)
dy = Float64(yrange.step)
Jconstant = sum(pdf(target_prob, [x, y]) * sum(abs2, gradlogpdf(target_prob, [x, y])) for y in yrange, x in xrange) * dx * dy / 2
return Jconstant
end
compute_constante (generic function with 1 method)
function compute_constante_MC(target_prob, sample_points)
Jconstant = mean(sum(abs2, gradlogpdf(target_prob, s)) for s in eachcol(sample_points)) / 2
return Jconstant
end
compute_constante_MC (generic function with 1 method)
Jconstant = compute_constante(target_prob, xrange, yrange)
0.8921462840364986
Jconstant_MC = compute_constante_MC(target_prob, sample_points)
0.9129656846610863
constants = [(n, compute_constante_MC(target_prob, rand(rng, target_prob, n))) for _ in 1:100 for n in (1, 10, 20, 50, 100, 500, 1000, 2000, 4000)]
900-element Vector{Tuple{Int64, Float64}}:
(1, 1.1538119449520563)
(10, 1.2247351227578578)
(20, 1.1028524320443522)
(50, 0.9920784868026302)
(100, 1.0063955226226398)
(500, 0.8755541733725885)
(1000, 0.8342517573145984)
(2000, 0.8759484336636649)
(4000, 0.8827405034687669)
(1, 0.9485837438780702)
⋮
(1, 0.298790143447742)
(10, 0.7794066301161721)
(20, 1.1364257738853651)
(50, 0.6101320519101868)
(100, 0.8696468922755639)
(500, 0.9416724546394305)
(1000, 0.9063918160454828)
(2000, 0.8914872767709621)
(4000, 0.871944900929463)
scatter(constants, markersize=2, title="constant computed by MC and fine mesh", titlefont=10, xlabel="sample size", ylabel="value", label="via various samples")
hline!([Jconstant], label="via fine mesh")
hline!([Jconstant_MC], label="via working sample", linestyle=:dash)
A test for the implementations of the loss functions
Notice that, for a sufficiently large sample and sufficiently small discretization step $\delta$, we should have
\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C \approx {\tilde J}_{\mathrm{FDSM}}({\boldsymbol{\theta}}) + C \approx {\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}}) + C.\]
which is a good test for the implementations of the loss functions. For example:
first(loss_function_cheat(model, ps, st, data_cheat))
32.62664847852629
first(loss_function(model, ps, st, data)) + Jconstant
32.56518812592081
Let us do a more statistically significant test.
test_losses = reduce(
hcat,
Lux.setup(rng, model) |> pstj ->
[
first(loss_function_cheat(model, pstj[1], pstj[2], data_cheat)),
first(loss_function(model, pstj[1], pstj[2], data))
]
for _ in 1:30
)
2×30 Matrix{Float64}:
13.7323 38.8598 5.71619 9.86731 … 38.9853 20.4021 14.7908 7.77597
13.0453 37.4968 4.90646 8.84147 38.2617 19.6932 13.6616 7.00106
plot(title="Loss functions at random model parameters", titlefont=10)
scatter!(test_losses[1, :], label="\${\\tilde J}_{\\mathrm{ESM}{\\tilde p}_0}\$")
scatter!(test_losses[2, :], label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0}\$")
scatter!(test_losses[2, :] .+ Jconstant, label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0} + C\$")
One can check by visual inspection that the agreement between ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) - C$ and ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$ seems reasonably good. Let us estimate the relative error.
rel_errors = abs.( ( test_losses[2, :] .+ Jconstant .- test_losses[1, :] ) ./ test_losses[1, :] )
plot(title="Relative error at random model parameters", titlefont=10, legend=false)
scatter!(rel_errors, markercolor=2, label="error")
mm = mean(rel_errors)
mmstd = std(rel_errors)
hline!([mm], label="mean")
hspan!([mm+mmstd, mm-mmstd], fillbetween=true, alpha=0.3, label="65% margin")
Ok, good enough, just a few percentage points.
An extra test for the implementations of the loss functions and the gradient computation
We also have
\[ \boldsymbol{\nabla}_{\boldsymbol{\theta}} {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx \boldsymbol{\nabla}_{\boldsymbol{\theta}} {\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}}),\]
which is another good test, which also checks the gradient computation, but everything seems fine, so no need to push this further.
Optimization setup
Optimization method
As usual, we use the ADAM optimization.
opt = Adam(0.003)
tstate_org = Lux.Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing)
# of parameters: 82
# of states: 0
optimizer: Optimisers.Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8)
step: 0
Automatic differentiation in the optimization
FluxML/Zygote.jl is used for the automatic differentiation as it is currently the only AD backend working with LuxDL/Lux.jl.
vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()
Processor
We use the CPU instead of the GPU.
dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Check differentiation
Check if AD is working fine to differentiate the loss functions for training.
Lux.Training.compute_gradients(vjp_rule, loss_function, data, tstate_org)
((layer_1 = (weight = Float32[3.6879158 4.0172615; -2.2284822 -2.6868043; … ; -1.9008694 -1.7861443; -4.612377 -4.320793], bias = Float32[0.8997364, -0.8043432, 1.7370119, 0.45905542, 0.8397131, -1.230505, 0.353755, -1.2257671, 0.9869879, 1.0354862, 0.52109236, 1.3021913, 1.5245895, 0.877233, 0.5858979, 1.4446392]), layer_2 = (weight = Float32[4.2770233 22.516876 … -17.53038 -24.572052; 11.982891 26.250607 … 15.06171 21.036867], bias = Float32[-0.113342285, 5.354232])), 31.673041841884316, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing), (layer_1 = (weight = Float32[-1.2633736 1.8181845; 2.3867178 0.5713208; … ; -0.7349689 -1.3411301; -1.167693 -1.8823313], bias = Float32[0.48686042, 0.32835403, 0.23500215, -0.6289243, -0.4877348, -0.08177762, 0.57667726, -0.06088416, 0.20568033, -0.66487825, 0.17736456, -0.20175992, 0.4999213, -0.5164128, 0.5682744, 0.33539677]), layer_2 = (weight = Float32[0.3586208 -0.057023764 … -0.17773739 -0.41380876; 0.12927775 -0.22223856 … 0.06703623 0.18231718], bias = Float32[0.067408144, 0.110440105])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Optimisers.Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))))), 0))
Lux.Training.compute_gradients(vjp_rule, loss_function_cheat, data_cheat, tstate_org)
((layer_1 = (weight = Float32[3.689433 4.0178785; -2.264155 -2.618133; … ; -1.9069793 -1.7834694; -4.6186814 -4.3203063], bias = Float32[0.8609152, -0.7731151, 1.725719, 0.44558188, 0.84764415, -1.251586, 0.36021093, -1.1504575, 0.975368, 1.044547, 0.5048005, 1.1946136, 1.5058247, 0.893234, 0.58936614, 1.429456]), layer_2 = (weight = Float32[4.235089 22.786686 … -17.630789 -24.70937; 11.946889 26.341808 … 14.838343 20.73272], bias = Float32[-0.11527908, 5.321644])), 32.62664847852629, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing), (layer_1 = (weight = Float32[-1.2633736 1.8181845; 2.3867178 0.5713208; … ; -0.7349689 -1.3411301; -1.167693 -1.8823313], bias = Float32[0.48686042, 0.32835403, 0.23500215, -0.6289243, -0.4877348, -0.08177762, 0.57667726, -0.06088416, 0.20568033, -0.66487825, 0.17736456, -0.20175992, 0.4999213, -0.5164128, 0.5682744, 0.33539677]), layer_2 = (weight = Float32[0.3586208 -0.057023764 … -0.17773739 -0.41380876; 0.12927775 -0.22223856 … 0.06703623 0.18231718], bias = Float32[0.067408144, 0.110440105])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Optimisers.Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))))), 0))
Training loop
Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration and the model state for animation.
function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
losses = zeros(epochs)
tstates = [(0, tstate)]
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
loss_function, data, tstate)
if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
println("Epoch: $(epoch) || Loss: $(loss)")
end
if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
push!(tstates, (epoch, tstate))
end
losses[epoch] = loss
tstate = Lux.Training.apply_gradients(tstate, grads)
end
return tstate, losses, tstates
end
train (generic function with 3 methods)
Cheat training with ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}$
We first train the model with the known score function on the sample data. That is cheating. The aim is a sanity check, to make sure the proposed model is good enough to fit the desired score function and that the setup is right.
@time tstate_cheat, losses_cheat, tstates_cheat = train(tstate_org, vjp_rule, data_cheat, loss_function_cheat, 2000, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/1B1qw/src/impl/matmul.jl:190
Epoch: 100 || Loss: 0.8401849278869697
Epoch: 200 || Loss: 0.563607700266948
Epoch: 300 || Loss: 0.4926258151863607
Epoch: 400 || Loss: 0.45830559661430026
Epoch: 500 || Loss: 0.43358616510264614
Epoch: 600 || Loss: 0.41197109354966255
Epoch: 700 || Loss: 0.39327886995951344
Epoch: 800 || Loss: 0.3768739639070251
Epoch: 900 || Loss: 0.3631936077433342
Epoch: 1000 || Loss: 0.35101209994291793
Epoch: 1100 || Loss: 0.33997820230228853
Epoch: 1200 || Loss: 0.3293390319395758
Epoch: 1300 || Loss: 0.31880814244768374
Epoch: 1400 || Loss: 0.30846754084144945
Epoch: 1500 || Loss: 0.2984211525235665
Epoch: 1600 || Loss: 0.28827943967632763
Epoch: 1700 || Loss: 0.2778414128971065
Epoch: 1800 || Loss: 0.26736572655519075
Epoch: 1900 || Loss: 0.25641855404231345
Epoch: 2000 || Loss: 0.2448096823452019
0.510287 seconds (722.74 k allocations: 955.932 MiB, 12.15% gc time, 7.54% compilation time)
Testing out the trained model.
uu_cheat = Lux.apply(tstate_cheat.model, vcat(xx', yy'), tstate_cheat.parameters, tstate_cheat.states)[1]
2×225 Matrix{Float64}:
2.19923 1.53509 0.870944 0.206801 … 0.502545 -0.277666 -1.05788
0.997473 1.47914 1.9608 2.44247 -3.00036 -2.75735 -2.51434
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score functions (vector fields)", titlefont=10, color=:vik, xlims=extrema(xrange), ylims=extrema(yrange), legend=false)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
quiver!(xx, yy, quiver = (uu_cheat[1, :] ./ 8, uu_cheat[2, :] ./ 8), color=:cyan, alpha=0.5)

plot(losses_cheat, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Real training with ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}$
Now we go to the real thing.
@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function, 2000, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/1B1qw/src/impl/matmul.jl:190
Epoch: 100 || Loss: -0.028526071882063422
Epoch: 200 || Loss: -0.3056820935385293
Epoch: 300 || Loss: -0.3735720043248265
Epoch: 400 || Loss: -0.4060071208695759
Epoch: 500 || Loss: -0.42700676492413925
Epoch: 600 || Loss: -0.4442805461267297
Epoch: 700 || Loss: -0.45911840778124424
Epoch: 800 || Loss: -0.4711387063320857
Epoch: 900 || Loss: -0.48093733644650705
Epoch: 1000 || Loss: -0.48972825080776355
Epoch: 1100 || Loss: -0.49795548481343427
Epoch: 1200 || Loss: -0.5057248136083738
Epoch: 1300 || Loss: -0.5133231683129973
Epoch: 1400 || Loss: -0.5203089294538508
Epoch: 1500 || Loss: -0.5271916186760042
Epoch: 1600 || Loss: -0.5339186475303448
Epoch: 1700 || Loss: -0.5407054119759426
Epoch: 1800 || Loss: -0.5470494965278492
Epoch: 1900 || Loss: -0.5533731989980991
Epoch: 2000 || Loss: -0.5598464396520446
1.851211 seconds (1.36 M allocations: 4.064 GiB, 10.31% gc time, 2.10% compilation time)
Testing out the trained model.
uu_pred = Lux.apply(tstate.model, vcat(xx', yy'), tstate.parameters, tstate.states)[1]
2×225 Matrix{Float64}:
1.71235 1.17107 0.606229 -0.0336345 … 1.14829 0.627492 0.106691
-0.196228 0.441023 1.20491 1.90302 -2.18941 -2.07645 -1.9635
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score functions (vector fields)", titlefont=10, color=:vik, xlims=extrema(xrange), ylims=extrema(yrange), legend=false)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
quiver!(xx, yy, quiver = (uu_pred[1, :] ./ 8, uu_pred[2, :] ./ 8), color=:cyan, alpha=0.5)

plot(losses, title="Evolution of the losses", titlefont=10, xlabel="iteration", ylabel="error", label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0}\$")
plot!(losses_cheat, linestyle=:dash, label="\${\\tilde J}_{\\mathrm{ESM}{\\tilde p}_0}\$")
plot!(losses .+ Jconstant, linestyle=:dash, color=1, label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0} + C\$")
Ok, that seems visually good enough. We will later check the sampling from this score function via Langevin sampling.