Finite-difference score-matching of a two-dimensional Gaussian mixture model

Introduction

Here, we modify the previous finite-difference score-matching example to fit a two-dimensional model.

Julia language setup

We use the Julia programming language with suitable packages.

using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation

We set the random seed for reproducibility purposes.

rng = Xoshiro(12345)

Data

We build the target model and draw samples from it. This time the target model is a bivariate random variable.

xrange = range(-8, 8, 120)
yrange = range(-8, 8, 120)
dx = Float64(xrange.step)
dy = Float64(yrange.step)

target_prob = MixtureModel([MvNormal([-3, -3], [1 0; 0 1]), MvNormal([3, 3], [1 0; 0 1]), MvNormal([-1, 1], [1 0; 0 1])], [0.4, 0.4, 0.2])

target_pdf = [pdf(target_prob, [x, y]) for y in yrange, x in xrange]
target_score = reduce(hcat, gradlogpdf(target_prob, [x, y]) for y in yrange, x in xrange)
2×14400 Matrix{Float64}:
 5.0  5.0      5.0      5.0      …  -5.0      -5.0      -5.0      -5.0
 5.0  4.86555  4.73109  4.59664     -4.59664  -4.73109  -4.86555  -5.0
sample_points = rand(rng, target_prob, 1024)
2×1024 Matrix{Float64}:
 -3.69692  2.55677  -0.317683  -1.16593   …  2.92216  -3.03365   5.16729
 -3.64622  3.4103    1.34462   -0.286077     2.96657   0.669646  3.21557
surface(xrange, yrange, target_pdf, title="PDF", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], [pdf(target_prob, [x, y]) for (x, y) in eachcol(sample_points)], markercolor=:lightgreen, markersize=2, alpha=0.5)
Example block output
heatmap(xrange, yrange, target_pdf, title="PDF", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
Example block output
surface(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf", titlefont=10, legend=false, color=:vik)
scatter!(sample_points[1, :], sample_points[2, :], [logpdf(target_prob, [x, y]) for (x, y) in eachcol(sample_points)], markercolor=:lightgreen, alpha=0.5, markersize=2)
Example block output
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
xx, yy = meshgrid(xrange[begin:8:end], yrange[begin:8:end])
uu = reduce(hcat, gradlogpdf(target_prob, [x, y]) for (x, y) in zip(xx, yy))
2×225 Matrix{Float64}:
 5.0  3.92437  2.84874  1.77311  0.697479  …  -1.90756  -2.98319  -4.05882
 5.0  5.0      5.0      5.0      5.0          -4.05882  -4.05882  -4.05882
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score function (vector field)", titlefont=10, legend=false, color=:vik)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
Example block output

The neural network model

The neural network we consider is again a simple feed-forward neural network made of a single hidden layer. For the 2d case, we need to bump it a little bit, doubling the width of the hidden layer.

model = Chain(Dense(2 => 16, relu), Dense(16 => 2))
Chain(
    layer_1 = Dense(2 => 16, relu),     # 48 parameters
    layer_2 = Dense(16 => 2),           # 34 parameters
)         # Total: 82 parameters,
          #        plus 0 states.

The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup function, giving us the parameters and the state of the model.

ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[-1.2633736 1.8181845; 2.3867178 0.5713208; … ; -0.7349689 -1.3411301; -1.167693 -1.8823313], bias = Float32[0.48686042, 0.32835403, 0.23500215, -0.6289243, -0.4877348, -0.08177762, 0.57667726, -0.06088416, 0.20568033, -0.66487825, 0.17736456, -0.20175992, 0.4999213, -0.5164128, 0.5682744, 0.33539677]), layer_2 = (weight = Float32[0.3586208 -0.057023764 … -0.17773739 -0.41380876; 0.12927775 -0.22223856 … 0.06703623 0.18231718], bias = Float32[0.067408144, 0.110440105])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Loss functions for score-matching

The loss function is again based on Aapo Hyvärinen (2005), combined with the work of Pang, Xu, Li, Song, Ermon, and Zhu (2020) using finite differences to approximate the divergence of the modeled score function.

In the multidimensional case, say on $\mathbb{R}^d$, $d\in\mathbb{N}$, the explicit score matching loss function is given by

\[ J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = \frac{1}{2}\int_{\mathbb{R}^d} p_{\mathbf{X}}(\mathbf{x}) \|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x})\|^2\;\mathrm{d}\mathbf{x};\]

where $p_{\mathbf{X}}(\mathbf{x})$ is the PDF of the target distribution.

The integration by parts in the expectation yields $J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C$, where $C$ is constant with respect to the parameters and the implicit score matching loss function $J_{\mathrm{ISM}}({\boldsymbol{\theta}})$ is given by

\[ J_{\mathrm{ISM}}({\boldsymbol{\theta}}) = \int_{\mathbb{R}} p_{\mathbf{X}}(\mathbf{x}) \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) \right)\;\mathrm{d}\mathbf{x},\]

which does not involve the unknown score function of ${\mathbf{X}}$. It does, however, involve the divergence of the modeled score function, which is expensive to compute.

In practice, the loss function is estimated via the empirical distribution, so the unknown $p_{\mathbf{X}}(\mathbf{x})$ is handled implicitly by the sample data $(\mathbf{x}_n)_n$, and we minimize the empirical implicit score matching loss function

\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) \right).\]

Componentwise, with $\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) = (\psi_i(\mathbf{x}; {\boldsymbol{\theta}}))_{i=1}^d$, this is written as

\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \left( \frac{1}{2}\psi_i(\mathbf{x}_n; {\boldsymbol{\theta}})^2 + \frac{\partial}{\partial x_i} \psi_i(\mathbf{x}_n; {\boldsymbol{\theta}}) \right).\]

As mentioned before, computing a derivative to form the loss function becomes expensive when combined with the usual optimization methods to fit a neural network, as they require the gradient of the loss function itself, so we approximate the derivative of the modeled score function by centered finite differences. With the model calculated at the displaced points, we just average them to avoid computing the model at the sample point itself. This leads to the empirical finite-difference (implicit) score matching loss function

\[ {\tilde J}_{\mathrm{FDSM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \Bigg( \frac{1}{2}\left(\frac{1}{d}\sum_{j=1}^d \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{2}\right)^2 \\ \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg).\]

Since this is a synthetic problem and we actually know the target distribution, we implement the empirical explicit score matching loss function

\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) = \frac{1}{2}\frac{1}{N}\sum_{n=1}^N \|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x}_n)\|^2.\]

This is used as a sure check whether the neural network is sufficient to model the score function and for checking the optimization process, since in theory this should be roughly (apart from the approximations by the empirical distribution, the finite-difference approximation, and the round-off errors) a constant different from the loss function for ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}$.

Implementation of ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$

In the two-dimensional case, $d = 2$, this becomes

\[ \begin{align*} {\tilde J}_{\mathrm{FDSM}{\tilde p}_0} & = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^d \Bigg( \frac{1}{2}\left(\frac{1}{d}\sum_{j=1}^d \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{2}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg) \\ & = \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^2 \Bigg( \frac{1}{2}\left(\sum_{j=1}^2 \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{4}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad + \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_i; {\boldsymbol{\theta}}) - \psi_i(\mathbf{x}_n - \delta\mathbf{e}_i; {\boldsymbol{\theta}})}{2\delta} \Bigg) \\ & = \frac{1}{N} \frac{1}{2} \sum_{n=1}^N \sum_{i=1}^2 \left(\sum_{j=1}^2 \frac{\psi_i(\mathbf{x}_n + \delta\mathbf{e}_j; {\boldsymbol{\theta}}) + \psi_i(\mathbf{x}_n - \delta\mathbf{e}_j; {\boldsymbol{\theta}})}{4}\right)^2 \\ & \qquad \qquad \qquad \qquad \qquad \qquad + \frac{1}{N}\sum_{n=1}^N \frac{\psi_1(\mathbf{x}_n + \delta\mathbf{e}_1; {\boldsymbol{\theta}}) - \psi_1(\mathbf{x}_n - \delta\mathbf{e}_1; {\boldsymbol{\theta}})}{2\delta} \\ & \qquad \qquad \qquad \qquad \qquad \qquad + \frac{1}{N}\sum_{n=1}^N \frac{\psi_2(\mathbf{x}_n + \delta\mathbf{e}_2; {\boldsymbol{\theta}}) - \psi_2(\mathbf{x}_n - \delta\mathbf{e}_2; {\boldsymbol{\theta}})}{2\delta} \end{align*}\]

function loss_function(model, ps, st, data)
    sample_points, deltax, deltay = data
    s_pred_fwd_x, = Lux.apply(model, sample_points .+ [deltax, 0.0], ps, st)
    s_pred_bwd_x, = Lux.apply(model, sample_points .- [deltax, 0.0], ps, st)
    s_pred_fwd_y, = Lux.apply(model, sample_points .+ [0.0, deltay], ps, st)
    s_pred_bwd_y, = Lux.apply(model, sample_points .- [0.0, deltay], ps, st)
    s_pred = ( s_pred_bwd_x .+ s_pred_fwd_x .+ s_pred_bwd_y .+ s_pred_fwd_y) ./ 4
    dsdx_pred = (s_pred_fwd_x .- s_pred_bwd_x ) ./ 2deltax
    dsdy_pred = (s_pred_fwd_y .- s_pred_bwd_y ) ./ 2deltay
    loss = mean(abs2, s_pred) + mean(view(dsdx_pred, 1, :)) +  mean(view(dsdy_pred, 2, :))
    return loss, st, ()
end
loss_function (generic function with 1 method)

We included the steps for the finite difference computations in the data passed to training to avoid repeated computations.

xmin, xmax = extrema(sample_points[1, :])
ymin, ymax = extrema(sample_points[2, :])
deltax, deltay = (xmax - xmin) / 2size(sample_points, 2), (ymax - ymin) / 2size(sample_points, 2)
(0.00581614558657555, 0.006593323926327331)
data = sample_points, deltax, deltay
([-3.696922040577957 2.556772674293886 … -3.033651509981303 5.1672947152544895; -3.6462237060140024 3.4102960875997117 … 0.6696459295015311 3.2155718822745527], 0.00581614558657555, 0.006593323926327331)

Implementation of ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}})$

As a sanity check, we also include the empirical explicit score matching loss function, which uses the know score functions of the target model.

In the two-dimensional case, this is simply the mean square value of all the components.

\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) = \frac{1}{2} \frac{1}{N}\sum_{n=1}^N \|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x}_n)\|^2 = \frac{1}{2} \frac{1}{N}\sum_{n=1}^N \sum_{i=1}^2 \left(\psi_i(\mathbf{x}_n; {\boldsymbol{\theta}}) - \psi_{\mathbf{X}, i}(\mathbf{x}_n) \right)^2.\]

function loss_function_cheat(model, ps, st, data)
    sample_points, score_cheat = data
    score_pred, st = Lux.apply(model, sample_points, ps, st)
    loss = mean(abs2, score_pred .- score_cheat)
    return loss, st, ()
end
loss_function_cheat (generic function with 1 method)

The data in this case includes information about the target distribution.

score_cheat = reduce(hcat, gradlogpdf(target_prob, u) for u in eachcol(sample_points))
data_cheat = sample_points, score_cheat
([-3.696922040577957 2.556772674293886 … -3.033651509981303 5.1672947152544895; -3.6462237060140024 3.4102960875997117 … 0.6696459295015311 3.2155718822745527], [0.6969228899748048 0.44299201282532824 … 1.9946701191708671 -2.1672947253904584; 0.6462254048076983 -0.41041374404010544 … 0.2523912079807282 -0.2155718873425372])

Computing the constant

The expression ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx {\tilde J}_{\mathrm{ISM}{\tilde p}_0}({\boldsymbol{\theta}}) + C$ can be used to test the implementation of the different loss functions. For that, we need to compute the constant $C$. This can be computed with a fine mesh or with a Monte-Carlo approximation. We do both just for fun.

function compute_constante(target_prob, xrange, yrange)
    dx = Float64(xrange.step)
    dy = Float64(yrange.step)
    Jconstant = sum(pdf(target_prob, [x, y]) * sum(abs2, gradlogpdf(target_prob, [x, y])) for y in yrange, x in xrange) * dx * dy / 2
    return Jconstant
end
compute_constante (generic function with 1 method)
function compute_constante_MC(target_prob, sample_points)
    Jconstant = mean(sum(abs2, gradlogpdf(target_prob, s)) for s in eachcol(sample_points)) / 2
    return Jconstant
end
compute_constante_MC (generic function with 1 method)
Jconstant = compute_constante(target_prob, xrange, yrange)
0.8921462840364986
Jconstant_MC = compute_constante_MC(target_prob, sample_points)
0.9129656846610863
constants = [(n, compute_constante_MC(target_prob, rand(rng, target_prob, n))) for _ in 1:100 for n in (1, 10, 20, 50, 100, 500, 1000, 2000, 4000)]
900-element Vector{Tuple{Int64, Float64}}:
 (1, 1.1538119449520563)
 (10, 1.2247351227578578)
 (20, 1.1028524320443522)
 (50, 0.9920784868026302)
 (100, 1.0063955226226398)
 (500, 0.8755541733725885)
 (1000, 0.8342517573145984)
 (2000, 0.8759484336636649)
 (4000, 0.8827405034687669)
 (1, 0.9485837438780702)
 ⋮
 (1, 0.298790143447742)
 (10, 0.7794066301161721)
 (20, 1.1364257738853651)
 (50, 0.6101320519101868)
 (100, 0.8696468922755639)
 (500, 0.9416724546394305)
 (1000, 0.9063918160454828)
 (2000, 0.8914872767709621)
 (4000, 0.871944900929463)
scatter(constants, markersize=2, title="constant computed by MC and fine mesh", titlefont=10, xlabel="sample size", ylabel="value", label="via various samples")
hline!([Jconstant], label="via fine mesh")
hline!([Jconstant_MC], label="via working sample", linestyle=:dash)
Example block output

A test for the implementations of the loss functions

Notice that, for a sufficiently large sample and sufficiently small discretization step $\delta$, we should have

\[ {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C \approx {\tilde J}_{\mathrm{FDSM}}({\boldsymbol{\theta}}) + C \approx {\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}}) + C.\]

which is a good test for the implementations of the loss functions. For example:

first(loss_function_cheat(model, ps, st, data_cheat))
32.62664847852629
first(loss_function(model, ps, st, data)) + Jconstant
32.56518812592081

Let us do a more statistically significant test.

test_losses = reduce(
    hcat,
    Lux.setup(rng, model) |> pstj ->
    [
        first(loss_function_cheat(model, pstj[1], pstj[2], data_cheat)),
        first(loss_function(model, pstj[1], pstj[2], data))
    ]
    for _ in 1:30
)
2×30 Matrix{Float64}:
 13.7323  38.8598  5.71619  9.86731  …  38.9853  20.4021  14.7908  7.77597
 13.0453  37.4968  4.90646  8.84147     38.2617  19.6932  13.6616  7.00106
plot(title="Loss functions at random model parameters", titlefont=10)
scatter!(test_losses[1, :], label="\${\\tilde J}_{\\mathrm{ESM}{\\tilde p}_0}\$")
scatter!(test_losses[2, :], label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0}\$")
scatter!(test_losses[2, :] .+ Jconstant, label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0} + C\$")
Example block output

One can check by visual inspection that the agreement between ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) - C$ and ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$ seems reasonably good. Let us estimate the relative error.

rel_errors = abs.( ( test_losses[2, :] .+ Jconstant .- test_losses[1, :] ) ./ test_losses[1, :] )
plot(title="Relative error at random model parameters", titlefont=10, legend=false)
scatter!(rel_errors, markercolor=2, label="error")
mm = mean(rel_errors)
mmstd = std(rel_errors)
hline!([mm], label="mean")
hspan!([mm+mmstd, mm-mmstd], fillbetween=true, alpha=0.3, label="65% margin")
Example block output

Ok, good enough, just a few percentage points.

An extra test for the implementations of the loss functions and the gradient computation

We also have

\[ \boldsymbol{\nabla}_{\boldsymbol{\theta}} {\tilde J}_{\mathrm{ESM}{\tilde p}_0}({\boldsymbol{\theta}}) \approx \boldsymbol{\nabla}_{\boldsymbol{\theta}} {\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}}),\]

which is another good test, which also checks the gradient computation, but everything seems fine, so no need to push this further.

Optimization setup

Optimization method

As usual, we use the ADAM optimization.

opt = Adam(0.003)

tstate_org = Lux.Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing)
    # of parameters: 82
    # of states: 0
    optimizer: Optimisers.Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8)
    step: 0

Automatic differentiation in the optimization

FluxML/Zygote.jl is used for the automatic differentiation as it is currently the only AD backend working with LuxDL/Lux.jl.

vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()

Processor

We use the CPU instead of the GPU.

dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Check differentiation

Check if AD is working fine to differentiate the loss functions for training.

Lux.Training.compute_gradients(vjp_rule, loss_function, data, tstate_org)
((layer_1 = (weight = Float32[3.6879158 4.0172615; -2.2284822 -2.6868043; … ; -1.9008694 -1.7861443; -4.612377 -4.320793], bias = Float32[0.8997364, -0.8043432, 1.7370119, 0.45905542, 0.8397131, -1.230505, 0.353755, -1.2257671, 0.9869879, 1.0354862, 0.52109236, 1.3021913, 1.5245895, 0.877233, 0.5858979, 1.4446392]), layer_2 = (weight = Float32[4.2770233 22.516876 … -17.53038 -24.572052; 11.982891 26.250607 … 15.06171 21.036867], bias = Float32[-0.113342285, 5.354232])), 31.673041841884316, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing), (layer_1 = (weight = Float32[-1.2633736 1.8181845; 2.3867178 0.5713208; … ; -0.7349689 -1.3411301; -1.167693 -1.8823313], bias = Float32[0.48686042, 0.32835403, 0.23500215, -0.6289243, -0.4877348, -0.08177762, 0.57667726, -0.06088416, 0.20568033, -0.66487825, 0.17736456, -0.20175992, 0.4999213, -0.5164128, 0.5682744, 0.33539677]), layer_2 = (weight = Float32[0.3586208 -0.057023764 … -0.17773739 -0.41380876; 0.12927775 -0.22223856 … 0.06703623 0.18231718], bias = Float32[0.067408144, 0.110440105])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Optimisers.Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))))), 0))
Lux.Training.compute_gradients(vjp_rule, loss_function_cheat, data_cheat, tstate_org)
((layer_1 = (weight = Float32[3.689433 4.0178785; -2.264155 -2.618133; … ; -1.9069793 -1.7834694; -4.6186814 -4.3203063], bias = Float32[0.8609152, -0.7731151, 1.725719, 0.44558188, 0.84764415, -1.251586, 0.36021093, -1.1504575, 0.975368, 1.044547, 0.5048005, 1.1946136, 1.5058247, 0.893234, 0.58936614, 1.429456]), layer_2 = (weight = Float32[4.235089 22.786686 … -17.630789 -24.70937; 11.946889 26.341808 … 14.838343 20.73272], bias = Float32[-0.11527908, 5.321644])), 32.62664847852629, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 16, relu), layer_2 = Dense(16 => 2)), nothing), (layer_1 = (weight = Float32[-1.2633736 1.8181845; 2.3867178 0.5713208; … ; -0.7349689 -1.3411301; -1.167693 -1.8823313], bias = Float32[0.48686042, 0.32835403, 0.23500215, -0.6289243, -0.4877348, -0.08177762, 0.57667726, -0.06088416, 0.20568033, -0.66487825, 0.17736456, -0.20175992, 0.4999213, -0.5164128, 0.5682744, 0.33539677]), layer_2 = (weight = Float32[0.3586208 -0.057023764 … -0.17773739 -0.41380876; 0.12927775 -0.22223856 … 0.06703623 0.18231718], bias = Float32[0.067408144, 0.110440105])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Optimisers.Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.003, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))))), 0))

Training loop

Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration and the model state for animation.

function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
    losses = zeros(epochs)
    tstates = [(0, tstate)]
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
            loss_function, data, tstate)
        if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
            println("Epoch: $(epoch) || Loss: $(loss)")
        end
        if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
            push!(tstates, (epoch, tstate))
        end
        losses[epoch] = loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate, losses, tstates
end
train (generic function with 3 methods)

Cheat training with ${\tilde J}_{\mathrm{ESM}{\tilde p}_0}$

We first train the model with the known score function on the sample data. That is cheating. The aim is a sanity check, to make sure the proposed model is good enough to fit the desired score function and that the setup is right.

@time tstate_cheat, losses_cheat, tstates_cheat = train(tstate_org, vjp_rule, data_cheat, loss_function_cheat, 2000, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/1B1qw/src/impl/matmul.jl:190
Epoch: 100 || Loss: 0.8401849278869697
Epoch: 200 || Loss: 0.563607700266948
Epoch: 300 || Loss: 0.4926258151863607
Epoch: 400 || Loss: 0.45830559661430026
Epoch: 500 || Loss: 0.43358616510264614
Epoch: 600 || Loss: 0.41197109354966255
Epoch: 700 || Loss: 0.39327886995951344
Epoch: 800 || Loss: 0.3768739639070251
Epoch: 900 || Loss: 0.3631936077433342
Epoch: 1000 || Loss: 0.35101209994291793
Epoch: 1100 || Loss: 0.33997820230228853
Epoch: 1200 || Loss: 0.3293390319395758
Epoch: 1300 || Loss: 0.31880814244768374
Epoch: 1400 || Loss: 0.30846754084144945
Epoch: 1500 || Loss: 0.2984211525235665
Epoch: 1600 || Loss: 0.28827943967632763
Epoch: 1700 || Loss: 0.2778414128971065
Epoch: 1800 || Loss: 0.26736572655519075
Epoch: 1900 || Loss: 0.25641855404231345
Epoch: 2000 || Loss: 0.2448096823452019
  0.510287 seconds (722.74 k allocations: 955.932 MiB, 12.15% gc time, 7.54% compilation time)

Testing out the trained model.

uu_cheat = Lux.apply(tstate_cheat.model, vcat(xx', yy'), tstate_cheat.parameters, tstate_cheat.states)[1]
2×225 Matrix{Float64}:
 2.19923   1.53509  0.870944  0.206801  …   0.502545  -0.277666  -1.05788
 0.997473  1.47914  1.9608    2.44247      -3.00036   -2.75735   -2.51434
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score functions (vector fields)", titlefont=10, color=:vik, xlims=extrema(xrange), ylims=extrema(yrange), legend=false)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
quiver!(xx, yy, quiver = (uu_cheat[1, :] ./ 8, uu_cheat[2, :] ./ 8), color=:cyan, alpha=0.5)
Example block outputExample block output
plot(losses_cheat, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

Real training with ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}$

Now we go to the real thing.

@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function, 2000, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/1B1qw/src/impl/matmul.jl:190
Epoch: 100 || Loss: -0.028526071882063422
Epoch: 200 || Loss: -0.3056820935385293
Epoch: 300 || Loss: -0.3735720043248265
Epoch: 400 || Loss: -0.4060071208695759
Epoch: 500 || Loss: -0.42700676492413925
Epoch: 600 || Loss: -0.4442805461267297
Epoch: 700 || Loss: -0.45911840778124424
Epoch: 800 || Loss: -0.4711387063320857
Epoch: 900 || Loss: -0.48093733644650705
Epoch: 1000 || Loss: -0.48972825080776355
Epoch: 1100 || Loss: -0.49795548481343427
Epoch: 1200 || Loss: -0.5057248136083738
Epoch: 1300 || Loss: -0.5133231683129973
Epoch: 1400 || Loss: -0.5203089294538508
Epoch: 1500 || Loss: -0.5271916186760042
Epoch: 1600 || Loss: -0.5339186475303448
Epoch: 1700 || Loss: -0.5407054119759426
Epoch: 1800 || Loss: -0.5470494965278492
Epoch: 1900 || Loss: -0.5533731989980991
Epoch: 2000 || Loss: -0.5598464396520446
  1.851211 seconds (1.36 M allocations: 4.064 GiB, 10.31% gc time, 2.10% compilation time)

Testing out the trained model.

uu_pred = Lux.apply(tstate.model, vcat(xx', yy'), tstate.parameters, tstate.states)[1]
2×225 Matrix{Float64}:
  1.71235   1.17107   0.606229  -0.0336345  …   1.14829   0.627492   0.106691
 -0.196228  0.441023  1.20491    1.90302       -2.18941  -2.07645   -1.9635
heatmap(xrange, yrange, (x, y) -> logpdf(target_prob, [x, y]), title="Logpdf (heatmap) and score functions (vector fields)", titlefont=10, color=:vik, xlims=extrema(xrange), ylims=extrema(yrange), legend=false)
quiver!(xx, yy, quiver = (uu[1, :] ./ 8, uu[2, :] ./ 8), color=:yellow, alpha=0.5)
scatter!(sample_points[1, :], sample_points[2, :], markersize=2, markercolor=:lightgreen, alpha=0.5)
quiver!(xx, yy, quiver = (uu_pred[1, :] ./ 8, uu_pred[2, :] ./ 8), color=:cyan, alpha=0.5)
Example block outputExample block output
plot(losses, title="Evolution of the losses", titlefont=10, xlabel="iteration", ylabel="error", label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0}\$")
plot!(losses_cheat, linestyle=:dash, label="\${\\tilde J}_{\\mathrm{ESM}{\\tilde p}_0}\$")
plot!(losses .+ Jconstant, linestyle=:dash, color=1, label="\${\\tilde J}_{\\mathrm{FDSM}{\\tilde p}_0} + C\$")
Example block output

Ok, that seems visually good enough. We will later check the sampling from this score function via Langevin sampling.