The Annotated Multi-Task Ranker: An MMoE Code Example
Natural Language Processing (NLP) has an abundance of intuitively explained tutorials with code, such as Andrej Kaparthy’s Neural Networks: Zero to Hero, the viral The Illustrated Transformer and its successor The Annotated Transformer, Umar Jamil’s YouTube series dissecting SOTA models and the companion repo, among others.
When it comes to Search/Ads/Recommendations (“搜广推”), however, intuitive explanations accompanied by code are rare. Company engineering blogs tend to focus on high-level system designs, and many top conference (e.g., KDD/RecSys/SIGIR) papers don’t share code. In this post, I explain the iconic Multi-gate Mixture-of-Experts (MMoE) paper (Ma et al., 2018) using implementation in the popular DeepCTR-Torch repo, to teach myself and readers how the authors' blueprint translates into code.
The Paper (Ma et al., 2018)
A huge appeal of deep learning is its ability to optimize for multiple task objectives at once, such as clicks and conversions in search/ads/feed ranking. In traditional machine learning, you would have to build multiple models, one per task, making the system hard to maintain for needing separate data/training/serving pipelines, and missing out on the opportunity for transfer learning between tasks.
In early designs, all tasks shared the same backbone that feeds into task-specific towers (e.g., Caruana, 1993, 1998). The Shared-Bottom architecture is simple and intuitive — the drawback is that low task correlation can hurt model performance.
As a solution to the above issue, we can replace the shared bottom with a group of bottom networks (“experts”), which explicitly learn relationships between tasks and how each task uses the shared representations. This is the Mixture-of-Experts (MoE) architecture (e.g., Jacobs et al., 1991, Eigen et al., 2013, Shazeer et al., 2017).
In the original Mixture of Experts (MoE) model, a “gating network” assembles expert outputs by learning each expert’s weight from input features (weights sum to 1) and returning the weighted sum of expert outputs as the output to the next layer:
$$y = \sum_{i=1}^n g(x)_i f_i(x)$$
, where $g(x)_i$ is the weight of the $i$th expert, and $f_i$ is the output from that expert.
The Multi-gate Mixture-of-Experts (MMoE) model has as many gating networks as there are tasks. Each gate learns a specific way to leverage expert outputs for its respective task. In contrast, a One-gate Mixture-of-Experts (OMoE) model uses a single gating network to find a best way to leverage expert outputs across all tasks.
As task correlation decreases, the MMoE architecture has a larger advantage over OMoE. Both MoE models outperform Shared-Bottom, regardless of task correlation. In today’s web-scale ranking systems, MMoE is by far the most widely adopted.
The Data (ByteRec)
Large-scale benchmark data played a pivotal role in the resurgence of deep learning. A prominent example is the ImageNet dataset with 14 million images from 20,000 categories, on which the CNN-based AlexNet achieved groundbreaking accuracy, outperforming non-DL models by a gigantic margin. Unlike in computer vision, ranking benchmarks are often set by companies famous for recommendation systems, such as Netflix (Netflix Prize) and ByteDance (Short Video Understanding Challenge).
The ByteDance data (henceforth “ByteRec”) is particularly suitable for multi-task learning, since there are 2 desired user behaviors to predict — finish and share.
ByteRec only has a couple of features, a simplification from real search/feed logs:
- Dense features: Only
duration_time
(video watch time in seconds) - Sparse features: Categorical features such as ID (
uid
,item_id
,author_id
,music_id
,channel
,device
) and locations (user_city
,item_city
) - Targets: Whether (
1
) or not (0
) a user didfinish
orshare
a video
The Code (DeepCTR-Torch)
For learning deep learning ranking model architectures, I find DeepCTR-Torch (TensorFlow version: DeepCTR) highly educational. The repo covers SOTA models spanning the last decade (e.g., Deep & Cross, DCN, DCN v2, DIN, DIEN, PNN, MMoE, etc.), even though it may not have the full functionalities needed by production-grade rankers (e.g., hash encoding for ID features). There is a doc accompanying the code.
Below, I’ll explain the MMoE architecture and how it was used in the example training script the author provided. As with The Annotated Transformer, this post aims not to create an original implementation, but to provide a line-by-line explanation of an existing one. Please find my commented code in this Colab notebook.
Load Data
For testing, we can use a toy sample with 200 rows rather than the full data. At work, I also like testing models on small datasets, so I can fail and debug fast.
1import pandas as pd
2import torch
3import torch.nn as nn
4from sklearn.metrics import log_loss, roc_auc_score
5from sklearn.preprocessing import LabelEncoder, MinMaxScaler
6
7from deepctr_torch.models.basemodel import BaseModel
8from deepctr_torch.layers import DNN, PredictionLayer
9from deepctr_torch.inputs import (
10 SparseFeat,
11 DenseFeat,
12 get_feature_names,
13 combined_dnn_input,
14)
15
16# sample data with 200 rows
17url = "https://raw.githubusercontent.com/shenweichen/DeepCTR-Torch/master/examples/byterec_sample.txt"
18
19data = pd.read_csv(
20 url,
21 sep="\t",
22 names=[
23 "uid",
24 "user_city",
25 "item_id",
26 "author_id",
27 "item_city",
28 "channel",
29 "finish",
30 "like",
31 "music_id",
32 "device",
33 "time",
34 "duration_time",
35 ],
36)
Transform Features
While deep learning models are known for automated feature engineering, we still need to encode sparse categorical features and scale dense numerical features to a limited range (e.g., [0, 1]) before feeding data to the input layer.
Specify Feature + Target Columns
For easy references, we first specify different types of feature + target columns:
1sparse_features = [
2 "uid", # watcher's user id
3 "user_city", # watcher's city
4 "item_id", # video's id
5 "author_id", # author's user id
6 "item_city", # video's city
7 "channel", # author's channel
8 "music_id", # soundtrack id if the video contains music
9 "device", # user's device
10]
11
12dense_features = ["duration_time"]
13target = ["finish", "like"]
Encode Sparse Features
For each sparse feature, we can instantiate a LabelEncoder
to assign an integer from 0
to (n - 1)
to each of the n
unique feature values.
1for feat in sparse_features:
2 lbe = LabelEncoder()
3 data[feat] = lbe.fit_transform(data[feat])
This method faces the out-of-vocabulary (OOV) problem: At inference time, if a feature has a value not seen during training such as a new author ID, the model might assign a low score if unknown authors typically have low engagement in historical data, leading the model to downrank new author content in the future. One solution is hash encoding: different OOV feature values are randomly assigned to different buckets, temporarily “borrowing” the model’s learning for that bucket to avoid systematic biases. Hash encoding is not implemented in DeepCTR-Torch.
Scale Dense Features
For each dense feature, we can use a MinMaxScaler
to cap its range to [0, 1]
.
1mms = MinMaxScaler(feature_range=(0, 1))
2data[dense_features] = mms.fit_transform(data[dense_features])
If exact values don’t matter, we can discretize dense features into buckets (e.g., age $\rightarrow$ age groups) and process them as categorical (e.g., one-hot encoding).
Create Training Data
Different deep learning libraries expect input data to be in different formats. DeepCTR-Torch uses named tuples such as SparseFeat
and DenseFeat
to store feature metadata such as names, data types, dimensions, etc.. See definitions here.
1# columns with sparse features
2sparse_feature_columns = [
3 SparseFeat(feat, vocabulary_size=data[feat].max() + 1, embedding_dim=4)
4 for feat in sparse_features
5]
6
7# columns with dense features
8dense_feature_columns = [
9 DenseFeat(
10 feat,
11 1,
12 )
13 for feat in dense_features
14]
15
16# columns with all features
17fixlen_feature_columns = sparse_feature_columns + dense_feature_columns
18dnn_feature_columns = fixlen_feature_columns
19linear_feature_columns = fixlen_feature_columns
After the train-test split, we store the two datasets into two dictionaries, where the keys are feature names and the values are lists of feature values.
1# split data by rows
2split_boundary = int(data.shape[0] * 0.8)
3train, test = data[:split_boundary], data[split_boundary:]
4
5# get feature names
6feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)
7
8# prepare input dicts: {feature name : feature list}
9train_model_input = {name: train[name] for name in feature_names}
10test_model_input = {name: test[name] for name in feature_names}
Model Training
Set Up Environment
We use GPU for training whenever available; otherwise, we use CPU instead:
1device = "cpu"
2use_cuda = True
3if use_cuda and torch.cuda.is_available():
4 print("cuda ready...")
5 device = "cuda:0"
Instantiate the Model
I will explain the inner working of the MMOE
model class in the next section.
To instantiate a new model, we need to provide a feature column list for the “deep” part of the network (our version of MMoE doesn’t have a “wide” part), a list of task types (binary
or regression
), L2 regularization strength, and the names of the tasks (e.g., ["finish", "share"]
), and so on. In the DeepCTR-Torch API, we need to run model.compile()
to specify the optimizer, the loss function for each task, and metrics to monitor during training before we can train the model.
1# instantiate a MMoE model
2model = MMOE(
3 dnn_feature_columns,
4 task_types=["binary", "binary"],
5 l2_reg_embedding=1e-5,
6 task_names=target,
7 device=device,
8)
9
10# specify optimizer, loss functions for each task, and metrics
11model.compile(
12 "adagrad",
13 loss=["binary_crossentropy", "binary_crossentropy"],
14 metrics=["binary_crossentropy"],
15)
Train the Model
The model training and inference API adopts the common fit
and predict
methods. model.fit
takes a dictionary as features ({feature name: feature list}
) and an n
-dimensional array as targets (n
being the number of tasks), along with training setup details (e.g., batch size, number of epochs, output verbosity, etc.). model.predict
takes a similar dictionary as features and the inference batch size, and it outputs predictions for each task (dimension: [batch_size, num_tasks]
).
1# fit model to training data
2history = model.fit(
3 train_model_input, train[target].values, batch_size=32, epochs=10, verbose=2
4)
5# generate predictions for test data
6pred_ans = model.predict(test_model_input, 256) # inference batch size: 256
7print("")
8for i, target_name in enumerate(target):
9 log_loss_value = round(log_loss(test[target[i]].values, pred_ans[:, i]), 4)
10 auc_value = round(roc_auc_score(test[target[i]].values, pred_ans[:, i]), 4)
11
12 print(f"{target_name} test LogLoss: {log_loss_value}")
13 print(f"{target_name} test AUC: {auc_value}")
The Anatomy of the MMOE
Class
The “meat” of this post is the class definition of the MMOE
model. You can find the source code here. The snippet below includes comments added by me. The model design is a faithful translation of the MMoE architecture (doodles also by me):
- Input: For each training instance, embeddings are flattened and concatenated with dense features into a single vector, which feeds into gates and experts.
- Expert outputs: Each expert gets the same input and generates an vector output.
- Gating outputs: Each gate gets the same input and generates a scalar weight for each expert; for each gate, the weights of all experts sum up to 1.
- MMoE outputs: For each task, we use the corresponding gate network to output a weighted sum of expert outputs, which is the input to the task-specific tower.
- Task outputs: Each tower gets a task-specific input and generates an output.
Now, let’s digest the model code bit by bit. You can read it in its entirety first:
1class MMOE(BaseModel):
2 """Instantiates the multi-gate mixture-of-experts architecture.
3
4 :param dnn_feature_columns: an iterable containing all the features used by deep part of the model.
5 :param num_experts: integer, number of experts.
6 :param expert_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of expert dnn.
7 :param gate_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of gate dnn.
8 :param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific dnn.
9 :param l2_reg_linear: float, l2 regularizer strength applied to linear part.
10 :param l2_reg_embedding: float, l2 regularizer strength applied to embedding vector.
11 :param l2_reg_dnn: float, l2 regularizer strength applied to dnn.
12 :param init_std: float, to use as the initialize std of embedding vector.
13 :param seed: integer, to use as random seed.
14 :param dnn_dropout: float in [0,1), the probability we will drop out a given dnn coordinate.
15 :param dnn_activation: activation function to use in dnn.
16 :param dnn_use_bn: bool, whether use batchnormalization before activation or not in dnn.
17 :param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss, ``"regression"`` for regression loss. e.g. ['binary', 'regression'].
18 :param task_names: list of str, indicating the predict target of each tasks.
19 :param device: str, ``"cpu"`` or ``"cuda:0"``.
20 :param gpus: list of int or torch.device for multiple gpus. if none, run on `device`. `gpus[0]` should be the same gpu with `device`.
21
22 :return: a pytorch model instance.
23 """
24
25 def __init__(
26 self,
27 dnn_feature_columns, # a list feature used by the deep part
28 num_experts=3, # number of experts
29 expert_dnn_hidden_units=(256, 128), # each expert dnn has 2 layers
30 gate_dnn_hidden_units=(64,), # each gate dnn has 1 layer
31 tower_dnn_hidden_units=(64,), # each tower dnn has 1 layer
32 l2_reg_linear=0.00001, # l2 regularizer strength for linear part
33 l2_reg_embedding=0.00001, # l2 regularizer strength for emb part
34 l2_reg_dnn=0, # l2 regularizer strength for DNN part
35 init_std=0.0001,
36 seed=1024,
37 dnn_dropout=0,
38 dnn_activation="relu",
39 dnn_use_bn=False, # whether to use batch norm
40 task_types=("binary", "binary"),
41 task_names=("ctr", "ctcvr"),
42 device="cpu",
43 gpus=None,
44 ):
45 super(MMOE, self).__init__(
46 linear_feature_columns=[],
47 dnn_feature_columns=dnn_feature_columns,
48 l2_reg_linear=l2_reg_linear,
49 l2_reg_embedding=l2_reg_embedding,
50 init_std=init_std,
51 seed=seed,
52 device=device,
53 gpus=gpus,
54 )
55 self.num_tasks = len(task_names) # infer task count from task names
56
57 # performs input validations
58 if self.num_tasks <= 1:
59 raise ValueError(
60 "num_tasks must be greater than 1"
61 ) # multi-task model must have multiple tasks
62 if num_experts <= 1:
63 raise ValueError(
64 "num_experts must be greater than 1"
65 ) # multi-expert model must have multiple experts
66 if len(dnn_feature_columns) == 0:
67 raise ValueError(
68 "dnn_feature_columns is null!"
69 ) # the deep part must have features
70 if len(task_types) != self.num_tasks:
71 raise ValueError(
72 "num_tasks must be equal to the length of task_types"
73 ) # make sure we specify a type for each task
74 for task_type in task_types:
75 if task_type not in ["binary", "regression"]:
76 raise ValueError(
77 f"task must be binary or regression, {task_type} is illegal"
78 ) # make sure task type is valid
79
80 self.num_experts = num_experts
81 self.task_names = task_names
82 self.input_dim = self.compute_input_dim(dnn_feature_columns)
83 self.expert_dnn_hidden_units = expert_dnn_hidden_units
84 self.gate_dnn_hidden_units = gate_dnn_hidden_units
85 self.tower_dnn_hidden_units = tower_dnn_hidden_units
86
87 # expert dnn: each element is an expert network
88 self.expert_dnn = nn.ModuleList(
89 [
90 DNN(
91 self.input_dim,
92 expert_dnn_hidden_units,
93 activation=dnn_activation,
94 l2_reg=l2_reg_dnn,
95 dropout_rate=dnn_dropout,
96 use_bn=dnn_use_bn,
97 init_std=init_std,
98 device=device,
99 )
100 for _ in range(self.num_experts)
101 ]
102 )
103
104 # gate dnn: each element is a gate for a task
105 if len(gate_dnn_hidden_units) > 0:
106 self.gate_dnn = nn.ModuleList(
107 [
108 DNN(
109 self.input_dim,
110 gate_dnn_hidden_units,
111 activation=dnn_activation,
112 l2_reg=l2_reg_dnn,
113 dropout_rate=dnn_dropout,
114 use_bn=dnn_use_bn,
115 init_std=init_std,
116 device=device,
117 )
118 for _ in range(self.num_tasks)
119 ]
120 )
121 # select weights to regularize
122 self.add_regularization_weight(
123 filter(
124 lambda x: "weight" in x[0] and "bn" not in x[0],
125 self.gate_dnn.named_parameters(),
126 ),
127 l2=l2_reg_dnn,
128 )
129 # a list of linear layers, one for each task
130 self.gate_dnn_final_layer = nn.ModuleList(
131 [
132 nn.Linear(
133 gate_dnn_hidden_units[-1]
134 if len(gate_dnn_hidden_units) > 0
135 else self.input_dim,
136 self.num_experts,
137 bias=False,
138 )
139 for _ in range(self.num_tasks)
140 ]
141 )
142
143 # tower dnn: each element is a tower for each task
144 if len(tower_dnn_hidden_units) > 0:
145 self.tower_dnn = nn.ModuleList(
146 [
147 DNN(
148 expert_dnn_hidden_units[-1],
149 tower_dnn_hidden_units,
150 activation=dnn_activation,
151 l2_reg=l2_reg_dnn,
152 dropout_rate=dnn_dropout,
153 use_bn=dnn_use_bn,
154 init_std=init_std,
155 device=device,
156 )
157 for _ in range(self.num_tasks)
158 ]
159 )
160 # select weights to regularize
161 self.add_regularization_weight(
162 filter(
163 lambda x: "weight" in x[0] and "bn" not in x[0],
164 self.tower_dnn.named_parameters(),
165 ),
166 l2=l2_reg_dnn,
167 )
168 # a list of linear layers, one for each task
169 self.tower_dnn_final_layer = nn.ModuleList(
170 [
171 nn.Linear(
172 tower_dnn_hidden_units[-1]
173 if len(tower_dnn_hidden_units) > 0
174 else expert_dnn_hidden_units[-1],
175 1,
176 bias=False,
177 )
178 for _ in range(self.num_tasks)
179 ]
180 )
181 # each task type has an output
182 self.out = nn.ModuleList([PredictionLayer(task) for task in task_types])
183
184 # add final parameters to be regularized
185 regularization_modules = [
186 self.expert_dnn,
187 self.gate_dnn_final_layer,
188 self.tower_dnn_final_layer,
189 ]
190 for module in regularization_modules:
191 self.add_regularization_weight(
192 filter(
193 lambda x: "weight" in x[0] and "bn" not in x[0],
194 module.named_parameters(),
195 ),
196 l2=l2_reg_dnn,
197 )
198 self.to(device)
199
200 def forward(self, X):
201 # list of embedding and dense feature values
202 sparse_embedding_list, dense_value_list = self.input_from_feature_columns(
203 X, self.dnn_feature_columns, self.embedding_dict
204 )
205 # concat features into a single vector for each instance
206 dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)
207
208 # expert dnn: collect output from each expert
209 expert_outs = []
210 for i in range(self.num_experts):
211 expert_out = self.expert_dnn[i](dnn_input)
212 expert_outs.append(expert_out)
213 expert_outs = torch.stack(expert_outs, 1) # (bs, num_experts, dim)
214
215 # gate dnn: each gate has a way to combine expert outputs
216 mmoe_outs = []
217 for i in range(self.num_tasks):
218 if (
219 len(self.gate_dnn_hidden_units) > 0
220 ): # input => gate dnn => final gate layer
221 gate_dnn_out = self.gate_dnn[i](dnn_input)
222 gate_dnn_out = self.gate_dnn_final_layer[i](gate_dnn_out)
223 else: # input => final gate layer
224 gate_dnn_out = self.gate_dnn_final_layer[i](dnn_input)
225 # performs matrix multiplication between post-softmax gate dnn output and expert outputs
226 gate_mul_expert = torch.matmul(
227 gate_dnn_out.softmax(1).unsqueeze(1), expert_outs
228 ) # (bs, 1, dim)
229 mmoe_outs.append(gate_mul_expert.squeeze())
230
231 # tower dnn: each tower generates output for a specific task
232 task_outs = []
233 for i in range(self.num_tasks):
234 if (
235 len(self.tower_dnn_hidden_units) > 0
236 ): # input => tower dnn => final tower layer
237 tower_dnn_out = self.tower_dnn[i](mmoe_outs[i])
238 tower_dnn_logit = self.tower_dnn_final_layer[i](tower_dnn_out)
239 else: # input => final tower layer
240 tower_dnn_logit = self.tower_dnn_final_layer[i](mmoe_outs[i])
241 output = self.out[i](tower_dnn_logit)
242 task_outs.append(output)
243 task_outs = torch.cat(task_outs, -1) # output dimension: (bs, num_tasks)
244 return task_outs
When coding up deep learning models, I think of the constructor (def __init__
) as the LEGO blocks: You specify what the pieces are and the properties of each piece, so that later you can use them at your disposal. The forward
pass is the building process — you put the pieces together in a specific way, to allow the data (X
) to flow through the network architecture and return the output. For each component, we’ll look at the LEGO pieces and how they are used in the forward
method.
Model Inputs
MMOE
inherits from BaseModel
(code), which contains components and methods shared by various model architectures. It uses the input_from_feature_columns
method from BaseModel
and the combined_dnn_input
function from the inputs
module to process the raw input into a format that can be consumed by the DNN model.
1# list of embedding and dense feature values
2sparse_embedding_list, dense_value_list = self.input_from_feature_columns(
3 X, self.dnn_feature_columns, self.embedding_dict
4)
5# concat feature lists into input for the deep part
6dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)
The input_from_feature_columns
method first identifies which features are sparse and which are dense. For sparse features, it performs a lookup. For each embedding_name
(e.g., uid
), it finds the embedding matrix, and then for a given feature value (a particular uid
), it retrieves the corresponding embedding vector. The output is a list of embedding vectors (sparse_embedding_list
: [emb1, emb2, emb3, ...]
). For dense features, it returns a list of scalars (dense_value_list
: [val1, val2, val3, ...]
). These two lists are returned separately.
1def input_from_feature_columns(self, X, feature_columns, embedding_dict, support_dense=True):
2
3 sparse_feature_columns = list(
4 filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else []
5 dense_feature_columns = list(
6 filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if len(feature_columns) else []
7
8 varlen_sparse_feature_columns = list(
9 filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []
10
11 if not support_dense and len(dense_feature_columns) > 0:
12 raise ValueError(
13 "DenseFeat is not supported in dnn_feature_columns")
14
15 sparse_embedding_list = [embedding_dict[feat.embedding_name](
16 X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
17 feat in sparse_feature_columns]
18
19 sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index,
20 varlen_sparse_feature_columns)
21 varlen_sparse_embedding_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index,
22 varlen_sparse_feature_columns, self.device)
23
24 dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
25 dense_feature_columns]
26
27 return sparse_embedding_list + varlen_sparse_embedding_list, dense_value_list
Then the combined_dnn_input
function takes the two feature lists as inputs and outputs a matrix ([batch_size, input_dim]
), where each row is a feature vector with all embeddings and dense features concatenated together:
sparse_dnn_input
: Concatenate sparse embeddings insparse_embedding_list
along the last dimension; flatten the result starting from the 2nd dimension.dense_dnn_input
: Concatenate dense values indense_value_list
along the last dimension; flatten the result starting from the 2nd dimension.dnn_input
: Concatenatesparse_dnn_input
(if available) anddense_dnn_input
(if available) along the feature dimension; one of them must be available.
1def combined_dnn_input(sparse_embedding_list, dense_value_list):
2 if len(sparse_embedding_list) > 0 and len(dense_value_list) > 0:
3 sparse_dnn_input = torch.flatten(
4 torch.cat(sparse_embedding_list, dim=-1), start_dim=1)
5 dense_dnn_input = torch.flatten(
6 torch.cat(dense_value_list, dim=-1), start_dim=1)
7 return concat_fun([sparse_dnn_input, dense_dnn_input])
8 elif len(sparse_embedding_list) > 0:
9 return torch.flatten(torch.cat(sparse_embedding_list, dim=-1), start_dim=1)
10 elif len(dense_value_list) > 0:
11 return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1)
12 else:
13 raise NotImplementedError
Expert Networks
In this implementation, each expert network is a two-layer fully connected DNN. The first layer has 256 neurons, and the second has 128, as specified in expert_dnn_hidden_units=(256, 128)
; they are connected by ReLU activation. nn.ModuleList
stores and manages a list of expert networks as a single module.
1self.expert_dnn = nn.ModuleList(
2 [
3 DNN(
4 self.input_dim,
5 expert_dnn_hidden_units,
6 activation=dnn_activation,
7 l2_reg=l2_reg_dnn,
8 dropout_rate=dnn_dropout,
9 use_bn=dnn_use_bn,
10 init_std=init_std,
11 device=device,
12 )
13 for _ in range(self.num_experts)
14 ]
15)
Each expert network processes dnn_input
independently. Outputs from all experts are returned in a matrix of dimension [batch_size, num_experts, output_dim]
.
1expert_outs = []
2for i in range(self.num_experts):
3 expert_out = self.expert_dnn[i](dnn_input)
4 expert_outs.append(expert_out)
5expert_outs = torch.stack(expert_outs, 1)
Gating Networks
The gating network has a simpler structure than the expert network — each gate is a one-layer DNN with 64 neurons, as specified in gate_dnn_hidden_units=(64,)
.
1self.gate_dnn = nn.ModuleList(
2 [
3 DNN(
4 self.input_dim,
5 gate_dnn_hidden_units,
6 activation=dnn_activation,
7 l2_reg=l2_reg_dnn,
8 dropout_rate=dnn_dropout,
9 use_bn=dnn_use_bn,
10 init_std=init_std,
11 device=device,
12 )
13 for _ in range(self.num_tasks)
14 ]
15)
The output from each gating network is passed to a linear layer to generate the weight of each expert that will be used in to combine expert outputs in each task.
1self.gate_dnn_final_layer = nn.ModuleList(
2 [
3 nn.Linear(
4 gate_dnn_hidden_units[-1]
5 if len(gate_dnn_hidden_units) > 0
6 else self.input_dim,
7 self.num_experts,
8 bias=False,
9 )
10 for _ in range(self.num_tasks)
11 ]
12)
Each gate takes dnn_input
and outputs expert-specific weights, which are then multiplied with expert outputs to get the input for the corresponding tower.
1mmoe_outs = []
2for i in range(self.num_tasks):
3 if (
4 len(self.gate_dnn_hidden_units) > 0
5 ): # input => gate dnn => final gate layer
6 gate_dnn_out = self.gate_dnn[i](dnn_input)
7 gate_dnn_out = self.gate_dnn_final_layer[i](gate_dnn_out)
8 else: # input => final gate layer
9 gate_dnn_out = self.gate_dnn_final_layer[i](dnn_input)
10 # performs matrix multiplication between post-softmax gate dnn output and expert outputs
11 gate_mul_expert = torch.matmul(
12 gate_dnn_out.softmax(1).unsqueeze(1), expert_outs
13 ) # (bs, 1, dim)
14 mmoe_outs.append(gate_mul_expert.squeeze())
Tower Networks
Similar to the gating networks above, each tower network is also a one-layer DNN with 64 neurons, as specified in gate_dnn_hidden_units=(64,)
.
1self.tower_dnn = nn.ModuleList(
2 [
3 DNN(
4 expert_dnn_hidden_units[-1],
5 tower_dnn_hidden_units,
6 activation=dnn_activation,
7 l2_reg=l2_reg_dnn,
8 dropout_rate=dnn_dropout,
9 use_bn=dnn_use_bn,
10 init_std=init_std,
11 device=device,
12 )
13 for _ in range(self.num_tasks)
14 ]
15)
The output from each tower is passed to a linear layer to generate the final prediction for a given task. Results of all tasks are collected in a nn.ModuleList
.
1# a list of linear layers, one for each task
2self.tower_dnn_final_layer = nn.ModuleList(
3 [
4 nn.Linear(
5 tower_dnn_hidden_units[-1]
6 if len(tower_dnn_hidden_units) > 0
7 else expert_dnn_hidden_units[-1],
8 1,
9 bias=False,
10 )
11 for _ in range(self.num_tasks)
12 ]
13)
14# each task type has an output
15self.out = nn.ModuleList([PredictionLayer(task) for task in task_types])
During training (fit
in BaseModel
), the total loss from all tasks and regularization is used to compute gradients and update weights — total_loss.backward()
.
Read More
Personally, I think deep learning ranking expertise is harder to come by than deep learning NLP expertise. Companies of any size can support fine-tuning “classic” NLP models such as BERT or BART, but many tech companies are still using Gradient Boosted Decision Trees (GBDT) in their ranking stack and struggling with the transition to deep learning. I’ll be collecting sources as I learn more. Below are some resources I find particularly useful at this moment:
- Ranking papers: A repository curated by a Meta engineer, collecting industry papers published by Meta, Airbnb, Amazon, and many more.
- Gaurav Chakravorty’s repo: Toy models created by a Meta E7 for educational purposes.