Finite-difference score-matching of a one-dimensional Gaussian mixture model

Introduction

Aim

The aim, this time, is to fit a neural network via finite-difference score matching, following the pioneering work of Aapo Hyvärinen (2005) about score-matching, combined with the work of Pang, Xu, Li, Song, Ermon, and Zhu (2020), which uses finite differences to efficiently approximate the gradient in the loss function proposed by Aapo Hyvärinen (2005), and the idea of 1. Song and Ermon (2019) of modeling directly the score function instead of the pdf or an energy potential of the pdf.

Background

Generative score-matching diffusion methods use Langevin dynamics to draw samples from a modeled score function. It rests on the idea of Aapo Hyvärinen (2005) that one can directly fit the score function, from the sample data, using a suitable loss function (associated with the Fisher divergence) not depending on the unknown score function of the random variable. This is obtained by a simple integration by parts on the expected square distance between the model score function and the actual score function. The integration by parts separates the dependence on the actual score function from the parameters of the model, so the fitting process (minimization over the parameters of the model) does not depend on the unknown score function.

The obtained loss function, however, depends on the gradient of the model, which is computationally expensive. Pang, Xu, Li, Song, Ermon, and Zhu (2020) proposed to use finite differences to approximate the derivative of the model to significantly reduce the computational cost of training the model.

The differentiation for the optimization is with respect to the parameters, while the differentiation of the modeled score function is on the variate, but still this is a great computational challenge and not all AD are fit for that. For this reason, we resort to centered finite differences to approximate the derivative of the modeled score function.

For a python version of a similar pedagogical example, see Eric J. Ma (2021). There, they use AD on top of AD, via the google/jax library, which apparently handles this double-AD not so badly.

Take away

We'll see that, in this simple example at least, we don't need a large or deep neural network. It is much more important to have enough sample points to capture the transition region in the mixture of Gaussians.

The finite-difference implicit score matching method

The score-matching method from Aapo Hyvärinen (2005) rests on the following ideas:

1. Fit the model by minimizing the expected square distance between the model score function $\psi(x; {\boldsymbol{\theta}})$ and the score function $\psi_X(x)$ of the random variable $X$, via the explicit score matching (ESM) objective

\[ J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = \frac{1}{2}\int_{\mathbb{R}^d} p_{\mathbf{X}}(\mathbf{x}) \left\|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) - \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x})\right\|^2\;\mathrm{d}\mathbf{x}.\]

2. Use integration by parts in the expectation to write that

\[ J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C,\]

where $C$ is constant with respect to the parameters, so we only need to minimize the implicit score matching (ISM) objective ${\tilde J}_{\mathrm{ISM}}$, given by

\[ J_{\mathrm{ISM}}({\boldsymbol{\theta}}) = \int_{\mathbb{R}} p_{\mathbf{X}}(\mathbf{x}) \left( \frac{1}{2}\left\|\boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}})\right\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}; {\boldsymbol{\theta}}) \right)\;\mathrm{d}\mathbf{x},\]

which does not involve the unknown score function of ${\mathbf{X}}$. It does, however, involve the gradient of the modeled score function, which is expensive to compute.

3. In practice, the implicit score-matching loss function, which depends on the unknown $p_\mathbf{X}(\mathbf{x})$, is estimated via the empirical distribution, obtained from the sample data $(\mathbf{x}_n)_n$. Thus, we minimize the empirical implicit score matching objective

\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) \right).\]

where the empirical distribution is given by ${\tilde p}_0 = (1/N)\sum_{n=1}^N \delta_{\mathbf{x}_n}.$

On top of that, we add one more step.

4. As mentioned before, computing a derivative to form the loss function becomes expensive when combined with the usual optimization methods to fit a neural network, as they require the gradient of the loss function itself, i.e. the optimization process involves the gradient of the gradient of something. Because of that, one alternative is to approximate the derivative of the model score function by centered finite differences, i.e.

\[ \frac{\partial}{\partial x} \psi(x_n; {\boldsymbol{\theta}}) \approx \frac{\psi(x_n + \delta; {\boldsymbol{\theta}}) - \psi(x_n - \delta; {\boldsymbol{\theta}})}{2\delta},\]

for a suitably small $\delta > 0$.

In this case, since we need compute $\psi(x_n + \delta; {\boldsymbol{\theta}})$ and $\psi(x_n - \delta; {\boldsymbol{\theta}})$, we avoid computing also $\psi(x_n; {\boldsymbol{\theta}})$ and approximate it with the average

\[ \psi(x_n; {\boldsymbol{\theta}}) \approx \frac{\psi(x_n + \delta; {\boldsymbol{\theta}}) + \psi(x_n - \delta; {\boldsymbol{\theta}})}{2}.\]

Hence, we approximate the implicit score matching ${\tilde J}_{\mathrm{ISM}{\tilde p}_0}$ by the finite-difference (implicit) score matching

\[ {\tilde J}_{\mathrm{FDSM}}({\boldsymbol{\theta}}) = \int_{\mathbb{R}} p_X(x) \Bigg( \frac{1}{2}\left(\frac{\psi(x + \delta; {\boldsymbol{\theta}}) + \psi(x - \delta; {\boldsymbol{\theta}})}{2}\right)^2 + \frac{\psi(x + \delta; {\boldsymbol{\theta}}) - \psi(x - \delta; {\boldsymbol{\theta}})}{2\delta} \Bigg)\;\mathrm{d}x,\]

And the empirical implicit score matching ${\tilde J}_{\mathrm{ISM}{\tilde p}_0}$ is approximated by

\[ {\tilde J}_{\mathrm{FDSM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \Bigg( \frac{1}{2}\left(\frac{\psi(x + \delta; {\boldsymbol{\theta}}) + \psi(x - \delta; {\boldsymbol{\theta}})}{2}\right)^2 + \frac{\psi(x + \delta; {\boldsymbol{\theta}}) - \psi(x - \delta; {\boldsymbol{\theta}})}{2\delta} \Bigg).\]

Numerical example

We illustrate the above method by fitting a neural network to a univariate Gaussian mixture distribution.

We played with different target distributions and settled here with a bimodal distribution used in Eric J. Ma (2021).

Julia language setup

We use the Julia programming language with suitable packages.

Packages

using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation

Reproducibility

We set the random seed for reproducibility purposes.

rng = Xoshiro(12345)

Code introspection

We do not attempt to overly optimize the code here since this is a simple one-dimensional problem. Nevertheless, it is always healthy to check the type stability of the critical parts (like the loss functions) with @code_warntype. One should also check for any unusual time and allocation with BenchmarkTools.@btime or BenchmarkTools.@benchmark. We performed these analysis and everything seems good. We found it unnecessary to clutter the notebook with their outputs here, though.

Data

We build the target model and draw samples from it. We need enough sample points to capture the transition region in the mixture of Gaussians.

xrange = range(-10, 10, 200)
dx = Float64(xrange.step)
x = permutedims(collect(xrange))

target_prob = MixtureModel([Normal(-3, 1), Normal(3, 1)], [0.1, 0.9])

target_pdf = pdf.(target_prob, x)
target_score = gradlogpdf.(target_prob, x)

y = target_score # just to simplify the notation
sample_points = permutedims(rand(rng, target_prob, 1024))
data = (x, y, target_pdf, sample_points)
([-10.0 -9.899497487437186 … 9.899497487437186 10.0], [7.0 6.899497487437186 … -6.899497487437186 -7.0], [9.134720408364594e-13 1.8366893783972853e-12 … 1.6530204405575567e-11 8.221248367528135e-12], [2.303077959422043 2.8428423932782843 … 3.1410080972036334 2.488464630750972])

Notice the data x and sample_points are defined as row vectors so we can apply the model in batch to all of their values at once. The values y are also row vectors for easy comparison with the predicted values. When, plotting, though, we need to revert them to vectors.

Visualizing the sample data drawn from the distribution and the PDF.

plot(title="PDF and histogram of sample data from the distribution", titlefont=10)
histogram!(sample_points', normalize=:pdf, nbins=80, label="sample histogram")
plot!(x', target_pdf', linewidth=4, label="pdf")
scatter!(sample_points', s -> pdf(target_prob, s), linewidth=4, label="sample")
Example block output

Visualizing the score function.

plot(title="The score function and the sample", titlefont=10)

plot!(x', target_score', label="score function", markersize=2)
scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)
Example block output

The neural network model

The neural network we consider is a simple feed-forward neural network made of a single hidden layer, obtained as a chain of a couple of dense layers. This is implemented with the LuxDL/Lux.jl package.

We will see, again, that we don't need a big neural network in this simple example. We go as low as it works.

model = Chain(Dense(1 => 8, relu), Dense(8 => 1))
Chain(
    layer_1 = Dense(1 => 8, relu),      # 16 parameters
    layer_2 = Dense(8 => 1),            # 9 parameters
)         # Total: 25 parameters,
          #        plus 0 states.

The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup function, giving us the parameters and the state of the model.

ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[-0.0008383376; -0.4411211; … ; 0.15721959; -0.22093461;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.14955823 0.4725347 … -0.6732873 -0.76267356], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Explicit score matching loss function $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$

For educational purposes, since we have the pdf and the score function, one of the ways we may train the model is directly with $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$. This is also useful to make sure that our network is able to model the desired score function.

Here is how we implement it.

function loss_function_esm(model, ps, st, data)
    x, y, target_pdf, sample_points = data
    y_pred, st = Lux.apply(model, x, ps, st)
    loss = mean(target_pdf .* (y_pred .- y) .^2)
    return loss, st, ()
end
loss_function_esm (generic function with 1 method)

Plain square error loss function

Still for educational purposes, we modify $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$ for training, without weighting on the distribution of the random variable itself, as in $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$. This has the benefit of giving more weight to the transition region. Here is how we implement it.

function loss_function_esm_plain(model, ps, st, data)
    x, y, target_pdf, sample_points = data
    y_pred, st = Lux.apply(model, x, ps, st)
    loss = mean(abs2, y_pred .- y)
    return loss, st, ()
end
loss_function_esm_plain (generic function with 1 method)

Finite-difference score matching ${\tilde J}_{\mathrm{FDSM}}$

Again, for educational purposes, we may implement ${\tilde J}_{\mathrm{FDSM}}({\boldsymbol{\theta}})$, as follows.

function loss_function_FDSM(model, ps, st, data)
    x, y, target_pdf, sample_points = data
    xmin, xmax = extrema(x)
    delta = (xmax - xmin) / 2length(x)
    y_pred_fwd, = Lux.apply(model, x .+ delta, ps, st)
    y_pred_bwd, = Lux.apply(model, x .- delta, ps, st)
    y_pred = ( y_pred_bwd .+ y_pred_fwd ) ./ 2
    dy_pred = (y_pred_fwd .- y_pred_bwd ) ./ 2delta
    loss = mean(target_pdf .* (dy_pred + y_pred .^ 2 / 2))
    return loss, st, ()
end
loss_function_FDSM (generic function with 1 method)

Empirical finite-difference score matching loss function ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}$

In practice we would use the sample data, not the supposedly unknown score function and PDF themselves. Here would be one implementation using finite differences, along with Monte-Carlo.

function loss_function_FDSM_over_sample(model, ps, st, data)
    x, y, target_pdf, sample_points = data
    xmin, xmax = extrema(sample_points)
    delta = (xmax - xmin) / 2length(sample_points)
    y_pred_fwd, = Lux.apply(model, sample_points .+ delta, ps, st)
    y_pred_bwd, = Lux.apply(model, sample_points .- delta, ps, st)
    y_pred = ( y_pred_bwd .+ y_pred_fwd ) ./ 2
    dy_pred = (y_pred_fwd .- y_pred_bwd ) ./ 2delta
    loss = mean(dy_pred + y_pred .^ 2 / 2)
    return loss, st, ()
end
loss_function_FDSM_over_sample (generic function with 1 method)

Optimization setup

Optimization method

We use the classical Adam optimiser (see Kingma and Ba (2015)), which is a stochastic gradient-based optimization method.

opt = Adam(0.03)

tstate_org = Lux.Training.TrainState(rng, model, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, relu), layer_2 = Dense(8 => 1)), nothing)
    # of parameters: 25
    # of states: 0
    optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
    step: 0

Automatic differentiation in the optimization

As mentioned, we setup differentiation in LuxDL/Lux.jl with the FluxML/Zygote.jl library, which is currently the only one implemented (there are pre-defined methods for AutoForwardDiff(), AutoReverseDiff(), AutoFiniteDifferences(), etc., but not implemented yet).

vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()

Processor

We use the CPU instead of the GPU.

dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::LuxDeviceUtils.LuxCPUDevice) (generic function with 5 methods)

Check differentiation

Check if Zygote via Lux is working fine to differentiate the loss functions for training.

Lux.Training.compute_gradients(vjp_rule, loss_function_esm, data, tstate_org)
((layer_1 = (weight = Float32[-0.13674347; 0.014439452; … ; 0.12662154; 0.40061465;;], bias = Float32[-0.046839785; -0.002902971; … ; 0.043372646; 0.1372256;;]), layer_2 = (weight = Float32[-0.22635852 -0.0072339457 … -0.354997 -0.46874574], bias = Float32[-0.21803467;;])), 0.2626268433348023, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, relu), layer_2 = Dense(8 => 1)), nothing), (layer_1 = (weight = Float32[0.36434844; -0.2782616; … ; 0.57140595; 0.7544968;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.22010337 0.5554293 … -0.20381103 -0.6448325], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.03, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))
Lux.Training.compute_gradients(vjp_rule, loss_function_esm_plain, data, tstate_org)
((layer_1 = (weight = Float32[-1.78867; 13.358742; … ; 1.6562703; 5.240231;;], bias = Float32[-0.42981282; -1.6990292; … ; 0.39799753; 1.2592142;;]), layer_2 = (weight = Float32[-2.9608774 -6.692526 … -4.64353 -6.1314178], bias = Float32[-5.011725;;])), 11.552828700551473, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, relu), layer_2 = Dense(8 => 1)), nothing), (layer_1 = (weight = Float32[0.36434844; -0.2782616; … ; 0.57140595; 0.7544968;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.22010337 0.5554293 … -0.20381103 -0.6448325], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.03, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))
Lux.Training.compute_gradients(vjp_rule, loss_function_FDSM, data, tstate_org)
((layer_1 = (weight = Float32[-0.068371624; 0.0072194785; … ; 0.06331068; 0.20030701;;], bias = Float32[-0.023468263; -0.0013294332; … ; 0.021731108; 0.068754494;;]), layer_2 = (weight = Float32[-0.11317912 -0.003616847 … -0.17749822 -0.2343725], bias = Float32[-0.10901734;;])), 0.10696602187860474, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, relu), layer_2 = Dense(8 => 1)), nothing), (layer_1 = (weight = Float32[0.36434844; -0.2782616; … ; 0.57140595; 0.7544968;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.22010337 0.5554293 … -0.20381103 -0.6448325], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.03, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))
Lux.Training.compute_gradients(vjp_rule, loss_function_FDSM_over_sample, data, tstate_org)
((layer_1 = (weight = Float32[-1.4075394; 0.13398743; … ; 1.3033524; 4.123642;;], bias = Float32[-0.47902107; -0.024228573; … ; 0.44356346; 1.4033775;;]), layer_2 = (weight = Float32[-2.3299713 -0.06712532 … -3.6540833 -4.824936], bias = Float32[-2.219963;;])), 2.2044295638298066, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, relu), layer_2 = Dense(8 => 1)), nothing), (layer_1 = (weight = Float32[0.36434844; -0.2782616; … ; 0.57140595; 0.7544968;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.22010337 0.5554293 … -0.20381103 -0.6448325], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.03, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))

Training loop

Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration.

function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
    losses = zeros(epochs)
    tstates = [(0, tstate)]
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
            loss_function, data, tstate)
        if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
            println("Epoch: $(epoch) || Loss: $(loss)")
        end
        if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
            push!(tstates, (epoch, tstate))
        end
        losses[epoch] = loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate, losses, tstates
end
train (generic function with 3 methods)

Training

Training with $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$

Now we attempt to train the model, starting with $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$.

@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_esm, 500, 20, 125)
Epoch: 25 || Loss: 0.032675831937050634
Epoch: 50 || Loss: 0.022691772992628065
Epoch: 75 || Loss: 0.011695980092717115
Epoch: 100 || Loss: 0.0041545269742722865
Epoch: 125 || Loss: 0.001266600849065558
Epoch: 150 || Loss: 0.0007529545466822072
Epoch: 175 || Loss: 0.0005941373524815312
Epoch: 200 || Loss: 0.0004803522782232856
Epoch: 225 || Loss: 0.00039198790492492606
Epoch: 250 || Loss: 0.0003257914389588754
Epoch: 275 || Loss: 0.0002743668252861494
Epoch: 300 || Loss: 0.00023285651412602946
Epoch: 325 || Loss: 0.00020334697196066498
Epoch: 350 || Loss: 0.00017471152727146773
Epoch: 375 || Loss: 0.0001520481276131909
Epoch: 400 || Loss: 0.0001333366811236779
Epoch: 425 || Loss: 0.00011750904380669638
Epoch: 450 || Loss: 0.0001036817484123565
Epoch: 475 || Loss: 0.00018705039420784494
Epoch: 500 || Loss: 8.662498033518696e-5
  1.056222 seconds (2.10 M allocations: 169.255 MiB, 2.89% gc time, 98.67% compilation time)

Testing out the trained model.

y_pred = Lux.apply(tstate.model, dev_cpu(x), tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 7.55523  7.44214  7.32906  7.21598  7.10289  …  -6.83325  -6.9343  -7.03534

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(x', y', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)

plot!(x', y_pred', linewidth=2, label="predicted MLP")
Example block output

Just for the fun of it, let us see an animation of the optimization process.

Example block output

We also visualize the evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) .* dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(x', target_pdf', label="original")
plot!(x', pdf_pred', label="recoverd")
Example block output

Training with plain square error loss

Now we attempt to train it with the plain square error loss function. We do not reuse the state from the previous optimization. We start over at the initial state, for the sake of comparison of the different loss functions.

@time tstate, losses, = train(tstate_org, vjp_rule, data, loss_function_esm_plain, 500)
Epoch: 25 || Loss: 1.6404572243566107
Epoch: 50 || Loss: 1.0461319834741905
Epoch: 75 || Loss: 0.6152428977363775
Epoch: 100 || Loss: 0.37437746165136226
Epoch: 125 || Loss: 0.24093628809173245
Epoch: 150 || Loss: 0.1705608385917573
Epoch: 175 || Loss: 0.12994612636312325
Epoch: 200 || Loss: 0.10238079472863867
Epoch: 225 || Loss: 0.08225582263392217
Epoch: 250 || Loss: 0.0662136887189749
Epoch: 275 || Loss: 0.09365226484930252
Epoch: 300 || Loss: 0.052680189314703564
Epoch: 325 || Loss: 0.041915331154278014
Epoch: 350 || Loss: 0.03690214374281758
Epoch: 375 || Loss: 0.03257531422315961
Epoch: 400 || Loss: 0.028878137231816245
Epoch: 425 || Loss: 0.025547490316876056
Epoch: 450 || Loss: 0.03823244697774122
Epoch: 475 || Loss: 0.022773572009922614
Epoch: 500 || Loss: 0.019275856076596393
  0.014497 seconds (119.51 k allocations: 32.316 MiB, 20.51% compilation time)

Testing out the trained model.

y_pred = Lux.apply(tstate.model, dev_cpu(x), tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 7.00408  6.90341  6.80274  6.70208  6.60141  …  -6.85726  -6.95932  -7.06138

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(x', y', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)

plot!(x', y_pred', linewidth=2, label="predicted MLP")
Example block output

And evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) * dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(x', target_pdf', label="original")
plot!(x', pdf_pred', label="recoverd")
Example block output

That is an almost perfect matching.

Training with ${\tilde J}_{\mathrm{FDSM}}({\boldsymbol{\theta}})$

Now we attempt to train it with ${\tilde J}_{\mathrm{FDSM}}$. Again we start over with the untrained state of the model.

@time tstate, losses, = train(tstate_org, vjp_rule, data, loss_function_FDSM, 500)
Epoch: 25 || Loss: -0.00801987717734566
Epoch: 50 || Loss: -0.013019028714630024
Epoch: 75 || Loss: -0.018512717053745874
Epoch: 100 || Loss: -0.02228094198567161
Epoch: 125 || Loss: -0.023724248250155983
Epoch: 150 || Loss: -0.023977730851049085
Epoch: 175 || Loss: -0.024058389932759727
Epoch: 200 || Loss: -0.024117417114578113
Epoch: 225 || Loss: -0.024160075045341723
Epoch: 250 || Loss: -0.024190869216752987
Epoch: 275 || Loss: -0.02421528977470945
Epoch: 300 || Loss: -0.024226069280766906
Epoch: 325 || Loss: -0.024246591455218652
Epoch: 350 || Loss: -0.024263590525287943
Epoch: 375 || Loss: -0.024275486399844226
Epoch: 400 || Loss: -0.024284253917517408
Epoch: 425 || Loss: -0.024286643723827814
Epoch: 450 || Loss: -0.0242971875005839
Epoch: 475 || Loss: -0.02425519735993265
Epoch: 500 || Loss: -0.02430940545784391
  0.046199 seconds (148.05 k allocations: 70.730 MiB, 37.02% gc time)

We may try a little longer from this state on.

@time tstate, losses_more, = train(tstate, vjp_rule, data, loss_function_FDSM, 500)
append!(losses, losses_more)
Epoch: 25 || Loss: -0.024314316260531047
Epoch: 50 || Loss: -0.024318056363743516
Epoch: 75 || Loss: -0.024321436419437802
Epoch: 100 || Loss: -0.024323537046868883
Epoch: 125 || Loss: -0.02432263411459379
Epoch: 150 || Loss: -0.024327275684534376
Epoch: 175 || Loss: -0.024329536918566097
Epoch: 200 || Loss: -0.024330929113513044
Epoch: 225 || Loss: -0.02433096813431871
Epoch: 250 || Loss: -0.02433375315416231
Epoch: 275 || Loss: -0.024301605060353816
Epoch: 300 || Loss: -0.0243357677965769
Epoch: 325 || Loss: -0.024335904301146773
Epoch: 350 || Loss: -0.02433656122698805
Epoch: 375 || Loss: -0.02433668776082089
Epoch: 400 || Loss: -0.024338925599734572
Epoch: 425 || Loss: -0.024336827206303015
Epoch: 450 || Loss: -0.024338350777532203
Epoch: 475 || Loss: -0.024339610693607558
Epoch: 500 || Loss: -0.024338929529610174
  0.020891 seconds (148.54 k allocations: 70.737 MiB)

Testing out the trained model.

y_pred = Lux.apply(tstate.model, dev_cpu(x), tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 7.53048  7.41566  7.30085  7.18603  7.07121  …  -6.81457  -6.91532  -7.01606

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(x', y', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)

plot!(x', y_pred', linewidth=2, label="predicted MLP")
Example block output

And evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) * dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(x', target_pdf', label="original")
plot!(x', pdf_pred', label="recoverd")
Example block output

Training with ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$

Finally we attemp to train with the sample data. This is the real thing, without anything from the supposedly unknown target distribution other than the sample data.

@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_FDSM_over_sample, 500, 20, 125)
Epoch: 25 || Loss: -0.1580686370842676
Epoch: 50 || Loss: -0.24964250186566106
Epoch: 75 || Loss: -0.3435181149382883
Epoch: 100 || Loss: -0.4186046789801589
Epoch: 125 || Loss: -0.4551414807655993
Epoch: 150 || Loss: -0.46028365054753734
Epoch: 175 || Loss: -0.4651216890914625
Epoch: 200 || Loss: -0.4669478549206047
Epoch: 225 || Loss: -0.4699121435284616
Epoch: 250 || Loss: -0.46814590731651334
Epoch: 275 || Loss: -0.467525417603086
Epoch: 300 || Loss: -0.4710916971725224
Epoch: 325 || Loss: -0.4733784560388959
Epoch: 350 || Loss: -0.47341817847673534
Epoch: 375 || Loss: -0.4748245704898255
Epoch: 400 || Loss: -0.4758741956904171
Epoch: 425 || Loss: -0.47646245614118193
Epoch: 450 || Loss: -0.47445831542103567
Epoch: 475 || Loss: -0.4648964263032016
Epoch: 500 || Loss: -0.47594087996096424
  0.120386 seconds (149.66 k allocations: 293.365 MiB, 18.36% gc time)

Testing out the trained model.

y_pred = Lux.apply(tstate.model, dev_cpu(x), tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 9.44848  9.30553  9.16258  9.01964  8.87669  …  -6.77144  -6.87191  -6.97239

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(x', y', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)

plot!(x', y_pred', linewidth=2, label="predicted MLP")
Example block output

Let us see an animation of the optimization process in this case, as well, since it is the one of interest.

Example block output

Here is the evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) * dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(x', target_pdf', label="original")
plot!(x', pdf_pred', label="recoverd")
Example block output

And the evolution of the PDF.

Example block output

Pre-training ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}{\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$ with $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$

Let us now pre-train the model with the $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$ and see if ${\tilde J}_{\mathrm{FDSM}{\tilde p}_0}({\boldsymbol{\theta}})$ improves.

tstate, = train(tstate_org, vjp_rule, data, loss_function_esm, 500)
Epoch: 25 || Loss: 0.032675831937050634
Epoch: 50 || Loss: 0.022691772992628065
Epoch: 75 || Loss: 0.011695980092717115
Epoch: 100 || Loss: 0.0041545269742722865
Epoch: 125 || Loss: 0.001266600849065558
Epoch: 150 || Loss: 0.0007529545466822072
Epoch: 175 || Loss: 0.0005941373524815312
Epoch: 200 || Loss: 0.0004803522782232856
Epoch: 225 || Loss: 0.00039198790492492606
Epoch: 250 || Loss: 0.0003257914389588754
Epoch: 275 || Loss: 0.0002743668252861494
Epoch: 300 || Loss: 0.00023285651412602946
Epoch: 325 || Loss: 0.00020334697196066498
Epoch: 350 || Loss: 0.00017471152727146773
Epoch: 375 || Loss: 0.0001520481276131909
Epoch: 400 || Loss: 0.0001333366811236779
Epoch: 425 || Loss: 0.00011750904380669638
Epoch: 450 || Loss: 0.0001036817484123565
Epoch: 475 || Loss: 0.00018705039420784494
Epoch: 500 || Loss: 8.662498033518696e-5
tstate, losses, = train(tstate, vjp_rule, data, loss_function_FDSM_over_sample, 500)
Epoch: 25 || Loss: -0.4632072985869437
Epoch: 50 || Loss: -0.4640909494884021
Epoch: 75 || Loss: -0.46590898640331585
Epoch: 100 || Loss: -0.46661665177106665
Epoch: 125 || Loss: -0.4669036102174335
Epoch: 150 || Loss: -0.4672225292467862
Epoch: 175 || Loss: -0.46679644435922474
Epoch: 200 || Loss: -0.4668625249287389
Epoch: 225 || Loss: -0.46750725017925027
Epoch: 250 || Loss: -0.46691416878797115
Epoch: 275 || Loss: -0.466527054759414
Epoch: 300 || Loss: -0.4665553999047953
Epoch: 325 || Loss: -0.4662668413859179
Epoch: 350 || Loss: -0.4665742826928974
Epoch: 375 || Loss: -0.46717195656367927
Epoch: 400 || Loss: -0.46693359959953584
Epoch: 425 || Loss: -0.4672689203067035
Epoch: 450 || Loss: -0.4669978898751888
Epoch: 475 || Loss: -0.46632362553584217
Epoch: 500 || Loss: -0.46695600974177137

Testing out the trained model.

y_pred = Lux.apply(tstate.model, dev_cpu(x), tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 6.28832  6.19808  6.10783  6.01759  5.92735  …  -6.71155  -6.8126  -6.91366

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(x', y', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)

plot!(x', y_pred', linewidth=2, label="predicted MLP")
Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) * dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(x', target_pdf', label="original")
plot!(x', pdf_pred', label="recoverd")
Example block output

And evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

The need for enough sample points

One interesting thing is that enough sample points in the low-probability transition region is required for a proper fit, as the following example with few samples illustrates.

y = target_score # just to simplify the notation
sample_points = permutedims(rand(rng, target_prob, 128))
data = (x, y, target_pdf, sample_points)
([-10.0 -9.899497487437186 … 9.899497487437186 10.0], [7.0 6.899497487437186 … -6.899497487437186 -7.0], [9.134720408364594e-13 1.8366893783972853e-12 … 1.6530204405575567e-11 8.221248367528135e-12], [1.0847168319001355 3.322724830699249 … -5.180972754540539 2.4620866466437277])
tstate, losses, = train(tstate_org, vjp_rule, data, loss_function_FDSM_over_sample, 500)
Epoch: 25 || Loss: -0.16504722700312094
Epoch: 50 || Loss: -0.2878780909181356
Epoch: 75 || Loss: -0.4376184179721708
Epoch: 100 || Loss: -0.554194545120695
Epoch: 125 || Loss: -0.6066528808048839
Epoch: 150 || Loss: -0.6161260128813595
Epoch: 175 || Loss: -0.6165745598819554
Epoch: 200 || Loss: -0.6176457771716408
Epoch: 225 || Loss: -0.6158904776313515
Epoch: 250 || Loss: -0.6179317552892093
Epoch: 275 || Loss: -0.6179934532810842
Epoch: 300 || Loss: -0.6193691250336079
Epoch: 325 || Loss: -0.6198039842271456
Epoch: 350 || Loss: -0.6207706393028439
Epoch: 375 || Loss: -0.621154960016422
Epoch: 400 || Loss: -0.621924187558537
Epoch: 425 || Loss: -0.6214588526539855
Epoch: 450 || Loss: -0.6227850340221079
Epoch: 475 || Loss: -0.6238055221149185
Epoch: 500 || Loss: -0.6222359497826381

Testing out the trained model.

y_pred = Lux.apply(tstate.model, dev_cpu(x), tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 7.07824  6.97367  6.86911  6.76454  6.65997  …  -9.06364  -9.19788  -9.33212

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(x', y', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s)', label="data", markersize=2)

plot!(x', y_pred', linewidth=2, label="predicted MLP")
Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) * dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(x', target_pdf', label="original")
plot!(x', pdf_pred', label="recoverd")
Example block output

And evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

References

  1. Aapo Hyvärinen (2005), "Estimation of non-normalized statistical models by score matching", Journal of Machine Learning Research 6, 695-709
  2. T. Pang, K. Xu, C. Li, Y. Song, S. Ermon, J. Zhu (2020), Efficient Learning of Generative Models via Finite-Difference Score Matching, NeurIPS - see also the arxiv version
  3. Eric J. Ma, A Pedagogical Introduction to Score Models, webpage, April 21, 2021 - with the associated github repo
  4. D. P. Kingma, J. Ba (2015), Adam: A Method for Stochastic Optimization, In International Conference on Learning Representations (ICLR) – see also the arxiv version
  5. Y. Song and S. Ermon (2019), "Generative modeling by estimating gradients of the data distribution", NIPS'19: Proceedings of the 33rd International Conference on Neural Information Processing Systems, no. 1067, 11918-11930