Finite-difference score-matching of a two-dimensional Gaussian mixture model
Introduction
Here, we modify the previous finite-difference score-matching example to fit a two-dimensional model.
Julia language setup
We use the Julia programming language with suitable packages.
using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
We set the random seed for reproducibility purposes.
rng = Xoshiro(12345)
Data
We build the target model and draw samples from it. This time the target model is a bivariate random variable.
xrange = range(-8, 8, 120)
yrange = range(-8, 8, 120)
dx = Float64(xrange.step)
dy = Float64(yrange.step)
target_prob = MixtureModel([MvNormal([-3, -3], [1 0; 0 1]), MvNormal([3, 3], [1 0; 0 1]), MvNormal([-1, 1], [1 0; 0 1])], [0.4, 0.4, 0.2])
target_pdf = [pdf(target_prob, [x, y]) for y in yrange, x in xrange]
target_score = reduce(hcat, gradlogpdf(target_prob, [x, y]) for y in yrange, x in xrange)
2×14400 Matrix{Float64}:
5.0 5.0 5.0 5.0 … -5.0 -5.0 -5.0 -5.0
5.0 4.86555 4.73109 4.59664 -4.59664 -4.73109 -4.86555 -5.0
sample_points = rand(rng, target_prob, 1024)
2×1024 Matrix{Float64}:
-3.69692 2.55677 -0.317683 -1.16593 … 2.92216 -3.03365 5.16729
-3.64622 3.4103 1.34462 -0.286077 2.96657 0.669646 3.21557
surface(xrange, yrange, target_pdf, title="PDF", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], [pdf(target_prob, [x, y]) for (x, y) in eachcol(sample_points)], markercolor=:lightgreen, markersize=2, alpha=0.5)
heatmap(xrange, yrange, target_pdf, title="PDF", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
surface(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], [logpdf(target_prob, [x, y]) for (x, y) in eachcol(sample_points)], markercolor=:lightgreen, alpha=0.5, markersize=2)
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
xx, yy = meshgrid(xrange[begin:8:end], yrange[begin:8:end])
uu = reduce(hcat, gradlogpdf(target_prob, [x, y]) for (x, y) in zip(xx, yy))
2×225 Matrix{Float64}:
5.0 3.92437 2.84874 1.77311 0.697479 … -1.90756 -2.98319 -4.05882
5.0 5.0 5.0 5.0 5.0 -4.05882 -4.05882 -4.05882
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score function (vector field)", titlefont=10, legend=false, color=:vik)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
The neural network model
The neural network we consider is again a simple feed-forward neural network made of a single hidden layer. For the 2d case, we need to bump it a little bit, doubling the width of the hidden layer.
model = Chain(Dense(2 => 16, relu), Dense(16 => 2))
Chain(
layer_1 = Dense(2 => 16, relu), # 48 parameters
layer_2 = Dense(16 => 2), # 34 parameters
) # Total: 82 parameters,
# plus 0 states.
The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup
function, giving us the parameters and the state of the model.
ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[-0.29777998 0.42855018; 0.5625548 0.13466159; … ; -0.17323382 -0.3161074; -0.27522787 -0.4436697], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.39751986 0.19187844 … 0.085118644 0.51215607; 0.26809993 -0.5135145 … 0.1600962 -0.24594013], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Loss functions for score-matching
The loss function is again based on Aapo Hyvärinen (2005), combined with the work of Pang, Xu, Li, Song, Ermon, and Zhu (2020) using finite differences to approximate the divergence of the modeled score function.
In the multidimensional case, say on $\mathbb{R}^d$, $d\in\mathbb{N}$, the explicit score matching loss function is given by
\[ J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = \frac{1}{2}\int_{\mathbb{R}^d} p_{\mathbf{X}}(\mathbf{x}) \|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x})\|^2\;\mathrm{d}\mathbf{x};\]
where $p_{\mathbf{X}}(\mathbf{x})$ is the PDF of the target distribution.
The integration by parts in the expectation yields $J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C$, where $C$ is constant with respect to the parameters and the implicit score matching loss function $J_{\mathrm{ISM}}({\boldsymbol{\theta}})$ is given by
\[ J_{\mathrm{ISM}}({\boldsymbol{\theta}}) = \int_{\mathbb{R}} p_{\mathbf{X}}(\mathbf{x}) \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) \right)\;\mathrm{d}\mathbf{x},\]
which does not involve the unknown score function of ${\mathbf{X}}$. It does, however, involve the divergence of the modeled score function, which is expensive to compute.
In practice, the loss function is estimated via the empirical distribution, so the unknown $p_{\mathbf{X}}(\mathbf{x})$ is handled implicitly by the sample data $(\mathbf{x}_n)_n$, and we minimize the empirical implicit score matching loss function
\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) \right).\]
Componentwise, with $\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) = (\psi_i(\mathbf{x}; {\boldsymbol{\theta}}))_{i=1}^d$, this is written as
\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \left( \frac{1}{2}\psi_i(\mathbf{x}_n; {\boldsymbol{\theta}})^2 + \frac{\partial}{\partial x_i} \psi_i(\mathbf{x}_n; {\boldsymbol{\theta}}) \right).\]
As mentioned before, computing a derivative to form the loss function becomes expensive when combined with the usual optimization methods to fit a neural network, as they require the gradient of the loss function itself, so we approximate the derivative of the modeled score function by centered finite differences. With the model calculated at the displaced points, we just average them to avoid computing the model at the sample point itself. This leads to the empirical finite-difference (implicit) score matching loss function
\[ {\tilde J}_{\mathrm{FDSM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \Bigg( \frac{1}{2}\left(\frac{1}{d}\sum_{j=1}^d \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{2}\right)^2 \\ \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg).\]
Since this is a synthetic problem and we actually know the target distribution, we implement the empirical explicit score matching loss function
\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) = \frac{1}{2}\frac{1}{N}\sum_{n=1}^N \|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x}_n)\|^2.\]
This is used as a sure check whether the neural network is sufficient to model the score function and for checking the optimization process, since in theory this should be roughly (apart from the approximations by the empirical distribution, the finite-difference approximation, and the round-off errors) a constant different from the loss function for ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}$.
Implementation of ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$
In the two-dimensional case, $d = 2$, this becomes
\[ \begin{align*} {\tilde J}_{\mathrm{FDSM}{\tilde p}_0} & = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \Bigg( \frac{1}{2}\left(\frac{1}{d}\sum_{j=1}^d \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{2}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg) \\ & = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^2 \Bigg( \frac{1}{2}\left(\sum_{j=1}^2 \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{4}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg) \\ & = \frac{1}{N} \frac{1}{2} \sum_{n=1}^N \sum_{i=1}^2 \left(\sum_{j=1}^2 \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{4}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad + \frac{1}{N}\sum_{n=1}^N \frac{\psi_1(\mathbf{x}_n + \delta\mathbf{e}_1; {\boldsymbol{\theta}}) - \psi_1(\mathbf{x}_n - \delta\mathbf{e}_1; {\boldsymbol{\theta}})}{2\delta} \\ & \qquad \qquad \qquad \qquad \qquad \qquad + \frac{1}{N}\sum_{n=1}^N \frac{\psi_2(\mathbf{x}_n + \delta\mathbf{e}_2; {\boldsymbol{\theta}}) - \psi_2(\mathbf{x}_n - \delta\mathbf{e}_2; {\boldsymbol{\theta}})}{2\delta} \end{align*}\]
function loss_function(model, ps, st, data)
sample_points, deltax, deltay = data
s_pred_fwd_x, = Lux.apply(model, sample_points .+ [deltax, 0.0], ps, st)
s_pred_bwd_x, = Lux.apply(model, sample_points .- [deltax, 0.0], ps, st)
s_pred_fwd_y, = Lux.apply(model, sample_points .+ [0.0, deltay], ps, st)
s_pred_bwd_y, = Lux.apply(model, sample_points .- [0.0, deltay], ps, st)
s_pred = ( s_pred_bwd_x .+ s_pred_fwd_x .+ s_pred_bwd_y .+ s_pred_fwd_y) ./ 4
dsdx_pred = (s_pred_fwd_x .- s_pred_bwd_x ) ./ 2deltax
dsdy_pred = (s_pred_fwd_y .- s_pred_bwd_y ) ./ 2deltay
loss = mean(abs2, s_pred) + mean(view(dsdx_pred, 1, :)) + mean(view(dsdy_pred, 2, :))
return loss, st, ()
end
loss_function (generic function with 1 method)
We included the steps for the finite difference computations in the data
passed to training to avoid repeated computations.
xmin, xmax = extrema(sample_points[1, :])
ymin, ymax = extrema(sample_points[2, :])
deltax, deltay = (xmax - xmin) / 2size(sample_points, 2), (ymax - ymin) / 2size(sample_points, 2)
(0.00581614558657555, 0.006467469210160699)
data = sample_points, deltax, deltay
([-3.696922040577957 2.556772674293886 … -3.033651509981303 5.1672947152544895; -3.6462237060140024 3.4102960875997117 … 0.6696459295015311 3.2155718822745527], 0.00581614558657555, 0.006467469210160699)
Implementation of ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}})$
As a sanity check, we also include the empirical explicit score matching loss function, which uses the know score functions of the target model.
In the two-dimensional case, this is simply the mean square value of all the components.
\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) = \frac{1}{2} \frac{1}{N}\sum_{n=1}^N \|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x}_n)\|^2 = \frac{1}{2} \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^2 \left(\psi_i(\mathbf{x}_n; {\boldsymbol{\theta}}) - \psi_{\mathbf{X}, i}(\mathbf{x}_n) \right)^2.\]
function loss_function_cheat(model, ps, st, data)
sample_points, score_cheat = data
score_pred, st = Lux.apply(model, sample_points, ps, st)
loss = mean(abs2, score_pred .- score_cheat)
return loss, st, ()
end
loss_function_cheat (generic function with 1 method)
The data in this case includes information about the target distribution.
score_cheat = reduce(hcat, gradlogpdf(target_prob, u) for u in eachcol(sample_points))
data_cheat = sample_points, score_cheat
([-3.696922040577957 2.556772674293886 … -3.033651509981303 5.1672947152544895; -3.6462237060140024 3.4102960875997117 … 0.6696459295015311 3.2155718822745527], [0.6969228899748048 0.44299201282532824 … 1.9946701191708671 -2.1672947253904584; 0.6462254048076983 -0.41041374404010544 … 0.2523912079807282 -0.2155718873425372])
Computing the constant
The expression ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx {\tilde J}_{\mathrm{ISM}{\tilde p}_0}({\boldsymbol{\theta}}) + C$ can be used to test the implementation of the different loss functions. For that, we need to compute the constant $C$. This can be computed with a fine mesh or with a Monte-Carlo approximation. We do both just for fun.
function compute_constante(target_prob, xrange, yrange)
dx = Float64(xrange.step)
dy = Float64(yrange.step)
Jconstant = sum(pdf(target_prob, [x, y]) * sum(abs2, gradlogpdf(target_prob, [x, y])) for y in yrange, x in xrange) * dx * dy / 2
return Jconstant
end
compute_constante (generic function with 1 method)
function compute_constante_MC(target_prob, sample_points)
Jconstant = mean(sum(abs2, gradlogpdf(target_prob, s)) for s in eachcol(sample_points)) / 2
return Jconstant
end
compute_constante_MC (generic function with 1 method)
Jconstant = compute_constante(target_prob, xrange, yrange)
0.8921462840364986
Jconstant_MC = compute_constante_MC(target_prob, sample_points)
0.9119907821889877
constants = [(n, compute_constante_MC(target_prob, rand(rng, target_prob, n))) for _ in 1:100 for n in (1, 10, 20, 50, 100, 500, 1000, 2000, 4000)]
900-element Vector{Tuple{Int64, Float64}}:
(1, 1.1324719316165153)
(10, 0.5574422233620775)
(20, 1.044989829805957)
(50, 1.0978032717040624)
(100, 1.0246691007312956)
(500, 0.8688543773742903)
(1000, 0.8357253711152717)
(2000, 0.877040331952378)
(4000, 0.8823978968841848)
(1, 0.4579093796260303)
⋮
(1, 0.47998329055623734)
(10, 1.4521777722883298)
(20, 0.9039503903514523)
(50, 0.7744734260262902)
(100, 0.8494876063783782)
(500, 0.9380236094631614)
(1000, 0.9086250828841758)
(2000, 0.8938509703138009)
(4000, 0.8685813200425203)
scatter(constants, markersize=2, title="constant computed by MC and fine mesh", titlefont=10, xlabel="sample size", ylabel="value", label="via various samples")
hline!([Jconstant], label="via fine mesh")
hline!([Jconstant_MC], label="via working sample", linestyle=:dash)
A test for the implementations of the loss functions
Notice that, for a sufficiently large sample and sufficiently small discretization step $\delta$, we should have
\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C \approx {\tilde J}_{\mathrm{FDSM}}({\boldsymbol{\theta}}) + C \approx {\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}}) + C.\]
which is a good test for the implementations of the loss functions. For example:
first(loss_function_cheat(model, ps, st, data_cheat))
2.0082112647167785
first(loss_function(model, ps, st, data)) + Jconstant
1.98284845662185
Let us do a more statistically significant test.
test_losses = reduce(
hcat,
Lux.setup(rng, model) |> pstj ->
[
first(loss_function_cheat(model, pstj[1], pstj[2], data_cheat)),
first(loss_function(model, pstj[1], pstj[2], data))
]
for _ in 1:30
)
2×30 Matrix{Float64}:
2.94316 3.8376 1.89918 2.11019 … 2.10937 2.01107 1.48007 2.47952
2.00378 2.82067 1.00093 1.20077 1.29642 1.03138 0.573855 1.53678
plot(title="Loss functions at random model parameters", titlefont=10)
scatter!(test_losses[1, :], label="\${\\tilde J}_{\\mathrm{ESM}{\\tilde p}_0}\$")
scatter!(test_losses[2, :], label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0}\$")
scatter!(test_losses[2, :] .+ Jconstant, label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0} + C\$")
One can check by visual inspection that the agreement between ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) - C$ and ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$ seems reasonably good. Let us estimate the relative error.
rel_errors = abs.( ( test_losses[2, :] .+ Jconstant .- test_losses[1, :] ) ./ test_losses[1, :] )
plot(title="Relative error at random model parameters", titlefont=10, legend=false)
scatter!(rel_errors, markercolor=2, label="error")
mm = mean(rel_errors)
mmstd = std(rel_errors)
hline!([mm], label="mean")
hspan!([mm+mmstd, mm-mmstd], fillbetween=true, alpha=0.3, label="65% margin")
Ok, good enough, just a few percentage points.
An extra test for the implementations of the loss functions and the gradient computation
We also have
\[ \boldsymbol{\nabla}_{\boldsymbol{\theta}} {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx \boldsymbol{\nabla}_{\boldsymbol{\theta}} {\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}}),\]
which is another good test, which also checks the gradient computation, but everything seems fine, so no need to push this further.
Optimization setup
Optimization method
As usual, we use the ADAM optimization.
opt = Adam(0.003)
tstate_org = Lux.Training.TrainState(rng, model, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing)
# of parameters: 82
# of states: 0
optimizer: Adam(0.003, (0.9, 0.999), 1.0e-8)
step: 0
Automatic differentiation in the optimization
FluxML/Zygote.jl is used for the automatic differentiation as it is currently the only AD backend working with LuxDL/Lux.jl.
vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()
Processor
We use the CPU instead of the GPU.
dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::LuxDeviceUtils.LuxCPUDevice) (generic function with 5 methods)
Check differentiation
Check if AD is working fine to differentiate the loss functions for training.
Lux.Training.compute_gradients(vjp_rule, loss_function, data, tstate_org)
((layer_1 = (weight = Float32[-0.28156662 -0.33435822; 0.3774376 0.010066986; … ; 0.13616276 0.1693325; -0.09726906 -0.0073611736], bias = Float32[0.034161568; 0.12943935; … ; -0.042779922; -0.05421412;;]), layer_2 = (weight = Float32[-0.30857086 0.06406593 … -0.059645653 0.13012981; -0.18228635 -0.12837453 … 0.009949422 -0.068630844], bias = Float32[0.31969452; 0.1720464;;])), 0.2138332083221143, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing), (layer_1 = (weight = Float32[-0.38554642 -0.20945442; 0.2078916 -0.3777418; … ; -0.101795964 0.048385024; 0.2549977 -0.3194088], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.27677727 0.38801375 … 0.14126216 0.043854997; -0.5111967 -0.38798138 … 0.2772176 0.4102923], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.003, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0;;], Float32[0.0; 0.0;;], (0.9, 0.999))))), 0))
Lux.Training.compute_gradients(vjp_rule, loss_function_cheat, data_cheat, tstate_org)
((layer_1 = (weight = Float32[-0.3613443 -0.35442525; 0.1799533 -0.11129212; … ; 0.19689769 0.2527126; 0.33015817 0.35009977], bias = Float32[0.09297755; 0.0841216; … ; -0.036208898; -0.03519474;;]), layer_2 = (weight = Float32[-0.33517796 0.0561612 … -0.061413787 0.13151819; -0.23627168 -0.14861318 … 0.0031006688 -0.081412934], bias = Float32[0.31775895; 0.13971485;;])), 1.1493356230477758, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing), (layer_1 = (weight = Float32[-0.38554642 -0.20945442; 0.2078916 -0.3777418; … ; -0.101795964 0.048385024; 0.2549977 -0.3194088], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.27677727 0.38801375 … 0.14126216 0.043854997; -0.5111967 -0.38798138 … 0.2772176 0.4102923], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.003, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.003, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0;;], Float32[0.0; 0.0;;], (0.9, 0.999))))), 0))
Training loop
Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration and the model state for animation.
function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
losses = zeros(epochs)
tstates = [(0, tstate)]
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
loss_function, data, tstate)
if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
println("Epoch: $(epoch) || Loss: $(loss)")
end
if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
push!(tstates, (epoch, tstate))
end
losses[epoch] = loss
tstate = Lux.Training.apply_gradients(tstate, grads)
end
return tstate, losses, tstates
end
train (generic function with 3 methods)
Cheat training with ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}$
We first train the model with the known score function on the sample data. That is cheating. The aim is a sanity check, to make sure the proposed model is good enough to fit the desired score function and that the setup is right.
@time tstate_cheat, losses_cheat, tstates_cheat = train(tstate_org, vjp_rule, data_cheat, loss_function_cheat, 2000, 20, 100)
Epoch: 100 || Loss: 0.3806969302460832
Epoch: 200 || Loss: 0.31455506688803014
Epoch: 300 || Loss: 0.23650958560221114
Epoch: 400 || Loss: 0.17584226483150162
Epoch: 500 || Loss: 0.13336146141196442
Epoch: 600 || Loss: 0.10494339476099604
Epoch: 700 || Loss: 0.08319613277988377
Epoch: 800 || Loss: 0.06353660069610059
Epoch: 900 || Loss: 0.04579514452999741
Epoch: 1000 || Loss: 0.033082546934338375
Epoch: 1100 || Loss: 0.024682957144383684
Epoch: 1200 || Loss: 0.018514478744863578
Epoch: 1300 || Loss: 0.014112584651106378
Epoch: 1400 || Loss: 0.011000791201833029
Epoch: 1500 || Loss: 0.008546208613690945
Epoch: 1600 || Loss: 0.00662317281685257
Epoch: 1700 || Loss: 0.005286038656992195
Epoch: 1800 || Loss: 0.004320141699648445
Epoch: 1900 || Loss: 0.003598532243191301
Epoch: 2000 || Loss: 0.0030705347715038554
0.258116 seconds (507.71 k allocations: 938.133 MiB, 16.82% gc time, 12.62% compilation time)
Testing out the trained model.
uu_cheat = Lux.apply(tstate_cheat.model, vcat(xx', yy'), tstate_cheat.parameters, tstate_cheat.states)[1]
2×225 Matrix{Float64}:
4.96955 3.89614 2.82273 1.74932 … -1.88924 -2.96167 -4.03409
4.93425 4.9404 4.94656 4.95271 -4.05192 -4.04988 -4.04783
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score functions (vector fields)", titlefont=10, color=:vik, xlims=extrema(xrange), ylims=extrema(yrange), legend=false)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
quiver!(xx, yy, quiver = (uu_cheat[1, :] ./ 8, uu_cheat[2, :] ./ 8), color=:cyan, alpha=0.5)
plot(losses_cheat, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Real training with ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}$
Now we go to the real thing.
@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function, 2000, 20, 100)
Epoch: 100 || Loss: -0.4847498428836804
Epoch: 200 || Loss: -0.515890134342076
Epoch: 300 || Loss: -0.5492928868336563
Epoch: 400 || Loss: -0.6022552318483821
Epoch: 500 || Loss: -0.6650272062116336
Epoch: 600 || Loss: -0.7145443843008135
Epoch: 700 || Loss: -0.7517938448317422
Epoch: 800 || Loss: -0.7799219290676855
Epoch: 900 || Loss: -0.8007178895698204
Epoch: 1000 || Loss: -0.8128233755305586
Epoch: 1100 || Loss: -0.8231598096900492
Epoch: 1200 || Loss: -0.8301576644434572
Epoch: 1300 || Loss: -0.8359090833789496
Epoch: 1400 || Loss: -0.8427182828751276
Epoch: 1500 || Loss: -0.8484728540837287
Epoch: 1600 || Loss: -0.8518516646003788
Epoch: 1700 || Loss: -0.8568799251405146
Epoch: 1800 || Loss: -0.8609811228233294
Epoch: 1900 || Loss: -0.862762600689807
Epoch: 2000 || Loss: -0.8652882390883634
0.914693 seconds (921.70 k allocations: 4.028 GiB, 5.63% gc time, 3.71% compilation time)
Testing out the trained model.
uu_pred = Lux.apply(tstate.model, vcat(xx', yy'), tstate.parameters, tstate.states)[1]
2×225 Matrix{Float64}:
4.52345 3.32892 2.24847 1.87043 0.96518 … -1.98797 -2.86147 -3.64632
3.99869 4.10775 4.1749 3.98047 3.82042 -4.14843 -4.14594 -4.17602
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score functions (vector fields)", titlefont=10, color=:vik, xlims=extrema(xrange), ylims=extrema(yrange), legend=false)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
quiver!(xx, yy, quiver = (uu_pred[1, :] ./ 8, uu_pred[2, :] ./ 8), color=:cyan, alpha=0.5)
plot(losses, title="Evolution of the losses", titlefont=10, xlabel="iteration", ylabel="error", label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0}\$")
plot!(losses_cheat, linestyle=:dash, label="\${\\tilde J}_{\\mathrm{ESM}{\\tilde p}_0}\$")
plot!(losses .+ Jconstant, linestyle=:dash, color=1, label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0} + C\$")
Ok, that seems visually good enough. We will later check the sampling from this score function via Langevin sampling.