Score-based generative modeling through stochastic differential equations

Introduction

Aim

Review the work of Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) that takes a complex data distribution, adds noise to it via a stochastic differential equation and generates new samples by modeling the reverse process. It is a generalization to the continuous case of the previous discrete processes of denoising diffusion probabilistic models and multiple denoising score matching.

Background

After Aapo Hyvärinen (2005) proposed the implicit score matching to model a distribution by fitting its score function, several works followed it, including the denosing score matching of Paul Vincent (2011), which perturbed the data so the analytic expression of the score function of the perturbation could be used. Then the denoising diffusion probabilistic models, of Sohl-Dickstein, Weiss, Maheswaranathan, Ganguli (2015) and Ho, Jain, and Abbeel (2020), and the multiple denoising score matching, of Song and Ermon (2019), went one step further by adding several levels of noise, facilitating the generation process. The work of Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) extended that idea to the continuous case, adding noise via a stochastic differential equation.

Forward SDE

A initial unknown probability distribution with density $p_0=p_0(x),$ associated with a random variable $X_0,$ is embedded into the distribution of an SDE of the form

\[ \mathrm{d}X_t = f(t)X_t\;\mathrm{d}t + g(t)\;\mathrm{d}W_t,\]

with initial condition $X_0.$ The solution can be obtained with the help of the integrating factor $e^{-\int_0^t f(s)\;\mathrm{d}s}$ associated with the deterministic part of the equation. In this case,

\[ \begin{aligned} \mathrm{d}\left(X_te^{-\int_0^t f(s)\;\mathrm{d}s}\right) & = \mathrm{d}X_t e^{-\int_0^t f(s)\;\mathrm{d}s} - X_t f(t) e^{\int_0^t f(s)\;\mathrm{d}s} \;\mathrm{d}t \\ & = \left(f(t)X_t\;\mathrm{d}t + g(t)\;\mathrm{d}W_t\right)e^{-\int_0^t f(s)\;\mathrm{d}s} - X_t f(t) e^{-\int_0^t f(s)\;\mathrm{d}s} \;\mathrm{d}t \\ & = g(t)e^{-\int_0^t f(s)\;\mathrm{d}s}\;\mathrm{d}W_t. \end{aligned}\]

Integrating yields

\[ X_te^{-\int_0^t f(s)\;\mathrm{d}s} - X_0 = \int_0^t g(s)e^{-\int_0^s f(\tau)\;\mathrm{d}\tau}\;\mathrm{d}W_s.\]

Moving the exponential term to the right hand side yields the solution

\[ X_t = X_0 e^{\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)\;\mathrm{d}W_s.\]

The mean value evolves according to

\[ \mathbb{E}[X_t] = \mathbb{E}[X_0] e^{\int_0^t f(s)\;\mathrm{d}s}.\]

Using the Itô isometry, the second moment evolves with

\[ \mathbb{E}[X_t^2] = \mathbb{E}[X_0^2]e^{2\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]

Hence, the variance is given by

\[ \operatorname{Var}(X_t) = \operatorname{Var}(X_0)e^{2\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]

Thus, the probability density function $p(t, x)$ can be obtained by conditioning it at each initial point, with

\[ p(t, x) = \int_{\mathbb{R}} p(t, x | 0, x_0) p_0(x_0)\;\mathrm{d}x_0,\]

and

\[ p(t, x | 0, x_0) = \mathcal{N}(x; \mu(t)x_0, \zeta(t)^2),\]

where

\[ \mu(t) = e^{\int_0^t f(s)\;\mathrm{d}s}\]

and

\[ \zeta(t)^2 = \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]

The probability density function $p(t, x)$ can also be obtained with the help of the Fokker-Planck equation

\[ \frac{\partial p}{\partial t} + \nabla_x \cdot (f(t) p(t, x)) = \frac{1}{2}\Delta_x \left( g(t)^2 p(t, x) \right),\]

whose fundamental solutions are precisely $p(t, x | 0, x_0) = \mathcal{N}(x; \mu(t)x_0, \zeta(t)^2).$

Examples

Variance-exploding SDE

For example, in the variance-exploding case (VE SDE), as discussed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020), as the continuous limit of the Multiple Denoising Score Matching, we have

\[ f(t) = 0, \quad g(t) = \sqrt{\frac{\mathrm{d}(\sigma(t)^2)}{\mathrm{d}t}},\]

so that

\[ \mu(t) = 1\]

and

\[ \zeta(t)^2 = \int_0^t \frac{\mathrm{d}(\sigma(s)^2)}{\mathrm{d}s}\;\mathrm{d}s = \sigma(t)^2 - \sigma(0)^2.\]

Thus,

\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; 1, \sigma(t)^2 - \sigma(0)^2\right).\]

Variance-preserving SDE

In the variance-preserving case (VP SDE), as discussed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020), as the continuous limit of the Denoising Diffusion Probabilistic Model,

\[ f(t) = -\frac{1}{2}\beta(t), \quad g(t) = \sqrt{\beta(t)},\]

so that

\[ \mu(t) = e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}\]

and

\[ \zeta(t)^2 = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s = \left. -e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau} \right|_{s=0}^{s=t} = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}.\]

Thus,

\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}, 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right).\]

Sub-variance-preserving SDE

In the sub-variance-preserving case (VP SDE), proposed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) as an alternative to the previous ones,

\[ f(t) = -\frac{1}{2}\beta(t), \quad g(t) = \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)\;\mathrm{d}s})},\]

so that

\[ \mu(t) = e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}\]

and

\[ \begin{align*} \zeta(t)^2 & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)(1 - e^{-2\int_0^s \beta(\tau)\;\mathrm{d}\tau})\;\mathrm{d}s \\ & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s - \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}e^{-2\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s - \int_0^t e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \int_0^t e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \left.e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\right|_{s=0}^t \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \left(e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} - 1\right) \\ & = 1 - 2e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-2\int_0^t \beta(\tau)\;\mathrm{d}\tau} \\ & = \left(1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right)^2. \end{align*}\]

Thus,

\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}, \left(1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right)^2\right).\]

Loss function

The loss function for training is a continuous version of the loss for the multiple denoising score-matching. In that case, we had

\[ J_{\textrm{SMLD}}(\boldsymbol{\theta}) = \frac{1}{2L}\sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{p(\mathbf{x})p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}\left[ \left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) - \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_i^2} \right\|^2 \right],\]

where $\lambda = \lambda(\sigma_i)$ is a weighting factor. When too many levels are considered, one takes a stochastic approach and approximate the loss $J_{\textrm{SMLD}}(\boldsymbol{\theta})$ by

\[ J_{\textrm{SMLD}}^*(\boldsymbol{\theta}) = \frac{1}{2}\lambda(\sigma_i) \mathbb{E}_{p(\mathbf{x})p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}\left[ \left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) - \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_i^2} \right\|^2 \right],\]

with

\[ \sigma_i \sim \operatorname{Uniform}[\{1, 2, \ldots, L\}].\]

The continuous version becomes

\[ J_{\textrm{SDE}}^*(\boldsymbol{\theta}) = \frac{1}{2}\lambda(t) \mathbb{E}_{p_0(\mathbf{x}_0)p(t, \tilde{\mathbf{x}}|0, \mathbf{x}_0)}\left[ \left\| s_{\boldsymbol{\theta}}(t, \tilde{\mathbf{x}}) - \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t, \tilde{\mathbf{x}}|0, \mathbf{x}_0) \right\|^2 \right],\]

with

\[ t \sim \operatorname{Uniform}[0, T].\]

In practice, the empirical distribution is considered for $p_0(\mathbf{x}_0),$ and a stochastic approach is taken by sampling a single $\tilde{\mathbf{x}}_n \sim p(t_n, \tilde{\mathbf{x}}|0, \mathbf{x}_n),$ besides $t_n \sim \operatorname{Uniform}([0, T]).$ Thus, the loss takes the form

\[ {\tilde J}_{\textrm{SDE}}^*(\boldsymbol{\theta}) = \frac{1}{2N}\sum_{n=1}^N \lambda(t_n) \left[ \left\| s_{\boldsymbol{\theta}}(t_n, \tilde{\mathbf{x}}_n) - \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t_n, \tilde{\mathbf{x}}_n|0, \mathbf{x}_n) \right\|^2 \right],\]

with

\[ \mathbf{x}_n \sim p_0, \quad t_n \sim \operatorname{Uniform}[0, T], \quad \mathbf{x}_n \sim p(t_n, x | 0, \mathbf{x}_n).\]

The explicit form for the distribution $p(t_n, x | 0, \mathbf{x}_n)$ and its score $\boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t_n, \tilde{\mathbf{x}}_n|0, \mathbf{x}_n)$ depends on the choice of the SDE.

Numerical example

We illustrate, numerically, the use of the score-based SDE method to model a synthetic univariate Gaussian mixture distribution.

Julia language setup

We use the Julia programming language for the numerical simulations, with suitable packages.

Packages

using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
using Markdown

Reproducibility

We set the random seed for reproducibility purposes.

rng = Xoshiro(12345)

Data

We build the usual target model and draw samples from it.

Visualizing the sample data drawn from the distribution and the PDF.

Example block output

Visualizing the score function.

Example block output

Parameters

Here we set some parameters for the model and prepare any necessary data.

trange = 0.0:0.01:1.0
0.0:0.01:1.0

Variance exploding

sigma_min = 0.01
sigma_max = 1.0

f_ve(t) = 0.0
g_ve(t; σₘᵢₙ = sigma_min, σₘₐₓ = sigma_max) = σₘᵢₙ * ( σₘₐₓ / σₘᵢₙ)^t * √(2 * log(σₘₐₓ/σₘᵢₙ))

prob_kernel_ve(t, x0; σₘᵢₙ = sigma_min, σₘₐₓ = sigma_max) = Normal( x0, σₘᵢₙ^2 * (σₘₐₓ/σₘᵢₙ)^(2t) )
p_kernel_ve(t, x, x0) = pdf(prob_kernel_ve(t, x0), x)
score_kernel_ve(t, x, x0) = gradlogpdf(prob_kernel_ve(t, x, x0), x)
score_kernel_ve (generic function with 1 method)
surface(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_ve(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output
heatmap(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_ve(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output

Variance preserving

beta_min = 0.1
beta_max = 20.0

f_vp(t; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = ( βₘᵢₙ + t * ( βₘₐₓ - βₘᵢₙ ) ) / 2
g_vp(t; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = √( βₘᵢₙ + t * ( βₘₐₓ - βₘᵢₙ ) )

prob_kernel_vp(t, x0; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = Normal( x0 * exp( - t^4 * ( βₘₐₓ - βₘᵢₙ ) / 4 - t * βₘᵢₙ / 2 ), 1 - exp( - t^4 * ( βₘₐₓ - βₘᵢₙ ) / 2 - t * βₘᵢₙ ))
prob_kernel_vp (generic function with 1 method)
surface(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_vp(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output
heatmap(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_vp(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output
L = 6
sigma_1 = 2.0
sigma_L = 0.5
theta = ( sigma_L / sigma_1 )^(1/(L-1))
sigmas = [sigma_1 * theta ^ (i-1) for i in 1:L]
6-element Vector{Float64}:
 2.0
 1.515716566510398
 1.1486983549970349
 0.870550563296124
 0.659753955386447
 0.49999999999999983
data = (sample_points, sigmas)
([0.8606155918844086 -1.0314315213443432 … -0.9717983805592734 0.8976929261501944], [2.0, 1.515716566510398, 1.1486983549970349, 0.870550563296124, 0.659753955386447, 0.49999999999999983])

The neural network model

The neural network we consider is a simple feed-forward neural network made of a single hidden layer, obtained as a chain of a couple of dense layers. This is implemented with the LuxDL/Lux.jl package.

We will see that we don't need a big neural network in this simple example. We go as low as it works.

model = Chain(Dense(2 => 64, relu), Dense(64 => 1))
Chain(
    layer_1 = Dense(2 => 64, relu),     # 192 parameters
    layer_2 = Dense(64 => 1),           # 65 parameters
)         # Total: 257 parameters,
          #        plus 0 states.

The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup function, giving us the parameters and the state of the model.

ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[0.96560603 -1.3963945; -0.864941 1.1274849; … ; 0.7158331 1.1567085; 2.3663533 1.2393725], bias = Float32[0.18311752, 0.4076413, 0.6408686, -0.013799805, 0.65769726, 0.27097112, -0.11596697, 0.48491788, -0.42007604, 0.118874006  …  0.1418748, -0.5943321, -0.60492563, 0.5363705, 0.31154066, -0.022509282, 0.3484151, 0.14444627, 0.36740452, -0.6803152]), layer_2 = (weight = Float32[0.09065902 0.13053486 … 0.042090528 0.21522477], bias = Float32[0.004801482])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Loss function

function loss_function_sde(model, ps, st, data)
    sample_points, sigmas = data

    noisy_sample_points = sample_points .+ sigmas .* randn(rng, size(sample_points))
    scores = ( sample_points .- noisy_sample_points ) ./ sigmas .^ 2

    flattened_noisy_sample_points = reshape(noisy_sample_points, 1, :)
    flattened_sigmas = repeat(sigmas', 1, length(sample_points))
    model_input = [flattened_noisy_sample_points; flattened_sigmas]

    y_score_pred, st = Lux.apply(model, model_input, ps, st)

    flattened_scores = reshape(scores, 1, :)

    loss = mean(abs2, flattened_sigmas .* (y_score_pred .- flattened_scores)) / 2

    return loss, st, ()
end
loss_function_sde (generic function with 1 method)

Optimization setup

Optimization method

We use the Adam optimiser.

opt = Adam(0.01)

tstate_org = Lux.Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing)
    # of parameters: 257
    # of states: 0
    optimizer: Optimisers.Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8)
    step: 0

Automatic differentiation in the optimization

As mentioned, we setup differentiation in LuxDL/Lux.jl with the FluxML/Zygote.jl library.

vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()

Processor

We use the CPU instead of the GPU.

dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Check differentiation

Check if Zygote via Lux is working fine to differentiate the loss functions for training.

Lux.Training.compute_gradients(vjp_rule, loss_function_sde, data, tstate_org)
((layer_1 = (weight = Float32[0.34248495 0.17144734; 0.3385421 0.43916073; … ; 0.23387475 0.2079672; 1.198205 1.0327464], bias = Float32[0.111951314, 0.255828, 0.05792513, -0.1414185, 0.011816772, -0.1453327, -0.040131677, 0.0, 0.17985131, 0.16945557  …  -0.20052071, -0.061180312, -0.339524, 0.113274135, -0.4008663, -0.14912026, 0.08600748, -0.020066934, 0.12535073, 0.6215555]), layer_2 = (weight = Float32[1.2331666 2.348907 … 10.78692 17.15639], bias = Float32[2.9029262])), 6.685081362917852, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing), (layer_1 = (weight = Float32[0.96560603 -1.3963945; -0.864941 1.1274849; … ; 0.7158331 1.1567085; 2.3663533 1.2393725], bias = Float32[0.18311752, 0.4076413, 0.6408686, -0.013799805, 0.65769726, 0.27097112, -0.11596697, 0.48491788, -0.42007604, 0.118874006  …  0.1418748, -0.5943321, -0.60492563, 0.5363705, 0.31154066, -0.022509282, 0.3484151, 0.14444627, 0.36740452, -0.6803152]), layer_2 = (weight = Float32[0.09065902 0.13053486 … 0.042090528 0.21522477], bias = Float32[0.004801482])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Optimisers.Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0], Float32[0.0], (0.9, 0.999))))), 0))

Training loop

Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration.

function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
    losses = zeros(epochs)
    tstates = [(0, tstate)]
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
            loss_function, data, tstate)
        if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
            println("Epoch: $(epoch) || Loss: $(loss)")
        end
        if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
            push!(tstates, (epoch, tstate))
        end
        losses[epoch] = loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate, losses, tstates
end
train (generic function with 3 methods)

Training

Now we train the model with the objective function ${\tilde J}_{\mathrm{ESM{\tilde p}_\sigma{\tilde p}_0}}({\boldsymbol{\theta}})$.

@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_sde, 1000, 20, 125)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Matrix{Float32}] x B [Matrix{Float64}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/1B1qw/src/impl/matmul.jl:190
Epoch: 50 || Loss: 0.2825422092260028
Epoch: 100 || Loss: 0.2664865870837632
Epoch: 150 || Loss: 0.27509917457510513
Epoch: 200 || Loss: 0.25550228691048
Epoch: 250 || Loss: 0.24452857389854796
Epoch: 300 || Loss: 0.23809200120701077
Epoch: 350 || Loss: 0.23584896317371198
Epoch: 400 || Loss: 0.2419833316277218
Epoch: 450 || Loss: 0.22898116619073225
Epoch: 500 || Loss: 0.23564025395575575
Epoch: 550 || Loss: 0.20653899841845824
Epoch: 600 || Loss: 0.20108098836715735
Epoch: 650 || Loss: 0.2252943786223162
Epoch: 700 || Loss: 0.21997904580626185
Epoch: 750 || Loss: 0.2125949631680312
Epoch: 800 || Loss: 0.20827904310322956
Epoch: 850 || Loss: 0.19328358064767995
Epoch: 900 || Loss: 0.20871206386816585
Epoch: 950 || Loss: 0.2064316094779668
Epoch: 1000 || Loss: 0.21487302803497235
  4.177228 seconds (511.96 k allocations: 10.010 GiB, 10.82% gc time, 0.85% compilation time)

Results

Checking out the trained model.

Example block output

Visualizing the result with the smallest noise.

Example block output

Recovering the PDF of the distribution from the trained score function.

Example block output

With the smallest noise.

Example block output

Just for the fun of it, let us see an animation of the optimization process.

Example block output

And the animation of the evolution of the PDF.

Example block output

We also visualize the evolution of the losses.

Example block output

Sampling with annealed Langevin

Now we sample the modeled distribution with the annealed Langevin method described earlier.

Here are the trajectories.

Example block output

The sample histogram obtained at the end of the trajectories.

Example block output

References

  1. Aapo Hyvärinen (2005), "Estimation of non-normalized statistical models by score matching", Journal of Machine Learning Research 6, 695-709
  2. Pascal Vincent (2011), "A connection between score matching and denoising autoencoders," Neural Computation, 23 (7), 1661-1674, doi:10.1162/NECOa00142
  3. J. Sohl-Dickstein, E. A. Weiss, N. Maheswaranathan, S. Ganguli (2015), "Deep unsupervised learning using nonequilibrium thermodynamics", ICML'15: Proceedings of the 32nd International Conference on International Conference on Machine Learning - Volume 37, 2256-2265
  4. Y. Song and S. Ermon (2019), "Generative modeling by estimating gradients of the data distribution", NIPS'19: Proceedings of the 33rd International Conference on Neural Information Processing Systems, no. 1067, 11918-11930
  5. J. Ho, A. Jain, P. Abbeel (2020), "Denoising diffusion probabilistic models", in Advances in Neural Information Processing Systems 33, NeurIPS2020
  6. Y. Song, J. Sohl-Dickstein, D. P. Kingma, A. Kumar, S. Ermon, B. Poole (2020), "Score-based generative modeling through stochastic differential equations", arXiv:2011.13456