淺入淺出 TPU (上集):為何 TPU 運算那麼便宜?
從 AllReduce 到 3D Torus 解析
訓練大型 AI model 時,單張晶片通常不夠,因為資料量跟模型參數太大,單顆 TPU memory 根本裝不下,因此要把多顆 TPU 組成 cluster 一起運算。
但 cluster 不是把很多機器放在一起而已。
當多台機器運行時,首先要解決的問題是:每台機器算完自己的部分後,結果要怎麼同步?
如果同步成本太高,晶片算得再快也沒用,因為大量時間會卡在 network communication。
而 TPU cluster 便宜的關鍵之一,是不單靠昂貴網路硬撐 bandwidth,而是透過 collective algorithm 與 network topology 設計,降低同步資料的成本。
那麼 TPU cluster 到底在同步什麼?
分散式運算的基本想法很簡單:把一坨資料切成很多小分區,分給不同 machine 運算。
例如 AI training 中,可把 input 切成多份 batch shard,讓不同 TPU 處理不同 batch shard,每台 machine 算完自己的部分後,要把結果用某種方式合併,否則每台 machine 只知道自己的局部結果,整個 cluster 沒有共同狀態。
而說到分散式計算的資料同步,最直覺會想到分散式資料計算的 MapReduce。
什麼是 MapReduce?AI 需要的是 MapReduce 嗎?
先用 SQL 建立直覺。
假設我們要算:
select uid, sum(amount) from orders group by uid;當 orders 量很大時,MapReduce 會先把資料 partition 分到不同 machine 做局部計算,此時每個 machine 會有不同 uid 的局部 sum,接著進入 shuffle 。
shuffle 的目的是讓相同 uid 的資料被送到同一台 reducer,最後每台 reducer 算出自己負責 key 的結果。
例如:
sum | uid | machine
10 | 1 | 1
20 | 2 | 2
30 | 3 | 3也就是說,MapReduce 中不同 machine (reducer) 只持有自己負責 key 的完整結果。
但 AI 不一樣,他需要 AllReduce,也就是 reduce 完後,還需要 broadcast 結果給大家讓每台 machine 都有所有 key 的聚合結果。
也就是:
sum | uid | machine
10,20,30 | 1,2,3 | 1
10,20,30 | 1,2,3 | 2
10,20,30 | 1,2,3 | 3因此可以先把 AllReduce 理解成:
Reduce + broadcast result to all machines
但 AI training 為什麼需要 All Reduce?
AI model 可以先簡化成:
y = w1*x1+w2*x2+...+wN*xN = wTxx 是 input, w 是模型權重, y 是模型輸出,訓練目標是調整 w,讓模型輸出的 y 越來越接近正確答案。
調整方式大致是:
1. 輸入 x,算出預測值 y
2. 用 loss function 算出預測值跟正確答案的差距
3. 透過微分計算每個 weight 對 loss 的影響程度,也就是 gradient
4. 用 gradient 更新 weight,讓下一次 loss 變小
note: 精準地說,gradient 是某個 weight 改變時,loss 變動的程度 (變大多少或變小多少)。
當訓練的 input data 很多,會把 input 切成多份 batch shard 分給不同 TPU,每台 TPU 都有相同的 AI model,但處理不同 batch shard,因此會算出不同的 local gradient vector。
例如三台 TPU 各自算出:
machine | w1_gradient | w2_gradient | w3_gradient
a | 1 | 2 | 3
b | 3 | 2 | 1
c | 5 | 3 | 1如果每台 TPU 直接用自己的 local gradient 更新 model,三台 TPU 的 model 很快就會變得不一樣。
所以更新 model 前,需要把所有 local gradient 聚合成一份 global gradient vector,並讓每台 TPU 都拿到相同結果。
類似
select weight_id, avg(gradient_value)
from gradient_vectors
group by weight_id;後 broadcast result to all machines。
也就是:
machine | w1_gradient_avg | w2_gradient_avg | w3_gradient_avg
a | 3 | 2.33 | 1.67
b | 3 | 2.33 | 1.67
c | 3 | 2.33 | 1.67隨後每台 machine 用相同的 gradient vector 更新各自的 local model 確保大家 model 一致。
但如果每輪 training 都要同步 gradient,最暴力的做法可行嗎?
暴力 All-to-All 為什麼太貴?
實現 All Reduce 最直覺作法是 all-to-all broadcast:
每台 machine 都把自己完整 gradient vector 傳給其他所有 machine,然後在把收到的 vector 加總或平均。
假設有 n 台 machine,每份 gradient vector 大小是 D。
那每台 machine 大約要接收 (n - 1) * D 資料量,如果 TPU 數量變大,n 會變大,如果 model 變大,gradient vector 的大小 D 也會變大。
因此 all-to-all 很快就會讓 network bandwidth 爆炸,成本變高,尤其每輪 training 都要做一次時,network communication 會變成主要瓶頸。
為了解決這個問題,需要換一種 AllReduce algorithm - Ring AllReduce。
什麼是 Ring AllReduce?
Ring AllReduce 是一種 collective communication algorithm。
collective communication 指的是多個 device 同時進行同一類通訊操作。
Ring AllReduce 會拆成兩步:ReduceScatter + AllGather
假設有三台 machine 在物理層面用網線排成一個 ring:
a -> b -> c -> a三台 machine 各自算出自己的 gradient vector:
w1 w2 w3
a: [1, 2, 3]
b: [3, 2, 1]
c: [5, 3, 1]ReduceScatter 會先把 gradient vector 切成 shard (e.g [w1, w2, w3] => [w1], [w2], [w3]),並定義每個 shard 最後由哪台 machine 持有 reduce 結果。
例如:
a 負責 w1
b 負責 w2
c 負責 w3接著所有 machine 根據固定的 network topology & shard 的 final reducer,決定要把哪個 shard 傳給自己的 neighbor,neighbor 收到後會把 shard 加上自身結果 (partial sum) 往下一個 neighbor 傳,具體邏輯:
a 先找出距離 n-1 的節點是 c,且 c 負責 w3 的 shard,那 a 會把他的 w3 結果傳給他的鄰居 (b)
b 收到後把 a 的 w3 加上自己的 w3 後傳給 c
最後 c 會收到 a+b 的 w3,加上自己的 w3 就能會得最終 w3 結果 (i.e 3+1+3=5)。
而 collective communication 代表 a & b & c 同時會透過上面邏輯傳 shard 給對應的鄰居 :
first communication round:
a - w3(a) -> b w3(a+b)
b - w1(b) -> c w1(b+c)
c - w2(c) -> a w3(c+a)
second communication round:
b - w3(a+b) -> c w3(a+b+c)
c - w1(b+c) -> a w1(b+c+a)
a - w2(c+a) -> b w2(c+a+b)
ReduceScatter 結束後,每台 machine 只持有自己負責的 reduced shard:
a: w1(b+c+a) = 1 + 3 + 5 = 9
b: w2(c+a+b) = 2 + 2 + 3 = 7
c: w3(b+c+a) = 3 + 1 + 1 = 5接著進入 AllGather,把每台 machine 持有的 reduced shard 再沿著 ring 傳給其他人:
first communication round:
a - w1 -> b
b - w2 -> c
c - w3 -> a
second communication round:
a - w3 -> b
b - w1 -> c
c - w2 -> a最後每台 machine 都拿到完整結果:
a: [9, 7, 5]
b: [9, 7, 5]
c: [9, 7, 5]Ring AllReduce 的好處是,每台 machine 每輪只傳 gradient vector 的一個 shard,而不是完整 vector。
對 n 台 machine、vector 大小 D 來說,Ring AllReduce 的 per-machine traffic 約:
2 * (n-1) * (D/N)
當 n 很大時,這大約接近 2D。
相比 all-to-all 的 (n - 1) * D,Ring AllReduce 大幅降低了每台 machine 要處理的資料量,但 Ring AllReduce 還是有缺點的:
若所有 machine 用 single global ring 串連,每個 machine 只有前後兩個 neighbor ,導致資料只能沿著固定 ring 傳給下一個 neighbor,machine 數量越多,communication rounds 越多。
換句話說,Ring AllReduce 降低了傳輸量,但如果 topology 只有一條長 ring,communication round 數還是會限制整體效率。
那麼 TPU 怎麼解決 communication round 數量問題?
不用 single global ring,TPU 採用的是 2D / 3D torus。
例如 3D torus 讓每個 machine 可有三個 (x, y, z) 不同方向的鄰居,總共6個鄰居,因此每次 communication round 可往多個 neighbor 傳遞訊息。
我們先用 2D torus (x, y) 為例,觀察當 neighbor 變多時,communication round 如何變少。
假設有四台 machine,每個 machine 有 (x, y) 兩個方向的鄰居:
a —— b (X軸向)
| |
c —— d (X軸向)
(Y軸向)note: 上面很像 2D mesh 但其實是 2D torus,差別在 torus 是一個 ring,也就是說上面的 a - b 其實是 a - b - a。
在 single global ring 中,每台 machine 每輪通常只能沿固定方向把資料推給下一個 neighbor。
但在 2D torus 中,machine 可以沿 x & y 兩個維度傳資料,假設 gradient vector 長度是四 且 ReduceScatter 過程 a 負責 w1, b 負責 w2,c 負責 w3, d 負責 w4。
那麼 a 要收集完整的 w1 需要:
first communication round (X軸並行)
b -> w1(b) -> a (a+b)
d -> w1(d) -> c (c+d)
second communication round (Y軸並行)
c -> w1(c+d) -> a (a+b+d+c)兩個 communication round ,a 就能收集完 w1,而 single global ring 需要三個 。
note: 上面只示範 w1 的路徑;實際 ReduceScatter 中,w2、w3、w4 也會同時沿其他 link 傳遞 partial sum。
而 3D torus 提供更多維度的鄰居,collective algorithm 可用更多 parallel path 傳遞資料,在大規模 cluster 中,這能有效減少 communication round。
因此 TPU cluster 變便宜的關鍵是不依靠昂貴的網路設備硬撐 bandwidth,而是透過:
Ring AllReduce 降低每台 machine 要傳的資料量
2D / 3D torus 提供更多維度鄰居,提高 parallel 傳輸量
但 topology 只回答 TPU chip 之間怎麼連,真正在 cluster 上運算還需要 compiler & runtime。
如何在 TPU Cluster 網路拓樸中運算資料?
這裡可用 Hive 類比,早期要用 MapReduce 寫分散式運算,要自己寫 mapper 和 reducer。MapReduce runtime 會把不同 partition 的資料分散到不同 machine,執行你的 mapper / reducer 邏輯。
但如果要寫複雜查詢,例如:
select uid, sum(amount)
from orders
join products on orders.pid = products.id
where products.name = 'XXX'
group by uid;直接手寫 mapper 和 reducer 會很麻煩,因此 Hive 提供一個更高層的介面:使用者寫 SQL,Hive 把 SQL compile 成 MapReduce stages。
TPU 訓練也有類似的抽象層,使用者在 TensorFlow 或 JAX 中撰寫的是高階 Tensor Computation,例如 Forward Pass、Loss Computation、Backward Pass 與 Optimizer Update,但 TPU 硬體本身只支援低階的 Tensor Operations,例如 Matrix Multiplication、Reduce、AllReduce、Load/Store Tensor 等。
因此需要 XLA 作為 Compiler,將高階 Tensor Computation Graph 轉換成 HLO(High Level Optimizer IR),並最佳化並產生 TPU Executable Program,與 Hive 類似,XLA 不只是做語法轉換,還會進行 Sharding、Operator Fusion、Device Placement、Collective Communication Planning 等最佳化。
例如,在分散式 TPU 訓練中,XLA 會根據 device mesh 和 TPU Network Topology 資訊安排 Collective Operations(ReduceScatter、AllGather)執行方式。
因此 TPU 的設計本質上是 Hardware、Network、Compiler 三者共同設計(Co-design)的系統,而不只是單純的加速器硬體。
下集預告:為何 TPU 運算那麼快?ICI 與 TPU Pod 內部資料解析
這篇先回答「為何 TPU 運算那麼便宜」。
重點不是單顆 TPU chip 神奇地省錢,而是 TPU cluster 透過 Ring AllReduce、2D / 3D torus topology,以及 XLA / runtime 的 collective schedule,降低每輪 training 的網路同步成本。
但這篇只講到 topology 和 computation runtime,下一個問題是:資料在 TPU Pod 裡交換時,要如何降低 Pod 收到資料的延遲?


