Score matching a neural network

Introduction

Aim

Apply the score-matching method of Aapo Hyvärinen (2005) to fit a neural network model of the score function to a univariate Gaussian distribution. This borrows ideas from Kingma and LeCun (2010), of using automatic differentiation to differentiate the neural network, and from Song and Ermon (2019), of modeling directly the score function, instead of the pdf or an energy potential for the pdf.

Motivation

The motivation is to revisit the original idea of Aapo Hyvärinen (2005) and see how it performs for training a neural network to model the score function.

Background

The idea of Aapo Hyvärinen (2005) is to directly fit the score function from the sample data, using a suitable implicit score matching loss function not depending on the unknown score function of the random variable. This loss function is obtained by a simple integration by parts on the explicit score matching objective function given by the expected square distance between the score of the model and the score of the unknown target distribution, also known as the Fisher divergence. The integration by parts separates the dependence on the unknown target score function from the parameters of the model, so the fitting process (minimization over the parameters of the model) does not depend on the unknown distribution.

The implicit score matching method requires, however, the derivative of the score function of the model pdf, which is costly to compute in general. In Hyvärinen's original work, all the examples considered models for which the gradient can be computed somewhat more explicitly. There was no artificial neural network involved.

In a subsequent work, Köster and Hyvärinen (2010) applied the method to fit the score function from a model probability with log-likelyhood obtained from a two-layer neural network, so that the gradient of the score function could still be expressed somehow explicitly.

After that, Kingma and LeCun (2010) considered a larger artificial neural network and used automatic differentiation to optimize the model. They also proposed a penalization term in the loss function, to regularize and stabilize the optimization process, yielding a regularized implicit score matching method. The model in Kingma and LeCun (2010) was not of the pdf directly, but of an energy potential, i.e. with

\[ p_{\boldsymbol{\theta}}(\mathbf{x}) = \frac{1}{Z(\boldsymbol{\theta})} e^{-U(\mathbf{x}; \boldsymbol{\theta})},\]

where $U(\mathbf{x}; \boldsymbol{\theta})$ is modeled after a neural network.

Finally, Song and Ermon (2019) proposed modeling directly the score function as a neural network $s(\mathbf{x}; \boldsymbol{\theta})$, i.e.

\[ \boldsymbol{\nabla}_{\mathbf{x}}p_{\boldsymbol{\theta}}(\mathbf{x}) = s(\mathbf{x}; \boldsymbol{\theta}).\]

Song and Ermon (2019), however, went further and proposed a different method (based on several perturbations of the data, each of which akin to denoising score matching). At this point, we do not address the main method proposed in Song and Ermon (2019), we only borrow the idea of modeling directly the score function instead of the pdf or an energy potential of the pdf.

In a sense, we do an analysis in hindsight, combining ideas proposed in subsequent articles, to implement the implicit score matching method in a different way. In summary, we illustrate the use of automatic differentiation to allow the application of the implicit score matching and the regularized implicit score matching methods to directly fit the score function as modeled by a neural networks.

Loss function for implicit score matching

The score-matching method of Aapo Hyvärinen (2005) aims to minimize the empirical implicit score matching loss function ${\tilde J}_{\mathrm{ISM}{\tilde p}_0}$ given by

\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) \right),\]

where $(\mathbf{x}_n)_{n=1}^N$ is the sample data from a unknown target distribution and where $\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})$ is a parametrized model for the desired score function.

The method rests on the idea of rewriting the explicit score matching loss function $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$ (essentially the Fisher divergence) in terms of the implicit score matching loss function $J_{\mathrm{ISM}}({\boldsymbol{\theta}})$, showing that

\[J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C,\]

and then approximating the latter by the empirical implicit score matching loss function ${\tilde J}_{\mathrm{ISM}{\tilde p}_0}({\boldsymbol{\theta}})$.

Numerical example

We illustrate the method, numerically, to model a synthetic univariate Gaussian mixture distribution.

Julia language setup

We use the Julia programming language for the numerical simulations, with suitable packages.

Packages

using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
using Markdown

There are several Julia libraries for artificial neural networks and for automatic differentiation (AD). The most established package for artificial neural networks is the FluxML/Flux.jl library, which handles the parameters implicitly, but it is moving to explicit parameters. A newer library that handles the parameters explicitly is the LuxDL/Lux.jl library, which is taylored to the differential equations SciML ecosystem.

Since we aim to combine score-matching with neural networks and, eventually, with stochastic differential equations, we thought it was a reasonable idea to experiment with the LuxDL/Lux.jl library.

As we mentioned, the LuxDL/Lux.jl library is a newer package and not as well developed. In particular, it seems the only AD that works with it is the FluxML/Zygote.jl library. Unfortunately, the FluxML/Zygote.jl library is not so much fit to do AD on top of AD, as one can see from e.g. Zygote: Design limitations. Thus we only illustrate this with a small network on a simple univariate problem.

Reproducibility

We set the random seed for reproducibility purposes.

rng = Xoshiro(12345)

Data

We build the target model and draw samples from it.

The target model is a univariate random variable denoted by $X$ and defined by a probability distribution. Associated with that we consider its PDF and its score-function.

target_prob = MixtureModel([Normal(-3, 1), Normal(3, 1)], [0.1, 0.9])

xrange = range(-10, 10, 200)
dx = Float64(xrange.step)
xx = permutedims(collect(xrange))
target_pdf = pdf.(target_prob, xrange')
target_score = gradlogpdf.(target_prob, xrange')

lambda = 0.1
sample_points = permutedims(rand(rng, target_prob, 1024))
data = (sample_points, lambda)
([2.303077959422043 2.8428423932782843 … 3.1410080972036334 2.488464630750972], 0.1)

Visualizing the sample data drawn from the distribution and the PDF.

Example block output

Visualizing the score function.

Example block output

The neural network model

The neural network we consider is a simple feed-forward neural network made of a single hidden layer, obtained as a chain of a couple of dense layers. This is implemented with the LuxDL/Lux.jl package.

We will see that we don't need a big neural network in this simple example. We go as low as it works.

model = Chain(Dense(1 => 8, sigmoid), Dense(8 => 1))
Chain(
    layer_1 = Dense(1 => 8, sigmoid_fast),  # 16 parameters
    layer_2 = Dense(8 => 1),            # 9 parameters
)         # Total: 25 parameters,
          #        plus 0 states.

The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup function, giving us the parameters and the state of the model.

ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[-0.0008383376; -0.4411211; … ; 0.15721959; -0.22093461;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.14955823 0.4725347 … -0.6732873 -0.76267356], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Loss function

Here it is how we implement the objective ${\tilde J}_{\mathrm{ISM{\tilde p}_0}}({\boldsymbol{\theta}})$.

function loss_function_EISM_Zygote(model, ps, st, sample_points)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    y_pred = smodel(sample_points)
    dy_pred = only(Zygote.gradient(sum ∘ smodel, sample_points))
    loss = mean(dy_pred .+ y_pred .^2 / 2)
    return loss, smodel.st, ()
end
loss_function_EISM_Zygote (generic function with 1 method)

We also implement a regularized version as proposed by Kingma and LeCun (2010).

function loss_function_EISM_Zygote_regularized(model, ps, st, data)
    sample_points, lambda = data
    smodel = StatefulLuxLayer{true}(model, ps, st)
    y_pred = smodel(sample_points)
    dy_pred = only(Zygote.gradient(sum ∘ smodel, sample_points))
    loss = mean(dy_pred .+ y_pred .^2 / 2 .+ lambda .* dy_pred .^2 )
    return loss, smodel.st, ()
end
loss_function_EISM_Zygote_regularized (generic function with 1 method)

Optimization setup

Optimization method

We use the Adam optimiser.

opt = Adam(0.01)

tstate_org = Lux.Training.TrainState(rng, model, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, sigmoid_fast), layer_2 = Dense(8 => 1)), nothing)
    # of parameters: 25
    # of states: 0
    optimizer: Adam(0.01, (0.9, 0.999), 1.0e-8)
    step: 0

Automatic differentiation in the optimization

As mentioned, we setup differentiation in LuxDL/Lux.jl with the FluxML/Zygote.jl library.

vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()

Processor

We use the CPU instead of the GPU.

dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::LuxDeviceUtils.LuxCPUDevice) (generic function with 5 methods)

Check differentiation

Check if Zygote via Lux is working fine to differentiate the loss functions for training.

@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote, sample_points, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
 15.322376 seconds (23.52 M allocations: 1.435 GiB, 2.07% gc time, 99.97% compilation time)

It is pretty slow to run it the first time, since it envolves compiling a specialized method for it. Remember there is already a gradient on the loss function, so this amounts to a double automatic differentiation. The subsequent times are faster, but still slow for training:

@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote, sample_points, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
  0.002869 seconds (1.11 k allocations: 1.311 MiB)

Now the version with regularization.

@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote_regularized, data, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
  0.883163 seconds (415.08 k allocations: 25.460 MiB, 99.60% compilation time)
@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote_regularized, data, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
  0.003019 seconds (1.14 k allocations: 1.360 MiB)

Training loop

Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration.

function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
    losses = zeros(epochs)
    tstates = [(0, tstate)]
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
            loss_function, data, tstate)
        if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
            println("Epoch: $(epoch) || Loss: $(loss)")
        end
        if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
            push!(tstates, (epoch, tstate))
        end
        losses[epoch] = loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate, losses, tstates
end
train (generic function with 3 methods)

Training

Now we train the model with the objective function ${\tilde J}_{\mathrm{ISM{\tilde p}_0}}({\boldsymbol{\theta}})$.

@time tstate, losses, tstates = train(tstate_org, vjp_rule, sample_points, loss_function_EISM_Zygote, 500, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
Epoch: 25 || Loss: -0.026121068195480806
Epoch: 50 || Loss: -0.10245648838075687
Epoch: 75 || Loss: -0.1330339498478429
Epoch: 100 || Loss: -0.15500073536203798
Epoch: 125 || Loss: -0.17861893135915577
Epoch: 150 || Loss: -0.20645410637788614
Epoch: 175 || Loss: -0.23856230210624216
Epoch: 200 || Loss: -0.27344080204253124
Epoch: 225 || Loss: -0.3085457902717294
Epoch: 250 || Loss: -0.34104238259227876
Epoch: 275 || Loss: -0.3685945677238572
Epoch: 300 || Loss: -0.38996727376574997
Epoch: 325 || Loss: -0.40519092069560814
Epoch: 350 || Loss: -0.4152452732298996
Epoch: 375 || Loss: -0.4215095802473945
Epoch: 400 || Loss: -0.42529234883270695
Epoch: 425 || Loss: -0.4275926681335935
Epoch: 450 || Loss: -0.42906658371088513
Epoch: 475 || Loss: -0.43009920286167247
Epoch: 500 || Loss: -0.4308988440394149
  0.414083 seconds (379.76 k allocations: 573.619 MiB, 15.10% gc time, 15.59% compilation time)

Results

Testing out the trained model.

y_pred = Lux.apply(tstate.model, xrange', tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 0.28458  0.284291  0.283986  0.283665  …  -2.93058  -2.93492  -2.93901

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(xrange, target_score', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)

plot!(xx', y_pred', linewidth=2, label="predicted MLP")
Example block output

Just for the fun of it, let us see an animation of the optimization process.

Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) .* dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(xrange, target_pdf', label="original")
plot!(xrange, pdf_pred', label="recoverd")
Example block output

And the animation of the evolution of the PDF.

Example block output

We also visualize the evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

Training with the regularization term

Now we train the model with the objective function ${\tilde J}_{\mathrm{ISM{\tilde p}_0}}({\boldsymbol{\theta}})$.

@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_EISM_Zygote_regularized, 500, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
Epoch: 25 || Loss: -0.02551931140871275
Epoch: 50 || Loss: -0.10008760157708277
Epoch: 75 || Loss: -0.12863587208648372
Epoch: 100 || Loss: -0.1484644217548649
Epoch: 125 || Loss: -0.1692681361456623
Epoch: 150 || Loss: -0.19312453983160396
Epoch: 175 || Loss: -0.21978857900107593
Epoch: 200 || Loss: -0.24774625219295204
Epoch: 225 || Loss: -0.27481156751194896
Epoch: 250 || Loss: -0.29885019303018395
Epoch: 275 || Loss: -0.31839504397525503
Epoch: 300 || Loss: -0.3329624490968291
Epoch: 325 || Loss: -0.34297975648715545
Epoch: 350 || Loss: -0.34942177477126557
Epoch: 375 || Loss: -0.3533874177310288
Epoch: 400 || Loss: -0.355808535067763
Epoch: 425 || Loss: -0.3573428201315082
Epoch: 450 || Loss: -0.35839588558614544
Epoch: 475 || Loss: -0.35919382690054624
Epoch: 500 || Loss: -0.35985425434130974
  0.342103 seconds (369.28 k allocations: 596.393 MiB, 7.74% gc time, 9.94% compilation time)

Results

Testing out the trained model.

y_pred = Lux.apply(tstate.model, xrange', tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
 0.356996  0.356549  0.356079  0.355586  …  -2.64119  -2.64653  -2.65158

Visualizing the result.

plot(title="Fitting", titlefont=10)

plot!(xrange, target_score', linewidth=4, label="score function")

scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)

plot!(xx', y_pred', linewidth=2, label="predicted MLP")
Example block output

Just for the fun of it, let us see an animation of the optimization process.

Example block output

Recovering the PDF of the distribution from the trained score function.

paux = exp.(accumulate(+, y_pred) .* dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(xrange, target_pdf', label="original")
plot!(xrange, pdf_pred', label="recoverd")
Example block output

And the animation of the evolution of the PDF.

Example block output

We also visualize the evolution of the losses.

plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Example block output

References

  1. Aapo Hyvärinen (2005), "Estimation of non-normalized statistical models by score matching", Journal of Machine Learning Research 6, 695-709
  2. U. Köster, A. Hyvärinen (2010), "A two-layer model of natural stimuli estimated with score matching", Neural. Comput. 22 (no. 9), 2308-33, doi: 10.1162/NECOa00010
  3. Durk P. Kingma, Yann Cun (2010), "Regularized estimation of image statistics by Score Matching", Advances in Neural Information Processing Systems 23 (NIPS 2010)
  4. Y. Song and S. Ermon (2019), "Generative modeling by estimating gradients of the data distribution", NIPS'19: Proceedings of the 33rd International Conference on Neural Information Processing Systems, no. 1067, 11918-11930