2020年7月15日 星期三

Siamese Neural Network 以及 Triplet Network

機器學習中有一類問題稱為 One-shot Learning 或 Few-shots Learning,意思是每個類別中的資料量都非常少,甚至一個類別只有一個樣本。打個比方來說 Omniglot dataset [1] 之中有 1623 個來自 50 種語言的字母,但是每個字母卻只有 20 個樣本。因此有人說 Omniglot 是 MNIST dataset 的轉置,因為在 MNIST 中總共 10 個類別,每個類別都有幾千個樣本。本篇文章的目的是簡介一種解決這個機器學習的方法:Siamese Neural Network 以及相關的 Triplet Network。

Siamese Neural Network

Siamese Neural Network 的概念很簡單:訓練的是任意兩個樣本的距離。拿 MNIST 手寫數字辨識來說,就是當兩張圖都為同一個數字時,我們希望訓練出來得到的距離是 0,而兩張圖是不同數字時,我們希望訓練出來得到的距離是 1。

從 Siamese Neural Networks for One-shot Image Recognition [3] 這篇文章的圖可以更容易地解釋概念:

Siamese Neural Network

\(x_1\) 與 \(x_2\) 為兩個樣本,中間各自經過一些 hidden layers,最後經由一層 distance layer 輸出兩個樣本之間的距離 p。在這篇文章中的輸入是兩張圖片,hidden layers 為 convolutional neural network,distance layer 為以下式子: \[ \sigma(\sum_j \alpha_j \left | h_{1,L-1}^j - h_{2,L-1}^j \right |) \]意思是先將兩者 distance layer 輸入相減、取 L1 norm、乘上一個學到的參數 \(\sigma_j\)、最後再由 sigmoid 函數轉換成一個 0 到 1 的數字。而本篇文章用的 loss function 是常見的 cross-entropy loss。


Triplet Network

Triplet Network [4] 的輸入是一次拿三個樣本 \(x, x^+, x^-\),目標是同時學習到相同類別的距離要小,不同類別的距離要大。這個方法的精神是一次同時從兩個距離中學習,而不是像上面的方法一次只能利用一個距離來學習。以下是這篇文章的概念圖:

Triplet Network

最後來簡介一下 Triplet Network 的 loss function: \[ Loss(d_+,d_-)=\left \| d_+,d_- - 1 \right \|^2_2 \\ d_+ = \frac{exp(\left \| Net(x) - Net(x^+) \right \|_2)}{exp(\left \| Net(x) - Net(x^+) \right \|_2) + exp(\left \| Net(x) - Net(x^-) \right \|_2)} \\ d_- = \frac{exp(\left \| Net(x) - Net(x^-) \right \|_2)}{exp(\left \| Net(x) - Net(x^+) \right \|_2) + exp(\left \| Net(x) - Net(x^-) \right \|_2)} \] 訓練目標是要讓 \( \frac{\left \| Net(x) - Net(x^+) \right \|}{\left \| Net(x) - Net(x^-) \right \|} \) 趨近於 0,此時 \(Loss(d_+,d_-)\) 也會趨近於 0。

沒有留言:

張貼留言