昇思 MindSpore 開(kāi)源社區(qū)將于 2025 年 12 月 25 日在杭州舉辦昇思人工智能框架峰會(huì)。本次峰會(huì)的昇思人工智能框架技術(shù)發(fā)展與行業(yè)實(shí)踐論壇將討論到昇思MindSpore 大模型技術(shù)進(jìn)展與實(shí)踐,并將設(shè)有昇思 AI for Science(AI4S)專(zhuān)題論壇。本文對(duì) AI4S 團(tuán)隊(duì)開(kāi)發(fā)的 MindSpore Protenix 蛋白質(zhì)結(jié)構(gòu)預(yù)測(cè)模型的性能與優(yōu)化進(jìn)行了深入解讀,揭示了如何實(shí)現(xiàn)該模型的訓(xùn)練與推理性能的提升。
背景
蛋白質(zhì)結(jié)構(gòu)預(yù)測(cè)是現(xiàn)代生命科學(xué)的圣杯之一。雖然AlphaFold2等AI工具已實(shí)現(xiàn)單體蛋白結(jié)構(gòu)的高精度預(yù)測(cè),但整個(gè)領(lǐng)域仍面臨兩大核心瓶頸:
第一,預(yù)測(cè)準(zhǔn)確性仍存在系統(tǒng)性盲區(qū)。當(dāng)前模型對(duì)蛋白質(zhì)動(dòng)態(tài)構(gòu)象、翻譯后修飾狀態(tài)、膜蛋白環(huán)境以及多鏈復(fù)合物組裝等關(guān)鍵場(chǎng)景的預(yù)測(cè)精度嚴(yán)重不足。模型在MSA信息稀疏時(shí)(如人工設(shè)計(jì)蛋白、孤兒蛋白)性能會(huì)斷崖式下跌,本質(zhì)上仍是基于進(jìn)化關(guān)聯(lián)的“模式外推”而非真正的物理規(guī)律學(xué)習(xí)。
第二,計(jì)算復(fù)雜性成為應(yīng)用壁壘。最先進(jìn)的預(yù)測(cè)模型需要同時(shí)處理數(shù)千條同源序列的MSA信息,單次推理就需數(shù)十GB顯存和數(shù)小時(shí)GPU時(shí)間。對(duì)于需要高通量掃描的工業(yè)場(chǎng)景或更大尺度的復(fù)合物預(yù)測(cè),算力需求呈指數(shù)級(jí)增長(zhǎng)。這使得前沿技術(shù)難以轉(zhuǎn)化為普惠工具,學(xué)術(shù)實(shí)驗(yàn)室和中小企業(yè)常因算力門(mén)檻而被排除在創(chuàng)新循環(huán)之外。
這兩個(gè)問(wèn)題相互纏繞:要提升對(duì)復(fù)雜場(chǎng)景的預(yù)測(cè)精度,往往需要更龐大的模型和更豐富的輸入特征,而這又會(huì)進(jìn)一步推高計(jì)算成本,形成難以突破的技術(shù)閉環(huán)。
昇思 MindSpore 的 AI for Science 方案詳解
昇思 MindSpore 通過(guò)軟硬件協(xié)同優(yōu)化及高效的 NPU 計(jì)算能力,為行業(yè)提供了高性能的自主創(chuàng)新 AI 解決方案,大幅加速蛋白質(zhì)研究進(jìn)程并降低計(jì)算成本。我們實(shí)現(xiàn)了蛋白質(zhì)結(jié)構(gòu)預(yù)測(cè)模型 Protenix 的 MindSpore 框架版本,并在昇騰硬件平臺(tái)上實(shí)現(xiàn)了高性能的訓(xùn)練和推理。為應(yīng)對(duì)大規(guī)模蛋白質(zhì)結(jié)構(gòu)預(yù)測(cè)的高計(jì)算需求,本項(xiàng)目充分利用 MindSpore 框架的計(jì)算圖優(yōu)化能力與昇騰處理器的硬件優(yōu)勢(shì),在完全繼承了模型推理精度的同時(shí),又顯著提升了模型性能。

圖1 MindSpore Protenix蛋白質(zhì)結(jié)構(gòu)預(yù)測(cè)模型的推理效果
在本文所描述的調(diào)優(yōu)策略下,模型在昇騰A2 64G 單卡上可達(dá)到 768 的最大訓(xùn)練長(zhǎng)度,并且最大單卡推理長(zhǎng)度超過(guò) 3000;以下是相應(yīng)的具體訓(xùn)推時(shí)間:


2.1 模型訓(xùn)練優(yōu)化
重計(jì)算(Recompute)優(yōu)化
在深度模型訓(xùn)練中,顯存占用通常可分為靜態(tài)顯存(Static Memory)與動(dòng)態(tài)顯存(Dynamic Memory)兩個(gè)部分。對(duì)于 Protenix(AF3 類(lèi)結(jié)構(gòu)模型) 這類(lèi)高度依賴(lài)幾何結(jié)構(gòu)建模的網(wǎng)絡(luò)而言,其瓶頸并非權(quán)重規(guī)模,而是激活值數(shù)量極大、計(jì)算路徑復(fù)雜、依賴(lài)大量三元(i,j,k)結(jié)構(gòu)相關(guān)中間張量。通過(guò)在前向傳播階段不保存部分激活值,而是在反向傳播需要梯度時(shí)重新執(zhí)行對(duì)應(yīng)的前向計(jì)算,即可顯著降低顯存占用。
PyTorch 版本 Protenix 中已經(jīng)大量使用了重計(jì)算來(lái)緩解激活膨脹的問(wèn)題。然而受限于硬件顯存容量限制、模型關(guān)鍵結(jié)構(gòu)適配不足,以及考慮到 MindSpore 對(duì)動(dòng)態(tài) shape 的靜態(tài)優(yōu)化與 PyTorch 有一定差異后,我們?cè)?MindSpore 版本中對(duì)重計(jì)算策略做了更細(xì)粒度的優(yōu)化。
如下圖紅框處所示,a 為未優(yōu)化前顯存占用曲線(xiàn),可以看到在紅框處達(dá)到峰值。通過(guò)分析可以確定此處位置用于計(jì)算 smooth_lddt_loss,因此將這個(gè)部分單獨(dú)進(jìn)行重計(jì)算后就得到了下圖的結(jié)果,此處峰值由 55G 下降到 20G 以?xún)?nèi)。

針對(duì)性重計(jì)算設(shè)計(jì)
在 MindSpore 實(shí)現(xiàn)中,我們分別對(duì)核心模塊進(jìn)行了獨(dú)立的重計(jì)算包裝,以精確控制激活緩存范圍并最大化釋放顯存。首先是針對(duì) Triangle Attention 的重計(jì)算,Triangle Attention 在 AF3 / Protenix 中是最重要的結(jié)構(gòu)依賴(lài)模塊之一,其 Q/K/V 計(jì)算與 pair-wise 三元交互的復(fù)雜度為 O(N^3) ,隨著序列的增長(zhǎng)會(huì)產(chǎn)生大量中間激活,在昇騰平臺(tái)上,由于當(dāng)前暫時(shí)缺乏對(duì)等的 fused kernel(如 FlashAttention-like kernel),Triangle Attention 的激活會(huì)占用更大量的顯存。因此針對(duì)一個(gè) PairFormer Layer 中的兩個(gè) Triangle Attention 分別進(jìn)行重計(jì)算。
其次我們對(duì) Triangle Multiplication 進(jìn)行重計(jì)算,因?yàn)?Triangle Multiplication 涉及大量 (i,j,k) 維度重排與張量廣播,且其激活值規(guī)模更大。
最后是 smooth_lddt_loss 計(jì)算的重計(jì)算(大規(guī)模 cdist),smooth_lDDT loss 中一項(xiàng)關(guān)鍵計(jì)算為 pairwise distance(cdist),其生成的距離矩陣為 O(L2 × d),其中L為原子數(shù)量,這與 TriangleAttention 等對(duì)應(yīng)的殘基數(shù)量不同,原子數(shù)量通常比殘基數(shù)大一個(gè)數(shù)量級(jí),因此對(duì)長(zhǎng)序列顯存壓力極大,我們?yōu)?loss 中的該部分單獨(dú)加入了重計(jì)算,使其在反向不需要保留巨大 distance matrix。
實(shí)際顯存收益
在未開(kāi)啟上述重計(jì)算策略時(shí):
? 64GB 顯存僅能訓(xùn)練長(zhǎng)度 64 的序列。
? 動(dòng)態(tài)顯存峰值約為20152 MB。
啟用重計(jì)算后:
? 顯存峰值下降到7025 MB,下降超 60%。
? 最長(zhǎng)可支持訓(xùn)練長(zhǎng)度提升到 768 tokens。
這一優(yōu)化是 Protenix MindSpore 版本能夠在昇騰A2 平臺(tái)上成功支持長(zhǎng)序列訓(xùn)練的關(guān)鍵技術(shù)點(diǎn)之一。
2.2 模型推理優(yōu)化
在這部分工作中,我們基于對(duì)模型性能的分析,逐一找到時(shí)間、內(nèi)存方面的性能瓶頸并予以?xún)?yōu)化。
Profiling 數(shù)據(jù)與分析
MindSpore 支持用戶(hù)使用 Profiler 類(lèi)對(duì)模型的性能進(jìn)行采集,所獲得的 Profiling 數(shù)據(jù)記錄了詳細(xì)的算子時(shí)間線(xiàn),也包括了算子的顯存占用信息。Profiling 數(shù)據(jù)可以通過(guò) MindInsight 工具進(jìn)行可視化分析,可以查看詳細(xì)的算子時(shí)間線(xiàn),以及流之間的調(diào)用關(guān)系。我們可以精確計(jì)算出每個(gè)模塊的位置及其耗時(shí),并據(jù)此來(lái)確定這些模塊是否需要進(jìn)一步的優(yōu)化。例如,下圖展示了我們對(duì)推理過(guò)程中 PairFormer 模塊的定位與拆解,為后續(xù)的時(shí)間、內(nèi)存的分析提供了框架與引導(dǎo):

Unfold 算子重構(gòu)
通過(guò)模型運(yùn)行時(shí)打印算子運(yùn)行時(shí)長(zhǎng)占比,發(fā)現(xiàn) Im2col 占總運(yùn)行時(shí)長(zhǎng)最高,高達(dá) 70.73%,故需要分析并消減該算子的調(diào)用。

定位后可確定為調(diào)用 mindspore.ops.unfold 算子引入問(wèn)題。根據(jù)原本 PyTorch 代碼邏輯,此處實(shí)際使用 torch.Tensor.unfold,其實(shí)際與 torch.nn.functional.unfold 行為不同,差異如下:
? Tensor.unfold:返回原始張量的一個(gè)視圖,該視圖包含在指定維度上從張量中提取的所有大小為 size 的切片。
? nn.functional.unfold:把 4-D 圖像 (N,C,H,W) 的每個(gè) kernel_size 平面窗拉成一列,輸出“二維矩陣”,方便后面用矩陣乘法代替卷積。本質(zhì)是 im2col 操作,為 im2col 的別名 api。
而 MindSpore 中,Tensor.unfold 與 ms.nn.functional.unfold 實(shí)現(xiàn)相同,實(shí)際調(diào)用為 im2col,因此造成實(shí)現(xiàn)差異。故此處整改方案為,使用 MindSpore 實(shí) 現(xiàn) Tensor.unfold 與 torch.Tensor.unfold 相同功能函數(shù)進(jìn)行替換。等價(jià)實(shí)現(xiàn)后,端到端推理性能提升1倍。后續(xù) MindSpore 實(shí)現(xiàn) Tensor.unfold 算子后可進(jìn)一步優(yōu)化顯存占用以提升性能。
融合算子的開(kāi)發(fā)與調(diào)優(yōu)
由于 SelfAttention 的顯存開(kāi)銷(xiāo)與蛋白質(zhì)序列長(zhǎng)度強(qiáng)相關(guān),且當(dāng)前對(duì)該模塊的優(yōu)化并不完全親和生物學(xué)場(chǎng)景,因此我們選擇開(kāi)發(fā)融合算子 EvoformerAttention。對(duì)此,我們實(shí)施了以下關(guān)鍵改進(jìn):
? UB 內(nèi)存布局重構(gòu):消除內(nèi)存碎片,提升 UB 利用率;
? 消除流同步算子:重構(gòu)計(jì)算流水線(xiàn),將串行內(nèi)存拷貝轉(zhuǎn)為并行異步操作;
? 稀疏掩碼優(yōu)化:去除 drop_mask 在 UB 中的顯存占用;
? 動(dòng)態(tài) tiling 調(diào)整:基于 UB 剩余容量自適應(yīng)調(diào)整分塊大小,顯著降低循環(huán)開(kāi)銷(xiāo);以上四個(gè)改進(jìn)總體時(shí)間性能提升約 6.5%;
? API 優(yōu)化:將傳統(tǒng)的 Level 1 API 配合顯式循環(huán)的模式,重構(gòu)為 Level 0 API 的批量處理接口,單步優(yōu)化后時(shí)間性能提升約 5%。
此外,Protenix 中使用了大量的張量計(jì)算,其實(shí)現(xiàn)方式均為 Einsum(Einstein Summation,愛(ài)因斯坦求和約定),因此該算子對(duì)模型整體的性能影響較大。Einsum 中規(guī)定的張量縮并運(yùn)算滿(mǎn)足下標(biāo)表達(dá)式

Einsum 高效實(shí)現(xiàn)在邏輯上離不開(kāi)對(duì)下標(biāo)的重排列(permute)。但 permute 操作的時(shí)間復(fù)雜度是 O(N),我們可以通過(guò)優(yōu)化下標(biāo)排布,減少或消除顯式的 permute 操作,來(lái)進(jìn)一步提升 Einsum 的算子性能。具體操作包括:
? 放棄不必要的 permute 操作,邏輯上改為對(duì)下標(biāo)循環(huán)的重排布,并通過(guò) reshape 操作合并下標(biāo),以實(shí)現(xiàn)批量操作;可將時(shí)間復(fù)雜度降到O(1);
? 使用 Mindspore 接口:ops.MatMul(transpose_a=False, transpose_b=False),該接口適配了最低兩維轉(zhuǎn)置的情況,可以替代符合這種情況下的 permute 操作。

尋找并解決內(nèi)存瓶頸
經(jīng)過(guò)此前的優(yōu)化后,Protenix 模型的 MindSpore 實(shí)現(xiàn)版本在單張 A2上的推理極限大致為包含 2000 個(gè)殘基的蛋白質(zhì)序列,也即推理長(zhǎng)度的極限只有 2k。通過(guò)分析 2k 長(zhǎng)度序列推理的 Profiling 數(shù)據(jù)、調(diào)查模型前期出現(xiàn)的若干個(gè)算子,我們發(fā)現(xiàn)在模型在 PairFormer 階段存在大量的內(nèi)存瓶頸:

通過(guò)對(duì)算子的定位我們可以將內(nèi)存峰值出現(xiàn)的時(shí)間與四次 EvoFormer Iteration 相吻合,最終定位出內(nèi)存瓶頸為該循環(huán)中的 outer_product_mean 計(jì)算。 該模塊主要承擔(dān)張量的縮并計(jì)算(愛(ài)因斯坦求和操作) 和一些線(xiàn)性變換,而內(nèi)存瓶頸正是發(fā)生在外積計(jì)算當(dāng)中:

對(duì)求和的左側(cè)部分進(jìn)行分塊操作,并調(diào)整合適的分塊尺寸(chunk_size),成功降低了內(nèi)存的峰值。我們后續(xù)又定位到其他可能導(dǎo)致內(nèi)存溢出的位置,分別是:
? 位于PairFormer 階段的 msa_attention,msa_transition 和 triangle_multiplication 計(jì)算;
? 位于Diffusion 階段的 transition_block 計(jì)算;
? 位于Confidence 階段的 ConfidenceHead 和 GridSelfAttention 計(jì)算。
關(guān)于分塊操作對(duì)時(shí)間、內(nèi)存以及算法精度上的影響,通過(guò)理論推導(dǎo)與實(shí)驗(yàn)驗(yàn)證,我們得到以下結(jié)論:
? 我們總是避開(kāi)了 LayerNorm,Softmax 等非線(xiàn)性操作所涉及的維度,因此分塊不會(huì)影響最終推理的精度;
? 整體而言,分塊尺寸與計(jì)算時(shí)間呈負(fù)相關(guān)關(guān)系,因此可在內(nèi)存容許的情況下,盡量增大分塊尺寸;下圖展示了 msa_attention 和 GridSelfAttention 在不同分塊下的計(jì)算時(shí)間;

使用以上策略,我們打通了單張 A2 上的 3k 長(zhǎng)度序列推理,成功提高了模型的推理極限。
2.3 jit 裝飾器與靜態(tài)圖編譯
MindSpore 與 PyTorch 的核心差異之一在于:
* PyTorch(Eager Mode)采用運(yùn)行時(shí)逐算子調(diào)度,算子粒度小、靈活但存在較高 launch 開(kāi)銷(xiāo);
* MindSpore 支持通過(guò) **`jit` 裝飾器** 將部分模塊提前編譯為靜態(tài)圖(Graph),在執(zhí)行時(shí)以 **大算子形式一次性下發(fā)**,極大減少算子調(diào)度成本。
在 Protenix 的 MindSpore 復(fù)現(xiàn)中,我們主要對(duì) Transformer 模塊進(jìn)行了 JIT 編譯以提升推理與訓(xùn)練效率。這主要是由于 Protenix 的 Transformer 層結(jié)構(gòu) 較為規(guī)則,輸入維度(hidden size、head_dim、num_heads)均為固定值,適合編譯為計(jì)算圖。在 Diffusion 采樣過(guò)程中,每步都需要調(diào)用 Transformer,共200次,但僅第一次需要編譯,后續(xù)可以直接復(fù)用。以序列長(zhǎng)度 109 的蛋白質(zhì) 5tgy 在 Atlas A2 的端到端推理性能為例(Diffusion 200 steps):
? JIT 編譯耗時(shí)大約30 s;
? 運(yùn)行平穩(wěn)后耗時(shí)約41 s;
? 非 JIT 模式下的推理耗時(shí)為72 s;
? JIT 模式下端到端加速比達(dá)到57%;

總結(jié)
我們成功將蛋白質(zhì)結(jié)構(gòu)預(yù)測(cè)模型 Protenix 從 PyTorch 遷移至 MindSpore 框架,并在昇騰 A2 平臺(tái)上實(shí)現(xiàn)了高性能訓(xùn)推。針對(duì)訓(xùn)練顯存瓶頸,我們?cè)O(shè)計(jì)了細(xì)粒度的重計(jì)算策略,對(duì) Triangle Attention、Triangle Multiplication 等模塊進(jìn)行針對(duì)性?xún)?yōu)化,將動(dòng)態(tài)顯存峰值降低 60% 以上,支持 768 長(zhǎng)度序列訓(xùn)練。推理優(yōu)化方面,通過(guò)重構(gòu) unfold 算子消除冗余 im2col 操作,開(kāi)發(fā) EvoformerAttention 融合算子,優(yōu)化 Einsum 實(shí)現(xiàn)減少數(shù)據(jù)移動(dòng),并采用分塊策略突破outer_product_mean 等模塊的內(nèi)存瓶頸,以及 JIT 編譯加速等,將推理長(zhǎng)度從 2k 擴(kuò)展至 3k 以上。我們驗(yàn)證了自主創(chuàng)新計(jì)算平臺(tái)在前沿蛋白質(zhì)預(yù)測(cè)任務(wù)中的高效性與可行性,為復(fù)雜科學(xué)計(jì)算模型向 MindSpore 生態(tài)遷移提供了實(shí)踐范例。
在蛋白質(zhì)領(lǐng)域,昇思 AI4S 團(tuán)隊(duì)通過(guò)算法與自主創(chuàng)新算力的深度協(xié)同,使實(shí)驗(yàn)室級(jí)的前沿AI工具,成為生物醫(yī)藥產(chǎn)業(yè)可規(guī)模部署的基礎(chǔ)設(shè)施。昇思 AI4S 團(tuán)隊(duì)聚焦于打造面向科學(xué)發(fā)現(xiàn)的專(zhuān)用 AI 框架,致力于構(gòu)建科學(xué)計(jì)算與人工智能融合的新型基礎(chǔ)設(shè)施。團(tuán)隊(duì)支撐范圍涵蓋了生物信息、地球物理、能源、電磁仿真、計(jì)算數(shù)學(xué)和材料化學(xué)等多個(gè)領(lǐng)域,未來(lái)將進(jìn)一步打造開(kāi)源生態(tài)并深化基礎(chǔ)設(shè)施的建造。昇思社區(qū)的 AI4S 開(kāi)源代碼倉(cāng)庫(kù)可見(jiàn) https://atomgit.com/mindspore-lab/mindscience.
本次在杭州舉辦的昇思人工智能框架峰會(huì),將會(huì)邀請(qǐng)思想領(lǐng)袖、專(zhuān)家學(xué)者、企業(yè)領(lǐng)軍人物及明星開(kāi)發(fā)者等產(chǎn)學(xué)研用代表,共探技術(shù)發(fā)展趨勢(shì)、分享創(chuàng)新成果與實(shí)踐經(jīng)驗(yàn)。歡迎各界精英共赴前沿之約,攜手打造開(kāi)放、協(xié)同、可持續(xù)的人工智能框架新生態(tài)!












