Score-based generative modeling through stochastic differential equations
Introduction
Aim
Review the work of Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) that takes a complex data distribution, adds noise to it via a stochastic differential equation and generates new samples by modeling the reverse process. It is a generalization to the continuous case of the previous discrete processes of denoising diffusion probabilistic models and multiple denoising score matching.
Background
After Aapo Hyvärinen (2005) proposed the implicit score matching to model a distribution by fitting its score function, several works followed it, including the denosing score matching of Paul Vincent (2011), which perturbed the data so the analytic expression of the score function of the perturbation could be used. Then the denoising diffusion probabilistic models, of Sohl-Dickstein, Weiss, Maheswaranathan, Ganguli (2015) and Ho, Jain, and Abbeel (2020), and the multiple denoising score matching, of Song and Ermon (2019), went one step further by adding several levels of noise, facilitating the generation process. The work of Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) extended that idea to the continuous case, adding noise via a stochastic differential equation.
Forward SDE
A initial unknown probability distribution with density $p_0=p_0(x),$ associated with a random variable $X_0,$ is embedded into the distribution of an SDE of the form
\[ \mathrm{d}X_t = f(t)X_t\;\mathrm{d}t + g(t)\;\mathrm{d}W_t,\]
with initial condition $X_0.$ The solution can be obtained with the help of the integrating factor $e^{-\int_0^t f(s)\;\mathrm{d}s}$ associated with the deterministic part of the equation. In this case,
\[ \begin{aligned} \mathrm{d}\left(X_te^{-\int_0^t f(s)\;\mathrm{d}s}\right) & = \mathrm{d}X_t e^{-\int_0^t f(s)\;\mathrm{d}s} - X_t f(t) e^{\int_0^t f(s)\;\mathrm{d}s} \;\mathrm{d}t \\ & = \left(f(t)X_t\;\mathrm{d}t + g(t)\;\mathrm{d}W_t\right)e^{-\int_0^t f(s)\;\mathrm{d}s} - X_t f(t) e^{-\int_0^t f(s)\;\mathrm{d}s} \;\mathrm{d}t \\ & = g(t)e^{-\int_0^t f(s)\;\mathrm{d}s}\;\mathrm{d}W_t. \end{aligned}\]
Integrating yields
\[ X_te^{-\int_0^t f(s)\;\mathrm{d}s} - X_0 = \int_0^t g(s)e^{-\int_0^s f(\tau)\;\mathrm{d}\tau}\;\mathrm{d}W_s.\]
Moving the exponential term to the right hand side yields the solution
\[ X_t = X_0 e^{\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)\;\mathrm{d}W_s.\]
The mean value evolves according to
\[ \mathbb{E}[X_t] = \mathbb{E}[X_0] e^{\int_0^t f(s)\;\mathrm{d}s}.\]
Using the Itô isometry, the second moment evolves with
\[ \mathbb{E}[X_t^2] = \mathbb{E}[X_0^2]e^{2\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]
Hence, the variance is given by
\[ \operatorname{Var}(X_t) = \operatorname{Var}(X_0)e^{2\int_0^t f(s)\;\mathrm{d}s} + \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]
Thus, the probability density function $p(t, x)$ can be obtained by conditioning it at each initial point, with
\[ p(t, x) = \int_{\mathbb{R}} p(t, x | 0, x_0) p_0(x_0)\;\mathrm{d}x_0,\]
and
\[ p(t, x | 0, x_0) = \mathcal{N}(x; \mu(t)x_0, \zeta(t)^2),\]
where
\[ \mu(t) = e^{\int_0^t f(s)\;\mathrm{d}s}\]
and
\[ \zeta(t)^2 = \int_0^t e^{2\int_s^t f(\tau)\;\mathrm{d}\tau}g(s)^2\;\mathrm{d}s.\]
The probability density function $p(t, x)$ can also be obtained with the help of the Fokker-Planck equation
\[ \frac{\partial p}{\partial t} + \nabla_x \cdot (f(t) p(t, x)) = \frac{1}{2}\Delta_x \left( g(t)^2 p(t, x) \right),\]
whose fundamental solutions are precisely $p(t, x | 0, x_0) = \mathcal{N}(x; \mu(t)x_0, \zeta(t)^2).$
Examples
Variance-exploding SDE
For example, in the variance-exploding case (VE SDE), as discussed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020), as the continuous limit of the Multiple Denoising Score Matching, we have
\[ f(t) = 0, \quad g(t) = \sqrt{\frac{\mathrm{d}(\sigma(t)^2)}{\mathrm{d}t}},\]
so that
\[ \mu(t) = 1\]
and
\[ \zeta(t)^2 = \int_0^t \frac{\mathrm{d}(\sigma(s)^2)}{\mathrm{d}s}\;\mathrm{d}s = \sigma(t)^2 - \sigma(0)^2.\]
Thus,
\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; 1, \sigma(t)^2 - \sigma(0)^2\right).\]
Variance-preserving SDE
In the variance-preserving case (VP SDE), as discussed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020), as the continuous limit of the Denoising Diffusion Probabilistic Model,
\[ f(t) = -\frac{1}{2}\beta(t), \quad g(t) = \sqrt{\beta(t)},\]
so that
\[ \mu(t) = e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}\]
and
\[ \zeta(t)^2 = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s = \left. -e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau} \right|_{s=0}^{s=t} = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}.\]
Thus,
\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}, 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right).\]
Sub-variance-preserving SDE
In the sub-variance-preserving case (VP SDE), proposed in Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole (2020) as an alternative to the previous ones,
\[ f(t) = -\frac{1}{2}\beta(t), \quad g(t) = \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)\;\mathrm{d}s})},\]
so that
\[ \mu(t) = e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}\]
and
\[ \begin{align*} \zeta(t)^2 & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)(1 - e^{-2\int_0^s \beta(\tau)\;\mathrm{d}\tau})\;\mathrm{d}s \\ & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s - \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}e^{-2\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = \int_0^t e^{-\int_s^t \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s - \int_0^t e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \int_0^t e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\beta(s)\;\mathrm{d}s \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \left.e^{-\int_0^s \beta(\tau)\;\mathrm{d}\tau}\right|_{s=0}^t \\ & = 1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} \left(e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} - 1\right) \\ & = 1 - 2e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau} + e^{-2\int_0^t \beta(\tau)\;\mathrm{d}\tau} \\ & = \left(1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right)^2. \end{align*}\]
Thus,
\[ p(t, x | 0, x_0) = \mathcal{N}\left( x; e^{-\frac{1}{2}\int_0^t \beta(s)\;\mathrm{d}s}, \left(1 - e^{-\int_0^t \beta(\tau)\;\mathrm{d}\tau}\right)^2\right).\]
Loss function
The loss function for training is a continuous version of the loss for the multiple denoising score-matching. In that case, we had
\[ J_{\textrm{SMLD}}(\boldsymbol{\theta}) = \frac{1}{2L}\sum_{i=1}^L \lambda(\sigma_i) \mathbb{E}_{p(\mathbf{x})p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}\left[ \left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) - \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_i^2} \right\|^2 \right],\]
where $\lambda = \lambda(\sigma_i)$ is a weighting factor. When too many levels are considered, one takes a stochastic approach and approximate the loss $J_{\textrm{SMLD}}(\boldsymbol{\theta})$ by
\[ J_{\textrm{SMLD}}^*(\boldsymbol{\theta}) = \frac{1}{2}\lambda(\sigma_i) \mathbb{E}_{p(\mathbf{x})p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}\left[ \left\| s_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}, \sigma_i) - \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_i^2} \right\|^2 \right],\]
with
\[ \sigma_i \sim \operatorname{Uniform}[\{1, 2, \ldots, L\}].\]
The continuous version becomes
\[ J_{\textrm{SDE}}^*(\boldsymbol{\theta}) = \frac{1}{2}\lambda(t) \mathbb{E}_{p_0(\mathbf{x}_0)p(t, \tilde{\mathbf{x}}|0, \mathbf{x}_0)}\left[ \left\| s_{\boldsymbol{\theta}}(t, \tilde{\mathbf{x}}) - \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t, \tilde{\mathbf{x}}|0, \mathbf{x}_0) \right\|^2 \right],\]
with
\[ t \sim \operatorname{Uniform}[0, T].\]
In practice, the empirical distribution is considered for $p_0(\mathbf{x}_0),$ and a stochastic approach is taken by sampling a single $\tilde{\mathbf{x}}_n \sim p(t_n, \tilde{\mathbf{x}}|0, \mathbf{x}_n),$ besides $t_n \sim \operatorname{Uniform}([0, T]).$ Thus, the loss takes the form
\[ {\tilde J}_{\textrm{SDE}}^*(\boldsymbol{\theta}) = \frac{1}{2N}\sum_{n=1}^N \lambda(t_n) \left[ \left\| s_{\boldsymbol{\theta}}(t_n, \tilde{\mathbf{x}}_n) - \boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t_n, \tilde{\mathbf{x}}_n|0, \mathbf{x}_n) \right\|^2 \right],\]
with
\[ \mathbf{x}_n \sim p_0, \quad t_n \sim \operatorname{Uniform}[0, T], \quad \mathbf{x}_n \sim p(t_n, x | 0, \mathbf{x}_n).\]
The explicit form for the distribution $p(t_n, x | 0, \mathbf{x}_n)$ and its score $\boldsymbol{\nabla}_{\tilde{\mathbf{x}}} p(t_n, \tilde{\mathbf{x}}_n|0, \mathbf{x}_n)$ depends on the choice of the SDE.
Numerical example
We illustrate, numerically, the use of the score-based SDE method to model a synthetic univariate Gaussian mixture distribution.
Julia language setup
We use the Julia programming language for the numerical simulations, with suitable packages.
Packages
using StatsPlots
using Random
using Distributions
using Lux # artificial neural networks explicitly parametrized
using Optimisers
using Zygote # automatic differentiation
using Markdown
Reproducibility
We set the random seed for reproducibility purposes.
rng = Xoshiro(12345)
Data
We build the usual target model and draw samples from it.
Visualizing the sample data drawn from the distribution and the PDF.
Visualizing the score function.
Parameters
Here we set some parameters for the model and prepare any necessary data.
trange = 0.0:0.01:1.0
0.0:0.01:1.0
Variance exploding
sigma_min = 0.01
sigma_max = 1.0
f_ve(t) = 0.0
g_ve(t; σₘᵢₙ = sigma_min, σₘₐₓ = sigma_max) = σₘᵢₙ * ( σₘₐₓ / σₘᵢₙ)^t * √(2 * log(σₘₐₓ/σₘᵢₙ))
prob_kernel_ve(t, x0; σₘᵢₙ = sigma_min, σₘₐₓ = sigma_max) = Normal( x0, σₘᵢₙ^2 * (σₘₐₓ/σₘᵢₙ)^(2t) )
p_kernel_ve(t, x, x0) = pdf(prob_kernel_ve(t, x0), x)
score_kernel_ve(t, x, x0) = gradlogpdf(prob_kernel_ve(t, x, x0), x)
score_kernel_ve (generic function with 1 method)
surface(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_ve(t, x0), x) * pdf(target_prob, x0), xrange)))
heatmap(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_ve(t, x0), x) * pdf(target_prob, x0), xrange)))
Variance preserving
beta_min = 0.1
beta_max = 20.0
f_vp(t; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = ( βₘᵢₙ + t * ( βₘₐₓ - βₘᵢₙ ) ) / 2
g_vp(t; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = √( βₘᵢₙ + t * ( βₘₐₓ - βₘᵢₙ ) )
prob_kernel_vp(t, x0; βₘᵢₙ=beta_min, βₘₐₓ=beta_max) = Normal( x0 * exp( - t^4 * ( βₘₐₓ - βₘᵢₙ ) / 4 - t * βₘᵢₙ / 2 ), 1 - exp( - t^4 * ( βₘₐₓ - βₘᵢₙ ) / 2 - t * βₘᵢₙ ))
prob_kernel_vp (generic function with 1 method)
surface(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_vp(t, x0), x) * pdf(target_prob, x0), xrange)))
heatmap(trange, xrange, (t, x) -> log(sum(x0 -> pdf(prob_kernel_vp(t, x0), x) * pdf(target_prob, x0), xrange)))
L = 6
sigma_1 = 2.0
sigma_L = 0.5
theta = ( sigma_L / sigma_1 )^(1/(L-1))
sigmas = [sigma_1 * theta ^ (i-1) for i in 1:L]
6-element Vector{Float64}:
2.0
1.515716566510398
1.1486983549970349
0.870550563296124
0.659753955386447
0.49999999999999983
data = (sample_points, sigmas)
([0.8606155918844086 -1.0314315213443432 … -0.9717983805592734 0.8976929261501944], [2.0, 1.515716566510398, 1.1486983549970349, 0.870550563296124, 0.659753955386447, 0.49999999999999983])
The neural network model
The neural network we consider is a simple feed-forward neural network made of a single hidden layer, obtained as a chain of a couple of dense layers. This is implemented with the LuxDL/Lux.jl package.
We will see that we don't need a big neural network in this simple example. We go as low as it works.
model = Chain(Dense(2 => 64, relu), Dense(64 => 1))
Chain(
layer_1 = Dense(2 => 64, relu), # 192 parameters
layer_2 = Dense(64 => 1), # 65 parameters
) # Total: 257 parameters,
# plus 0 states.
The LuxDL/Lux.jl package uses explicit parameters, that are initialized (or obtained) with the Lux.setup
function, giving us the parameters and the state of the model.
ps, st = Lux.setup(rng, model) # initialize and get the parameters and states of the model
((layer_1 = (weight = Float32[0.11885788 -0.17188427; -0.106466874 0.1387838; … ; 0.088112965 0.14238097; 0.29127795 0.1525562], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.078679904 0.17515078 … 0.1578623 -0.2923103], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Loss function
function loss_function_sde(model, ps, st, data)
sample_points, sigmas = data
noisy_sample_points = sample_points .+ sigmas .* randn(rng, size(sample_points))
scores = ( sample_points .- noisy_sample_points ) ./ sigmas .^ 2
flattened_noisy_sample_points = reshape(noisy_sample_points, 1, :)
flattened_sigmas = repeat(sigmas', 1, length(sample_points))
model_input = [flattened_noisy_sample_points; flattened_sigmas]
y_score_pred, st = Lux.apply(model, model_input, ps, st)
flattened_scores = reshape(scores, 1, :)
loss = mean(abs2, flattened_sigmas .* (y_score_pred .- flattened_scores)) / 2
return loss, st, ()
end
loss_function_sde (generic function with 1 method)
Optimization setup
Optimization method
We use the Adam optimiser.
opt = Adam(0.01)
tstate_org = Lux.Training.TrainState(rng, model, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing)
# of parameters: 257
# of states: 0
optimizer: Adam(0.01, (0.9, 0.999), 1.0e-8)
step: 0
Automatic differentiation in the optimization
As mentioned, we setup differentiation in LuxDL/Lux.jl with the FluxML/Zygote.jl library.
vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()
Processor
We use the CPU instead of the GPU.
dev_cpu = cpu_device()
## dev_gpu = gpu_device()
(::LuxDeviceUtils.LuxCPUDevice) (generic function with 5 methods)
Check differentiation
Check if Zygote via Lux is working fine to differentiate the loss functions for training.
Lux.Training.compute_gradients(vjp_rule, loss_function_sde, data, tstate_org)
((layer_1 = (weight = Float32[-0.084076755 -0.103119634; 0.064187735 0.041677713; … ; 0.2305074 0.17864294; 0.25498065 0.21438307], bias = Float32[-0.063371815; 0.026718244; … ; 0.11149582; 0.13256949;;]), layer_2 = (weight = Float32[0.22775885 0.116418496 … 0.05377945 0.32220203], bias = Float32[0.55424875;;])), 0.532838462238451, (), Lux.Training.TrainState{Nothing, Nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Optimisers.Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(nothing, nothing, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}((layer_1 = Dense(2 => 64, relu), layer_2 = Dense(64 => 1)), nothing), (layer_1 = (weight = Float32[0.12625368 0.12188108; 0.18178563 -0.08310586; … ; 0.058616165 -0.0022444578; 0.2997266 0.05032168], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.10178894 0.0704763 … 0.24378277 0.27067676], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), Adam(0.01, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], Float32[0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))
Training loop
Here is the typical main training loop suggest in the LuxDL/Lux.jl tutorials, but sligthly modified to save the history of losses per iteration.
function train(tstate, vjp, data, loss_function, epochs, numshowepochs=20, numsavestates=0)
losses = zeros(epochs)
tstates = [(0, tstate)]
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
loss_function, data, tstate)
if ( epochs ≥ numshowepochs > 0 ) && rem(epoch, div(epochs, numshowepochs)) == 0
println("Epoch: $(epoch) || Loss: $(loss)")
end
if ( epochs ≥ numsavestates > 0 ) && rem(epoch, div(epochs, numsavestates)) == 0
push!(tstates, (epoch, tstate))
end
losses[epoch] = loss
tstate = Lux.Training.apply_gradients(tstate, grads)
end
return tstate, losses, tstates
end
train (generic function with 3 methods)
Training
Now we train the model with the objective function ${\tilde J}_{\mathrm{ESM{\tilde p}_\sigma{\tilde p}_0}}({\boldsymbol{\theta}})$.
@time tstate, losses, tstates = train(tstate_org, vjp_rule, data, loss_function_sde, 1000, 20, 125)
Epoch: 50 || Loss: 0.2715468181871004
Epoch: 100 || Loss: 0.2695423077608621
Epoch: 150 || Loss: 0.2514633697434851
Epoch: 200 || Loss: 0.23665890168742354
Epoch: 250 || Loss: 0.2330884618254507
Epoch: 300 || Loss: 0.22884665011645136
Epoch: 350 || Loss: 0.22958715585951375
Epoch: 400 || Loss: 0.22552964253715402
Epoch: 450 || Loss: 0.22143441845904807
Epoch: 500 || Loss: 0.2210533151851991
Epoch: 550 || Loss: 0.19939301850289137
Epoch: 600 || Loss: 0.21509002089540163
Epoch: 650 || Loss: 0.23640786716906548
Epoch: 700 || Loss: 0.2181223249676826
Epoch: 750 || Loss: 0.2066386924973382
Epoch: 800 || Loss: 0.20148734916493982
Epoch: 850 || Loss: 0.19891221269388212
Epoch: 900 || Loss: 0.194704762071967
Epoch: 950 || Loss: 0.19949806411497586
Epoch: 1000 || Loss: 0.20058908662461314
2.017295 seconds (418.17 k allocations: 10.005 GiB, 4.91% gc time, 1.51% compilation time)
Results
Checking out the trained model.
Visualizing the result with the smallest noise.
Recovering the PDF of the distribution from the trained score function.
With the smallest noise.
Just for the fun of it, let us see an animation of the optimization process.
And the animation of the evolution of the PDF.
We also visualize the evolution of the losses.
Sampling with annealed Langevin
Now we sample the modeled distribution with the annealed Langevin method described earlier.
Here are the trajectories.
The sample histogram obtained at the end of the trajectories.
References
- Aapo Hyvärinen (2005), "Estimation of non-normalized statistical models by score matching", Journal of Machine Learning Research 6, 695-709
- Pascal Vincent (2011), "A connection between score matching and denoising autoencoders," Neural Computation, 23 (7), 1661-1674, doi:10.1162/NECOa00142
- J. Sohl-Dickstein, E. A. Weiss, N. Maheswaranathan, S. Ganguli (2015), "Deep unsupervised learning using nonequilibrium thermodynamics", ICML'15: Proceedings of the 32nd International Conference on International Conference on Machine Learning - Volume 37, 2256-2265
- Y. Song and S. Ermon (2019), "Generative modeling by estimating gradients of the data distribution", NIPS'19: Proceedings of the 33rd International Conference on Neural Information Processing Systems, no. 1067, 11918-11930
- J. Ho, A. Jain, P. Abbeel (2020), "Denoising diffusion probabilistic models", in Advances in Neural Information Processing Systems 33, NeurIPS2020
- Y. Song, J. Sohl-Dickstein, D. P. Kingma, A. Kumar, S. Ermon, B. Poole (2020), "Score-based generative modeling through stochastic differential equations", arXiv:2011.13456