Score-based generative modeling through stochastic differential equations

Introduction

Aim

Review the work of Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) that takes a complex data distribution, adds noise to it via a stochastic differential equation and generates new samples by modeling the reverse process. It is a generalization to the continuous case of the previous discrete processes of denoising diffusion probabilistic models and multiple denoising score matching.

Background

After Aapo Hyvärinen (2005) proposed the implicit score matching to model a distribution by fitting its score function, several works followed it, including the denosing score matching of Paul Vincent (2011), which perturbed the data so the analytic expression of the score function of the perturbation could be used. Then the denoising diffusion probabilistic models, of Sohl-Dickstein, Weiss, Maheswaranathan, Ganguli (2015) and Ho, Jain, and Abbeel (2020), and the multiple denoising score matching, of Song and Ermon (2019), went one step further by adding several levels of noise, facilitating the generation process. The work of Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) extended that idea to the continuous case, adding noise via a stochastic differential equation.

Forward SDE

A initial unknown probability distribution with density $p_0=p_0(x),$ associated with a random variable $X_0,$ is embedded into the distribution of an SDE of the form

\[ \mathrm{d}X_t = f(t)X_t\;\mathrm{d}t + g(t)\;\mathrm{d}W_t,\]

with initial condition $X_0.$ The solution can be obtained with the help of the integrating factor $e^{-\int_0^t f(s)\;\mathrm{d}s}$ associated with the deterministic part of the equation. In this case,

\[ \begin{aligned} \mathrm{d}\left(X_te^{-\int_0^t f(s)\;\mathrm{d}s}\right) & = \mathrm{d}X_t e^{-\int_0^t f(s)\;\mathrm{d}s} - X_t f(t) e^{\int_0^t f(s)\;\mathrm{d}s} \;\mathrm{d}t \\ & = \left(f(t)X_t\;\mathrm{d}t + g(t)\;\mathrm{d}W_t\right)e^{-\int_0^t f(s)\;\mathrm{d}s} - X_t f(t) e^{-\int_0^t f(s)\;\mathrm{d}s} \;\mathrm{d}t \\ & = g(t)e^{-\int_0^t f(s)\;\mathrm{d}s}\;\mathrm{d}W_t. \end{aligned}\]

Integrating yields

\[ X_te^{-\int_0^t f(s)\;\mathrm{d}s} - X_0 = \int_0^t g(s)e^{-\int_0^s f(\tau)\;\mathrm{d}\tau}\;\mathrm{d}W_s.\]

Moving the exponential term to the right hand side yields the solution

\[ X_t = X_0 e^{\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)\;\mathrm{d}W_s.\]

The mean value evolves according to

\[ \mathbb{E}[X_t] = \mathbb{E}[X_0] e^{\int_0^t f(s)\;\mathrm{d}s}.\]

Using the Itô isometry, the second moment evolves with

\[ \mathbb{E}[X_t^2] = \mathbb{E}[X_0^2]e^{2\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]

Hence, the variance is given by

\[ \operatorname{Var}(X_t) = \operatorname{Var}(X_0)e^{2\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]

Thus, the probability density function $p(t, x)$ can be obtained by conditioning it at each initial point, with

\[ p(t, x) = \int_{\mathbb{R}} p(t, x | 0, x_0) p_0(x_0)\;\mathrm{d}x_0,\]

and

\[ p(t, x | 0, x_0) = \mathcal{N}(x; \mu(t)x_0, \zeta(t)^2),\]

where

\[ \mu(t) = e^{\int_0^t f(s)\;\mathrm{d}s}\]

and

\[ \zeta(t)^2 = \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]

The probability density function $p(t, x)$ can also be obtained with the help of the Fokker-Planck equation

\[ \frac{\partial p}{\partial t} + \nabla_x \cdot (f(t) p(t, x)) = \frac{1}{2}\Delta_x \left( g(t)^2 p(t, x) \right),\]

whose fundamental solutions are precisely $p(t, x | 0, x_0) = \mathcal{N}(x; \mu(t)x_0, \zeta(t)^2).$

Examples

Variance-exploding SDE

For example, in the variance-exploding case (VE SDE), as discussed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020), as the continuous limit of the Multiple Denoising Score Matching, we have

\[ f(t) = 0, \quad g(t) = \sqrt{\frac{\mathrm{d}(\sigma(t)^2)}{\mathrm{d}t}},\]

so that

\[ \mu(t) = 1\]

and

\[ \zeta(t)^2 = \int_0^t \frac{\mathrm{d}(\sigma(s)^2)}{\mathrm{d}s}\;\mathrm{d}s = \sigma(t)^2 - \sigma(0)^2.\]

Thus,

\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; 1, \sigma(t)^2 - \sigma(0)^2\right).\]

Variance-preserving SDE

In the variance-preserving case (VP SDE), as discussed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020), as the continuous limit of the Denoising Diffusion Probabilistic Model,

\[ f(t) = -\frac{1}{2}\beta(t), \quad g(t) = \sqrt{\beta(t)},\]

so that

\[ \mu(t) = e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}\]

and

\[ \zeta(t)^2 = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s = \left. -e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau} \right|_{s=0}^{s=t} = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}.\]

Thus,

\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}, 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right).\]

Sub-variance-preserving SDE

In the sub-variance-preserving case (VP SDE), proposed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) as an alternative to the previous ones,

\[ f(t) = -\frac{1}{2}\beta(t), \quad g(t) = \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)\;\mathrm{d}s})},\]

so that

\[ \mu(t) = e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}\]

and

\[ \begin{align*} \zeta(t)^2 & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)(1 - e^{-2\int_0^s \beta(\tau)\;\mathrm{d}\tau})\;\mathrm{d}s \\ & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s - \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}e^{-2\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s - \int_0^t e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \int_0^t e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \left.e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\right|_{s=0}^t \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \left(e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} - 1\right) \\ & = 1 - 2e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-2\int_0^t \beta(\tau)\;\mathrm{d}\tau} \\ & = \left(1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right)^2. \end{align*}\]

Thus,

\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}, \left(1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right)^2\right).\]

Loss function

The loss function for training is a continuous version of the loss for the multiple denoising score-matching. In that case, we had

\[ J_{\textrm{SMLD}}(\boldsymbol{\theta}) = \frac{1}{2L}\sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{p(\mathbf{x})p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}\left[ \left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) - \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_i^2} \right\|^2 \right],\]

where $\lambda = \lambda(\sigma_i)$ is a weighting factor. When too many levels are considered, one takes a stochastic approach and approximate the loss $J_{\textrm{SMLD}}(\boldsymbol{\theta})$ by

\[ J_{\textrm{SMLD}}^*(\boldsymbol{\theta}) = \frac{1}{2}\lambda(\sigma_i) \mathbb{E}_{p(\mathbf{x})p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}\left[ \left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) - \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_i^2} \right\|^2 \right],\]

with

\[ \sigma_i \sim \operatorname{Uniform}[\{1, 2, \ldots, L\}].\]

The continuous version becomes

\[ J_{\textrm{SDE}}^*(\boldsymbol{\theta}) = \frac{1}{2}\lambda(t) \mathbb{E}_{p_0(\mathbf{x}_0)p(t, \tilde{\mathbf{x}}|0, \mathbf{x}_0)}\left[ \left\| s_{\boldsymbol{\theta}}(t, \tilde{\mathbf{x}}) - \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t, \tilde{\mathbf{x}}|0, \mathbf{x}_0) \right\|^2 \right],\]

with

\[ t \sim \operatorname{Uniform}[0, T].\]

In practice, the empirical distribution is considered for $p_0(\mathbf{x}_0),$ and a stochastic approach is taken by sampling a single $\tilde{\mathbf{x}}_n \sim p(t_n, \tilde{\mathbf{x}}|0, \mathbf{x}_n),$ besides $t_n \sim \operatorname{Uniform}([0, T]).$ Thus, the loss takes the form

\[ {\tilde J}_{\textrm{SDE}}^*(\boldsymbol{\theta}) = \frac{1}{2N}\sum_{n=1}^N \lambda(t_n) \left[ \left\| s_{\boldsymbol{\theta}}(t_n, \tilde{\mathbf{x}}_n) - \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t_n, \tilde{\mathbf{x}}_n|0, \mathbf{x}_n) \right\|^2 \right],\]

with

\[ \mathbf{x}_n \sim p_0, \quad t_n \sim \operatorname{Uniform}[0, T], \quad \mathbf{x}_n \sim p(t_n, x | 0, \mathbf{x}_n).\]

The explicit form for the distribution $p(t_n, x | 0, \mathbf{x}_n)$ and its score $\boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t_n, \tilde{\mathbf{x}}_n|0, \mathbf{x}_n)$ depends on the choice of the SDE.

Numerical example

We illustrate, numerically, the use of the score-based SDE method to model a synthetic univariate Gaussian mixture distribution.

Julia language setup

We use the Julia programming language for the numerical simulations, with suitable packages.

Packages

using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
using Markdown

Reproducibility

We set the random seed for reproducibility purposes.

rng = Xoshiro(12345)

Data

We build the usual target model and draw samples from it.

Visualizing the sample data drawn from the distribution and the PDF.

Example block output

Visualizing the score function.

Example block output

Parameters

Here we set some parameters for the model and prepare any necessary data.

trange = 0.0:0.01:1.0
0.0:0.01:1.0

Variance exploding

sigma_min = 0.01
sigma_max = 1.0

f_ve(t) = 0.0
g_ve(t; σₘᵢₙ = sigma_min, σₘₐₓ = sigma_max) = σₘᵢₙ * ( σₘₐₓ / σₘᵢₙ)^t * √(2 * log(σₘₐₓ/σₘᵢₙ))

prob_kernel_ve(t, x0; σₘᵢₙ = sigma_min, σₘₐₓ = sigma_max) = Normal( x0, σₘᵢₙ^2 * (σₘₐₓ/σₘᵢₙ)^(2t) )
p_kernel_ve(t, x, x0) = pdf(prob_kernel_ve(t, x0), x)
score_kernel_ve(t, x, x0) = gradlogpdf(prob_kernel_ve(t, x, x0), x)
score_kernel_ve (generic function with 1 method)
surface(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_ve(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output
heatmap(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_ve(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output

Variance preserving

beta_min = 0.1
beta_max = 20.0

f_vp(t; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = ( βₘᵢₙ + t * ( βₘₐₓ - βₘᵢₙ ) ) / 2
g_vp(t; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = √( βₘᵢₙ + t * ( βₘₐₓ - βₘᵢₙ ) )

prob_kernel_vp(t, x0; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = Normal( x0 * exp( - t^4 * ( βₘₐₓ - βₘᵢₙ ) / 4 - t * βₘᵢₙ / 2 ), 1 - exp( - t^4 * ( βₘₐₓ - βₘᵢₙ ) / 2 - t * βₘᵢₙ ))
prob_kernel_vp (generic function with 1 method)
surface(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_vp(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output
heatmap(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_vp(t, x0), x) * pdf(target_prob, x0), xrange)))
Example block output
L = 6
sigma_1 = 2.0
sigma_L = 0.5
theta = ( sigma_L / sigma_1 )^(1/(L-1))
sigmas = [sigma_1 * theta ^ (i-1) for i in 1:L]
6-element Vector{Float64}:
 2.0
 1.515716566510398
 1.1486983549970349
 0.870550563296124
 0.659753955386447
 0.49999999999999983
data = (sample_points, sigmas)
([0.8606155918844086 -1.0314315213443432 … -0.9717983805592734 0.8976929261501944], [2.0, 1.515716566510398, 1.1486983549970349, 0.870550563296124, 0.659753955386447, 0.49999999999999983])

The neural network model

The neural network we consider is a simple feed-forward neural network made of a single hidden layer, obtained as a chain of a couple of dense layers. This is implemented with the LuxDL/Lux.jl package.

We will see that we don't need a big neural network in this simple example. We go as low as it works.

model = Chain(Dense(2 => 64, relu), Dense(64 => 1))
Chain(
    layer_1 = Dense(2 => 64, relu),     # 192 parameters
    layer_2 = Dense(64 => 1),           # 65 parameters
)         # Total: 257 parameters,
          #        plus 0 states.

The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup function, giving us the parameters and the state of the model.

ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[0.11885788 -0.17188427; -0.106466874 0.1387838; … ; 0.088112965 0.14238097; 0.29127795 0.1525562], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.078679904 0.17515078 … 0.1578623 -0.2923103], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Loss function

function loss_function_sde(model, ps, st, data)
    sample_points, sigmas = data

    noisy_sample_points = sample_points .+ sigmas .* randn(rng, size(sample_points))
    scores = ( sample_points .- noisy_sample_points ) ./ sigmas .^ 2

    flattened_noisy_sample_points = reshape(noisy_sample_points, 1, :)
    flattened_sigmas = repeat(sigmas', 1, length(sample_points))
    model_input = [flattened_noisy_sample_points; flattened_sigmas]

    y_score_pred, st = Lux.apply(model, model_input, ps, st)

    flattened_scores = reshape(scores, 1, :)

    loss = mean(abs2, flattened_sigmas .* (y_score_pred .- flattened_scores)) / 2

    return loss, st, ()
end
loss_function_sde (generic function with 1 method)

Optimization setup

Optimization method

We use the Adam optimiser.

opt = Adam(0.01)

tstate_org = Lux.Training.TrainState(rng, model, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing)
    # of parameters: 257
    # of states: 0
    optimizer: Adam(0.01, (0.9, 0.999), 1.0e-8)
    step: 0

Automatic differentiation in the optimization

As mentioned, we setup differentiation in LuxDL/Lux.jl with the FluxML/Zygote.jl library.

vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()

Processor

We use the CPU instead of the GPU.

dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::LuxDeviceUtils.LuxCPUDevice) (generic function with 5 methods)

Check differentiation

Check if Zygote via Lux is working fine to differentiate the loss functions for training.

Lux.Training.compute_gradients(vjp_rule, loss_function_sde, data, tstate_org)
((layer_1 = (weight = Float32[-0.084076755 -0.103119634; 0.064187735 0.041677713; … ; 0.2305074 0.17864294; 0.25498065 0.21438307], bias = Float32[-0.063371815; 0.026718244; … ; 0.11149582; 0.13256949;;]), layer_2 = (weight = Float32[0.22775885 0.116418496 … 0.05377945 0.32220203], bias = Float32[0.55424875;;])), 0.532838462238451, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing), (layer_1 = (weight = Float32[0.12625368 0.12188108; 0.18178563 -0.08310586; … ; 0.058616165 -0.0022444578; 0.2997266 0.05032168], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.10178894 0.0704763 … 0.24378277 0.27067676], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.01, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))

Training loop

Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration.

function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
    losses = zeros(epochs)
    tstates = [(0, tstate)]
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
            loss_function, data, tstate)
        if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
            println("Epoch: $(epoch) || Loss: $(loss)")
        end
        if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
            push!(tstates, (epoch, tstate))
        end
        losses[epoch] = loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate, losses, tstates
end
train (generic function with 3 methods)

Training

Now we train the model with the objective function ${\tilde J}_{\mathrm{ESM{\tilde p}_\sigma{\tilde p}_0}}({\boldsymbol{\theta}})$.

@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_sde, 1000, 20, 125)
Epoch: 50 || Loss: 0.2715468181871004
Epoch: 100 || Loss: 0.2695423077608621
Epoch: 150 || Loss: 0.2514633697434851
Epoch: 200 || Loss: 0.23665890168742354
Epoch: 250 || Loss: 0.2330884618254507
Epoch: 300 || Loss: 0.22884665011645136
Epoch: 350 || Loss: 0.22958715585951375
Epoch: 400 || Loss: 0.22552964253715402
Epoch: 450 || Loss: 0.22143441845904807
Epoch: 500 || Loss: 0.2210533151851991
Epoch: 550 || Loss: 0.19939301850289137
Epoch: 600 || Loss: 0.21509002089540163
Epoch: 650 || Loss: 0.23640786716906548
Epoch: 700 || Loss: 0.2181223249676826
Epoch: 750 || Loss: 0.2066386924973382
Epoch: 800 || Loss: 0.20148734916493982
Epoch: 850 || Loss: 0.19891221269388212
Epoch: 900 || Loss: 0.194704762071967
Epoch: 950 || Loss: 0.19949806411497586
Epoch: 1000 || Loss: 0.20058908662461314
  2.017295 seconds (418.17 k allocations: 10.005 GiB, 4.91% gc time, 1.51% compilation time)

Results

Checking out the trained model.

Example block output

Visualizing the result with the smallest noise.

Example block output

Recovering the PDF of the distribution from the trained score function.

Example block output

With the smallest noise.

Example block output

Just for the fun of it, let us see an animation of the optimization process.

Example block output

And the animation of the evolution of the PDF.

Example block output

We also visualize the evolution of the losses.

Example block output

Sampling with annealed Langevin

Now we sample the modeled distribution with the annealed Langevin method described earlier.

Here are the trajectories.

Example block output

The sample histogram obtained at the end of the trajectories.

Example block output

References

  1. Aapo Hyvärinen (2005), "Estimation of non-normalized statistical models by score matching", Journal of Machine Learning Research 6, 695-709
  2. Pascal Vincent (2011), "A connection between score matching and denoising autoencoders," Neural Computation, 23 (7), 1661-1674, doi:10.1162/NECOa00142
  3. J. Sohl-Dickstein, E. A. Weiss, N. Maheswaranathan, S. Ganguli (2015), "Deep unsupervised learning using nonequilibrium thermodynamics", ICML'15: Proceedings of the 32nd International Conference on International Conference on Machine Learning - Volume 37, 2256-2265
  4. Y. Song and S. Ermon (2019), "Generative modeling by estimating gradients of the data distribution", NIPS'19: Proceedings of the 33rd International Conference on Neural Information Processing Systems, no. 1067, 11918-11930
  5. J. Ho, A. Jain, P. Abbeel (2020), "Denoising diffusion probabilistic models", in Advances in Neural Information Processing Systems 33, NeurIPS2020
  6. Y. Song, J. Sohl-Dickstein, D. P. Kingma, A. Kumar, S. Ermon, B. Poole (2020), "Score-based generative modeling through stochastic differential equations", arXiv:2011.13456