作者|吳育昕
1
【資料圖】
為什么是TorchDynamo
Graph capture 把用戶 Python 寫的模型代碼變成 graph,是一切編譯的根基。而 PyTorch 在試了這么多方案之后似乎已經(jīng)鎖定 TorchDynamo 作為 graph capture 的未來方向了,所以寫一點(diǎn)關(guān)于 TorchDynamo 的內(nèi)容,主要是解釋到底為什么要做這個(gè)東西(離開FB一年了,內(nèi)容主要憑自己的猜測(cè)和理解)。
一句話盡量解釋 TorchDynamo 干了什么:利用?PEP523(https://peps.python.org/pep-0523/)?的 API 在用戶執(zhí)行每個(gè) python frame 前,?拿到這個(gè) frame 的 bytecode,把其中認(rèn)識(shí)的部分用 tracing 的方式提取出 graph (并送給后端編譯),不認(rèn)識(shí)的部分維持原樣。把修改后的??bytecode還給 CPython 跑。
由于 LazyTensor 和 TorchDynamo 都做 tracing,并且都是 best-effort graph capture,即只編譯自己能 capture 的部分,capture 不到的用 Python 跑 (aka Python fallback),所以觀感上兩者可能會(huì)差不多。
然而,這兩個(gè)方案的差別正是 TorchDynamo 關(guān)鍵的地方:
LazyTensor 是個(gè)純靠 tracing 的方案,不可避免的問題是「只能看見 trace 到的部分,只有 trace 一下才知道哪里不能 trace」。而每次執(zhí)行模型的時(shí)候,不能 trace 的部分可能不太一樣。為了保證正確性,LazyTensor 就不得不每次執(zhí)行都要重新 trace。舉個(gè)極端的例子,模型里寫了一個(gè)torch.add(tensor, random.random()) ,其中 random 是個(gè) LazyTensor 看不見摸不著的 Python 函數(shù),那只有重新 trace 才能保證正確性。
而當(dāng) TorchDynamo 修改 bytecode 的時(shí)候,事情就不太一樣了:
在 bytecode 里能夠看得見所有需要的信息,所以能夠證明「這段模型代碼沒有用到奇怪的東西所以不需要重新 trace」。
光證明了「不需要 trace」不代表可以真的不 trace,因?yàn)橛脩舻拇a還是一行行給 Python 來跑的。但是 TorchDynamo 又來了:CPython 到底跑什么 bytecode 是可以被它換掉的!
因此它可以做到這么一件事:當(dāng)用戶 call 一個(gè)被 capture 過的模型時(shí),模型里大部分 Python 代碼都相當(dāng)于不存在了,連 symbolic execution 的 overhead 都沒有,而被換成了編譯后的 native code。這一點(diǎn)在以前所有的 partial graph capture 的方案里是做不到的: ?
LazyTensor 即使編譯過的 graph 也要每次重新在 Python 里 trace 一遍,才能發(fā)現(xiàn)「哦,這個(gè) graph 我曾見過的」。
@torch.jit.script?、@tf.function、?@jax.jit?可以把裝飾的 python code 換成編譯后的,但是這都依賴用戶把這個(gè) subgraph refactor 出來放到一個(gè)單獨(dú)的函數(shù)里。而 TorchDynamo 是全自動(dòng)不需要用戶改代碼的。
這種 refactor 除了增加額外的工作量之外,還可能與用戶的代碼結(jié)構(gòu)沖突,因?yàn)?「用來編譯的graph的邊界」與「用戶代碼需要的抽象邊界」很可能不 match:例如用戶本來希望寫三個(gè)函數(shù),但是最佳的優(yōu)化是把其中兩個(gè)半函數(shù)變成一個(gè) graph,這會(huì)讓用戶很尷尬。
這只是一個(gè)最直接的例子。由于能夠讀寫 bytecode,理論上 TorchDynamo 能 access 更多 LazyTensor 根本沒有的信息,做更多事情(后面會(huì)提到)。而讀寫 bytecode 的難度比?source code要低不少,所以成為了一個(gè)可行的方案。
2whole-graph capture用處不大?
有的人可能會(huì)說,上面提到的東西對(duì) whole-graph capture 沒太大用啊。? 我覺得確實(shí)是這樣:TorchDynamo 是一個(gè)對(duì) partial-graph capture 追求極致的方案,能夠?qū)缀跛械?Python 實(shí)現(xiàn)的模型開箱即用有加速,不用改代碼——前提是還要跑 Python 作為 fallback。但是部署一般需要的是 whole-graph capture?整個(gè)模型在一個(gè) graph 里不能用 Python。
用 tracing 做 whole-graph capture 的前提是用戶要在 Python 代碼里避免所有不能被 trace 的東西,最常見的用戶要做的三件事是:使用?symbolic shape,使用 symbolic control flow,禁用除了當(dāng)前 tensor library之外的所有其它?library。如果用戶做到了這些,那只要一個(gè)普通的 symbolic tracing 就能 capture 到完整的 graph 了,不需要 TorchDynamo 這么復(fù)雜的機(jī)制。TorchDynamo 可能可以略微簡(jiǎn)化用戶做這些的工作量,但我感覺不會(huì)有本質(zhì)不同。
我個(gè)人的觀點(diǎn)是,從實(shí)用角度出發(fā),要求用戶做上面幾件事不算是太復(fù)雜的要求:禁用其他 library 理所應(yīng)當(dāng)就不說了;即使今天 PyTorch 還沒有很好的 symbolic {shape, control flow},但是只要用 @torch.jit.script_if_tracing 來處理少量的 symbolic shape 和 symbolic control flow,大多數(shù)模型都是可以正確的被 torch.jit.tracecapture 的。Meta 應(yīng)該有幾十上百個(gè) vision 模型實(shí)現(xiàn)在 detectron2/d2go 里,?目前基本都是走這條路部署的(我另有篇文章https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/介紹這里面的細(xì)節(jié))。
TensorFlow 的 whole-graph capture 就簡(jiǎn)單了:TF 從第一天就有很好的 symbolic shape 和 symbolic control flow,用就完了。tf.autograph 甚至還自動(dòng)化了一部分 control flow 的改寫工作。
所以,用戶少量改代碼仍然是必須的。當(dāng)然,TorchDynamo 畢竟有著"改變用戶要跑的 bytecode" 的超能力。所以如果愿意的話,理論上可以讓用戶的 whole-graph capture 工作變得更簡(jiǎn)單。例如: ?
模型中間的一些像 if x.shape[0] > 100 的分支,有的可以通過 shape inference 等價(jià)轉(zhuǎn)移到模型開頭的。這樣的話就可以 capture 到更大的沒有分支的 subgraph。?這件事在 TorchDynamo 里現(xiàn)在叫做 "guard"。?
理論上可以把 python control flow 自動(dòng)替換成 symbolic 的,類似tf.autograph 做的事情,只不過輸入是 bytecode 而不是 source code。? ?
目前 TorchDynamo 的 "nopython" 模式就是 whole-graph capture 了。不過似乎還不是工作重心 (以下引用自https://docs.google.com/document/d/1tlgPcR2YmC3PcQuYDPUORFmEaBPQEmo8dsh4eUjnlyI/edit#heading=h.rmxeybu31e0):
不過與此同時(shí),PyTorch 2.0 最近在完善 symbolic shape 的支持;functorch 里也加入了少量 control flow operator。這算是利好 whole-graph capture 的消息。
3總結(jié)
總的來說,由于?TorchDynamo 在 bytecode 層面做文章,能做到一些其他方案做不到的事情。它的優(yōu)點(diǎn)主要為 partial graph capture 服務(wù): 讓用戶的 Python 模型代碼在 0 修改的情況下就能 capture 并獲得加速。這體現(xiàn)了 PyTorch 對(duì)于 "Python first" 哲學(xué)的執(zhí)念。這種執(zhí)著是否有必要,見仁見智。
TorchDynamo 的主要優(yōu)勢(shì)來自對(duì) bytecode 的讀寫。JIT scripting compiler 的失敗表明在 source code level 做不了太多事,TorchDynamo 能在 bytecode level 做事情確實(shí)很巧妙。不過,要完整的復(fù)刻 CPython bytecode interpreter,它的工作量、維護(hù)難度(以及出 bug 的概率)都是不小的。
另外,TorchDynamo 對(duì) whole-graph capture 沒有很大的幫助。?對(duì)于復(fù)雜的模型,用戶該做的改寫還是得做。不過我估計(jì) 2.0 至少能對(duì)「用戶該做什么」有個(gè)清晰的說法。
當(dāng)然,最后 PT2 到底能不能把 compiler 做好,還有很多其他因素:IR 怎么設(shè)計(jì),何時(shí)specialize/recompile,各種 backend 不同的特性等等。比如 TorchDynamo 和 LazyTensor 使用的 IR 其實(shí)也不一樣。但是本文只討論 graph capture,其他問題就不提了。 (本文經(jīng)授權(quán)后發(fā)布。原文:https://www.zhihu.com/question/570220953/answer/2798657470) ?
其他人都在看
李白:你的模型權(quán)重很不錯(cuò),可惜被我沒收了
單RTX 3090訓(xùn)練YOLOv5s,時(shí)間減少11小時(shí)
OpenAI掌門Sam Altman:AI下一個(gè)發(fā)展階段
32篇年度最佳AI論文;Python編譯器Codon開源
對(duì)比四大深度學(xué)習(xí)框架,我發(fā)現(xiàn)都關(guān)注兩大問題
比快更快,開源Stable Diffusion刷新作圖速度
OneEmbedding:單卡訓(xùn)練TB級(jí)推薦模型不是夢(mèng)