AIxiv專(zhuān)欄是機(jī)器之心發(fā)布學(xué)術(shù)、技術(shù)內(nèi)容的欄目。過(guò)去數(shù)年,機(jī)器之心AIxiv專(zhuān)欄接收?qǐng)?bào)道了2000多篇內(nèi)容,覆蓋全球各大高校與企業(yè)的頂級(jí)實(shí)驗(yàn)室,有效促進(jìn)了學(xué)術(shù)交流與傳播。如果您有優(yōu)秀的工作想要分享,歡迎投稿或者聯(lián)系報(bào)道。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
王家豪,香港大學(xué)計(jì)算機(jī)系二年級(jí)博士,導(dǎo)師為羅平教授,研究方向?yàn)樯窠?jīng)網(wǎng)絡(luò)輕量化。碩士畢業(yè)于清華大學(xué)自動(dòng)化系,已在 NeurIPS、CVPR 等頂級(jí)會(huì)議上發(fā)表了數(shù)篇論文。
太長(zhǎng)不看版:香港大學(xué)聯(lián)合上海人工智能實(shí)驗(yàn)室,華為諾亞方舟實(shí)驗(yàn)室提出高效擴(kuò)散模型 LiT:探索了擴(kuò)散模型中極簡(jiǎn)線(xiàn)性注意力的架構(gòu)設(shè)計(jì)和訓(xùn)練策略。LiT-0.6B 可以在斷網(wǎng)狀態(tài),離線(xiàn)部署在 Windows 筆記本電腦上,遵循用戶(hù)指令快速生成 1K 分辨率逼真圖片。
圖 1:LiT 在 Windows 筆記本電腦的離線(xiàn)端側(cè)部署:LiT 可以在端側(cè),斷網(wǎng)狀態(tài),以完全離線(xiàn)的方式遵循用戶(hù)指令,快速生成 1K 分辨率圖片
- 論文名稱(chēng):LiT: Delving into a Simplified Linear Diffusion Transformer for Image Generation
- 論文地址:https://arxiv.org/pdf/2501.12976v1
- 項(xiàng)目主頁(yè):https://techmonsterwang.github.io/LiT/
為了提高擴(kuò)散模型的計(jì)算效率,一些工作使用 Sub-quadratic 計(jì)算復(fù)雜度的模塊來(lái)替代二次計(jì)算復(fù)雜度的自注意力(Self-attention)機(jī)制。這其中,線(xiàn)性注意力的主要特點(diǎn)是:1) 簡(jiǎn)潔;2) 并行化程度高。這對(duì)于大型語(yǔ)言模型、擴(kuò)散模型這樣的大尺寸、大計(jì)算的模型而言很重要。
就在幾天前,MiniMax 團(tuán)隊(duì)著名的《MiniMax-01: Scaling Foundation Models with Lightning Attention》已經(jīng)在大型語(yǔ)言模型中驗(yàn)證了線(xiàn)性模型的有效性。而在擴(kuò)散模型中,關(guān)于「線(xiàn)性注意力要怎么樣設(shè)計(jì),如何訓(xùn)練好基于純線(xiàn)性注意力的擴(kuò)散模型」的討論仍然不多。
本文針對(duì)這個(gè)問(wèn)題,該團(tuán)隊(duì)提出了幾條「拿來(lái)即用」的解決方案,向社區(qū)讀者報(bào)告了可以如何設(shè)計(jì)和訓(xùn)練你的線(xiàn)性擴(kuò)散 Transformer(linear diffusion Transformers)。列舉如下:
- 使用極簡(jiǎn)線(xiàn)性注意力機(jī)制足夠擴(kuò)散模型完成圖像生成。除此之外,線(xiàn)性注意力還有一個(gè)「免費(fèi)午餐」,即:使用更少的頭(head),可以在增加理論 GMACs 的同時(shí) (給模型更多計(jì)算),不增加實(shí)際的 GPU 延遲。
- 線(xiàn)性擴(kuò)散 Transformer 強(qiáng)烈建議從一個(gè)預(yù)訓(xùn)練好的 Diffusion Transformer 里做權(quán)重繼承。但是,繼承權(quán)重的時(shí)候,不要繼承自注意力中的任何權(quán)重(Query, Key, Value, Output 的投影權(quán)重)。
- 可以使用知識(shí)蒸餾(Knowledge Distillation)加速訓(xùn)練。但是,在設(shè)計(jì) KD 策略時(shí),我們強(qiáng)烈建議不但蒸餾噪聲預(yù)測(cè)結(jié)果,同樣也蒸餾方差預(yù)測(cè)結(jié)果 (這一項(xiàng)權(quán)重更小)
LiT 將上述方案匯總成了 5 條指導(dǎo)原則,方便社區(qū)讀者拿來(lái)即用。
在標(biāo)準(zhǔn) ImageNet 基準(zhǔn)上,LiT 只使用 DiT 20% 和 23% 的訓(xùn)練迭代數(shù),即可實(shí)現(xiàn)相當(dāng) FID 結(jié)果。LiT 同樣比肩基于 Mamba 和門(mén)控線(xiàn)性注意力的擴(kuò)散模型。
在文生圖任務(wù)中,LiT-0.6B 可以在斷網(wǎng)狀態(tài),離線(xiàn)部署在 Windows 筆記本電腦上,遵循用戶(hù)指令快速生成 1K 分辨率逼真圖片,助力 AIPC 時(shí)代降臨。
目錄
1 LiT 研究背景
2 線(xiàn)性注意力計(jì)算范式
3 線(xiàn)性擴(kuò)散 Transformer 的架構(gòu)設(shè)計(jì)
4 線(xiàn)性擴(kuò)散 Transformer 的訓(xùn)練方法
5 圖像生成實(shí)驗(yàn)驗(yàn)證
6 文生圖實(shí)驗(yàn)驗(yàn)證
7 離線(xiàn)端側(cè)部署
1 LiT 研究背景
Diffusion Transformer 正在助力文生圖應(yīng)用的商業(yè)化,展示出了極強(qiáng)的商業(yè)價(jià)值和潛力。但是,自注意力的二次計(jì)算復(fù)雜度也成為了 Diffusion Transformer 的一個(gè)老大難問(wèn)題。因?yàn)檫@對(duì)于高分辨率的場(chǎng)景,或者端側(cè)設(shè)備的部署都不算友好。
常見(jiàn)的 Sub-quadratic 計(jì)算復(fù)雜度的模塊有 Mamba 的狀態(tài)空間模型(SSM)、門(mén)控線(xiàn)性注意力(GLA)、線(xiàn)性注意力等等。目前也有相關(guān)的工作將其用在基于類(lèi)別的(class-conditional)圖像生成領(lǐng)域 (非文生圖),比如使用了 Mamba 的 DiM、使用了 GLA 的 DiG 。但是,雖然這些工作確實(shí)實(shí)現(xiàn)了 Sub-quadratic 的計(jì)算復(fù)雜度,但是,這些做法也存在明顯的不足:
- 其一,SSM 和 GLA 模塊都依賴(lài)遞歸的狀態(tài) (State) 變量,需要序列化迭代計(jì)算,對(duì)于并行化并不友好。
- 其二,SSM 和 GLA 模塊的計(jì)算圖相對(duì)于 線(xiàn)性注意力 而言更加復(fù)雜,而且會(huì)引入一些算數(shù)強(qiáng)度 (arithmetic-intensity) 比較低的操作,比如逐元素乘法。
而線(xiàn)性注意力相比前兩者,如下圖 2 所示,不但設(shè)計(jì)簡(jiǎn)單,而且很容易實(shí)現(xiàn)并行化。這樣的特點(diǎn)使得線(xiàn)性注意力對(duì)于高分辨率極其友好。比如對(duì)于 2048px 分辨率圖片,線(xiàn)性注意力比自注意力快約 9 倍,對(duì)于 DiT-S/2 生成所需要的 GPU 內(nèi)存也可以從約 14GB 降低到 4GB。因此,訓(xùn)練出一個(gè)性能優(yōu)異的基于線(xiàn)性注意力的擴(kuò)散模型很有價(jià)值。
圖 2:與 SSM 和 GLA 相比,線(xiàn)性注意力同樣實(shí)現(xiàn) sub-quadratic 的計(jì)算復(fù)雜度,同時(shí)設(shè)計(jì)極其簡(jiǎn)潔,且不依賴(lài)遞歸的狀態(tài)變量
但是,對(duì)于有挑戰(zhàn)性的圖像生成任務(wù),怎么快速,有效地訓(xùn)練好基于線(xiàn)性注意力的擴(kuò)散模型呢?
這個(gè)問(wèn)題很重要,因?yàn)橐环矫妫M管線(xiàn)性注意力在視覺(jué)識(shí)別領(lǐng)域已經(jīng)被探索很多,可以取代自注意力,但是在圖像生成中仍然是一個(gè)探索不足的問(wèn)題。另一方面,從頭開(kāi)始訓(xùn)練擴(kuò)散模型成本高昂。比如訓(xùn)練 RAPHAEL 需要 60K A100 GPU days ( 中報(bào)告)。因此,針對(duì)線(xiàn)性擴(kuò)散 Transformer 的高性?xún)r(jià)比訓(xùn)練策略仍然值得探索。
LiT 從架構(gòu)設(shè)計(jì)和訓(xùn)練策略中系統(tǒng)地研究了純線(xiàn)性注意力的擴(kuò)散 Transformer 實(shí)現(xiàn)。LiT 是一種使用純線(xiàn)性注意力的 Diffusion Transformer。LiT 訓(xùn)練時(shí)的成本效率很高,同時(shí)在推理過(guò)程中保持高分辨率友好屬性,并且可以在 Windows 11 筆記本電腦上離線(xiàn)部署。在基于類(lèi)別的 ImageNet 256×256 基準(zhǔn)上面,100K 訓(xùn)練步數(shù)的 LiT-S/B/L 在 FID 方面優(yōu)于 400K 訓(xùn)練步數(shù)的 DiT-S/B/L。對(duì)于 ImageNet 256×256 和 512×512,LiT-XL/2 在訓(xùn)練步驟只有 20% 和 23% 的條件下,實(shí)現(xiàn)了與 DiT-XL/2 相當(dāng)?shù)?FID。在文生圖任務(wù)中,LiT-0.6B 可以在斷網(wǎng)狀態(tài),離線(xiàn)部署在 Windows 筆記本電腦上,遵循用戶(hù)指令快速生成 1K 分辨率逼真圖片。
2 線(xiàn)性注意力計(jì)算范式
3 線(xiàn)性擴(kuò)散 Transformer 的架構(gòu)設(shè)計(jì)
鑒于對(duì)生成任務(wù)上的線(xiàn)性擴(kuò)散 Transformer 的探索不多,LiT 先以 DiT 為基礎(chǔ),構(gòu)建了一個(gè)使用線(xiàn)性注意力的基線(xiàn)模型。基線(xiàn)模型與 DiT 共享相同的宏觀架構(gòu),唯一的區(qū)別是將自注意力替換為 線(xiàn)性注意力。所有實(shí)驗(yàn)均在基于類(lèi)別的 ImageNet 256×256 基準(zhǔn)上進(jìn)行,使用 256 的 Batch Size 訓(xùn)練了 400K 迭代次數(shù)。
Guideline 1:Simplified 線(xiàn)性注意力對(duì)于基于 DiT 的圖像生成擴(kuò)散模型完全足夠。
我們首先嘗試了在通用視覺(jué)基礎(chǔ)模型中成功驗(yàn)證的常見(jiàn)線(xiàn)性注意力的架構(gòu)設(shè)計(jì),比如 ReLU 線(xiàn)性注意力 (使用 ReLU 激活函數(shù)作為線(xiàn)性注意力的 Kernel Function)。
對(duì)于性能參考,將其與 DiT 進(jìn)行比較,其中任何性能差異都可以歸因于線(xiàn)性注意力對(duì)生成質(zhì)量的影響。如圖 4 中所示。與 DiT 相比,使用 ReLU 線(xiàn)性注意力的 LiT-S/2 和 B/2 性能下降很大。結(jié)果表明,視覺(jué)識(shí)別中常用的線(xiàn)性注意力在噪聲預(yù)測(cè)任務(wù)中有改進(jìn)的空間。
然后我們探索以下方法:
- 簡(jiǎn)化型線(xiàn)性注意力 (圖 3,相當(dāng)于在 ReLU 線(xiàn)性注意力的基礎(chǔ)上加上 Depth-wise 卷積)。
- Focused 線(xiàn)性注意力。
- Focused 線(xiàn)性注意力 (使用 GELU 替換 ReLU)。
這些選擇中的每一個(gè)都保持了線(xiàn)性復(fù)雜度,保持了 LiT 在計(jì)算效率方面的優(yōu)勢(shì)。我們使用相對(duì)較大的卷積核 (Kernel Size 5) 來(lái)確保在預(yù)測(cè)噪聲時(shí)足夠大的感受野。
圖 3:在 Simplified 線(xiàn)性注意力中使用更少的 heads
圖 4:不同架構(gòu)的線(xiàn)性注意力消融研究
實(shí)驗(yàn)結(jié)果如圖 4 所示。加了 DWC 的模塊都可以取得大幅的性能提升,我們認(rèn)為這是因?yàn)槟P驮陬A(yù)測(cè)給定像素的噪聲時(shí)關(guān)注相鄰像素信息。同時(shí),我們發(fā)現(xiàn) Focused Function 的有效性有限,我們將其歸因于其設(shè)計(jì)動(dòng)機(jī),以幫助線(xiàn)性注意聚焦于特定區(qū)域。此功能可能適合分類(lèi)模型,但可能不是噪聲預(yù)測(cè)所必需的。為了簡(jiǎn)單起見(jiàn),最后使用簡(jiǎn)化 線(xiàn)性注意力。
Guideline 2:在線(xiàn)性注意力中建議使用很少的頭,可以在增加計(jì)算的同時(shí)不增加時(shí)延。
多頭自注意力和線(xiàn)性注意力的計(jì)算量分別為:
直覺(jué)上似乎使用更多頭可以減少計(jì)算壓力。但相反,我們建議使用更少的頭,因?yàn)槲覀冇^察到線(xiàn)性注意力存在 Free Lunch 效應(yīng),如圖 5 所示。圖 5 展示了使用線(xiàn)性注意力的 Small,Base,Large,XLarge 模型使用不同頭數(shù)量的延遲和 GMACs 變化。
圖 5:線(xiàn)性注意力中的 Free Lunch 效應(yīng):不同頭數(shù)量線(xiàn)性注意的延遲與理論 GMACs 比較
我們使用 NVIDIA A100 GPU 生成 256×256 分辨率的圖像,批量大小為 8 (NVIDIA V100 GPU 出現(xiàn)類(lèi)似現(xiàn)象)。結(jié)果表明,減小頭數(shù)量會(huì)導(dǎo)致理論 GMACs 穩(wěn)定增加,實(shí)際延遲卻并沒(méi)有呈現(xiàn)出增加的趨勢(shì),甚至出現(xiàn)下降。我們將這種現(xiàn)象總結(jié)為線(xiàn)性注意力的「免費(fèi)午餐(Free Lunch)」效應(yīng)。
我們認(rèn)為在線(xiàn)性注意力中使用更少的頭之后,允許模型有較高的理論計(jì)算,根據(jù) scaling law,允許模型在生成性能上達(dá)到更高的上限。
實(shí)驗(yàn)結(jié)果如圖 6 所示,對(duì)于不同的模型尺度,線(xiàn)性注意力中使用更少的頭數(shù) (比如,2,3,4) 優(yōu)于 DiT 中的默認(rèn)設(shè)置。相反,使用過(guò)多的頭(例如,S/2 的 96 或 B/2 的 192),則會(huì)嚴(yán)重阻礙生成質(zhì)量。
4 線(xiàn)性擴(kuò)散 Transformer 的訓(xùn)練方法
LiT 與 DiT 共享一些相同的結(jié)構(gòu),允許權(quán)重繼承自預(yù)訓(xùn)練的 DiT 架構(gòu)。這些權(quán)重包含豐富的與噪聲預(yù)測(cè)相關(guān)的知識(shí),有望以成本高效的方式轉(zhuǎn)移到 LiT。因此,在這個(gè)部分我們探索把預(yù)訓(xùn)練的 DiT 權(quán)重 (FFN 模塊、adaLN、位置編碼和 Conditional Embedding 相關(guān)的參數(shù)) 繼承給線(xiàn)性 DiT,除了線(xiàn)性注意力部分。
圖 6:線(xiàn)性擴(kuò)散 Transformer 的權(quán)重繼承策略
Guideline 3:線(xiàn)性擴(kuò)散 Transformer 的參數(shù)應(yīng)該從一個(gè)預(yù)訓(xùn)練到收斂的 DiT 初始化。
我們首先預(yù)訓(xùn)練 DiT-S/2 不同的訓(xùn)練迭代次數(shù):200K、300K、400K、600K 和 800K,并且在每個(gè)實(shí)驗(yàn)中,分別將這些預(yù)訓(xùn)練的權(quán)重加載到 LiT-S/2 中,同時(shí)線(xiàn)性注意力部分的參數(shù)保持隨機(jī)。然后將初始化的 LiT-S/2 在 ImageNet 上訓(xùn)練 400K 迭代次數(shù),結(jié)果如圖 6 所示。
我們觀察到一些有趣的發(fā)現(xiàn):
- DiT 的預(yù)訓(xùn)練權(quán)重,即使只訓(xùn)練了 200K 步,也起著重要作用,將 FID 從 63.24 提高到 57.84。
- 使用預(yù)訓(xùn)練權(quán)重的指數(shù)移動(dòng)平均 (EMA) 影響很小。
- DiT 訓(xùn)練更收斂時(shí) (800K 步),更適合作為 LiT 的初始化,即使架構(gòu)沒(méi)有完全對(duì)齊。
我們認(rèn)為這種現(xiàn)象的一種可能解釋是 Diffusion Transformer 中不同模塊的功能是解耦的。盡管 DiT 和 LiT 具有不同的架構(gòu),但它們的共享組件 (例如 FFN 和 adaLN) 的行為非常相似。因此,可以遷移這些組件預(yù)訓(xùn)練參數(shù)中的知識(shí)。同時(shí),即使把 DiT 訓(xùn)練到收斂并遷移共享組件的權(quán)重,也不會(huì)阻礙線(xiàn)性注意力部分的優(yōu)化。
圖 7:ImageNet 256×256 上的權(quán)重繼承消融實(shí)驗(yàn)結(jié)果
Guideline 4:線(xiàn)性注意力中的 Query、Key、Value 和 Output 投影矩陣參數(shù)應(yīng)該隨機(jī)初始化,不要繼承自自注意力。
在 LiT 中,線(xiàn)性注意力中的一些權(quán)重與 DiT 的自注意力中的權(quán)重重疊,包括 Query、Key、Value 和 Output 投影矩陣。盡管計(jì)算范式存在差異,但這些權(quán)重可以直接從 DiT 加載到 LiT 中,而不需要從頭訓(xùn)練。但是,這是否可以加速其收斂性仍然是一個(gè)懸而未決的問(wèn)題。
我們使用經(jīng)過(guò) 600K 次迭代預(yù)訓(xùn)練的 DiT-S/2 進(jìn)行消融實(shí)驗(yàn)。探索了 5 種不同類(lèi)型的加載策略,包括:
- 加載 Query,Key 和 Value 投影矩陣。
- 加載 Key 和 Value 投影矩陣。
- 加載 Value 投影矩陣。
- 加載 Query 投影矩陣。
- 加載 Output 投影矩陣。
結(jié)果如圖 7 所示。與沒(méi)有加載自注意力權(quán)重的基線(xiàn)相比,沒(méi)有一個(gè)探索的策略顯示出更好的生成性能。這種現(xiàn)象可歸因于計(jì)算范式的差異。具體來(lái)說(shuō),線(xiàn)性注意力直接計(jì)算鍵和值矩陣的乘積,但是自注意力就不是這樣的。因此,自注意力中的 Key 和 Value 相關(guān)的權(quán)重對(duì)線(xiàn)性注意力的好處有限。
我們建議繼承除線(xiàn)性注意力之外的所有預(yù)訓(xùn)練參數(shù)從預(yù)訓(xùn)練好的 DiT 中,因?yàn)樗子趯?shí)現(xiàn)并且非常適合基于 Transformer 架構(gòu)的擴(kuò)散模型。
圖 8:混合知識(shí)蒸餾訓(xùn)練線(xiàn)性擴(kuò)散 Transformer
Guideline 5:使用混合知識(shí)蒸餾訓(xùn)練線(xiàn)性擴(kuò)散 Transformer 很關(guān)鍵,不僅蒸餾噪聲預(yù)測(cè)結(jié)果,還蒸餾方差的預(yù)測(cè)結(jié)果。
知識(shí)蒸餾通常采用教師網(wǎng)絡(luò)來(lái)幫助訓(xùn)練輕量級(jí)學(xué)生網(wǎng)絡(luò)。對(duì)于擴(kuò)散模型,蒸餾通常側(cè)重于減少目標(biāo)模型的采樣步驟。相比之下,我們專(zhuān)注于在保持采樣步驟的前提下,從復(fù)雜的模型蒸餾出更簡(jiǎn)單的模型。
圖 9:ImageNet 256×256 上的知識(shí)蒸餾實(shí)驗(yàn)結(jié)果,帶有下劃線(xiàn)的結(jié)果表示不使用知識(shí)蒸餾
到目前為止,LiT 遵循 DiT 的宏觀 / 微觀設(shè)計(jì),但采用了高效的線(xiàn)性注意力。使用我們的訓(xùn)練策略,LiT-S/2 顯著地提高了 FID。接下來(lái),我們?cè)诟蟮淖凅w (例如 B/L/XL) 和具有挑戰(zhàn)性的任務(wù) (比如 T2I) 上驗(yàn)證它。
5 圖像生成實(shí)驗(yàn)驗(yàn)證
ImageNet 256×256 基準(zhǔn)
我們首先在 ImageNet 256×256 基準(zhǔn)上驗(yàn)證 LiT。LiT-S/2、B/2、L/2、XL/2 配置與 DiT 一致,只是線(xiàn)性注意力的頭分別設(shè)置為 2/3/4/4。對(duì)于所有模型變體,DWC Kernel Size 都設(shè)置為 5。我們以 256 的 Batch Size 訓(xùn)練 400K 步。對(duì)于 LiT-XL/2,將訓(xùn)練步數(shù)擴(kuò)展到 1.4M 步 (只有 DiT-XL/2 7M 的 20%)。我們使用預(yù)訓(xùn)練的 DiT 初始化 LiT 的參數(shù)。Lambda_1 和 lambda_2 在混合知識(shí)蒸餾中設(shè)置為 0.5 和 0.05。
圖 10 和 11 比較了 LiT 和 DiT 的不同尺寸模型的結(jié)果。值得注意的是,僅 100K 訓(xùn)練迭代次數(shù)訓(xùn)練的 LiT 已經(jīng)在各種評(píng)估指標(biāo)和不同尺寸的模型中優(yōu)于 400K 訓(xùn)練迭代次數(shù)訓(xùn)練的 DiT。使用 400K 訓(xùn)練迭代次數(shù)的額外訓(xùn)練,模型的性能繼續(xù)提高。盡管訓(xùn)練步驟只有 DiT-XL/2 的 20%,但 LiT-XL/2 仍然取得與 DiT 相當(dāng)?shù)?FID 結(jié)果 (2.32 對(duì) 2.27)。此外,LiT 與基于 U-Net 的基線(xiàn)性能相當(dāng)。這些結(jié)果表明,當(dāng)線(xiàn)性注意力結(jié)合合適的優(yōu)化策略時(shí),可以可靠地用于圖像生成應(yīng)用。
圖 10:ImageNet 256×256 基準(zhǔn)實(shí)驗(yàn)結(jié)果,與基于自注意力的 DiT 和基于門(mén)控線(xiàn)性注意力的 DiG 的比較
圖 11:ImageNet 256×256 基準(zhǔn)實(shí)驗(yàn)結(jié)果
ImageNet 512×512 基準(zhǔn)
我們繼續(xù)在 ImageNet 512×512 基準(zhǔn)上進(jìn)一步驗(yàn)證了 LiT-XL/2。使用預(yù)訓(xùn)練的 DiT-XL/2 作為教師模型,使用其權(quán)重初始化 LiT-XL/2。對(duì)于知識(shí)蒸餾,分別設(shè)置 Lambda_1 和 lambda_2 為 1.0 和 0.05,并且只訓(xùn)練 LiT-XL/2 700K 訓(xùn)練迭代次數(shù) (是 DiT 3M 訓(xùn)練迭代次數(shù)的 23%)。
值得注意的是,與使用 256 的 Batch Size 的 DiT 不同,我們采用 128 的較小 Batch Size。這其實(shí)不占便宜,因?yàn)?128 的 Batch Size 相比 256 的情況,完成 1 Epoch 需要 2 倍的訓(xùn)練迭代次數(shù)。也就是說(shuō),我們 700K 的訓(xùn)練迭代次數(shù)其實(shí)只等效為 256 Batch Size 下的 350K。盡管如此,使用純線(xiàn)性注意力的 LiT 實(shí)現(xiàn)了 3.69 的 FID,與 3M 步訓(xùn)練的 DiT 相當(dāng),將訓(xùn)練步驟減少了約 77%。此外,LiT 優(yōu)于幾個(gè)強(qiáng)大的 Baseline。這些結(jié)果證明了我們提出的成本高效的訓(xùn)練策略在高分辨率數(shù)據(jù)集上的有效性。實(shí)驗(yàn)結(jié)果如圖 12 所示。
圖 12:ImageNet 512×512 基準(zhǔn)實(shí)驗(yàn)結(jié)果
6 文生圖實(shí)驗(yàn)驗(yàn)證
文生圖對(duì)于擴(kuò)散模型的商業(yè)應(yīng)用極為重要。LiT 遵循 PixArt-α 的做法,將交叉注意力添加到 LiT-XL/2 中使其支持文本嵌入。LiT 將線(xiàn)性注意力的頭數(shù)設(shè)置為 2,DWC Kernel Size 設(shè)置為 5。遵循 PixArt-Σ 的做法,使用預(yù)訓(xùn)練的 SDXL VAE Encoder 和 T5 編碼器 (即 Flan-T5-XXL) 分別提取圖像和文本特征。
LiT 使用 PixArt-Σ 作為教師來(lái)監(jiān)督其訓(xùn)練,分別設(shè)置 Lambda_1 和 lambda_2 為 1.0 和 0.05。LiT 從 PixArt-Σ 繼承權(quán)重,除了自注意力的參數(shù)。隨后在內(nèi)部數(shù)據(jù)集上訓(xùn)練,學(xué)習(xí)率為 2e-5,僅訓(xùn)練 45400 步,明顯低于 PixArt-α 的多階段訓(xùn)練。圖 13 為 LiT 生成的 512px 圖像采樣結(jié)果。盡管在每個(gè) Block 中都使用了線(xiàn)性注意力,以及我們的成本高效的訓(xùn)練策略,LiT 仍然可以產(chǎn)生異常逼真的圖像。
圖 13:LiT 根據(jù)用戶(hù)指令生成的 512px 圖片
我們還將分辨率進(jìn)一步增加到 1K。更多的實(shí)驗(yàn)細(xì)節(jié)請(qǐng)參閱原論文。圖 14 是生成的結(jié)果采樣。盡管用廉價(jià)的線(xiàn)性注意力替換所有自注意力,但 LiT 仍然能夠以高分辨率生成逼真的圖像。
圖 14:LiT 根據(jù)用戶(hù)指令生成的 1K 分辨率圖片
7 離線(xiàn)端側(cè)部署
我們還將 1K 分辨率的 LiT-XL/2 模型部署到一臺(tái) Windows 11 操作系統(tǒng)驅(qū)動(dòng)的筆記本電腦上,以驗(yàn)證其 On-device 的能力。考慮到筆記本電腦的 GPU 內(nèi)存的限制,我們將文本編碼器量化為 8-bit,同時(shí)在線(xiàn)性注意力計(jì)算期間保持 fp16 精度。圖 1 顯示了我們的部署結(jié)果。預(yù)訓(xùn)練的 LiT 可以在離線(xiàn)設(shè)置 (沒(méi)有網(wǎng)絡(luò)連接) 的情況下快速生成照片逼真的 1K 分辨率圖像。這些結(jié)果說(shuō)明 LiT 作為一種 On-device 的擴(kuò)散模型的成功實(shí)現(xiàn),推進(jìn)邊緣設(shè)備上的高分辨率文生圖任務(wù)。
下面提供了一個(gè)視頻 Demo:
https://mp.weixin.qq.com/s/XEQJnt5cJ63spqSG67WLGw?token=1784997338&lang=zh_CN
展示了在斷網(wǎng)狀態(tài)下離線(xiàn)使用 LiT 完成 1K 分辨率文生圖任務(wù)的過(guò)程。
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶(hù)上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.