Multiple denoising score matching with annealed Langevin dynamics
Introduction
Aim
Review the (multiple denoising) score matching with Langevin dynamics (SMLD), which fits a noise conditional score network (NCSN), as introduced by Song and Ermon (2019), which together with DDPM was one step closer to the score-based SDE model.
Background
After Aapo Hyvärinen (2005) suggested fitting the score function of a distribution, several directions were undertaken to improve the quality of the method and make it more practical.
One of the approaches was the denoising score matching of Pascal Vincent (2011), in which the data is corrupted by a Gaussian noise and the model was trained to correctly denoise the corrupted data. The model itself would either be of the pdf itself or of an energy potential for the pdf. In any case, one would have a model for the pdf and could draw samples directly using that.
Song and Ermon (2019) came with two ideas tied together. The first idea was to model directly the score function and use the Langevin equation to draw samples from it. One difficulty with Langevin sampling, however, is in correctly estimating the weights of multimodal distributions, either superestimating or subestimating some modal regions, depending on where the initial distribution of points is located relative to the model regions. It may take a long time to reach the desired distribution.
In order to overcome that, Song and Ermon (2019) also proposed using an annealed version of Langevin dynamics, based on a scale of denoising score matching models, with different levels of noise, instead of a single denoising. Lower noises are closer to the target distribution but are challenging to the Langevin sampling, while higher noises are better for Langevin sampling but depart from the target distributions. Combining different levels of noise and gradually sampling between different denoising models improve the modeling and sampling of a distribution. That is the idea of their proposed noise conditional score network (NCSN) framework, in a method that was later denominated denosing score matching with Langevin dynamics (SMLD), and for which a more precise description would be (multiple denosing) score matching with (annealed) Langevin dynamics.
Multiple denoising score matching
The idea is to consider a sequence of denoising score matching models, starting with a relatively large noise level $\sigma_1$, to avoid the difficulties with Langevin sampling described earlier, and end up with a relatively small noise level $\sigma_L$, to minimize the noisy effect on the data.
For training, one trains directly a score model according to a weighted loss involving all noise levels.
Then, for sampling, a corresponding sequence of Langevin dynamics, with decreasing levels of noise, driving new samples closer and closer to the target distribution.
The model
More precisely, one starts with a positive geometric sequence of noise levels $\sigma_1, \ldots, \sigma_L$ satisfying
\[ \frac{\sigma_1}{\sigma_2} = \cdots = \frac{\sigma_{L-1}}{\sigma_L} > 1,\]
which is the same as
\[ \sigma_i = \theta^{i-1} \sigma_1, \quad i = 1, \ldots, L,\]
for a starting $\sigma_1 > 0$ and a rate $0 < \theta < 1$ given by $\theta = \sigma_2/\sigma_1 = \ldots = \sigma_L/\sigma_{L-1}$.
For each $\sigma=\sigma_i$, $i=1, \ldots, L$, one considers the perturbed distribution
\[ p_{\sigma}(\tilde{\mathbf{x}}) = \int_{\mathbb{R}^d} p(\mathbf{x})p_{\sigma}(\tilde{\mathbf{x}}|\mathbf{x})\;\mathrm{d}\mathbf{x},\]
with a perturbation kernel
\[ p_{\sigma}(\tilde{\mathbf{x}}|\mathbf{x}) = \mathcal{N}\left(\tilde{\mathbf{x}}; \mathbf{x}, \sigma^2 \mathbf{I}\right).\]
This yields a sequence of perturbed distributions
\[ \{p_{\sigma_i}\}_{i=1}^L.\]
We model the corresponding family of score functions $\{s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i)\}$, i.e. such that $s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i)$ approximates the score function of $p_{\sigma_i}$, i.e.
\[ s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) \approx \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} \log p_{\sigma_i}(\tilde{\mathbf{x}}).\]
The noise conditional score network (NCSN) is precisely
\[ s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma).\]
The loss function
One wants to train the noise conditional score network $s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma)$ by weighting together the denosing loss function of each perturbation, i.e.
\[ J_{\textrm{SMLD}}(\boldsymbol{\theta}) = \frac{1}{2L}\sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{p(\mathbf{x})p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}\left[ \left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) - \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_i^2} \right\|^2 \right],\]
where $\lambda = \lambda(\sigma_i)$ is a weighting factor.
In practice, we use the empirical distribution and a single corrupted data for each sample data, i.e.
\[ {\tilde J}_{\textrm{SMLD}}(\boldsymbol{\theta}) = \frac{1}{2LN} \sum_{n=1}^N \sum_{i=1}^L \lambda(\sigma_i)\left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}_{n, i}, \sigma_i) - \frac{\mathbf{x}_n - \tilde{\mathbf{x}}_{n, i}}{\sigma_i^2} \right\|^2, \quad \tilde{\mathbf{x}}_{n, i} \sim \mathcal{N}\left(\mathbf{x}_n, \sigma^2 \mathbf{I}\right).\]
This can also be written with a reparametrization,
\[ {\tilde J}_{\textrm{SMLD}}(\boldsymbol{\theta}) = \frac{1}{2LN} \sum_{n=1}^N \sum_{i=1}^L \lambda(\sigma_i) \left\| s_{\boldsymbol{\theta}}(\mathbf{x}_n + \boldsymbol{\epsilon}_{n, i}, \sigma_i) + \frac{\boldsymbol{\epsilon}_{n, i}}{\sigma_i} \right\|^2, \quad \boldsymbol{\epsilon}_{n, i} \sim \mathcal{N}\left(\mathbf{0}_n, \mathbf{I}\right).\]
As for the choice of $\lambda(\sigma)$, Song and Ermon (2019) suggested choosing
\[ \lambda(\sigma) = \sigma^2.\]
This comes from the observation that,
\[ \|s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}_n, \sigma_i)\|^2 \sim \frac{1}{\sigma_i},\]
hence
\[ \lambda(\sigma_i)\left\| s_{\boldsymbol{\theta}}(\mathbf{x}_n + \boldsymbol{\epsilon}_{n, i}, \sigma_i) + \frac{\boldsymbol{\epsilon}_{n, i}}{\sigma_i} \right\|^2 \sim 1\]
is independent of $i=1, \ldots, L$. Choosing such weighting, the loss function becomes
\[ {\tilde J}_{\textrm{SMLD}}(\boldsymbol{\theta}) = \frac{1}{2LN} \sum_{n=1}^N \sum_{i=1}^L \left\| \sigma_i s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}_{n, i}, \sigma_i) - (\mathbf{x}_n - \tilde{\mathbf{x}}_{n, i})\right\|^2, \quad \tilde{\mathbf{x}}_{n, i} \sim \mathcal{N}\left(\mathbf{x}_n, \sigma^2 \mathbf{I}\right).\]
Sampling
For each $i=1, \ldots, L$, the dynamics of the overdamped Langevin equation
\[ \mathrm{d}X_t = \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} \log p_{\sigma_i}(\tilde{\mathbf{X}}_t)\;\mathrm{d}t + \sqrt{2}\;\mathrm{d}W_t\]
drives any initial sample towards the distribution defined by $p_{\sigma_i}$. With $s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i)$ being an approximation of $\boldsymbol{\nabla}_{\tilde{\mathbf{x}}} \log p_{\sigma_i}(\tilde{\mathbf{x}})$ and with $p_{\sigma_i}(\tilde{\mathbf{x}})$ being closer to the target $p(\mathbf{x})$ the smaller the $\sigma_i$, the idea is to run batches of Langevin dynamics for decreasing values of noise, i.e. for $\sigma_1$ down to $\sigma_L$.
More precisely, given $K\in\mathbb{N}$, we run the Langevin equation for $K$ steps, for each $i=1, \ldots, L$:
1. Start with a $M$ sample points $\mathbf{y}_m$, $m=1, \ldots, M$, $M\in\mathbb{N}$, of a multivariate Normal distribution, or a uniform distribution, or any other known distribution.
2. Then for each $i=1, \ldots, L$, run the discretized overdamped Langevin equation for $K$ steps
\[ \mathbf{y}^i_{m, k} = \mathbf{y}^i_{m, k-1} + s_{\boldsymbol{\theta}}(\tilde{\mathbf{y}}^{i-1}_{m, k-1}, \sigma_i) \tau_i + \sqrt{2\tau_i}\mathbf{z}^i_{m, k},\]
where $\tau_i > 0$ is a given time step (which may or may not vary with $i$); the $\mathbf{z}^i_{m, k} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ are independent; and the initial conditions are given by
\[ \mathbf{y}^1_{m, 0} = \mathbf{y}_m,\]
for $i=1$, and
\[ \mathbf{y}^i_{m, 0} = \mathbf{y}^{i-1}_{m, K},\]
for $i = 2, \ldots, L$, i.e. the final $K$th-step of the solution of the Langevin equation with a given $i = 1, \ldots, K-1$ is the initial step of the Langevin equation for $i+1$.
3. The final points $\mathbf{y}^L_{m, K}$, $m=1, \ldots, M$, are the $M$ desired new generated samples of the distribution approximating the data distribution.
Numerical example
We illustrate, numerically, the use of multiple denoising score matching to model a synthetic univariate Gaussian mixture distribution.
Julia language setup
We use the Julia programming language for the numerical simulations, with suitable packages.
Packages
using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
using Markdown
Reproducibility
We set the random seed for reproducibility purposes.
rng = Xoshiro(12345)
Data
We build the usual target model and draw samples from it.
Visualizing the sample data drawn from the distribution and the PDF.
Visualizing the score function.
Parameters
Here we set some parameters for the model and prepare any necessary data.
L = 16
sigma_1 = 2.0
sigma_L = 0.5
theta = ( sigma_L / sigma_1 )^(1/(L-1))
sigmas = [sigma_1 * theta ^ (i-1) for i in 1:L]
16-element Vector{Float64}:
2.0
1.8234449771164336
1.6624757922855755
1.515716566510398
1.381912879967776
1.2599210498948732
1.1486983549970349
1.0472941228206267
0.9548416039104165
0.8705505632961241
0.7937005259840997
0.723634618720189
0.6597539553864471
0.6015125180410583
0.548412489847313
0.49999999999999994
data = (sample_points, sigmas)
([2.303077959422043 2.8428423932782843 … 3.1410080972036334 2.488464630750972], [2.0, 1.8234449771164336, 1.6624757922855755, 1.515716566510398, 1.381912879967776, 1.2599210498948732, 1.1486983549970349, 1.0472941228206267, 0.9548416039104165, 0.8705505632961241, 0.7937005259840997, 0.723634618720189, 0.6597539553864471, 0.6015125180410583, 0.548412489847313, 0.49999999999999994])
The neural network model
The neural network we consider is a simple feed-forward neural network made of a single hidden layer, obtained as a chain of a couple of dense layers. This is implemented with the LuxDL/Lux.jl package.
We will see that we don't need a big neural network in this simple example. We go as low as it works.
model = Chain(Dense(2 => 64, relu), Dense(64 => 1))
Chain(
layer_1 = Dense(2 => 64, relu), # 192 parameters
layer_2 = Dense(64 => 1), # 65 parameters
) # Total: 257 parameters,
# plus 0 states.
The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup
function, giving us the parameters and the state of the model.
ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[0.11885788 -0.17188427; -0.106466874 0.1387838; … ; 0.088112965 0.14238097; 0.29127795 0.1525562], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.078679904 0.17515078 … 0.1578623 -0.2923103], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Loss function
function loss_function_mdsm(model, ps, st, data)
sample_points, sigmas = data
noisy_sample_points = sample_points .+ sigmas .* randn(rng, size(sample_points))
scores = ( sample_points .- noisy_sample_points ) ./ sigmas .^ 2
flattened_noisy_sample_points = reshape(noisy_sample_points, 1, :)
flattened_sigmas = repeat(sigmas', 1, length(sample_points))
model_input = [flattened_noisy_sample_points; flattened_sigmas]
y_score_pred, st = Lux.apply(model, model_input, ps, st)
flattened_scores = reshape(scores, 1, :)
loss = mean(abs2, flattened_sigmas .* (y_score_pred .- flattened_scores)) / 2
return loss, st, ()
end
loss_function_mdsm (generic function with 1 method)
Optimization setup
Optimization method
We use the Adam optimiser.
opt = Adam(0.01)
tstate_org = Lux.Training.TrainState(rng, model, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing)
# of parameters: 257
# of states: 0
optimizer: Adam(0.01, (0.9, 0.999), 1.0e-8)
step: 0
Automatic differentiation in the optimization
As mentioned, we setup differentiation in LuxDL/Lux.jl with the FluxML/Zygote.jl library.
vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()
Processor
We use the CPU instead of the GPU.
dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::LuxDeviceUtils.LuxCPUDevice) (generic function with 5 methods)
Check differentiation
Check if Zygote via Lux is working fine to differentiate the loss functions for training.
Lux.Training.compute_gradients(vjp_rule, loss_function_mdsm, data, tstate_org)
((layer_1 = (weight = Float32[-0.16558973 -0.041655425; 0.115008354 0.030782552; … ; 0.39658588 0.10170903; 0.4404184 0.11116221], bias = Float32[-0.024725612; 0.018824441; … ; 0.061242532; 0.06660229;;]), layer_2 = (weight = Float32[0.25526664 0.26035213 … 0.09442038 0.5083516], bias = Float32[0.2629434;;])), 0.4839145961178523, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing), (layer_1 = (weight = Float32[0.12625368 0.12188108; 0.18178563 -0.08310586; … ; 0.058616165 -0.0022444578; 0.2997266 0.05032168], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.10178894 0.0704763 … 0.24378277 0.27067676], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.01, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))
Training loop
Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration.
function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
losses = zeros(epochs)
tstates = [(0, tstate)]
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
loss_function, data, tstate)
if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
println("Epoch: $(epoch) || Loss: $(loss)")
end
if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
push!(tstates, (epoch, tstate))
end
losses[epoch] = loss
tstate = Lux.Training.apply_gradients(tstate, grads)
end
return tstate, losses, tstates
end
train (generic function with 3 methods)
Training
Now we train the model with the objective function ${\tilde J}_{\mathrm{ESM{\tilde p}_\sigma{\tilde p}_0}}({\boldsymbol{\theta}})$.
@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_mdsm, 1000, 20, 125)
Epoch: 50 || Loss: 0.3273645129483826
Epoch: 100 || Loss: 0.3337394634350786
Epoch: 150 || Loss: 0.323018634313816
Epoch: 200 || Loss: 0.3107805815721189
Epoch: 250 || Loss: 0.2921578105878577
Epoch: 300 || Loss: 0.2997176522570383
Epoch: 350 || Loss: 0.2819946712084884
Epoch: 400 || Loss: 0.31059622850316665
Epoch: 450 || Loss: 0.302095911993006
Epoch: 500 || Loss: 0.3041323161998701
Epoch: 550 || Loss: 0.2913061692857132
Epoch: 600 || Loss: 0.30119762097901603
Epoch: 650 || Loss: 0.3248521166140513
Epoch: 700 || Loss: 0.3081915535443235
Epoch: 750 || Loss: 0.2926553609360526
Epoch: 800 || Loss: 0.283543851027674
Epoch: 850 || Loss: 0.2899438329188816
Epoch: 900 || Loss: 0.262881496007787
Epoch: 950 || Loss: 0.28384273036379404
Epoch: 1000 || Loss: 0.29185283449284694
6.104041 seconds (431.80 k allocations: 26.562 GiB, 4.33% gc time, 0.55% compilation time)
Results
Checking out the trained model.
Visualizing the result with the smallest noise.
Recovering the PDF of the distribution from the trained score function.
With the smallest noise.
Just for the fun of it, let us see an animation of the optimization process.
And the animation of the evolution of the PDF.
We also visualize the evolution of the losses.
Sampling with annealed Langevin
Now we sample the modeled distribution with the annealed Langevin method described earlier.
Here are the trajectories.
The sample histogram obtained at the end of the trajectories.
References
- Aapo Hyvärinen (2005), "Estimation of non-normalized statistical models by score matching", Journal of Machine Learning Research 6, 695-709
- Pascal Vincent (2011), "A connection between score matching and denoising autoencoders," Neural Computation, 23 (7), 1661-1674, doi:10.1162/NECOa00142
- Y. Song and S. Ermon (2019), "Generative modeling by estimating gradients of the data distribution", NIPS'19: Proceedings of the 33rd International Conference on Neural Information Processing Systems, no. 1067, 11918-11930