Score matching a neural network
Introduction
Aim
Apply the score-matching method of Aapo Hyvärinen (2005) to fit a neural network model of the score function to a univariate Gaussian distribution. This borrows ideas from Kingma and LeCun (2010), of using automatic differentiation to differentiate the neural network, and from Song and Ermon (2019), of modeling directly the score function, instead of the pdf or an energy potential for the pdf.
Motivation
The motivation is to revisit the original idea of Aapo Hyvärinen (2005) and see how it performs for training a neural network to model the score function.
Background
The idea of Aapo Hyvärinen (2005) is to directly fit the score function from the sample data, using a suitable implicit score matching loss function not depending on the unknown score function of the random variable. This loss function is obtained by a simple integration by parts on the explicit score matching objective function given by the expected square distance between the score of the model and the score of the unknown target distribution, also known as the Fisher divergence. The integration by parts separates the dependence on the unknown target score function from the parameters of the model, so the fitting process (minimization over the parameters of the model) does not depend on the unknown distribution.
The implicit score matching method requires, however, the derivative of the score function of the model pdf, which is costly to compute in general. In Hyvärinen's original work, all the examples considered models for which the gradient can be computed somewhat more explicitly. There was no artificial neural network involved.
In a subsequent work, Köster and Hyvärinen (2010) applied the method to fit the score function from a model probability with log-likelyhood obtained from a two-layer neural network, so that the gradient of the score function could still be expressed somehow explicitly.
After that, Kingma and LeCun (2010) considered a larger artificial neural network and used automatic differentiation to optimize the model. They also proposed a penalization term in the loss function, to regularize and stabilize the optimization process, yielding a regularized implicit score matching method. The model in Kingma and LeCun (2010) was not of the pdf directly, but of an energy potential, i.e. with
\[ p_{\boldsymbol{\theta}}(\mathbf{x}) = \frac{1}{Z(\boldsymbol{\theta})} e^{-U(\mathbf{x}; \boldsymbol{\theta})},\]
where $U(\mathbf{x}; \boldsymbol{\theta})$ is modeled after a neural network.
Finally, Song and Ermon (2019) proposed modeling directly the score function as a neural network $s(\mathbf{x}; \boldsymbol{\theta})$, i.e.
\[ \boldsymbol{\nabla}_{\mathbf{x}}p_{\boldsymbol{\theta}}(\mathbf{x}) = s(\mathbf{x}; \boldsymbol{\theta}).\]
Song and Ermon (2019), however, went further and proposed a different method (based on several perturbations of the data, each of which akin to denoising score matching). At this point, we do not address the main method proposed in Song and Ermon (2019), we only borrow the idea of modeling directly the score function instead of the pdf or an energy potential of the pdf.
In a sense, we do an analysis in hindsight, combining ideas proposed in subsequent articles, to implement the implicit score matching method in a different way. In summary, we illustrate the use of automatic differentiation to allow the application of the implicit score matching and the regularized implicit score matching methods to directly fit the score function as modeled by a neural networks.
Loss function for implicit score matching
The score-matching method of Aapo Hyvärinen (2005) aims to minimize the empirical implicit score matching loss function ${\tilde J}_{\mathrm{ISM}{\tilde p}_0}$ given by
\[ {\tilde J}_{\mathrm{ISM}{\tilde p}_0} = \frac{1}{N}\sum_{n=1}^N \left( \frac{1}{2}\|\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})\|^2 + \boldsymbol{\nabla}_{\mathbf{x}} \cdot \boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}}) \right),\]
where $(\mathbf{x}_n)_{n=1}^N$ is the sample data from a unknown target distribution and where $\boldsymbol{\psi}(\mathbf{x}_n; {\boldsymbol{\theta}})$ is a parametrized model for the desired score function.
The method rests on the idea of rewriting the explicit score matching loss function $J_{\mathrm{ESM}}({\boldsymbol{\theta}})$ (essentially the Fisher divergence) in terms of the implicit score matching loss function $J_{\mathrm{ISM}}({\boldsymbol{\theta}})$, showing that
\[J_{\mathrm{ESM}}({\boldsymbol{\theta}}) = J_{\mathrm{ISM}}({\boldsymbol{\theta}}) + C,\]
and then approximating the latter by the empirical implicit score matching loss function ${\tilde J}_{\mathrm{ISM}{\tilde p}_0}({\boldsymbol{\theta}})$.
Numerical example
We illustrate the method, numerically, to model a synthetic univariate Gaussian mixture distribution.
Julia language setup
We use the Julia programming language for the numerical simulations, with suitable packages.
Packages
using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
using Markdown
There are several Julia libraries for artificial neural networks and for automatic differentiation (AD). The most established package for artificial neural networks is the FluxML/Flux.jl library, which handles the parameters implicitly, but it is moving to explicit parameters. A newer library that handles the parameters explicitly is the LuxDL/Lux.jl library, which is taylored to the differential equations SciML ecosystem.
Since we aim to combine score-matching with neural networks and, eventually, with stochastic differential equations, we thought it was a reasonable idea to experiment with the LuxDL/Lux.jl library.
As we mentioned, the LuxDL/Lux.jl library is a newer package and not as well developed. In particular, it seems the only AD that works with it is the FluxML/Zygote.jl library. Unfortunately, the FluxML/Zygote.jl library is not so much fit to do AD on top of AD, as one can see from e.g. Zygote: Design limitations. Thus we only illustrate this with a small network on a simple univariate problem.
Reproducibility
We set the random seed for reproducibility purposes.
rng = Xoshiro(12345)
Data
We build the target model and draw samples from it.
The target model is a univariate random variable denoted by $X$ and defined by a probability distribution. Associated with that we consider its PDF and its score-function.
target_prob = MixtureModel([Normal(-3, 1), Normal(3, 1)], [0.1, 0.9])
xrange = range(-10, 10, 200)
dx = Float64(xrange.step)
xx = permutedims(collect(xrange))
target_pdf = pdf.(target_prob, xrange')
target_score = gradlogpdf.(target_prob, xrange')
lambda = 0.1
sample_points = permutedims(rand(rng, target_prob, 1024))
data = (sample_points, lambda)
([2.303077959422043 2.8428423932782843 … 3.1410080972036334 2.488464630750972], 0.1)
Visualizing the sample data drawn from the distribution and the PDF.
Visualizing the score function.
The neural network model
The neural network we consider is a simple feed-forward neural network made of a single hidden layer, obtained as a chain of a couple of dense layers. This is implemented with the LuxDL/Lux.jl package.
We will see that we don't need a big neural network in this simple example. We go as low as it works.
model = Chain(Dense(1 => 8, sigmoid), Dense(8 => 1))
Chain(
layer_1 = Dense(1 => 8, sigmoid_fast), # 16 parameters
layer_2 = Dense(8 => 1), # 9 parameters
) # Total: 25 parameters,
# plus 0 states.
The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup
function, giving us the parameters and the state of the model.
ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[-0.0008383376; -0.4411211; … ; 0.15721959; -0.22093461;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.14955823 0.4725347 … -0.6732873 -0.76267356], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Loss function
Here it is how we implement the objective ${\tilde J}_{\mathrm{ISM{\tilde p}_0}}({\boldsymbol{\theta}})$.
function loss_function_EISM_Zygote(model, ps, st, sample_points)
smodel = StatefulLuxLayer{true}(model, ps, st)
y_pred = smodel(sample_points)
dy_pred = only(Zygote.gradient(sum ∘ smodel, sample_points))
loss = mean(dy_pred .+ y_pred .^2 / 2)
return loss, smodel.st, ()
end
loss_function_EISM_Zygote (generic function with 1 method)
We also implement a regularized version as proposed by Kingma and LeCun (2010).
function loss_function_EISM_Zygote_regularized(model, ps, st, data)
sample_points, lambda = data
smodel = StatefulLuxLayer{true}(model, ps, st)
y_pred = smodel(sample_points)
dy_pred = only(Zygote.gradient(sum ∘ smodel, sample_points))
loss = mean(dy_pred .+ y_pred .^2 / 2 .+ lambda .* dy_pred .^2 )
return loss, smodel.st, ()
end
loss_function_EISM_Zygote_regularized (generic function with 1 method)
Optimization setup
Optimization method
We use the Adam optimiser.
opt = Adam(0.01)
tstate_org = Lux.Training.TrainState(rng, model, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, sigmoid_fast), layer_2 = Dense(8 => 1)), nothing)
# of parameters: 25
# of states: 0
optimizer: Adam(0.01, (0.9, 0.999), 1.0e-8)
step: 0
Automatic differentiation in the optimization
As mentioned, we setup differentiation in LuxDL/Lux.jl with the FluxML/Zygote.jl library.
vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()
Processor
We use the CPU instead of the GPU.
dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::LuxDeviceUtils.LuxCPUDevice) (generic function with 5 methods)
Check differentiation
Check if Zygote via Lux is working fine to differentiate the loss functions for training.
@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote, sample_points, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
15.322376 seconds (23.52 M allocations: 1.435 GiB, 2.07% gc time, 99.97% compilation time)
It is pretty slow to run it the first time, since it envolves compiling a specialized method for it. Remember there is already a gradient on the loss function, so this amounts to a double automatic differentiation. The subsequent times are faster, but still slow for training:
@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote, sample_points, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
0.002869 seconds (1.11 k allocations: 1.311 MiB)
Now the version with regularization.
@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote_regularized, data, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
0.883163 seconds (415.08 k allocations: 25.460 MiB, 99.60% compilation time)
@time Lux.Training.compute_gradients(vjp_rule, loss_function_EISM_Zygote_regularized, data, tstate_org)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
0.003019 seconds (1.14 k allocations: 1.360 MiB)
Training loop
Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration.
function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
losses = zeros(epochs)
tstates = [(0, tstate)]
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
loss_function, data, tstate)
if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
println("Epoch: $(epoch) || Loss: $(loss)")
end
if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
push!(tstates, (epoch, tstate))
end
losses[epoch] = loss
tstate = Lux.Training.apply_gradients(tstate, grads)
end
return tstate, losses, tstates
end
train (generic function with 3 methods)
Training
Now we train the model with the objective function ${\tilde J}_{\mathrm{ISM{\tilde p}_0}}({\boldsymbol{\theta}})$.
@time tstate, losses, tstates = train(tstate_org, vjp_rule, sample_points, loss_function_EISM_Zygote, 500, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
Epoch: 25 || Loss: -0.026121068195480806
Epoch: 50 || Loss: -0.10245648838075687
Epoch: 75 || Loss: -0.1330339498478429
Epoch: 100 || Loss: -0.15500073536203798
Epoch: 125 || Loss: -0.17861893135915577
Epoch: 150 || Loss: -0.20645410637788614
Epoch: 175 || Loss: -0.23856230210624216
Epoch: 200 || Loss: -0.27344080204253124
Epoch: 225 || Loss: -0.3085457902717294
Epoch: 250 || Loss: -0.34104238259227876
Epoch: 275 || Loss: -0.3685945677238572
Epoch: 300 || Loss: -0.38996727376574997
Epoch: 325 || Loss: -0.40519092069560814
Epoch: 350 || Loss: -0.4152452732298996
Epoch: 375 || Loss: -0.4215095802473945
Epoch: 400 || Loss: -0.42529234883270695
Epoch: 425 || Loss: -0.4275926681335935
Epoch: 450 || Loss: -0.42906658371088513
Epoch: 475 || Loss: -0.43009920286167247
Epoch: 500 || Loss: -0.4308988440394149
0.414083 seconds (379.76 k allocations: 573.619 MiB, 15.10% gc time, 15.59% compilation time)
Results
Testing out the trained model.
y_pred = Lux.apply(tstate.model, xrange', tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
0.28458 0.284291 0.283986 0.283665 … -2.93058 -2.93492 -2.93901
Visualizing the result.
plot(title="Fitting", titlefont=10)
plot!(xrange, target_score', linewidth=4, label="score function")
scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)
plot!(xx', y_pred', linewidth=2, label="predicted MLP")
Just for the fun of it, let us see an animation of the optimization process.
Recovering the PDF of the distribution from the trained score function.
paux = exp.(accumulate(+, y_pred) .* dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(xrange, target_pdf', label="original")
plot!(xrange, pdf_pred', label="recoverd")
And the animation of the evolution of the PDF.
We also visualize the evolution of the losses.
plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
Training with the regularization term
Now we train the model with the objective function ${\tilde J}_{\mathrm{ISM{\tilde p}_0}}({\boldsymbol{\theta}})$.
@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_EISM_Zygote_regularized, 500, 20, 100)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [LinearAlgebra.Adjoint{Float32, Matrix{Float32}}] x B [FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}]). Converting to common type to to attempt to use BLAS. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/ZEWr3/src/impl/matmul.jl:153
Epoch: 25 || Loss: -0.02551931140871275
Epoch: 50 || Loss: -0.10008760157708277
Epoch: 75 || Loss: -0.12863587208648372
Epoch: 100 || Loss: -0.1484644217548649
Epoch: 125 || Loss: -0.1692681361456623
Epoch: 150 || Loss: -0.19312453983160396
Epoch: 175 || Loss: -0.21978857900107593
Epoch: 200 || Loss: -0.24774625219295204
Epoch: 225 || Loss: -0.27481156751194896
Epoch: 250 || Loss: -0.29885019303018395
Epoch: 275 || Loss: -0.31839504397525503
Epoch: 300 || Loss: -0.3329624490968291
Epoch: 325 || Loss: -0.34297975648715545
Epoch: 350 || Loss: -0.34942177477126557
Epoch: 375 || Loss: -0.3533874177310288
Epoch: 400 || Loss: -0.355808535067763
Epoch: 425 || Loss: -0.3573428201315082
Epoch: 450 || Loss: -0.35839588558614544
Epoch: 475 || Loss: -0.35919382690054624
Epoch: 500 || Loss: -0.35985425434130974
0.342103 seconds (369.28 k allocations: 596.393 MiB, 7.74% gc time, 9.94% compilation time)
Results
Testing out the trained model.
y_pred = Lux.apply(tstate.model, xrange', tstate.parameters, tstate.states)[1]
1×200 Matrix{Float64}:
0.356996 0.356549 0.356079 0.355586 … -2.64119 -2.64653 -2.65158
Visualizing the result.
plot(title="Fitting", titlefont=10)
plot!(xrange, target_score', linewidth=4, label="score function")
scatter!(sample_points', s -> gradlogpdf(target_prob, s), label="data", markersize=2)
plot!(xx', y_pred', linewidth=2, label="predicted MLP")
Just for the fun of it, let us see an animation of the optimization process.
Recovering the PDF of the distribution from the trained score function.
paux = exp.(accumulate(+, y_pred) .* dx)
pdf_pred = paux ./ sum(paux) ./ dx
plot(title="Original PDF and PDF from predicted score function", titlefont=10)
plot!(xrange, target_pdf', label="original")
plot!(xrange, pdf_pred', label="recoverd")
And the animation of the evolution of the PDF.
We also visualize the evolution of the losses.
plot(losses, title="Evolution of the loss", titlefont=10, xlabel="iteration", ylabel="error", legend=false)
References
- Aapo Hyvärinen (2005), "Estimation of non-normalized statistical models by score matching", Journal of Machine Learning Research 6, 695-709
- U. Köster, A. Hyvärinen (2010), "A two-layer model of natural stimuli estimated with score matching", Neural. Comput. 22 (no. 9), 2308-33, doi: 10.1162/NECOa00010
- Durk P. Kingma, Yann Cun (2010), "Regularized estimation of image statistics by Score Matching", Advances in Neural Information Processing Systems 23 (NIPS 2010)
- Y. Song and S. Ermon (2019), "Generative modeling by estimating gradients of the data distribution", NIPS'19: Proceedings of the 33rd International Conference on Neural Information Processing Systems, no. 1067, 11918-11930