Stein score function

Aim

Revisit the origin of the Stein score function, which is the basis of score-based generative models.

The Stein score function

Given a random variable $\mathbf{X}$ in $\mathbb{R}^d$, $d\in\mathbb{N},$ we denote its pdf by $p_{\mathbf{X}}(\mathbf{x})$, while its (Stein) score function, also known as gradlogpdf, is defined by

\[ \boldsymbol{\psi}_{\mathbf{X}}(\mathbf{x}) = \boldsymbol{\nabla}_{\mathbf{x}} \log p_{\mathbf{X}}(\mathbf{x}) = \frac{\boldsymbol{\partial}\log p_{\mathbf{X}}(\mathbf{x})}{\boldsymbol{\partial}\mathbf{x}} = \left( \frac{\partial}{\partial x_j} \log p_{\mathbf{X}}(\mathbf{x})\right)_{j=1, \ldots, d},\]

where we may use either notation $\boldsymbol{\nabla}_{\mathbf{x}}$ or ${\boldsymbol{\partial}}/{\boldsymbol{\partial}\mathbf{x}}$ for the gradient of a scalar function. (For the differential of a vector-valued function, we will use either $\mathrm{D}_{\mathbf{x}}$ or ${\boldsymbol{\partial}}/{\boldsymbol{\partial}\mathbf{x}}$.)

For a parametrized model with pdf denoted by $p(\mathbf{x}; \boldsymbol{\theta})$, or $p(\mathbf{x} | \boldsymbol{\theta})$, and parameters $\boldsymbol{\theta} = (\theta_1, \ldots, \theta_m),$ $m\in \mathbb{N}$, the score function becomes

\[ \boldsymbol{\psi}(\mathbf{x}; \boldsymbol{\theta}) = \boldsymbol{\nabla}_{\mathbf{x}}p(\mathbf{x}; \boldsymbol{\theta}) = \left( \frac{\partial}{\partial x_j} p(\mathbf{x}; \boldsymbol{\theta})\right)_{j=1, \ldots, d}.\]

In the univariate case, the score function is also univariate and is given by the derivative of the log of the pdf. For example, for a univariate Normal distribution $\mathcal{N}(\mu, \sigma^2)$, $\mu\in\mathbb{R}$, $\sigma > 0$, the pdf, logpdf and gradlogpdf are

\[ \begin{align*} p_X(x) & = \frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{1}{2}\left(\frac{x - \mu}{\sigma}\right)^2}, \\ \log p_X(x) & = -\frac{1}{2}\left(\frac{x - \mu}{\sigma}\right)^2 - \log(\sqrt{2\pi}\sigma), \\ \psi_X(x) & = - \frac{x - \mu}{\sigma^2}. \end{align*}\]

Notice the score function in this case is just a linear function vanishing at the mean of the distribution and with the slope being minus the multiplicative inverse of its variance.

Example block output

In the multivariate case, the score function is a vector field in the event space $\mathbb{R}^d$.

Example block output

This notion of score function used in generative models in machine learning is different from the more classical notion of score in Statistics. The classical score function is defined for a parametrized model and refers to the gradient of the log-likelyhood

\[ \ell(\boldsymbol{\theta}|\mathbf{x}) = \log\mathcal{L}(\boldsymbol{\theta}|\mathbf{x}) = p(\mathbf{x}|\boldsymbol{\theta}),\]

of a parametrized model, with respect to the parameters, i.e.

\[ s(\boldsymbol{\theta}; \mathbf{x}) = \boldsymbol{\nabla}_{\boldsymbol{\theta}}\log \mathcal{L}(\boldsymbol{\theta}|\mathbf{x}) = \frac{\boldsymbol{\partial}\log \mathcal{L}(\boldsymbol{\theta}|\mathbf{x})}{\boldsymbol{\partial}\boldsymbol{\theta}}.\]

This notion measures the sensitivity of the model with respect to changes in the parameters and is useful, for instance, in the maximization of the likelyhood function when fitting a parametrized distribution to data.

The score function given by the gradlogpdf of a distribution is, on the other hand, useful for drawing samples via Langevin dynamics.

Stein divergence

Stein (1972) addressed a more general framework to estimate distances between distributions with the aim of approximating the sum of dependent random variables by a normal distribution, in a generalization of the Central Limit Theorem. In a particular case, as described in Liu, Lee, and Jordan (2016), this distance involves the Stein score function and reads as follows.

If $p=p(\mathbf{x})$ is a probability density function on $\mathbf{x}\in\mathbb{R}^d$, then, for any smooth scalar function $f(\mathbf{x})$ decaying sufficiently fast relative to $p(\mathbf{x})$,

\[ \begin{align*} \int_{\mathbb{R}^d} p(\mathbf{x})\left( \boldsymbol{\nabla}_{\mathbf{x}} \log p(\mathbf{x})f(\mathbf{x}) + \boldsymbol{\nabla}_{\mathbf{x}} f(\mathbf{x}) \right)\;\mathrm{d}\mathbf{x} & = \int_{\mathbb{R}^d} p(\mathbf{x})\left( \frac{\boldsymbol{\nabla}_{\mathbf{x}} p(\mathbf{x})}{p(\mathbf{x})} f(\mathbf{x}) + \boldsymbol{\nabla}_{\mathbf{x}} f(\mathbf{x}) \right)\;\mathrm{d}\mathbf{x} \\ & = \int_{\mathbb{R}^d} \left(\boldsymbol{\nabla}_{\mathbf{x}} p(\mathbf{x}) f(\mathbf{x}) + p(\mathbf{x})\boldsymbol{\nabla}_{\mathbf{x}} f(\mathbf{x}) \right)\;\mathrm{d}\mathbf{x} \\ & = \int_{\mathbb{R}^d} \boldsymbol{\nabla}_{\mathbf{x}} \left(p(\mathbf{x}) f(\mathbf{x})\right)\;\mathrm{d}\mathbf{x} \\ & = 0. \end{align*}\]

This is a particular case of the Stein identity. Now if $q=q(\mathbf{x})$ is another probability density function, then

\[ \begin{align*} \int_{\mathbb{R}^d} p(\mathbf{x})\left( \boldsymbol{\nabla}_{\mathbf{x}} \log q(\mathbf{x})f(\mathbf{x}) + \boldsymbol{\nabla}_{\mathbf{x}} f(\mathbf{x}) \right)\;\mathrm{d}\mathbf{x} & = \int_{\mathbb{R}^d} p(\mathbf{x})\boldsymbol{\nabla}_{\mathbf{x}} \log \left(q(\mathbf{x}) - p(\mathbf{x}) \right)f(\mathbf{x}) \;\mathrm{d}\mathbf{x}, \end{align*}\]

and we see that

\[ \int_{\mathbb{R}^d} p(\mathbf{x})\left( \boldsymbol{\nabla}_{\mathbf{x}} \log q(\mathbf{x})f(\mathbf{x}) + \boldsymbol{\nabla}_{\mathbf{x}} f(\mathbf{x}) \right)\;\mathrm{d}\mathbf{x} = 0, \;\forall f\in\mathcal{F}, \quad \textrm{if, and only if,} \quad q = p, \;\textrm{a.e.,}\]

where $\mathcal{F}$ is taken to be the class of function $f=f(\mathbf{x})$ which are smooth and decay relatively fast with respect to $p=p(\mathbf{x})$ (or, more generally, a subset of that which is relatively large in a suitable sense). Based on that, the Stein discrepancy measure is defined originally as

\[ \mathbb{S}_{\mathcal{F}}^1(p, q) = \max_{f\in\mathcal{F}}\mathbb{E}_p\left[\boldsymbol{\nabla}_{\mathbf{x}} \log q(\mathbf{x})f(\mathbf{x}) + \boldsymbol{\nabla}_{\mathbf{x}} f(\mathbf{x})\right],\]

but in other works such as Liu, Lee, and Jordan (2016) the Stein discrepancy measure is taken with the square of the expectation:

\[ \mathbb{S}_{\mathcal{F}}^2(p, q) = \max_{f\in\mathcal{F}}\mathbb{E}_p\left[\boldsymbol{\nabla}_{\mathbf{x}} \log q(\mathbf{x})f(\mathbf{x}) + \boldsymbol{\nabla}_{\mathbf{x}} f(\mathbf{x})\right]^2.\]

This is not usually computationally tractable and is not often used in practice, when $\mathcal{F}$ is such a large class of functions. Liu, Lee, and Jordan (2016), however, proposed working with a particular subset $\mathcal{F}$, defined by a ball in a reproducing kernel Hilbert space, for which the discrepancy can be computed via

\[ \mathbb{S}_{\mathcal{F}}^2(p, q) = \mathbb{E}_{\mathbf{x}, \mathbf{x}' \sim p}\left[u_q(\mathbf{x}, \mathbf{x}')\right],\]

where $\mathbf{x}, \mathbf{x}'$ are independently draw from $p$, and $u_q$ is a function involving the (Stein) score of $q$ and a suitable (Stein) kernel.

We do not get into more details here since this moves away from our objective. The aim here was just to mention the context in which the (Stein) score function was brought to relevance, before starting to be used for generative methods.

Score function in the Julia language

The distributions and their pdf are obtained from the JuliaStats/Distributions.jl package. The score function is also implemented in JuliaStats/Distributions.jl as gradlogpdf, but only for some distributions. Since we are interested on Gaussian mixtures, we did some pirating and extended Distributions.gradlogpdf to MixtureModels, both univariate and multivariate.

Consider a mixture model with pdf given by

\[ p(\mathbf{x}) = \alpha_1 p_1(\mathbf{x}) + \cdots + \alpha_k p_k(\mathbf{x}),\]

where $0 \leq \alpha_i \leq 1$, $\sum_i \alpha_i = 1$, and each $p_i(\mathbf{x})$ is a PDF of a distribution. If each $p(\mathbf{x})$ is supported on the whole space $\mathbb{R}^d$, then

\[ \begin{align*} \boldsymbol{\nabla}_{\mathbf{x}} \log p(\mathbf{x}) & = \frac{1}{p(\mathbf{x})}\boldsymbol{\nabla}_{\mathbf{x}} p(\mathbf{x}) \\ & = \frac{1}{p(\mathbf{x})}\boldsymbol{\nabla}_{\mathbf{x}} \left( \alpha_1 p_1(\mathbf{x}) + \cdots + \alpha_k p_k(\mathbf{x}) \right) \\ & = \frac{1}{p(\mathbf{x})}\boldsymbol{\nabla}_{\mathbf{x}} \left( \alpha_1 \boldsymbol{\nabla}_{\mathbf{x}} p_1(\mathbf{x}) + \cdots + \alpha_k\boldsymbol{\nabla}_{\mathbf{x}} p_k(\mathbf{x}) \right). \end{align*}\]

This would be sufficient if each gradpdf were implemented for the distributions in JuliaStats/Distributions.jl. But unfortunately it is not. What we can do then is to assume that each distribution $p_i(\mathbf{x})$ is also supported on the whole $\mathbb{R}^d$ and use that

\[ \boldsymbol{\nabla}_{\mathbf{x}} p_i(\mathbf{x}) = p_i(\mathbf{x})\boldsymbol{\nabla}_{\mathbf{x}} \log p_i(\mathbf{x}).\]

In this case, we have the identity

\[ \boldsymbol{\nabla}_{\mathbf{x}} \log p(\mathbf{x}) = \frac{1}{p(\mathbf{x})}\boldsymbol{\nabla}_{\mathbf{x}} \left( \alpha_1 p_1(\mathbf{x})\boldsymbol{\nabla}_{\mathbf{x}} \log p_1(\mathbf{x}) + \cdots + \alpha_k p_k(\mathbf{x})\boldsymbol{\nabla}_{\mathbf{x}} \log p_k(\mathbf{x}) \right).\]

Assuming the Stein score function is implemented for each distribution, we write the score $s_p(\mathbf{x})$ of the mixture model in terms of the score $s_{p_i}(\mathbf{x})$ of each distribution as

\[ s_p(\mathbf{x}) = \frac{1}{p(\mathbf{x})}\left(\alpha_1 p_1(\mathbf{x})s_1(\mathbf{x}) + \cdots + \alpha_k p_k(\mathbf{x})s_k(\mathbf{x})\right).\]

These are the codes for that.

function Distributions.gradlogpdf(d::UnivariateMixture, x::Real)
    ps = probs(d)
    cs = components(d)
    ps1 = first(ps)
    cs1 = first(cs)
    pdfx1 = pdf(cs1, x)
    pdfx = ps1 * pdfx1
    glp = pdfx * gradlogpdf(cs1, x)
    if iszero(ps1)
        glp = zero(glp)
    end
    @inbounds for (psi, csi) in Iterators.drop(zip(ps, cs), 1)
        if !iszero(psi)
            pdfxi = pdf(csi, x)
            if !iszero(pdfxi)
                pipdfxi = psi * pdfxi
                pdfx += pipdfxi
                glp += pipdfxi * gradlogpdf(csi, x)
            end
        end
    end
    if !iszero(pdfx) # else glp is already zero
        glp /= pdfx
    end 
    return glp
end
function Distributions.gradlogpdf(d::MultivariateMixture, x::AbstractVector{<:Real})
    ps = probs(d)
    cs = components(d)

    # `d` is expected to have at least one distribution, otherwise this will just error
    psi, idxps = iterate(ps)
    csi, idxcs = iterate(cs)
    pdfx1 = pdf(csi, x)
    pdfx = psi * pdfx1
    glp = pdfx * gradlogpdf(csi, x)
    if iszero(psi)
        fill!(glp, zero(eltype(glp)))
    end
    
    while (iterps = iterate(ps, idxps)) !== nothing && (itercs = iterate(cs, idxcs)) !== nothing
        psi, idxps = iterps
        csi, idxcs = itercs
        if !iszero(psi)
            pdfxi = pdf(csi, x)
            if !iszero(pdfxi)
                pipdfxi = psi * pdfxi
                pdfx += pipdfxi
                glp .+= pipdfxi .* gradlogpdf(csi, x)
            end
        end
    end
    if !iszero(pdfx) # else glp is already zero
        glp ./= pdfx
    end 
    return glp
end

References

  1. C. Stein (1972), "A bound for the error in the Normal approximation to the distribution of a sum of dependent random variables", Proceedings of the Sixth Berkeley Symposium on Mathematical Statistics and Probability, 583-602
  2. Q. Liu, J. Lee, M. Jordan (2016), "A kernelized Stein discrepancy for goodness-of-fit tests", Proceedings of The 33rd International Conference on Machine Learning, PMLR 48, 276-284