本文為機器學習系統設計的好書 Machine Learning Design Patterns 中介紹的第一個模式:Hashed Feature 的筆記。
用 One-hot encoding 可能會碰到的問題
一般在處理 categorical 的資料時(比如說要預測飛機會不會準時抵達時,出發跟目的地的機場代碼就是屬於 categorical 資料),我們用 one-hot encoding 來表示。假設現在有三個機場 A、B、C,我們可以用 [1, 0, 0] 來表示機場 A。而當類別太多的時候在機器學習的系統中會產生三個問題:
- 必須先看過所有的資料才能列舉出所有的類別,但是在資料太多的情況下只能用 random sampling 來列舉資料,那麼列舉出來的類別可能不完整,無法表示整組資料。
- 當類別數目大的時候,此向量的維度就會很大,因此模型的參數數目也得跟著增加,但是通常我們沒有這麼多資料。
- 當有新的類別產生後,我們得重新用 one-hot encoding 來表示新的資料。
Hashed Feature 的解法是將類別對應的字串(例如機場代碼)當成 hash 函數的輸入,利用一個 deterministic 及 portable 的演算法產生一個 hash 結果,再利用 modulo 來將輸入資料分至其中一個 bucket。(官方 Github 有一個例子可供參考)
Hashed Feature 如何解決上面三個問題
- 任何的機場都能透過 hash 函數及 modulo 找出其對應的 bucket,因此對於任何的資料都不會出錯。以上面的例子來說,如果分成 10 個 bucket 的話,那平均每個 bucket 會有 35 個機場(因為整個資料中有 347 個不同的機場代碼)。
- 當 bucket 的數目變小以後,one-hot encoding high cardinality 的問題也就消失了。
- 跟上述問題 (1) 一樣,當有新的資料類別出現時直接用此 hash 函數及 module 就可以找出對應的 bucket。而書中建議是先嘗試讓每個 bucket 包含大概五種類別,這樣通常可以達到不錯的效果。
其他要考慮的設計問題
- Bucket collision:利用 hash 函數及 module 代表會有不同的類別分到同樣的 bucket,這也就是常見的生日問題。
- Skew:當好幾個類別分到同一個 bucket 時,有可能某些類別的資料量遠大於其他的。比如說機場 A 與 B 都分到同一個 bucket,但是這個 bucket 中有 95% 的資料都是機場 A 的,只有 5% 是機場 B 的,因此機器學習模型很有可能將 A 預測得很準,但有關於 B 的都是錯的。
- Aggregate feature:用來解決上面提到的問題,當分到同一個 bucket 的資料類別數量差距懸殊時,我們可以用多一個 feature 來當作輸入,比如說機場 A 在整個 dataset 中佔 19%,而 B 佔 1%,我們可以將 0.19 及 0.01 當成另一個輸入 feature。
- Cryptographic hash:我們要考慮怎麼樣的 hash 函數才是適合的。在機器學習系統中我們需要的是 deterministic 及 portable,讓我們對於同樣的資料總是能得到相同的 hash 結果。這跟密碼系統中的 hash 函數的要求不太一樣,一般密碼系統要求的是 uniformly distributed,但不保證一定要 deterministic。
- 要注意 hash 函數結果可能會 overflow。
沒有留言:
張貼留言