2020年8月19日 星期三

Variational Inference 簡介

Variational Inference 是用來近似複雜機率分布的方法。最簡單的概念是我們想要依照現有的資料 x 來估計 x 背後的機率分布 p(x),但是當 p(x) 太複雜或是不容易表達時,我們便改成尋找容易表達及分解(或積分)的機率分布 q(比如說高斯分布),並且最佳化 p 與 q 的距離,當 p 與 q 非常接近時,我們就把 q 當成 p 的近似機率分布。本篇文章目的是簡介 Variational Inference 背後的數學原理,以及如何用在生成模型 generative model 上(也就是 Variational Auto-encoder)。

隱變量模型 Latent Variable Models

在這個機率模型中我們用 x 代表觀測到的資料,也可以說是可觀測的隨機變數;用 z 代表無法觀測的隨機變數,而在此模型中隱藏的隨機變數 z 會影響觀測到的資料 x。拿混合高斯模型打個比方,假設這個機率模型是由 K 的高斯分布所組成的,並假設這 K 個高斯分布的參數為 \(\mu=\{\mu_1, \cdots , \mu_K\}\)。而 x 有 N 個樣本,每個樣本生成時我們都先選定 K 個高斯分布中的其中一個,再將這個樣本生成出來。描述這件事情的數學式為 [1]: \[ \mu_k \sim N(0, \sigma^2) \\ c_i \sim Categorical(\frac{1}{K}, \cdots, \frac{1}{K}) \\ x_i | c_i, \mu \sim N(c_i^T \mu, 1) \] 因此在這個例子中,\(\mu\) 與 \(c_i\) 都為隱變量 z,因為我們無法觀測到它們。我們能觀測到的就只有 N 筆資料 \(x_i\)。下圖 [2] 為此例子的示意圖:
 
Laten Variable Model Example

 

為什麼計算 p(x) 會太困難?

隱變量模型的 joint probability distribution (聯合分布)p(x, z) 是當機率模型設計出來時便已知的,也就是說只要有像上面的示意圖,我們就一定能寫出對應的 joint probability distribution。比如說上面混合高斯模型例子的 joint probability distribution 為: \[ p(\mu, c, x) = p(\mu)\prod_{i=1}^{N}p(c_i)p(x_i|c_i, \mu) \] 其中 \(p(\mu)\) 為這組參數出現的機率,\(p(c_i)\) 是此樣本被分配到第 i 個高斯分布的機率,而 \(p(x_i|c_i, \mu)\) 為高斯分布的機率。 
 
 上式中 \(\mu\) 與 c 為隱變量,而我們想要求的是 p(x),因此得將上式積分: \[ p(x) = \int p(\mu) \prod_{i=1}^{N} \sum_{c_i} p(c_i)p(x_i|c_i, \mu) d\mu \] 計算此積分的複雜度太高了,無法直接算出 p(x),因此下面介紹的 Variational Inference 便是用來近似 p(x) 的方法。

Variational Inference

從最大似然估計的觀點(請參考前文)來看,我們想要最佳化的是以下似然函數 likelihood function: \[ log(p(x)) = \sum_1^N log\ p_\theta (x_i) \] 此時我們引入隨機變數 z 以及其對應的機率分布 q(z) 來幫助計算,q(z) 如前面所說可以是任何的機率分布,只要它容易表達及分解即可。我們這邊先寫出數學式子推導過程再說明其意義: \[ log(p(x)) = \int_z q(z) log(p(x)) dz \\ = \int_z q(z) log(\frac{p(z,x)}{p(z|x)}) dz = \int_z q(z) log(\frac{p(z,x)}{q(z)} \frac{q(z)}{p(z|x)})dz \\ = \int_z q(z) log(\frac{p(z,x)}{q(z)})dz + \int_z q(z) log(\frac{q(z)}{p(z|x)})dz \] 上面第一個等式是由於 q(z) 為任意的機率分布,積分後等同於計算 log(p(x)) 的期望值,也就是 log(p(x));第二個等式是條件機率的代換,而第三個等式為一個小技巧。上式的結果可以拆成兩項,我們可以看出第二項便是 q(z) 與 p(z|x) 的 KL Divergence:\(D_{KL}(q(x)||p(z|x))\)(請參考前文)。 而我們已知 KL Divergence 一定大於等於 0,因此可以將上式改成以下不等式: \[ log(p(x)) \geq \int_z q(z) log(\frac{p(z,x)}{q(z)})dz = L_b \] 我們稱不等式中右邊這一項 \(L_b\) 為 Evidence Lower Bound (ELBO)。因此 Variational Inference 的精神便是不斷地調整 q(z) 使得 ELBO 不斷提升,我們可以用 EM 演算法來調整 q(z), 細節請參考文章 [3]。另外如果使用 Importance Sampling 也可以推導出一樣的結果,推導過程也可以在同一篇參考資料中找到。

參考資料

[1]  Variational Inference: A Review for Statisticians, David Blei, Alp Kucukelbir, Jon McAuliffe, 2018

沒有留言:

張貼留言