0x0. 前言
更多的深度學(xué)習(xí)編譯器知識(shí)可以在 https://github.com/BBuf/tvm_mlir_learn 找到。同時(shí)也維護(hù)了一個(gè)cuda學(xué)習(xí)倉庫 https://github.com/BBuf/how-to-optim-algorithm-in-cuda 以及一個(gè)如何學(xué)習(xí)深度學(xué)習(xí)框架(PyTorch和OneFlow)的學(xué)習(xí)倉庫,https://github.com/BBuf/how-to-learn-deep-learning-framework , 有需要的小伙伴可以點(diǎn)一點(diǎn)star 。在https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large-language-model-note 這個(gè)目錄下收集了一系列和LLM訓(xùn)練,推理相關(guān)的文章。
【省流】上次介紹了深度學(xué)習(xí)編譯器之Layerout Transform優(yōu)化 ,在這篇文章中提到還會(huì)介紹常量折疊優(yōu)化Pass的實(shí)現(xiàn),但在介紹常量折疊Pass之前我想再介紹一個(gè)類似的優(yōu)化方法也就是公共子表達(dá)式消除實(shí)現(xiàn)(CSE)。仍然是以O(shè)neFlow中基于MLIR進(jìn)行實(shí)現(xiàn)的CSE Pass為例子來講解。在解析代碼實(shí)現(xiàn)的過程中,我發(fā)現(xiàn)基于MLIR來做公共子表達(dá)式消除的時(shí)候還順帶做了死代碼消除的功能。另外,在考慮公共子表達(dá)式消除的時(shí)候需要保證兩個(gè)重復(fù)的操作處于同一個(gè)基本塊中以及兩個(gè)重復(fù)操作之間沒有其它具有副作用的操作才可以消除。在OneFlow的實(shí)現(xiàn)中只是對(duì)OneFlow的UserOp的特殊屬性即OpName和SymbolID進(jìn)行了擦除,用一個(gè)魔法屬性來代替,這是因?yàn)檫@兩個(gè)屬性不應(yīng)該去影響公共子表達(dá)式的消除。這個(gè)優(yōu)化還是比較有用的,在OneFlow的Stable Diffusion優(yōu)化中發(fā)揮了不小的作用。
0x1. 效果
公共子表達(dá)式消除的作用很簡單,就是把公共的表達(dá)式折疊為1個(gè)表達(dá)式來避免重復(fù)的計(jì)算開銷。我們以O(shè)neFlow針對(duì)CSE Pass寫的2個(gè)測試為例子來進(jìn)行說明。這兩個(gè)例子在 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/test/OneFlow/cse.mlir ,這里提供了一個(gè) MLIR Module,包含兩個(gè)函數(shù):@Cast_1__FUSE__ScalarMulByTensor_2 和 @f2。
其中,第一個(gè)函數(shù) @Cast_1__FUSE__ScalarMulByTensor_2 接受一個(gè)形狀為 96x96xi64 的張量作為輸入,并執(zhí)行兩個(gè)類型轉(zhuǎn)換操作,將輸入轉(zhuǎn)換為 96x96xf32 張量。然后,它使用 oneflow.add_n 操作將兩個(gè)結(jié)果張量相加,并返回結(jié)果 96x96xf32 張量。FileCheck 命令驗(yàn)證了具有 "ScalarMulByTensor_2" op_name 屬性的 "oneflow.cast" 和 "oneflow.add_n2" 操作的存在。這里再解釋一下 CHECK 指定,比如CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.cast" 是一個(gè) FileCheck 指令,用于驗(yàn)證生成的代碼是否符合預(yù)期。FileCheck 是 LLVM 項(xiàng)目的一部分,用于為編譯器測試提供模式匹配功能。%[[OUT:[a-zA-Z0-9_]+]] 是一個(gè)正則表達(dá)式捕獲組,用于捕獲一個(gè)以 % 開頭、后跟一系列字母、數(shù)字或下劃線的字符串。這個(gè)字符串對(duì)應(yīng)于 MLIR 中的一個(gè)值名稱。"oneflow.cast" 表示我們希望找到一個(gè)名為 "oneflow.cast" 的操作。
第二個(gè)函數(shù) @f2 接受三個(gè)輸入張量:一個(gè)形狀為 2x64x64x320xf16 的張量,一個(gè)形狀為 320x320x3x3xf16 的張量,和一個(gè)形狀為 320xf16 的張量。它將第二個(gè)輸入張量轉(zhuǎn)置兩次,并使用轉(zhuǎn)置后的張量、第一個(gè)輸入張量和第三個(gè)輸入張量執(zhí)行兩個(gè) conv2d 操作。該函數(shù)返回兩個(gè)形狀為 2x64x64x320xf16 的結(jié)果張量。FileCheck 命令驗(yàn)證了具有等于 163 的 scope_symbol_id 屬性的 "oneflow.conv2d" 操作的存在,并檢查輸出的兩個(gè)結(jié)果張量。
這兩個(gè)函數(shù)有一個(gè)共同點(diǎn),那就是它們都存在一個(gè)完全相同的公共Op,我們可以編譯oneflow之后使用下面的命令將CSE Pass添加到opt pass pipline里面來運(yùn)行這個(gè)mlir表達(dá)式做變換,我們可以關(guān)注變換后的表達(dá)式。命令如下:
oneflow/build/oneflow/ir/bin/oneflow-optoneflow/oneflow/ir/test/OneFlow/cse.mlir-cse-with-attributes-ignored-cse-cse-put-attributes-canonicalize
解釋一下這里的幾個(gè)選項(xiàng):
cse-with-attributes-ignored: 此參數(shù)告訴優(yōu)化器在執(zhí)行公共子表達(dá)式消除(CSE)時(shí)忽略O(shè)neFlow IR特有的會(huì)影響CSE的屬性(這里是OpName和SymbolID)。
cse: 這個(gè)參數(shù)開啟公共子表達(dá)式消除(CSE)優(yōu)化。CSE 是一種編譯器優(yōu)化技術(shù),用于刪除冗余的子表達(dá)式,從而減少計(jì)算量和提高程序運(yùn)行速度。
cse-put-attributes: 此參數(shù)指示優(yōu)化器在執(zhí)行 CSE 之后,將原始屬性放回原始操作。這有助于確保在優(yōu)化過程中保留操作的屬性信息。(也暗示我們必須把原始的屬性保存下來)
canonicalize: 這個(gè)參數(shù)開啟規(guī)范化優(yōu)化。規(guī)范化優(yōu)化會(huì)將程序中的操作和表達(dá)式轉(zhuǎn)換為一種統(tǒng)一的標(biāo)準(zhǔn)形式,從而簡化后續(xù)優(yōu)化的實(shí)現(xiàn)和提高效率。(這兩個(gè)給定的例子里,不開啟canonicalize也不會(huì)影響輸出IR的表達(dá))
接下來是運(yùn)行上述命令后輸出的MLIR Module。
module{
func.func@Cast_1__FUSE__ScalarMulByTensor_2(%arg0:tensor<96x96xi64>)->tensor<96x96xf32>{
%0="oneflow.cast"(%arg0){device_name=["0:0"],device_tag="cpu",dtype=2:i32,hierarchy=[1],op_name="Cast_1",op_type_name="cast",pin_memory=false,scope_symbol_id=4611686018427416574:i64}:(tensor<96x96xi64>)->tensor<96x96xf32>
%1="oneflow.add_n2"(%0,%0){device_name=["0:0"],device_tag="cpu",hierarchy=[1],op_name="ScalarMulByTensor_2",op_type_name="add_n",scope_symbol_id=4611686018427416574:i64}:(tensor<96x96xf32>,tensor<96x96xf32>)->tensor<96x96xf32>
return%1:tensor<96x96xf32>
}
func.func@f2(%arg0:tensor<2x64x64x320xf16>,%arg1:tensor<320x320x3x3xf16>,%arg2:tensor<320xf16>)->(tensor<2x64x64x320xf16>,tensor<2x64x64x320xf16>){
%0="oneflow.transpose"(%arg1){device_name=["@0:0"],device_tag="cuda",hierarchy=[1],op_name="unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_1",perm=[0:si32,2:si32,3:si32,1:si32],scope_symbol_id=163:i64}:(tensor<320x320x3x3xf16>)->tensor<320x3x3x320xf16>
%1="oneflow.conv2d"(%arg0,%0,%arg2){data_format="channels_last",device_name=["@0:0"],device_tag="cuda",dilation_rate=[1:si32,1:si32],filters=320:si32,groups=1:si32,hierarchy=[1],kernel_size=[3:si32,3:si32],op_name="unet.down_blocks.0.resnets.0.conv1-conv2d-31",operand_segment_sizes=array,padding_before=[1:si32,1:si32],scope_symbol_id=163:i64,strides=[1:si32,1:si32],tuning_cache=""}:(tensor<2x64x64x320xf16>,tensor<320x3x3x320xf16>,tensor<320xf16>)->tensor<2x64x64x320xf16>
return%1,%1:tensor<2x64x64x320xf16>,tensor<2x64x64x320xf16>
}
}
和原始的MLIR ModuleOp對(duì)比,我們發(fā)現(xiàn)這兩個(gè)函數(shù)里面的公共子表達(dá)式(cast和transpose)都只保留了一個(gè),實(shí)現(xiàn)了公共子表達(dá)式消除的目的。在OneFlow編譯器中,這個(gè)優(yōu)化率先在OneFlow的Stable Diffusion引人,加速了模型的推理速度。
0x2. 原理&代碼實(shí)現(xiàn)
基于 OneFlow 實(shí)現(xiàn) CSE 的原理是,我們需要先消除 OneFlow 的 UserOp 的 OpName 和 SymbolID 這兩個(gè)屬性,這兩個(gè)屬性對(duì) CSE 來說是沒影響的,但是是由 OneFlow 系統(tǒng)添加的,所以我們需要做個(gè)預(yù)處理忽略掉這兩個(gè)不一致。然后調(diào)用MLIR系統(tǒng)的 CSE Pass 之后我們需要把這個(gè)忽略的屬性加回來。這樣才可以保證優(yōu)化后的IR可以轉(zhuǎn)回OneFlow的圖并正確執(zhí)行。
首先基于ODS在https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/OneFlowPasses.td#L156-L172 定義了兩個(gè)CSE相關(guān)的Pass類,MLIR會(huì)自動(dòng)生成這兩個(gè)Pass的定義。我們?cè)敿?xì)看一下細(xì)節(jié):
defCSEWithAttributesIgnored:Pass<"cse-with-attributes-ignored",?"ModuleOp">{//定義了一個(gè)名為"cse-with-attributes-ignored"的Pass,它作用在MLIR中的模塊操作(ModuleOp)上。
letsummary="ignoreoneflowattributestohavecsework";//summary和description:提供了有關(guān)Pass功能的簡短描述和詳細(xì)說明。這個(gè)Pass的目的是執(zhí)行CSE優(yōu)化,同時(shí)忽略O(shè)neFlow屬性(如操作名、符號(hào)ID等)。
letdescription=[{
cseandignoreoneflowattributeslikeopname,symbolid,etc.
}];
letconstructor="mlir::createCSEWithAttributesIgnored()";//指定用于創(chuàng)建這個(gè)Pass的函數(shù),即mlir::createCSEWithAttributesIgnored()。
letdependentDialects=[];//列出這個(gè)Pass依賴的其他方言。在這種情況下,它是空的,表示沒有依賴關(guān)系。
}
defCSEPutAttributes:Pass<"cse-put-attributes",?"ModuleOp">{
letsummary="cseandignoreoneflowattributes";
letdescription=[{
putbackoneflowattributeslikeopname,symbolid,etc.
}];
letconstructor="mlir::createCSEPutAttributes()";
letdependentDialects=[];
}
可以看到 CSE 的預(yù)處理和后處理 Pass 主要就是實(shí)現(xiàn) createCSEWithAttributesIgnored 和 createCSEPutAttributes 這兩個(gè)函數(shù)。它們的定義在:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/Transform/CSEWithAttributesIgnored.h#L25-L33
//CSEState結(jié)構(gòu)體包含兩個(gè)成員:
//scopeSymbolIDs:一個(gè)llvm::DenseMap,將Operation*類型的指針映射到IntegerAttr類型的屬性。這個(gè)映射可能用于存儲(chǔ)操作的范圍符號(hào)ID。
//opNames:一個(gè)llvm::DenseMap,將Operation*類型的指針映射到StringAttr類型的屬性。這個(gè)映射可能用于存儲(chǔ)操作的名稱。
structCSEState{
llvm::DenseMapscopeSymbolIDs;
llvm::DenseMapopNames;
};
//這個(gè)函數(shù)返回一個(gè)std::unique_ptr類型的對(duì)象。根據(jù)函數(shù)名稱,這個(gè)函數(shù)創(chuàng)建一個(gè)CSEPass,其中忽略了屬性。
std::unique_ptrcreateCSEWithAttributesIgnored();
//這個(gè)函數(shù)也返回一個(gè)std::unique_ptr類型的對(duì)象。根據(jù)函數(shù)名稱,這個(gè)函數(shù)創(chuàng)建一個(gè)CSEPass,會(huì)處理或放置屬性。
std::unique_ptrcreateCSEPutAttributes();
//這個(gè)函數(shù)接受一個(gè)std::shared_ptr類型的參數(shù),并返回一個(gè)std::pair,其中包含兩個(gè)std::unique_ptr類型的對(duì)象。這個(gè)函數(shù)創(chuàng)建一對(duì)CSEPass,它們共享給定的CSEState。
std::pair,std::unique_ptr>createCSEPasses(
std::shared_ptrstate);
//這個(gè)函數(shù)接受一個(gè)std::shared_ptr類型的參數(shù)。根據(jù)函數(shù)名稱,這個(gè)函數(shù)可能會(huì)注冊(cè)一組CSEPass,它們共享給定的CSEState。
voidregisterCSEPasses(std::shared_ptrstate);
接下來看下這幾個(gè) Pass 的具體實(shí)現(xiàn)。代碼在 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/lib/OneFlow/Transform/CSEWithAttributesIgnored.cpp
首先來看createCSEWithAttributesIgnored:
structEraseAttributes:publicmlir::OpInterfaceRewritePattern{ explicitEraseAttributes(mlir::MLIRContext*context,std::shared_ptr state) :OpInterfaceRewritePattern (context,/*benefit=*/1),state_{state}{} mlir::LogicalResultmatchAndRewrite(UserOpCompatibleop, mlir::PatternRewriter&rewriter)constoverride{ if(op->getAttrOfType (OpTrait::IsOpConfCompatible ::getOpNameAttr()) .getValue() .str() !=MAGIC_OP_NAME){ if(state_){ state_->opNames[op]= op->getAttrOfType (OpTrait::IsOpConfCompatible ::getOpNameAttr()); state_->scopeSymbolIDs[op]=op->getAttrOfType ( OpTrait::IsOpConfCompatible ::getScopeSymbolIDAttr()); } op->setAttr(OpTrait::IsOpConfCompatible ::getOpNameAttr(), rewriter.getStringAttr(MAGIC_OP_NAME)); op->setAttr(OpTrait::IsOpConfCompatible ::getScopeSymbolIDAttr(), rewriter.getI64IntegerAttr(MAGIC_SCOPE_SYMBOL_ID)); returnsuccess(); }else{ returnfailure(); } } private: std::shared_ptr state_; }; classCSEWithAttributesIgnored:publicCSEWithAttributesIgnoredBase { public: explicitCSEWithAttributesIgnored(){} explicitCSEWithAttributesIgnored(std::shared_ptr state):state_(state){} voidrunOnOperation()override{ Operation*op=getOperation(); RewritePatternSetpatterns(op->getContext()); patterns.add (op->getContext(),state_); (void)applyPatternsAndFoldGreedily(op,std::move(patterns)); } private: std::shared_ptr state_; }; std::unique_ptr createCSEWithAttributesIgnored(){ returnstd::make_unique (); }
這段代碼定義了一個(gè) EraseAttributes 重寫類, 它會(huì)移除 op 中的某些屬性。它繼承自 OpInterfaceRewritePattern, 意味著它可以匹配實(shí)現(xiàn)了 UserOpCompatible 這個(gè) OpInterface 的 op。然后 EraseAttributes 構(gòu)造函數(shù)接受一個(gè) MLIRContext* 和一個(gè)shared_ptr。CSEState 用于跟蹤已重寫的 op 的屬性。matchAndRewrite 方法檢查 op 是否有名為 OpNameAttr 的 StringAttr 屬性, 如果有, 并且其值不等于 MAGIC_OP_NAME, 則該方法會(huì):
將 op 的 OpNameAttr 和 ScopeSymbolIDAttr 屬性記錄在 CSEState 中。
將 OpNameAttr 設(shè)置為 MAGIC_OP_NAME, 將 ScopeSymbolIDAttr 設(shè)置為 MAGIC_SCOPE_SYMBOL_ID。
然后,CSEWithAttributesIgnored 繼承自 CSEWithAttributesIgnoredBase, 重寫了其 runOnOperation 方法。該方法會(huì)實(shí)例化一個(gè) RewritePatternSet, 添加 EraseAttributes 這個(gè)匹配重寫模板, 然后應(yīng)用該模板, 從而移除user op 中的屬性。它還保存一個(gè)指向CSEState 的 shared_ptr , 可以在 EraseAttributes 中使用。注意這里的 CSEWithAttributesIgnoredBase 是通過ODS自動(dòng)生成的 Pass 類定義。createCSEWithAttributesIgnored 函數(shù)會(huì)創(chuàng)建一個(gè) CSEWithAttributesIgnored pass 并返回。
接著看一下 createCSEPutAttributes 的實(shí)現(xiàn),
structPutAttributes:publicmlir::OpInterfaceRewritePattern{ explicitPutAttributes(mlir::MLIRContext*context,std::shared_ptr state) :OpInterfaceRewritePattern (context,/*benefit=*/1),state_{state}{} mlir::LogicalResultmatchAndRewrite(UserOpCompatibleop, mlir::PatternRewriter&rewriter)constoverride{ if(op->getAttrOfType (OpTrait::IsOpConfCompatible ::getOpNameAttr()) .getValue() .str() ==MAGIC_OP_NAME){ if(state_){ op->setAttr(OpTrait::IsOpConfCompatible ::getOpNameAttr(),state_->opNames[op]); op->setAttr(OpTrait::IsOpConfCompatible ::getScopeSymbolIDAttr(), state_->scopeSymbolIDs[op]); } returnsuccess(); }else{ returnfailure(); } } private: std::shared_ptr state_; }; classCSEPutAttributes:publicCSEPutAttributesBase { public: explicitCSEPutAttributes(){} explicitCSEPutAttributes(std::shared_ptr state){state_=state;} voidrunOnOperation()override{ Operation*op=getOperation(); RewritePatternSetpatterns(op->getContext()); patterns.add (op->getContext(),state_); (void)applyPatternsAndFoldGreedily(op,std::move(patterns)); } private: std::shared_ptr state_; }; std::unique_ptr createCSEPutAttributes(){returnstd::make_unique ();}
這個(gè) PutAttributes 重寫模板與 EraseAttributes 相反, 它會(huì)將先前刪除的屬性恢復(fù)回 op。PutAttributes 構(gòu)造函數(shù)也接受一個(gè) MLIRContext* 和一個(gè) shared_ptr。它使用 CSEState 來查找先前刪除的屬性值。matchAndRewrite 方法檢查 op 是否有一個(gè)名為 OpNameAttr 的 StringAttr 屬性,其值等 于 MAGIC_OP_NAME 。如果是,它會(huì)從 CSEState 中查找原先的 OpNameAttr 和 ScopeSymbolIDAttr 屬性值。將 OpNameAttr 設(shè)置為原先的值,將 ScopeSymbolIDAttr 設(shè)置為原先的值。
上面的2個(gè)Pass都是OneFlow中的預(yù)處理和后處理,而真的CSE Pass則是MLIR自帶的CSE Pass(oneflow/build/oneflow/ir/llvm_monorepo-src/mlir/lib/Transforms/CSE.cpp), 我們來解析一下。
structSimpleOperationInfo:publicllvm::DenseMapInfo{ staticunsignedgetHashValue(constOperation*opC){ returnOperationEquivalence::computeHash( const_cast (opC), /*hashOperands=*/OperationEquivalence::directHashValue, /*hashResults=*/OperationEquivalence::ignoreHashValue, OperationEquivalence::IgnoreLocations); } staticboolisEqual(constOperation*lhsC,constOperation*rhsC){ auto*lhs=const_cast (lhsC); auto*rhs=const_cast (rhsC); if(lhs==rhs) returntrue; if(lhs==getTombstoneKey()||lhs==getEmptyKey()|| rhs==getTombstoneKey()||rhs==getEmptyKey()) returnfalse; returnOperationEquivalence::isEquivalentTo( const_cast (lhsC),const_cast (rhsC), OperationEquivalence::IgnoreLocations); } };
SimpleOperationInfo 這個(gè)結(jié)構(gòu)體繼承自 llvm::DenseMapInfo
getHashValue: 為 Operation* 計(jì)算哈希值。它使用 OperationEquivalence::computeHash 來計(jì)算哈希值,并傳遞 hashOperands=directHashValue 和 hashResults=ignoreHashValue。這意味著它會(huì)直接對(duì) op 的操作數(shù)計(jì)算哈希值,但會(huì)忽略結(jié)果。
isEqual: 檢查兩個(gè) Operation* 是否相等。它首先檢查是否是相同的 op , 如果是,則返回 true。否則,它使用OperationEquivalence::isEquivalentTo 檢查兩個(gè) op 是否等價(jià)。同樣,它傳遞了 IgnoreLocations, 意味著它會(huì)忽略 op 的位置信息。
所以, 這個(gè) DenseMapInfo 允許以忽略結(jié)果和位置的方式將 Operation* 用作 DenseMap 的鍵。操作數(shù)用于等價(jià)性檢查和哈希值計(jì)算。
///Simplecommonsub-expressionelimination. //這是一個(gè)名為CSE(CommonSub-expressionElimination,公共子表達(dá)式消除)的結(jié)構(gòu)體定義,用于執(zhí)行簡單的公共子表達(dá)式消除。CSE是一種編譯器優(yōu)化技術(shù),用于消除程序中的重復(fù)計(jì)算,提高執(zhí)行效率。 structCSE:publicimpl::CSEBase{ ///Sharedimplementationofoperationeliminationandscopedmapdefinitions. //使用AllocatorTy和ScopedMapTy來定義分配器和作用域映射。ScopedMapTy是一個(gè)散列表,用于存儲(chǔ)操作之間的映射關(guān)系。 usingAllocatorTy=llvm::RecyclingAllocator< ??????llvm::BumpPtrAllocator, ??????llvm::ScopedHashTableVal >; usingScopedMapTy=llvm::ScopedHashTable ; ///CacheholdingMemoryEffectsinformationbetweentwooperations.Thefirst ///operationisstoredhasthekey.Thesecondoperationisstoredinsidea ///pairinthevalue.ThepairalsoholdtheMemoryEffectsbetweenthose ///twooperations.IftheMemoryEffectsisnullptrthenweassumethereis ///nooperationwithMemoryEffects::Writebetweenthetwooperations. //MemEffectsCache用于在兩個(gè)操作之間緩存MemoryEffects信息。MemoryEffects表示某個(gè)操作對(duì)內(nèi)存的影響。 usingMemEffectsCache= DenseMap >; ///RepresentsasingleentryinthedepthfirsttraversalofaCFG. //CFGStackNode結(jié)構(gòu)體表示控制流圖(CFG)深度優(yōu)先遍歷中的一個(gè)節(jié)點(diǎn)。包括作用域、節(jié)點(diǎn)、子節(jié)點(diǎn)迭代器等信息。 structCFGStackNode{ CFGStackNode(ScopedMapTy&knownValues,DominanceInfoNode*node) :scope(knownValues),node(node),childIterator(node->begin()){} ///Scopefortheknownvalues. ScopedMapTy::ScopeTyscope; DominanceInfoNode*node; DominanceInfoNode::const_iteratorchildIterator; ///Ifthisnodehasbeenfullyprocessedyetornot. boolprocessed=false; }; ///Attempttoeliminatearedundantoperation.Returnssuccessifthe ///operationwasmarkedforremoval,failureotherwise. //simplifyOperation函數(shù)嘗試消除冗余操作。如果操作被標(biāo)記為移除,則返回成功,否則返回失敗。 LogicalResultsimplifyOperation(ScopedMapTy&knownValues,Operation*op, boolhasSSADominance); //simplifyBlock函數(shù)簡化指定的基本塊(Block)。 voidsimplifyBlock(ScopedMapTy&knownValues,Block*bb,boolhasSSADominance); //simplifyRegion函數(shù)簡化指定的區(qū)域(Region)。 voidsimplifyRegion(ScopedMapTy&knownValues,Region®ion); //runOnOperation函數(shù)是重寫的基類方法,用于執(zhí)行CSE優(yōu)化。 voidrunOnOperation()override; private: //replaceUsesAndDelete函數(shù)用于替換操作的使用和刪除操作。 voidreplaceUsesAndDelete(ScopedMapTy&knownValues,Operation*op, Operation*existing,boolhasSSADominance); ///Checkifthereisside-effectingoperationsotherthanthegiveneffect ///betweenthetwooperations. //hasOtherSideEffectingOpInBetween函數(shù)檢查給定操作之間是否存在其他具有副作用的操作。 boolhasOtherSideEffectingOpInBetween(Operation*fromOp,Operation*toOp); ///Operationsmarkedasdeadandtobeerased. //opsToErase是一個(gè)用于存儲(chǔ)將要?jiǎng)h除的操作的向量。 std::vector opsToErase; //domInfo是一個(gè)指向支配信息(DominanceInfo)的指針。 DominanceInfo*domInfo=nullptr; //memEffectsCache是一個(gè)緩存,用于存儲(chǔ)操作之間的內(nèi)存效果信息。 MemEffectsCachememEffectsCache; }; }//namespace
我們先看一下核心的runOperation方法。
voidCSE::runOnOperation(){ ///Ascopedhashtableofdefiningoperationswithinaregion. //定義一個(gè)名為knownValues的局部變量。它是一個(gè)作用域內(nèi)的哈希表,用于存儲(chǔ)在一個(gè)區(qū)域內(nèi)定義的操作。 ScopedMapTyknownValues; //從DominanceInfo分析中獲取支配關(guān)系信息,并將其存儲(chǔ)在名為domInfo的變量中。 domInfo=&getAnalysis(); //獲取當(dāng)前操作(rootOp),并遍歷其所有區(qū)域。對(duì)每個(gè)區(qū)域執(zhí)行簡化操作(simplifyRegion)。 Operation*rootOp=getOperation(); for(auto®ion:rootOp->getRegions()) simplifyRegion(knownValues,region); //如果opsToErase(要?jiǎng)h除的操作)為空,說明沒有操作被刪除,因此保留所有分析。 //Ifnooperationswereerased,thenwemarkallanalysesaspreserved. if(opsToErase.empty()) returnmarkAllAnalysesPreserved(); ///Eraseanyoperationsthatweremarkedasdeadduringsimplification. //如果opsToErase中有操作,遍歷opsToErase并刪除其中的操作。然后清空opsToErase。 for(auto*op:opsToErase) op->erase(); opsToErase.clear(); //Wecurrentlydon'tremoveregionoperations,somarkdominanceas //preserved. //由于當(dāng)前代碼不會(huì)刪除區(qū)域操作,因此將支配關(guān)系信息(DominanceInfo)和后支配關(guān)系信息(PostDominanceInfo)標(biāo)記為已保留。將domInfo設(shè)置為nullptr。 markAnalysesPreserved (); domInfo=nullptr; }
這里首先會(huì)獲取當(dāng)前 ModuleOp 中 Region 里的支配關(guān)系,以便后續(xù)執(zhí)行完 CSE 之后刪除 Op 后可以更新支配信息。這里的重點(diǎn)是 simplifyRegion 函數(shù),這是執(zhí)行 CSE 的具體細(xì)節(jié)。這個(gè)函數(shù)主要使用支配樹遍歷區(qū)域中的基本塊,并調(diào)用 simplifyBlock() 函數(shù)對(duì)每個(gè)基本塊進(jìn)行簡化。
//函數(shù)接受一個(gè)類型為ScopedMapTy的引用knownValues和一個(gè)類型為Region的引用region作為參數(shù)。
voidCSE::simplifyRegion(ScopedMapTy&knownValues,Region®ion){
//Iftheregionisemptythereisnothingtodo.
if(region.empty())
return;
//判斷區(qū)域是否具有SSA支配關(guān)系(StaticSingleAssignmentDominance),并將結(jié)果存儲(chǔ)在變量hasSSADominance中。
boolhasSSADominance=domInfo->hasSSADominance(®ion);
//Iftheregiononlycontainsoneblock,thensimplifyitdirectly.
//如果區(qū)域只包含一個(gè)基本塊,那么直接對(duì)其進(jìn)行簡化。創(chuàng)建一個(gè)名為scope的ScopedMapTy::ScopeTy對(duì)象,然后調(diào)用simplifyBlock()函數(shù)對(duì)該基本塊進(jìn)行簡化。
if(region.hasOneBlock()){
ScopedMapTy::ScopeTyscope(knownValues);
simplifyBlock(knownValues,®ion.front(),hasSSADominance);
return;
}
//IftheregiondoesnothavedominanceInfo,thenskipit.
//TODO:RegionswithoutSSAdominanceshoulddefineadifferent
//traversalorderwhichisappropriateandcanbeusedhere.
//如果區(qū)域沒有支配關(guān)系信息(hasSSADominance為false),則跳過它。此處提到了一個(gè)TODO:對(duì)于沒有SSA支配關(guān)系的區(qū)域,應(yīng)該定義一個(gè)不同的遍歷順序。
if(!hasSSADominance)
return;
//Note,dequeisbeingusedherebecausetherewassignificantperformance
//gainsovervectorwhenthecontainerbecomesverylargeduetothe
//specificaccesspatterns.If/whentheseperformanceissuesareno
//longeraproblemwecanchangethistovector.Formoreinformationsee
//thellvmmailinglistdiscussiononthis:
//http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
//定義一個(gè)名為stack的std::deque容器,用于存儲(chǔ)CFGStackNode的std::unique_ptr。這里使用deque是因?yàn)樗谌萜髯兇髸r(shí)具有更好的性能表現(xiàn)。
std::deque>stack;
//Processthenodesofthedomtreeforthisregion.
//處理這個(gè)區(qū)域的支配樹節(jié)點(diǎn)。將區(qū)域的根節(jié)點(diǎn)壓入棧中。
stack.emplace_back(std::make_unique(
knownValues,domInfo->getRootNode(®ion)));
//當(dāng)棧不為空時(shí),執(zhí)行以下循環(huán)操作:
while(!stack.empty()){
//獲取棧頂?shù)漠?dāng)前節(jié)點(diǎn)(currentNode)。
auto¤tNode=stack.back();
//Checktoseeifweneedtoprocessthisnode.
//檢查當(dāng)前節(jié)點(diǎn)是否需要被處理。如果未處理,則將其標(biāo)記為已處理,并調(diào)用simplifyBlock()函數(shù)對(duì)當(dāng)前節(jié)點(diǎn)所在的基本塊進(jìn)行簡化。
if(!currentNode->processed){
currentNode->processed=true;
simplifyBlock(knownValues,currentNode->node->getBlock(),
hasSSADominance);
}
//Otherwise,checktoseeifweneedtoprocessachildnode.
//檢查是否需要處理子節(jié)點(diǎn)。如果當(dāng)前節(jié)點(diǎn)的子節(jié)點(diǎn)迭代器未到達(dá)末尾,將子節(jié)點(diǎn)壓入棧中。
if(currentNode->childIterator!=currentNode->node->end()){
auto*childNode=*(currentNode->childIterator++);
stack.emplace_back(
std::make_unique(knownValues,childNode));
}else{
//Finally,ifthenodeandallofitschildrenhavebeenprocessed
//thenwedeletethenode.
//如果當(dāng)前節(jié)點(diǎn)及其所有子節(jié)點(diǎn)都已處理完畢,則將節(jié)點(diǎn)從棧中彈出。
stack.pop_back();
}
}
}
函數(shù)的執(zhí)行流程請(qǐng)看注釋,到這一步之后CSE的具體實(shí)現(xiàn)實(shí)際上就在 simplifyBlock 函數(shù)了,我們繼續(xù)追蹤。函數(shù)接受一個(gè)類型為 ScopedMapTy 的引用 knownValues,一個(gè)類型為 Block 的指針 bb,以及一個(gè)布爾值 hasSSADominance 作為參數(shù)。從代碼中可以推測,該函數(shù)的目的是簡化一個(gè)給定的基本塊。
voidCSE::simplifyBlock(ScopedMapTy&knownValues,Block*bb,
boolhasSSADominance){
//遍歷基本塊bb中的所有操作(op)
for(auto&op:*bb){
//Mostoperationsdon'thaveregions,sofastpaththatcase.
//檢查操作是否包含區(qū)域。如果操作包含區(qū)域,執(zhí)行以下操作:
if(op.getNumRegions()!=0){
//Ifthisoperationisisolatedabove,wecan'tprocessnestedregions
//withthegiven'knownValues'map.Thiswouldcausetheinsertionof
//implicitcapturesinexplicitcaptureonlyregions.
//如果操作具有IsIsolatedFromAbove特性,那么我們不能使用給定的knownValues映射來處理嵌套區(qū)域,
//因?yàn)檫@可能導(dǎo)致在僅顯式捕獲的區(qū)域中插入隱式捕獲。在這種情況下,創(chuàng)建一個(gè)新的nestedKnownValues映射,
//并對(duì)操作的每個(gè)區(qū)域調(diào)用simplifyRegion()函數(shù)。
if(op.mightHaveTrait()){
ScopedMapTynestedKnownValues;
for(auto®ion:op.getRegions())
simplifyRegion(nestedKnownValues,region);
}else{
//Otherwise,processnestedregionsnormally.
//如果操作沒有IsIsolatedFromAbove特性,那么正常處理嵌套區(qū)域。
//對(duì)操作的每個(gè)區(qū)域調(diào)用simplifyRegion()函數(shù),傳入knownValues映射。
for(auto®ion:op.getRegions())
simplifyRegion(knownValues,region);
}
}
//如果操作被簡化(調(diào)用simplifyOperation()函數(shù)并檢查其返回值),則不處理操作包含的任何區(qū)域,繼續(xù)處理下一個(gè)操作。
//Iftheoperationissimplified,wedon'tprocessanyheldregions.
if(succeeded(simplifyOperation(knownValues,&op,hasSSADominance)))
continue;
}
//CleartheMemoryEffectscachesinceitsusageisbyblockonly.
//在處理完所有操作后,清空memEffectsCache,因?yàn)樗氖褂脙H限于單個(gè)基本塊。
memEffectsCache.clear();
}
在 simplifyBlock 中會(huì)進(jìn)一步調(diào)用到 simplifyOperation 來對(duì) Operation 做優(yōu)化。我們最后跟進(jìn)這個(gè)函數(shù)看一下。函數(shù)的參數(shù)和 simplifyBlock 一樣,接受一個(gè)類型為 ScopedMapTy 的引用 knownValues,一個(gè)類型為 Operation 的指針op,以及一個(gè)布爾值 hasSSADominance 作為參數(shù)。
///Attempttoeliminatearedundantoperation.
LogicalResultCSE::simplifyOperation(ScopedMapTy&knownValues,Operation*op,
boolhasSSADominance){
//Don'tsimplifyterminatoroperations.
//如果操作是終止操作(具有IsTerminator特性),則不對(duì)其進(jìn)行簡化。
if(op->hasTrait())
returnfailure();
//Iftheoperationisalreadytriviallydeadjustaddittotheeraselist.
//如果操作已經(jīng)是無關(guān)緊要的死代碼,將其添加到待擦除操作列表opsToErase中,增加死代碼消除計(jì)數(shù),然后返回成功。
if(isOpTriviallyDead(op)){
opsToErase.push_back(op);
++numDCE;
returnsuccess();
}
//Don'tsimplifyoperationswithregionsthathavemultipleblocks.
//TODO:WeneedadditionalteststoverifythatwehandlesuchIRcorrectly.
//不簡化具有多個(gè)基本塊的區(qū)域中的操作。這里提到了一個(gè)TODO:需要額外的測試來驗(yàn)證處理此類IR的正確性。
if(!llvm::all_of(op->getRegions(),[](Region&r){
returnr.getBlocks().empty()||llvm::hasSingleElement(r.getBlocks());
}))
returnfailure();
//Somesimpleusecaseofoperationwithmemoryside-effectaredealtwith
//here.Operationswithnoside-effectaredoneafter.
//首先處理具有內(nèi)存副作用的簡單操作。沒有副作用的操作會(huì)在后面處理。
if(!isMemoryEffectFree(op)){
automemEffects=dyn_cast(op);
//TODO:OnlybasicusecaseforoperationswithMemoryEffects::Readcanbe
//eleminatednow.Moreworkneedstobedoneformorecomplicatedpatterns
//andotherside-effects.
//如果操作不是無內(nèi)存副作用的,嘗試獲取其MemoryEffectOpInterface。
//如果操作沒有MemoryEffectOpInterface,或者它不僅僅具有MemoryEffects::Read副作用,則返回失敗。
if(!memEffects||!memEffects.onlyHasEffect())
returnfailure();
//Lookforanexistingdefinitionfortheoperation.
//查找操作的現(xiàn)有定義。如果找到現(xiàn)有定義,并且操作在同一個(gè)基本塊中,并且兩者之間沒有其它具有副作用的操作,
//則可以刪除冗余操作。調(diào)用replaceUsesAndDelete()函數(shù)替換使用并刪除操作。
if(auto*existing=knownValues.lookup(op)){
if(existing->getBlock()==op->getBlock()&&
!hasOtherSideEffectingOpInBetween(existing,op)){
//Theoperationthatcanbedeletedhasbeenreachwithno
//side-effectingoperationsinbetweentheexistingoperationand
//thisonesowecanremovetheduplicate.
replaceUsesAndDelete(knownValues,op,existing,hasSSADominance);
returnsuccess();
}
}
//將操作插入knownValues映射中,并返回失敗。
knownValues.insert(op,op);
returnfailure();
}
//Lookforanexistingdefinitionfortheoperation.
//查找操作的現(xiàn)有定義。如果找到現(xiàn)有定義,調(diào)用replaceUsesAndDelete()函數(shù)替換使用并刪除操作,
//增加公共子表達(dá)式消除計(jì)數(shù),并返回成功。
if(auto*existing=knownValues.lookup(op)){
replaceUsesAndDelete(knownValues,op,existing,hasSSADominance);
++numCSE;
returnsuccess();
}
//Otherwise,weaddthisoperationtotheknownvaluesmap.
//否則,將此操作添加到knownValues映射中,并返回失敗。
knownValues.insert(op,op);
returnfailure();
}
我們可以看到在 simplifyOperation 中,不僅僅包含公共子表達(dá)式消除(CSE),而且包含了死代碼消除(DCE)。此外,在處理 Operation 時(shí),它會(huì)考慮 Operation 的內(nèi)存副作用以及 Operation 是否在具有多個(gè)基本塊的區(qū)域中。
0x3. 總結(jié)
在閱讀代碼實(shí)現(xiàn)的過程中,我發(fā)現(xiàn)基于MLIR來做公共子表達(dá)式消除的時(shí)候還順帶做了死代碼消除的功能。另外,在考慮公共子表達(dá)式消除的時(shí)候需要保證兩個(gè)重復(fù)的操作處于同一個(gè)基本塊中以及兩個(gè)重復(fù)操作之間沒有其它具有副作用的操作才可以消除。在OneFlow的實(shí)現(xiàn)中只是對(duì)OneFlow的UserOp的特殊屬性即OpName和SymbolID進(jìn)行了擦除,用一個(gè)魔法屬性來代替,這是因?yàn)檫@兩個(gè)屬性不應(yīng)該去影響公共子表達(dá)式的消除。這個(gè)優(yōu)化還是比較有用的,在OneFlow的Stable Diffusion優(yōu)化中發(fā)揮了不小的作用。
-
代碼
+關(guān)注
關(guān)注
30文章
4972瀏覽量
74106 -
編譯器
+關(guān)注
關(guān)注
1文章
1672瀏覽量
51753 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5600瀏覽量
124470
原文標(biāo)題:0x4. 相關(guān)鏈接
文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
國產(chǎn)深度學(xué)習(xí)框架的挑戰(zhàn)和機(jī)會(huì)
Nanopi深度學(xué)習(xí)之路(1)深度學(xué)習(xí)框架分析
深度學(xué)習(xí)框架你了解多少
如何學(xué)習(xí)深度學(xué)習(xí)框架
評(píng)論