拯救Transformer推理能力!DeepMind新研究TransNAR:給模型嵌入「算法推理大腦」

0 評論 860 瀏覽 1 收藏 19 分鐘

DeepMind最近發(fā)表的一篇論文提出用混合架構(gòu)的方法解決Transformer模型的推理缺陷。將Transformer的NLU技能與基于GNN的神經(jīng)算法推理器(NAR)的強(qiáng)大算法推理能力相結(jié)合,可以實(shí)現(xiàn)更加泛化、穩(wěn)健、準(zhǔn)確的LLM推理。

如今的NLP領(lǐng)域,已然是Transformer架構(gòu)的天下。

從Bert到GPT,再到Llama、Claude,LLM模型使用Transformer已經(jīng)是再正常不過的事情。

Transformer的「大一統(tǒng)」局面正是由于其簡單、高效的架構(gòu),以及在理解自然語言方面無與倫比的泛化能力。

然而,隨著研究的逐漸深入,Transformer的一個致命缺陷也逐漸暴露出來——無法勝任算法推理任務(wù),尤其是不能進(jìn)行精確、穩(wěn)健的推理。

這嚴(yán)重限制了模型在數(shù)學(xué)、代碼等領(lǐng)域下游任務(wù)的應(yīng)用,近年來對Transformer的各種調(diào)優(yōu)、修改似乎也收效甚微。

于是DeepMind的研究人員想到了混合架構(gòu)——將Transformers的語言理解能力與基于圖神經(jīng)網(wǎng)絡(luò)(GNN)的神經(jīng)算法推理器(NAR)的穩(wěn)健性結(jié)合起來,提升其算法推理能力。

他們最近在arxiv上的一篇論文就提出了這個名為TransNAR的架構(gòu),但遺憾的是,目前還沒有公布源代碼。

論文地址:https://arxiv.org/abs/2406.09308

神經(jīng)算法推理(NAR)由本文作者之一Petar Veleckovic在2021年與人合著的一篇論文中提出,并被接收為Patterns期刊的opinion paper。

論文地址:https://arxiv.org/abs/2105.02761

NAR被稱為「構(gòu)建能執(zhí)行算法的神經(jīng)網(wǎng)絡(luò)的藝術(shù)」。作者提出,算法與深度學(xué)習(xí)的本質(zhì)不同,但如果神經(jīng)網(wǎng)絡(luò)能夠更好地模仿算法,它甚至可能具備算法的強(qiáng)泛化性。

更進(jìn)一步,神經(jīng)網(wǎng)絡(luò)若能表示出算法中連續(xù)空間內(nèi)的元素,就會使已知算法更接近現(xiàn)實(shí)世界的問題,提出的解決方案可能超過人類科學(xué)家。

如上圖所示,NAR的整體想法是訓(xùn)練出一個高維隱空間中的處理器網(wǎng)絡(luò)P(processor network),旨在不斷逼近算法的運(yùn)行結(jié)果A(x)。

但由于算法的輸入和輸出一般是圖、樹、矩陣等抽象、結(jié)構(gòu)化的形式,這與深度學(xué)習(xí)模型高維、嘈雜且多變的輸入很不兼容,因此還需要訓(xùn)練編碼器f和解碼器g,將抽象形式轉(zhuǎn)換為自然形式。

NAR發(fā)布后,有多項(xiàng)研究證實(shí)了它有同時執(zhí)行多種算法的能力,也能部署在各種下游任務(wù)中。更重要的是,它的泛化能力似乎遠(yuǎn)遠(yuǎn)優(yōu)于Transformer架構(gòu)。

原則上,NAR可以擴(kuò)展到比訓(xùn)練數(shù)據(jù)的分布大幾個數(shù)量級的系統(tǒng)上,有時這個數(shù)量級能達(dá)到1.8萬倍。

在使用適當(dāng)?shù)臍w納偏差(inductive biases)時,即使輸入比訓(xùn)練集大6倍,NAR也能在高度復(fù)雜的算法任務(wù)中保持完美的泛化能力。

找到了Transformer和NAR這兩種十分強(qiáng)大且各有所長的架構(gòu),下面最關(guān)鍵的問題就是如何進(jìn)行相應(yīng)的調(diào)整和修改,使這兩個似乎完全不相容的模型真正實(shí)現(xiàn)溝通和Embedding交換。

TransNAR:用預(yù)訓(xùn)練NAR增強(qiáng)Transformer

如何實(shí)現(xiàn)NAR+Transformer的有效溝通?作者從多模態(tài)LLM中找到了靈感。

多模態(tài)LLM可以同時接收文本和圖像兩種模態(tài)的輸入,TransNAR也是如此。一邊是算法運(yùn)行需要的圖結(jié)構(gòu),一邊是描述問題的自然語言。

作者的設(shè)想是,將預(yù)訓(xùn)練的NAR作為Transformer中編碼的調(diào)制器(modulator),二者通過embedding溝通,同時借鑒VLM和Flamingo模型中所用的交叉注意算子,融合不同模態(tài)的信息。

TransNAR接受雙重輸入,包括文本形式的算法問題規(guī)范(T個token)及其對應(yīng)的圖表征(N個節(jié)點(diǎn)),并輸出問題的文本答案。其中輸入的圖表征遵循算法推理基準(zhǔn)CLRS-30的格式。

我們可以假設(shè),編碼完成后,文本輸入存儲在T ∈ R^(T×k)中,圖輸入存儲在G ∈ R^(N×l)中。

TransNAR的前向傳播過程如下:

首先,我們通過設(shè)置T^(0) = T和G^(0) = G來正確初始化輸入。

接下來,為了計算第(t+1)步的表征,文本(token)表征被輸入到Transformer的當(dāng)前層:

其中,Qt,Kt ∈ Rk×d_k,Vt ∈ Rk×k分別是鍵、查詢和值矩陣的變換,F(xiàn)FN是一個前饋神經(jīng)網(wǎng)絡(luò)。

以類似的方式,圖表征被輸入到NAR層,例如實(shí)現(xiàn)一個標(biāo)準(zhǔn)的max-MPNN:

其中,ψ,? : Rk × Rk → Rk分別是可學(xué)習(xí)的消息函數(shù)和更新函數(shù),max是逐元素最大值聚合。

需要注意的是,方程2僅簡要提供了節(jié)點(diǎn)之間的成對交互——實(shí)際上,這里的NAR是一個Triplet-GMPNN,它還包含三元組交互和一個門控機(jī)制。

此外,還需注意,NAR的可學(xué)習(xí)部分沒有時間步索引——每一步都應(yīng)用相同的共享函數(shù)。這很好地契合了圖算法計算的迭代和重復(fù)性質(zhì)。

一旦兩個流都準(zhǔn)備好它們的表征Θt+1和Gt+1,圖中的節(jié)點(diǎn)嵌入將對Transformer的token嵌入進(jìn)行條件設(shè)置,從而產(chǎn)生Transformer流中TransNAR塊的最終結(jié)果:

其中,Qt×,Kt× ∈ Rk×d_k, Vtx ∈ Rk×k分別是交叉注意力的鍵、查詢和值變換。在結(jié)束這一層之前,對Gt+1不進(jìn)行額外的變換。

這個過程會一直重復(fù),直到最后的第Nl層,在這一層中,從TN_l讀取最終的文本輸出。

最終輸出通過最后一層生成的預(yù)測頭轉(zhuǎn)換為token logits,并通過標(biāo)準(zhǔn)的下一個token預(yù)測來監(jiān)督訓(xùn)練。

在開始TransNAR微調(diào)之前,首先預(yù)訓(xùn)練NAR,使其能夠穩(wěn)健地執(zhí)行CLRS-30覆蓋的三十個算法。這種方法已知可以在圖空間中實(shí)現(xiàn)高達(dá)4倍輸入規(guī)模的分布外泛化。

在微調(diào)過程中,NAR的參數(shù)通常保持凍結(jié)狀態(tài),因?yàn)轭~外的梯度會削弱模型的原有穩(wěn)健性特性。同樣的原因,圖嵌入不會執(zhí)行交叉注意力。

LLM本身可以在大規(guī)模數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練,以建立其一般語言先驗(yàn),即使在開始時隨機(jī)初始化LM,也能獲得相同的實(shí)驗(yàn)結(jié)果。

實(shí)驗(yàn)設(shè)置

在實(shí)驗(yàn)中,作者展示了TransNAR為大語言模型架構(gòu)中的分布外推理帶來的顯著優(yōu)勢。Transformer架構(gòu)和初始化

論文使用Chinchilla家族的一個decoder-only架構(gòu)、6層的Transformer模型,首先在MassiveText上進(jìn)行了預(yù)訓(xùn)練,參數(shù)量有70M,上下文大小為2048。

為了探究初始化設(shè)置的影響,作者設(shè)計了兩個變體進(jìn)行消融實(shí)驗(yàn)。

第一個變體中,Transformer權(quán)重用預(yù)訓(xùn)練的結(jié)果初始化,模擬微調(diào)場景;第二個變體則是完全隨機(jī)的初始化。這兩個模型分別被標(biāo)記為「預(yù)訓(xùn)練」和「未訓(xùn)練」。隨機(jī)位置編碼

之前DeepMind的一篇論文論證過,隨機(jī)位置編碼可以增強(qiáng)Transformer的長度泛化與推理穩(wěn)健性。

論文地址:https://arxiv.org/abs/2305.16843

作者也提到,隨機(jī)位置嵌入確實(shí)在基線模型和TransNAR上都帶來了顯著增益,因此本文中的所有實(shí)驗(yàn)也都使用隨機(jī)位置嵌入。預(yù)訓(xùn)練NAR

論文使用CLRS-30基準(zhǔn)中的問題預(yù)訓(xùn)練了一個多任務(wù)、基于MPNN的NAR,輸入問題規(guī)模最多達(dá)16個。

由于CLRS-30的標(biāo)準(zhǔn)圖結(jié)構(gòu)表達(dá),這樣訓(xùn)練出來的NAR有很強(qiáng)的分布外(OOD)泛化能力,有時在4倍大小的圖上仍保持競爭力,這種豐富的知識表達(dá)正是文本模型可資利用的。結(jié)合節(jié)點(diǎn)和邊緣的跨注意力貢獻(xiàn)

在上述的算法描述中,我們將NAR模型的圖輸入限于N個節(jié)點(diǎn),但作者注意到了之前的研究曾嘗試過,同時對圖的節(jié)點(diǎn)和邊生成隱變量表達(dá),也許可以添加有用的互補(bǔ)信息。

于是實(shí)驗(yàn)中引入圖中邊的特征E(t) ∈ RN×N×k,并再次應(yīng)用公式3讓Θ(t)對E(t)進(jìn)行交叉注意力。

作者也嘗試其他方法,希望將E(t)和G(t)結(jié)合起來,比如拼接后加線性層組合、向量求和、2層MLP,或者用Gram-Schmidt過程使二者的貢獻(xiàn)正交化,但這些都沒有給原始方法帶來提升。數(shù)據(jù)集

訓(xùn)練數(shù)據(jù)使用CLRS-Text基準(zhǔn),即CLRS-30基準(zhǔn)的文本版本,以確定性的方式直接從基于圖的CLRS-30中派生,因此這兩個數(shù)據(jù)集傳達(dá)的是完全相同的信息。

表1展示了該數(shù)據(jù)集的幾個樣本,以及它們的輸入大小和token數(shù)量。

由于語言模型上下文長度的限制,實(shí)驗(yàn)選擇用規(guī)模為4、8、12的問題訓(xùn)練,并在規(guī)模為110、12、14的問題上評估。

值得注意的是,與當(dāng)前的評估環(huán)境相比,CLRS-Text是對LM最具挑戰(zhàn)性的長程推理任務(wù)之一——相比小學(xué)數(shù)學(xué),復(fù)雜度顯著提高。

CLRS-Text的挑戰(zhàn)性主要源于它允許顯式控制分布外泛化。然而,每個問題都有清晰的多項(xiàng)式時間解法,這意味當(dāng)今典型LLM的參數(shù)量應(yīng)該足以解決這些問題。

該數(shù)據(jù)集每種算法的每種輸入規(guī)模包含一萬個樣本,總共240萬個數(shù)據(jù)點(diǎn),其中70%用于訓(xùn)練、30%用于驗(yàn)證。

訓(xùn)練細(xì)節(jié)

實(shí)驗(yàn)將batch大小設(shè)置為256訓(xùn)練了7個epoch,并使用Adam優(yōu)化器,學(xué)習(xí)率為10-4。

如前所述,在所有Chinchilla Transformer的旋轉(zhuǎn)位置編碼(RoPE)之上應(yīng)用隨機(jī)位置編碼,最大長度為8192,且訓(xùn)練期間保持NAR凍結(jié)。評估指標(biāo)

作者提出,合適的評估指標(biāo)應(yīng)該反映模型在特定樣本上失敗的原因,且需要度量型輸出與正確答案的接近程度。因此,使用精確字符串匹配來計算模型準(zhǔn)確性是絕對不可行的。

論文選擇的性能指標(biāo)包括以下三個:

1. 形狀分?jǐn)?shù):一個二元指標(biāo),用于判斷輸出是否具有正確的形狀。例如,在排序任務(wù)中,輸出應(yīng)與輸入有完全相同的元素數(shù)量?;蛘?,如果輸出是一個矩陣,我們需要確保其形狀與輸入和任務(wù)一致。

2. 解析分?jǐn)?shù):一個二元指標(biāo),用于判斷輸出是否不含任何非法字符。例如,在對數(shù)字列表進(jìn)行排序的任務(wù)中,輸出不應(yīng)包含任何字母。

3. CLRS分?jǐn)?shù):輸出中與真實(shí)答案匹配的元素百分比,也常用于CLRS-30測試。形狀分?jǐn)?shù)為0時,CLRS分?jǐn)?shù)也會自動置零。

這種多方面的指標(biāo)設(shè)計能夠捕捉到LLM在文本上進(jìn)行推理任務(wù)的各種失敗模式。

比如在某個問題規(guī)模上過度專門化訓(xùn)練(導(dǎo)致輸出的形狀不正確)、無法處理看不見的數(shù)字組合(導(dǎo)致解析錯誤),由于推理錯誤造成的答案不一致則由CLRS分?jǐn)?shù)反映。

結(jié)果

實(shí)驗(yàn)結(jié)果顯示,TransNAR整體上顯著優(yōu)于Transformer模型,在動態(tài)規(guī)劃、幾何、圖、貪心算法、排序、字符串等任務(wù)上的OOD推理能力都有大幅提升。

并且在大多數(shù)單個算法上,無論是在分布內(nèi)還是分布外都表現(xiàn)更佳。

特別值得注意的是,這種方法不僅增強(qiáng)了Transformer原有的OOD泛化能力,還激發(fā)了一些模型先前完全不具備的能力。

比如Graham掃描(graham_scan)、最長公子串長度(lcs_length)、強(qiáng)連通分量(scc)這些經(jīng)典問題中,基線模型得分為零或接近零,但TransNAR卻實(shí)現(xiàn)了突破。

分析形狀分?jǐn)?shù)可以進(jìn)一步解釋,為什么TransNAR表現(xiàn)如此出色。

首先,回顧一下,如果形狀不匹配,CLRS得分必然為零。

從形狀得分來看,將Transformer的輸出建立在NAR嵌入基礎(chǔ)上顯著提高了答案中形狀正確的比例——這表明TransNAR緩解了一種特定的LLM故障模式。

此外,通過對比「預(yù)訓(xùn)練」和「未訓(xùn)練」兩種初始化方式的分?jǐn)?shù),可以看到模型較好的穩(wěn)定性和可用性。在隨機(jī)初始化時,也能訓(xùn)練到與微調(diào)相當(dāng)?shù)乃疁?zhǔn)。

然而,在一些算法中,TransNAR仍未能超越基線,且在分布內(nèi)和分布外都是如此。

這些算法包括二分搜索、尋找最大子數(shù)組、最小值和快速選擇等,都涉及在輸入列表中按照索引搜索特定元素。

這暗示了TransNAR的一種故障模式:模型無法泛化到訓(xùn)練數(shù)據(jù)中未見過的新索引邊界。因此,使用索引提示或許是一條有前景的改進(jìn)途徑。

另一種可能的解釋是,NAR最終計算出的隱藏狀態(tài)難以在交叉注意力層以可泛化的方式被解碼。如果原因在此,解決途徑可以是增加交叉注意力的容量,或者采用漸進(jìn)式解碼。

此外,TransNAR在架構(gòu)上有一個本質(zhì)的局限性,就是必需一個能得出ground truth的模擬器或者數(shù)據(jù)標(biāo)簽,用于將輸入的文本轉(zhuǎn)換為圖結(jié)構(gòu),再作為模型輸入。

但是作者強(qiáng)調(diào),TransNAR的概念對于未來研究是有借鑒意義的??梢钥紤]將這種混合架構(gòu)的想法移植到單模態(tài)LLM,或者將TransNAR訓(xùn)練后獲得的知識提煉出來注入到普通的Transformer中。參考資料:

https://arxiv.org/abs/2406.09308

新智元報道 編輯:喬楊 好困

本文由人人都是產(chǎn)品經(jīng)理作者【新智元】,微信公眾號:【新智元】,原創(chuàng)/授權(quán) 發(fā)布于人人都是產(chǎn)品經(jīng)理,未經(jīng)許可,禁止轉(zhuǎn)載。

題圖來自Unsplash,基于 CC0 協(xié)議。

更多精彩內(nèi)容,請關(guān)注人人都是產(chǎn)品經(jīng)理微信公眾號或下載App
評論
評論請登錄
  1. 目前還沒評論,等你發(fā)揮!