TokenMixer-Large: 工业推荐系统中大规模排序模型的扩展¶
1. 研究背景与动机¶
推荐系统中的深度学习排序模型(DLRMs)已经成为工业界的核心组件。受 LLM scaling laws 的启发,研究者开始探索 DLRMs 是否也能通过增加参数量和计算量获得持续的性能提升。早期工作如 Wukong、HiFormer、DHEN 尝试设计更精细的模型结构并进行扩展,但常常忽略硬件协同设计,导致次优性能。
当前工业界 SOTA 的排序模型结构是 TokenMixer——一种高度简化的 Transformer 变体,用轻量级 token mixing 操作替换了 self-attention。RankMixer [40] 将 TokenMixer 作为 backbone 并验证了其有效性,通过硬件感知协同设计显著提升了 Model FLOPs Utilization (MFU)。
然而,作者识别出 RankMixer(TokenMixer) 存在以下关键局限:
- 次优的残差设计(Sub-optimal Residual Design):RankMixer 的 mixing 操作要求新 token 数 $T'$ 与原始 token 数 $T$ 匹配,且在 add & norm 操作中直接将 mixing 前后的 token 相加,导致语义错位(semantic misalignment)
- 臃肿的模型架构(Impure Model Architecture):由于历史迭代,很多已部署模型保留了大量碎片化算子(如 LHUC、DCNv2),这些 memory-bound 算子拖低了整体 MFU
- 深层模型梯度更新不足(Insufficient Gradient Updates in Deep Models):TokenMixer 通常配置为浅层(如 2 层),缺乏针对深层网络的设计
- 不完善的 MoE 稀疏化(Inadequate MoE Sparsification):RankMixer 的 Sparse MoE 采用"Dense Train, Sparse Infer"范式,且 ReLU-MoE 存在激活动态性问题
- 有限的扩展探索(Limited Scaling Exploration):RankMixer 参数规模仅推到约 1B
2. 方法论¶
2.1 整体架构¶
TokenMixer-Large 框架由三部分组成:
- Tokenization:将高维稀疏 one-hot 特征转化为稠密 embedding,再对齐为维度一致的语义 token
- Token Mixing & Channel Mixing:采用"Mixing and Reverting"范式解决 RankMixer(TokenMixer) 中的维度不匹配问题
- Sparse-Pertoken MoE:升级版的 Pertoken-FFN/relu-MoE
最终使用 mean pooling 聚合输出 token 用于预测。
2.2 Tokenization¶
2.2.1 Semantic Group-wise Tokenizer¶
工业推荐系统包含用户特征、物品特征、序列特征(短期 DIN、长期 SIM、超长期 LONGER)和交叉特征等。第一步将所有特征投影到 embedding 空间:
$$e_i = \text{Embedding}(F_i, d_i) \in \mathbb{R}^{d_i} \tag{1}$$
由于 TokenMixer-Large 是高度并行架构,需要将变长 embedding 转换为维度对齐的 token。做法是按语义分组 $\{G_1, ..., G_{T-1}\}$,每组使用不同的 DNN 映射以保留异质性:
$$X_i = \text{MLP}_i(\text{concat}[e_l, ..., e_m]), \quad e_l, ..., e_m \in G_i \tag{2}$$
2.2.2 Global Token¶
引入一个 global token(类似 BERT 的 [CLS]),用于封装全局信息并传播:
$$X_G = \text{MLP}_g(\text{concat}[G_1, ..., G_{T-1}]) \tag{3}$$
最终输入:
$$\mathbf{X} = \text{concat}[X_G, X_0, ..., X_{T-1}] \in \mathbb{R}^{T \times D} \tag{4}$$
2.3 TokenMixer-Large Block¶
每个 block 包含三个组件:mixing module、pertoken SwiGLU、normalization。
2.3.1 Mixing & Reverting¶
RankMixer(TokenMixer) block 的公式为:
$$[\ldots, [\mathbf{x}_t^{(0)}, \ldots, \mathbf{x}_t^{(H)}], \ldots] = \text{split}(\mathbf{X}) \in \mathbb{R}^{T \times H \times (D/H)} \tag{5}$$
$$\mathbf{H}_h = \text{concat}[\mathbf{x}_1^{(h)}, \ldots, \mathbf{x}_T^{(h)}] \in \mathbb{R}^{T \times D/H} \tag{6}$$
$$\mathbf{H} = \text{concat}[\mathbf{H}_1, \ldots, \mathbf{H}_h] \in \mathbb{R}^{H \times (T \cdot D/H)} \tag{7}$$
$$\mathbf{H}^{\text{next}} = \text{Norm}(\text{pSwiGLU}(\mathbf{H}) + \mathbf{H}) \in \mathbb{R}^{H \times (T \cdot D/H)} \tag{8}$$
问题在于:除非每层的 $H$ 保持与 $T$ 相等,残差连接无法跨层顺畅传播。
TokenMixer-Large 设计了对称的两层结构——第一层负责 mixing,第二层负责 reverting:
$$[\ldots, [\mathbf{x}_t^{(1)}, \mathbf{x}_t^{(2)}, \ldots, \mathbf{x}_t^{(H)}], \ldots] = \text{split}(\mathbf{X}) \in \mathbb{R}^{T \times H \times (D/H)} \tag{9}$$
$$\mathbf{H}_h = \text{concat}[\mathbf{x}_1^{(h)}, \mathbf{x}_2^{(h)}, \ldots, \mathbf{x}_T^{(h)}] \in \mathbb{R}^{(T \cdot D/H)} \tag{10}$$
$$\mathbf{H} = \text{concat}[\mathbf{H}_1, \mathbf{H}_2, \ldots, \mathbf{H}_H] \in \mathbb{R}^{H \times (T \cdot D/H)} \tag{11}$$
$$\mathbf{H}^{\text{next}} = \text{Norm}(\text{pSwiGLU}(\mathbf{H}) + \mathbf{H}) \in \mathbb{R}^{H \times (T \cdot D/H)} \tag{12}$$
$$[\ldots, [\mathbf{x}_t'^{(h)}, \mathbf{x}_t'^{(h)}, \ldots, \mathbf{x}_T'^{(h)}], \ldots] = \text{split}(\mathbf{H}^{\text{next}}) \in \mathbb{R}^{T \times H \times (D/H)} \tag{13}$$
$$\mathbf{X}_t^{\text{revert}} = \text{concat}[\mathbf{x}_t'^{(1)}, \mathbf{x}_t'^{(2)}, \ldots, \mathbf{x}_t'^{(H)}] \in \mathbb{R}^D \tag{14}$$
$$\mathbf{X}^{\text{revert}} = \text{concat}[\mathbf{X}_1^{\text{revert}}, \mathbf{X}_2^{\text{revert}}, \ldots, \mathbf{X}_T^{\text{revert}}] \in \mathbb{R}^{T \times D} \tag{15}$$
$$\mathbf{X}^{\text{next}} = \text{Norm}(\text{pSwiGLU}(\mathbf{X}^{\text{revert}}) + \mathbf{X}) \in \mathbb{R}^{T \times D} \tag{16}$$
这种"mixing-reverting"设计确保输入输出维度一致,建立从初始输入到深层的连续信号通路,实现稳定的残差连接。
2.3.2 Pertoken SwiGLU¶
将 RankMixer 中的 Pertoken FFN 升级为 Pertoken SwiGLU:
$$\text{pSwiGLU}(\cdot) = \text{FC}_{\text{down}}(\text{Swish}(\text{FC}_{\text{gate}}(\cdot)) \odot \text{FC}_{\text{up}}(\cdot)) \tag{17}$$
其中:
$$\text{FC}_i(\mathbf{x}) = \mathbf{W}_i^t \mathbf{x}_t + \mathbf{b}_i^t, \quad i \in \{\text{up}, \text{gate}, \text{down}\} \tag{18}$$
$\{\mathbf{W}_{\text{up}}^t, \mathbf{W}_{\text{gate}}^t\} \in \mathbb{R}^{D \times nD}$,$\mathbf{W}_{\text{down}}^t \in \mathbb{R}^{nD \times D}$,$n$ 是超参数控制 hidden dimension。
2.3.3 Residual & Normalization¶
- 用 RMSNorm 替换 LayerNorm(去掉 mean-centering),减少计算开销
- 采用 Pre-Norm 替换原始的 Post-Norm,避免数值爆炸和梯度不稳定
2.3.4 Inter-Residual & Auxiliary Loss¶
除了标准残差连接,还引入 inter-residual connections(通常每隔 2-3 层添加一次),用于:
- 解决梯度消失问题
- 加速早期层参数收敛
- 缓解深层网络中梯度的逐渐衰减
注意:inter-residual 不推荐用于最后一层,因为最后一层的主要功能是提炼高层抽象特征,引入过多低层信息可能损害最终任务性能。
同时,结合低层和高层的 logits 计算联合 loss,形成轻量级 auxiliary loss,使低层学会"估计高层特征的偏差",增强特征表示能力。
2.4 Sparse-Pertoken MoE¶
2.4.1 First Enlarge, Then Sparse¶
采用"先扩大,再稀疏"的迭代策略:先放大模型获得效果,再拆分 pertoken SwiGLU 并稀疏激活获得效率。
Sparse-Pertoken MoE (S-P MoE) 公式:
$$\text{S-P MoE}(\cdot) = \sum_{j=1}^{k} g_j(\cdot) \cdot \text{Expert}_j(\cdot), \quad \textbf{if j is chosen} \tag{19}$$
$$\text{Expert}_j(\cdot) = \text{FC}_{\text{down}, j}(\text{Swish}(\text{FC}_{\text{gate}, j}(\cdot)) \odot \text{FC}_{\text{up}, j}(\cdot)) \tag{19}$$
其中 $\{\mathbf{W}_{\text{up}}^t, \mathbf{W}_{\text{gate}}^t\} \in \mathbb{R}^{D \times nD/E}$,$E$ 是 expert 数量,$k$ 是 router 的 top-k 数。
2.4.2 Shared Expert¶
每个 token 分配一个 shared expert(不是全局共享,而是 per-token 共享):
$$\text{S-P MoE}(\cdot) = \sum_{i=1}^{k-1} g_i(\cdot) \cdot \text{Expert}_i(\cdot) + \text{SharedExpert}_i(\cdot) \tag{20}$$
2.4.3 Gate Value Scaling¶
由于 softmax 使路由权重求和为 1,导致 SwiGLU 梯度更新不足。添加常数缩放因子 $\alpha$:
$$\text{S-P MoE}(\cdot) = \alpha \cdot \sum_{i=1}^{k-1} g_i(\cdot) \cdot \text{Expert}_i(\cdot) + \text{SharedExpert}_i(\cdot) \tag{21}$$
$\alpha$ 与稀疏比例反比关系:1:2 稀疏时 $\alpha=2$ 最优,1:4 稀疏时 $\alpha=4$ 最优。
2.4.4 Down-Matrix Small Init¶
受 Rezero 启发,降低 SwiGLU 中 $\mathbf{W}_{\text{down}}$ 的初始化方差(xavier_uniform 的标准差超参从默认 1 降到 0.01),使 $F(\mathbf{x}) + \mathbf{x}$ 在训练初期近似恒等映射,提升深层模型训练稳定性。
2.5 训练/服务优化¶
2.5.1 高性能自定义算子¶
- MoEPermute:将输入从 batch-first 重排为 expert-first
- MoEGroupedSwiGLU / MoEGroupedGemm:用单个 kernel 完成所有 expert FFN 计算
- MoEUnpermute:计算激活 expert 的加权求和
Table 1 显示 MoEGroupedFFN 在训练中占 89.18% 时间(Compute Bound),在推理中占 98.35%(Memory Bound)。
2.5.2 FP8 量化¶
采用 FP8 E4M3 post-training 量化,推理中保持 bfloat16 精度训练。FP8 serving 提供 1.7x 加速且不损失模型精度。
2.5.3 Token Parallel¶
为解决分布式多设备环境下的通信瓶颈,提出 Token Parallel——按 token 维度分片。相比朴素的模型并行(每步需要 all2all 切换 sharding layout),Token Parallel 将通信开销从 $4L$ 降到 $2L+1$。
实验结果:4-way token parallelism(global batch size 320)实现 29.2% 吞吐量提升,叠加通信-计算 overlap 后提升至 96.6%。
3. 实验¶
3.1 实验设置¶
数据集:字节跳动抖音主 feed 的电商场景真实数据集,包含 500+ 特征(数值型、ID型、交叉型、序列型),覆盖数亿用户,经采样后约 4 亿条/天,跨两年收集。离线实验还包含抖音广告和抖音直播数据,日量分别达 3 亿和 170 亿条。
评估指标:AUC(User-Level AUC)和 UAUC 作为主要指标;dense parameter count、FLOPs、MFU 作为效率指标。
训练环境:64 GPU 混合分布式训练(电商 256 GPU用于 Feed-Ads 和 Live-Streaming),稀疏参数异步更新,稠密参数同步更新,Adagrad 优化器(dense lr=0.01, sparse lr=0.05)。
Baselines:DLRM-MLP、DCNv2、DHEN、AutoInt、HiFormer、Wukong、Group Transformer、FAT、RankMixer(TokenMixer)。
3.2 与 SOTA 方法对比(Table 2)¶
~500M 参数模型在电商场景的性能与效率对比:
| Model | CTCVR ΔAUC↑ | Params | FLOPs/Batch |
|---|---|---|---|
| DLRM-MLP-500M | – | 499 M | 125.1 T |
| HiFormer | +0.44% | 570 M | 28.8 T |
| DCNv2 | +0.49% | 502 M | 125.8 T |
| DHEN | +0.63% | 415 M | 103.4 T |
| AutoInt | +0.75% | 549 M | 138.6 T |
| Wukong | +0.76% | 513 M | 4.6 T |
| Group Transformer | +0.81% | 550 M | 4.5 T |
| FAT | +0.82% | 551 M | 4.59 T |
| RankMixer(TokenMixer) | +0.84% | 567 M | 4.6 T |
| TokenMixer-Large 500M | +0.94% | 501 M | 4.2 T |
| TokenMixer-Large 4B | +1.14% | 4.6 B | 29.8 T |
| TokenMixer-Large 7B | +1.20% | 7.6 B | 49.0 T |
| TokenMixer-Large 4B SP-MoE | +1.14% | 2.3 B in 4.6 B | 15.1 T |
关键结论:
- TokenMixer-Large 在所有 SOTA 模型中表现最优,500M 规模下 CTCVR AUC 比 MLP-500M baseline 提升 +1.14%
- Wukong、Group Transformer、RankMixer 等采用 pertoken 网络设计的模型 FLOPs 显著更低(batch size 2048)
- Sparse-Pertoken MoE(4B SP-MoE)仅激活一半参数即可匹配 dense 4B 模型性能,大幅提升训练/推理 ROI
3.3 与 RankMixer(TokenMixer) 的详细对比(Table 3)¶
定义三个设计维度:
- Standard Residual (SR):block 间是否存在标准残差连接
- Original Token Residual (OTR):原始 token 的语义信息是否在残差 $F(x') + x$ 中得到保留和传播
- Token Semantic Alignment in Residual (TSA):残差操作前后 token 语义是否一致
| Model Version | SR | OTR | TSA | AUC↑ | Params | FLOPs |
|---|---|---|---|---|---|---|
| Group Transformer | ✓ | ✓ | ✓ | – | 500M | 4.5T |
| RankMixer w/o SR&OTR | ✗ | ✗ | ✗ | -0.20% | 510M | 4.2T |
| RankMixer w/o OTR | ✓ | ✗ | ✗ | -0.13% | 510M | 4.2T |
| RankMixer | ✓ | ✓ | ✗ | +0.03% | 567M | 4.6T |
| TokenMixer-Large | ✓ | ✓ | ✓ | +0.13% | 500M | 4.2T |
TokenMixer-Large 满足所有三个设计属性,取得最佳性能。
3.4 Scaling Laws¶
3.4.1 SOTA 模型的 Scaling Laws(Figure 4 & 5)¶
在 Feed Ads 15B、E-Commerce 7B、Live Streaming 4B 三个场景中验证了 dense scaling law,AUC 与参数量/FLOPs 的对数呈线性关系。TokenMixer-Large 展现出比 RankMixer 更陡的改进斜率。
3.4.2 不同场景的 Scaling Laws¶
通过增加维度 $D$、深度 $L$ 和 SwiGLU hidden expansion factor $N$ 进行扩展。两个关键发现:
- 均衡扩展各维度效果更好:单独扩展 width/depth/scaling factor 会逐渐遇到瓶颈,超过 1B 后需要均衡增长
- 更大模型需要更多数据才能收敛:30M→90M 仅需 14 天数据,500M→2B 需要 60 天
Table 4: 抖音直播模型收敛对比
| Param* | Convergence Day | ΔUAUC |
|---|---|---|
| 30m | – | – |
| 90m↑ | 14d | +0.94% |
| 500m↑ | 30d | +0.62% |
| 2.3B↑ | 30d | +0.41% |
| 2.3B↑(60d) | 60d | +0.70% |
*每行以前一行为 baseline
最终在 Feed Ads、E-Commerce、Live Streaming 场景分别扩展至 15B、7B、4B(离线),线上分别为 7B、4B、2B。
3.5 消融实验¶
3.5.1 TokenMixer-Large Block 消融(Table 5)¶
| Setting | ΔAUC |
|---|---|
| w/o Global Token | -0.02% |
| w/o Mixing & Reverting | -0.27% |
| w/o Residual | -0.15% |
| w/o Internal Residual & AuxLoss | -0.04% |
| Pertoken SwiGLU → SwiGLU | -0.21% |
| Pertoken SwiGLU → Pertoken FFN | -0.10% |
结论:Mixing & Reverting 和 Pertoken SwiGLU 对整体性能影响最大。
3.5.2 Sparse-Pertoken MoE 消融(Table 6)¶
| Setting | ΔAUC | ΔParams | ΔFLOPs |
|---|---|---|---|
| w/o Shared Expert | -0.02% | 0.0% | 0.0% |
| w/o Gate Value Scaling | -0.03% | 0.0% | 0.0% |
| w/o Down-Matrix Small Init | -0.03% | 0.0% | 0.0% |
| Sparse-Pertoken MoE → Sparse MoE | -0.10% | 0.0% | 0.0% |
结论:所有组件都正向贡献,且不引入额外参数或 FLOPs。Sparse-Pertoken MoE 对比标准 Sparse MoE 有显著优势(-0.10%),因为 per-token 的 expert 不共享,相当于给标准 MoE 加了路由先验,避免了 router 早期学习困难。
Pure Model Design¶
随着 TokenMixer-Large 参数量增加,碎片化算子(DCN、LHUC 等)的增益被模型自身吸收(Table 9):
| Params | DCN Gain |
|---|---|
| 150M | +0.09% |
| 500M | +0.04% |
| 700M | +0.00% |
Normalization(Table 10)¶
| Setting | ΔAUC |
|---|---|
| Pre-Norm | – |
| Post-Norm | +0.01% → NaN |
| Sandwich-Norm | -0.03% |
Pre-Norm 虽效果略弱于 Post-Norm,但能确保稳定训练。
Mixing Strategy(Table 11)¶
| Setting | ΔAUC |
|---|---|
| Vertical division | – |
| Diagonal division | +0.00% |
| Random division | +0.00% |
| Vertical division (half of raw tokens) | -0.08% |
不同的 split-concat 策略不影响性能,关键是每个新 mixed token 包含所有原始 token 信息。
Model Sparsity 与 α 设置(Table 12)¶
| Version | Sparsity | α Setting | AUC Performance |
|---|---|---|---|
| 4B dense | – | – | – |
| 2B in 4B MoE | 1:2 | 1 | -0.04% |
| 2 | -0.00% | ||
| 3 | -0.02% | ||
| 4 | -0.05% | ||
| 1B in 4B MoE | 1:4 | 1 | -0.07% |
| 2 | -0.05% | ||
| 4 | -0.03% | ||
| 6 | -0.04% | ||
| 8 | -0.05% |
$\alpha$ 与稀疏比例的倒数成正比是最优选择。
Down-Matrix Small Init(Table 14)¶
| Version | Stddev Init Value [FC_up, FC_gate, FC_down] | ΔAUC |
|---|---|---|
| Base | [1, 1, 1] | – |
| SmallInit-001 | [1, 1, 0.01] | +0.03% |
| SmallInit-01 | [1, 1, 0.1] | +0.02% |
| SmallInit-001-All | [0.01, 0.01, 0.01] | -0.10% |
| SmallInit-001-Reverse | [0.01, 0.01, 1] | -0.01% |
仅对 $\text{FC}_{\text{down}}$ 应用 small init 效果最佳。
3.6 在线性能(Table 7)¶
| Feed Ads | E-Commerce | Live Streaming | |||||
|---|---|---|---|---|---|---|---|
| ΔAUC | ADSS | ΔAUC | Order | GMV | ΔUAUC | Pay | |
| Lift↑ | 0.35% | 2.0% | 0.51% | 1.66% | 2.98% | 0.7% | 1.4% |
- 抖音直播电商:订单量增长 +1.66%,人均预览付款 GMV 提升 +2.98%
- 抖音信息流广告:ADSS 提升 +2.0%
- 抖音直播(非电商):收入增长 +1.4%
在线 baselines 分别为 RankMixer-1B(广告)、RankMixer(电商)、RankMixer-500M(直播),扩展至 TokenMixer-Large 7B、4B、2B。
4. 总结¶
TokenMixer-Large 是 TokenMixer 的系统性升级,通过以下核心创新解决了原有设计在深层扩展中的瓶颈:
- Mixing & Reverting 操作确保跨层维度一致和残差信号的稳定传播
- Inter-Residual 和 Auxiliary Loss 解决深层网络梯度不足问题
- Sparse-Pertoken MoE 实现"Sparse Train, Sparse Infer",配合 Gate Value Scaling 和 Down-Matrix Small Init
- Pure Model 设计哲学:去除碎片化算子,仅保留 parameterless mixing/reverting 和 GroupedGemm,实现 60% MFU
- Token Parallel 和 FP8 量化等工程优化
成功将模型扩展到 15B 参数(离线),在字节跳动多个在线场景服务数亿用户并取得显著业务收益。