The Annotated Multi-Task Ranker: An MMoE Code Example

18 minute read

Natural Language Processing (NLP) has an abundance of intuitively explained tutorials with code, such as Andrej Kaparthy’s Neural Networks: Zero to Hero, the viral The Illustrated Transformer and its successor The Annotated Transformer, Umar Jamil’s YouTube series dissecting SOTA models and the companion repo, among others.

When it comes to Search/Ads/Recommendations (“搜广推”), however, intuitive explanations accompanied by code are rare. Company engineering blogs tend to focus on high-level system designs, and many top conference (e.g., KDD/RecSys/SIGIR) papers don’t share code. In this post, I explain the iconic Multi-gate Mixture-of-Experts (MMoE) paper (Ma et al., 2018) using implementation in the popular DeepCTR-Torch repo, to teach myself and readers how the authors' blueprint translates into code.

The Paper (Ma et al., 2018)

A huge appeal of deep learning is its ability to optimize for multiple task objectives at once, such as clicks and conversions in search/ads/feed ranking. In traditional machine learning, you would have to build multiple models, one per task, making the system hard to maintain for needing separate data/training/serving pipelines, and missing out on the opportunity for transfer learning between tasks.

In early designs, all tasks shared the same backbone that feeds into task-specific towers (e.g., Caruana, 1993, 1998). The Shared-Bottom architecture is simple and intuitive — the drawback is that low task correlation can hurt model performance.

As a solution to the above issue, we can replace the shared bottom with a group of bottom networks (“experts”), which explicitly learn relationships between tasks and how each task uses the shared representations. This is the Mixture-of-Experts (MoE) architecture (e.g., Jacobs et al., 1991, Eigen et al., 2013, Shazeer et al., 2017).

In the original Mixture of Experts (MoE) model, a “gating network” assembles expert outputs by learning each expert’s weight from input features (weights sum to 1) and returning the weighted sum of expert outputs as the output to the next layer:

$$y = \sum_{i=1}^n g(x)_i f_i(x)$$

, where $g(x)_i$ is the weight of the $i$th expert, and $f_i$ is the output from that expert.

The Multi-gate Mixture-of-Experts (MMoE) model has as many gating networks as there are tasks. Each gate learns a specific way to leverage expert outputs for its respective task. In contrast, a One-gate Mixture-of-Experts (OMoE) model uses a single gating network to find a best way to leverage expert outputs across all tasks.

As task correlation decreases, the MMoE architecture has a larger advantage over OMoE. Both MoE models outperform Shared-Bottom, regardless of task correlation. In today’s web-scale ranking systems, MMoE is by far the most widely adopted.

The Data (ByteRec)

Large-scale benchmark data played a pivotal role in the resurgence of deep learning. A prominent example is the ImageNet dataset with 14 million images from 20,000 categories, on which the CNN-based AlexNet achieved groundbreaking accuracy, outperforming non-DL models by a gigantic margin. Unlike in computer vision, ranking benchmarks are often set by companies famous for recommendation systems, such as Netflix (Netflix Prize) and ByteDance (Short Video Understanding Challenge).

The ByteDance data (henceforth “ByteRec”) is particularly suitable for multi-task learning, since there are 2 desired user behaviors to predict — finish and share.

ByteRec only has a couple of features, a simplification from real search/feed logs:

  • Dense features: Only duration_time (video watch time in seconds)
  • Sparse features: Categorical features such as ID (uid, item_id, author_id, music_id, channel, device) and locations (user_city, item_city)
  • Targets: Whether (1) or not (0) a user did finish or share a video

The Code (DeepCTR-Torch)

For learning deep learning ranking model architectures, I find DeepCTR-Torch (TensorFlow version: DeepCTR) highly educational. The repo covers SOTA models spanning the last decade (e.g., Deep & Cross, DCN, DCN v2, DIN, DIEN, PNN, MMoE, etc.), even though it may not have the full functionalities needed by production-grade rankers (e.g., hash encoding for ID features). There is a doc accompanying the code.

Below, I’ll explain the MMoE architecture and how it was used in the example training script the author provided. As with The Annotated Transformer, this post aims not to create an original implementation, but to provide a line-by-line explanation of an existing one. Please find my commented code in this Colab notebook.

Load Data

For testing, we can use a toy sample with 200 rows rather than the full data. At work, I also like testing models on small datasets, so I can fail and debug fast.

 1import pandas as pd
 2import torch
 3import torch.nn as nn
 4from sklearn.metrics import log_loss, roc_auc_score
 5from sklearn.preprocessing import LabelEncoder, MinMaxScaler
 6
 7from deepctr_torch.models.basemodel import BaseModel
 8from deepctr_torch.layers import DNN, PredictionLayer
 9from deepctr_torch.inputs import (
10    SparseFeat,
11    DenseFeat,
12    get_feature_names,
13    combined_dnn_input,
14)
15
16# sample data with 200 rows
17url = "https://raw.githubusercontent.com/shenweichen/DeepCTR-Torch/master/examples/byterec_sample.txt"
18
19data = pd.read_csv(
20    url,
21    sep="\t",
22    names=[
23        "uid",
24        "user_city",
25        "item_id",
26        "author_id",
27        "item_city",
28        "channel",
29        "finish",
30        "like",
31        "music_id",
32        "device",
33        "time",
34        "duration_time",
35    ],
36)

Transform Features

While deep learning models are known for automated feature engineering, we still need to encode sparse categorical features and scale dense numerical features to a limited range (e.g., [0, 1]) before feeding data to the input layer.

Specify Feature + Target Columns

For easy references, we first specify different types of feature + target columns:

 1sparse_features = [
 2    "uid",  # watcher's user id
 3    "user_city",  # watcher's city
 4    "item_id",  # video's id
 5    "author_id",  # author's user id
 6    "item_city",  # video's city
 7    "channel",  # author's channel
 8    "music_id",  # soundtrack id if the video contains music
 9    "device",  # user's device
10]
11
12dense_features = ["duration_time"]
13target = ["finish", "like"]

Encode Sparse Features

For each sparse feature, we can instantiate a LabelEncoder to assign an integer from 0 to (n - 1) to each of the n unique feature values.

1for feat in sparse_features:
2    lbe = LabelEncoder()
3    data[feat] = lbe.fit_transform(data[feat])

This method faces the out-of-vocabulary (OOV) problem: At inference time, if a feature has a value not seen during training such as a new author ID, the model might assign a low score if unknown authors typically have low engagement in historical data, leading the model to downrank new author content in the future. One solution is hash encoding: different OOV feature values are randomly assigned to different buckets, temporarily “borrowing” the model’s learning for that bucket to avoid systematic biases. Hash encoding is not implemented in DeepCTR-Torch.

Scale Dense Features

For each dense feature, we can use a MinMaxScaler to cap its range to [0, 1].

1mms = MinMaxScaler(feature_range=(0, 1))
2data[dense_features] = mms.fit_transform(data[dense_features])

If exact values don’t matter, we can discretize dense features into buckets (e.g., age $\rightarrow$ age groups) and process them as categorical (e.g., one-hot encoding).

Create Training Data

Different deep learning libraries expect input data to be in different formats. DeepCTR-Torch uses named tuples such as SparseFeat and DenseFeat to store feature metadata such as names, data types, dimensions, etc.. See definitions here.

 1# columns with sparse features
 2sparse_feature_columns = [
 3    SparseFeat(feat, vocabulary_size=data[feat].max() + 1, embedding_dim=4)
 4    for feat in sparse_features
 5]
 6
 7# columns with dense features
 8dense_feature_columns = [
 9    DenseFeat(
10        feat,
11        1,
12    )
13    for feat in dense_features
14]
15
16# columns with all features
17fixlen_feature_columns = sparse_feature_columns + dense_feature_columns
18dnn_feature_columns = fixlen_feature_columns
19linear_feature_columns = fixlen_feature_columns

After the train-test split, we store the two datasets into two dictionaries, where the keys are feature names and the values are lists of feature values.

 1# split data by rows
 2split_boundary = int(data.shape[0] * 0.8)
 3train, test = data[:split_boundary], data[split_boundary:]
 4
 5# get feature names
 6feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)
 7
 8# prepare input dicts: {feature name : feature list}
 9train_model_input = {name: train[name] for name in feature_names}
10test_model_input = {name: test[name] for name in feature_names}

Model Training

Set Up Environment

We use GPU for training whenever available; otherwise, we use CPU instead:

1device = "cpu"
2use_cuda = True
3if use_cuda and torch.cuda.is_available():
4    print("cuda ready...")
5    device = "cuda:0"

Instantiate the Model

I will explain the inner working of the MMOE model class in the next section.

To instantiate a new model, we need to provide a feature column list for the “deep” part of the network (our version of MMoE doesn’t have a “wide” part), a list of task types (binary or regression), L2 regularization strength, and the names of the tasks (e.g., ["finish", "share"]), and so on. In the DeepCTR-Torch API, we need to run model.compile() to specify the optimizer, the loss function for each task, and metrics to monitor during training before we can train the model.

 1# instantiate a MMoE model
 2model = MMOE(
 3    dnn_feature_columns,
 4    task_types=["binary", "binary"],
 5    l2_reg_embedding=1e-5,
 6    task_names=target,
 7    device=device,
 8)
 9
10# specify optimizer, loss functions for each task, and metrics
11model.compile(
12    "adagrad",
13    loss=["binary_crossentropy", "binary_crossentropy"],
14    metrics=["binary_crossentropy"],
15)

Train the Model

The model training and inference API adopts the common fit and predict methods. model.fit takes a dictionary as features ({feature name: feature list}) and an n-dimensional array as targets (n being the number of tasks), along with training setup details (e.g., batch size, number of epochs, output verbosity, etc.). model.predict takes a similar dictionary as features and the inference batch size, and it outputs predictions for each task (dimension: [batch_size, num_tasks]).

 1# fit model to training data
 2history = model.fit(
 3    train_model_input, train[target].values, batch_size=32, epochs=10, verbose=2
 4)
 5# generate predictions for test data
 6pred_ans = model.predict(test_model_input, 256) # inference batch size: 256
 7print("")
 8for i, target_name in enumerate(target):
 9    log_loss_value = round(log_loss(test[target[i]].values, pred_ans[:, i]), 4)
10    auc_value = round(roc_auc_score(test[target[i]].values, pred_ans[:, i]), 4)
11
12    print(f"{target_name} test LogLoss: {log_loss_value}")
13    print(f"{target_name} test AUC: {auc_value}")

The Anatomy of the MMOE Class

The “meat” of this post is the class definition of the MMOE model. You can find the source code here. The snippet below includes comments added by me. The model design is a faithful translation of the MMoE architecture (doodles also by me):

  • Input: For each training instance, embeddings are flattened and concatenated with dense features into a single vector, which feeds into gates and experts.
  • Expert outputs: Each expert gets the same input and generates an vector output.
  • Gating outputs: Each gate gets the same input and generates a scalar weight for each expert; for each gate, the weights of all experts sum up to 1.
  • MMoE outputs: For each task, we use the corresponding gate network to output a weighted sum of expert outputs, which is the input to the task-specific tower.
  • Task outputs: Each tower gets a task-specific input and generates an output.

Now, let’s digest the model code bit by bit. You can read it in its entirety first:

  1class MMOE(BaseModel):
  2    """Instantiates the multi-gate mixture-of-experts architecture.
  3
  4    :param dnn_feature_columns: an iterable containing all the features used by deep part of the model.
  5    :param num_experts: integer, number of experts.
  6    :param expert_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of expert dnn.
  7    :param gate_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of gate dnn.
  8    :param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific dnn.
  9    :param l2_reg_linear: float, l2 regularizer strength applied to linear part.
 10    :param l2_reg_embedding: float, l2 regularizer strength applied to embedding vector.
 11    :param l2_reg_dnn: float, l2 regularizer strength applied to dnn.
 12    :param init_std: float, to use as the initialize std of embedding vector.
 13    :param seed: integer, to use as random seed.
 14    :param dnn_dropout: float in [0,1), the probability we will drop out a given dnn coordinate.
 15    :param dnn_activation: activation function to use in dnn.
 16    :param dnn_use_bn: bool, whether use batchnormalization before activation or not in dnn.
 17    :param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss, ``"regression"`` for regression loss. e.g. ['binary', 'regression'].
 18    :param task_names: list of str, indicating the predict target of each tasks.
 19    :param device: str, ``"cpu"`` or ``"cuda:0"``.
 20    :param gpus: list of int or torch.device for multiple gpus. if none, run on `device`. `gpus[0]` should be the same gpu with `device`.
 21
 22    :return: a pytorch model instance.
 23    """
 24
 25    def __init__(
 26        self,
 27        dnn_feature_columns,  # a list feature used by the deep part
 28        num_experts=3,  # number of experts
 29        expert_dnn_hidden_units=(256, 128),  # each expert dnn has 2 layers
 30        gate_dnn_hidden_units=(64,),  # each gate dnn has 1 layer
 31        tower_dnn_hidden_units=(64,),  # each tower dnn has 1 layer
 32        l2_reg_linear=0.00001,  # l2 regularizer strength for linear part
 33        l2_reg_embedding=0.00001, # l2 regularizer strength for emb part
 34        l2_reg_dnn=0, # l2 regularizer strength for DNN part
 35        init_std=0.0001,
 36        seed=1024,
 37        dnn_dropout=0,
 38        dnn_activation="relu",
 39        dnn_use_bn=False,  # whether to use batch norm
 40        task_types=("binary", "binary"),
 41        task_names=("ctr", "ctcvr"),
 42        device="cpu",
 43        gpus=None,
 44    ):
 45        super(MMOE, self).__init__(
 46            linear_feature_columns=[],
 47            dnn_feature_columns=dnn_feature_columns,
 48            l2_reg_linear=l2_reg_linear,
 49            l2_reg_embedding=l2_reg_embedding,
 50            init_std=init_std,
 51            seed=seed,
 52            device=device,
 53            gpus=gpus,
 54        )
 55        self.num_tasks = len(task_names)  # infer task count from task names
 56
 57        # performs input validations
 58        if self.num_tasks <= 1:
 59            raise ValueError(
 60                "num_tasks must be greater than 1"
 61            )  # multi-task model must have multiple tasks
 62        if num_experts <= 1:
 63            raise ValueError(
 64                "num_experts must be greater than 1"
 65            )  # multi-expert model must have multiple experts
 66        if len(dnn_feature_columns) == 0:
 67            raise ValueError(
 68                "dnn_feature_columns is null!"
 69            )  # the deep part must have features
 70        if len(task_types) != self.num_tasks:
 71            raise ValueError(
 72                "num_tasks must be equal to the length of task_types"
 73            )  # make sure we specify a type for each task
 74        for task_type in task_types:
 75            if task_type not in ["binary", "regression"]:
 76                raise ValueError(
 77                    f"task must be binary or regression, {task_type} is illegal"
 78                )  # make sure task type is valid
 79
 80        self.num_experts = num_experts
 81        self.task_names = task_names
 82        self.input_dim = self.compute_input_dim(dnn_feature_columns)  
 83        self.expert_dnn_hidden_units = expert_dnn_hidden_units
 84        self.gate_dnn_hidden_units = gate_dnn_hidden_units
 85        self.tower_dnn_hidden_units = tower_dnn_hidden_units
 86
 87        # expert dnn: each element is an expert network
 88        self.expert_dnn = nn.ModuleList(
 89            [
 90                DNN(
 91                    self.input_dim,
 92                    expert_dnn_hidden_units,
 93                    activation=dnn_activation,
 94                    l2_reg=l2_reg_dnn,
 95                    dropout_rate=dnn_dropout,
 96                    use_bn=dnn_use_bn,
 97                    init_std=init_std,
 98                    device=device,
 99                )
100                for _ in range(self.num_experts)
101            ]
102        )
103
104        # gate dnn: each element is a gate for a task
105        if len(gate_dnn_hidden_units) > 0:
106            self.gate_dnn = nn.ModuleList(
107                [
108                    DNN(
109                        self.input_dim,
110                        gate_dnn_hidden_units,
111                        activation=dnn_activation,
112                        l2_reg=l2_reg_dnn,
113                        dropout_rate=dnn_dropout,
114                        use_bn=dnn_use_bn,
115                        init_std=init_std,
116                        device=device,
117                    )
118                    for _ in range(self.num_tasks)
119                ]
120            )
121            # select weights to regularize
122            self.add_regularization_weight(
123                filter(
124                    lambda x: "weight" in x[0] and "bn" not in x[0],
125                    self.gate_dnn.named_parameters(),
126                ),
127                l2=l2_reg_dnn,
128            )
129        # a list of linear layers, one for each task
130        self.gate_dnn_final_layer = nn.ModuleList(
131            [
132                nn.Linear(
133                    gate_dnn_hidden_units[-1]
134                    if len(gate_dnn_hidden_units) > 0
135                    else self.input_dim,
136                    self.num_experts,
137                    bias=False,
138                )
139                for _ in range(self.num_tasks)
140            ]
141        )
142
143        # tower dnn: each element is a tower for each task
144        if len(tower_dnn_hidden_units) > 0:
145            self.tower_dnn = nn.ModuleList(
146                [
147                    DNN(
148                        expert_dnn_hidden_units[-1],
149                        tower_dnn_hidden_units,
150                        activation=dnn_activation,
151                        l2_reg=l2_reg_dnn,
152                        dropout_rate=dnn_dropout,
153                        use_bn=dnn_use_bn,
154                        init_std=init_std,
155                        device=device,
156                    )
157                    for _ in range(self.num_tasks)
158                ]
159            )
160            # select weights to regularize
161            self.add_regularization_weight(
162                filter(
163                    lambda x: "weight" in x[0] and "bn" not in x[0],
164                    self.tower_dnn.named_parameters(),
165                ),
166                l2=l2_reg_dnn,
167            )
168        # a list of linear layers, one for each task
169        self.tower_dnn_final_layer = nn.ModuleList(
170            [
171                nn.Linear(
172                    tower_dnn_hidden_units[-1]
173                    if len(tower_dnn_hidden_units) > 0
174                    else expert_dnn_hidden_units[-1],
175                    1,
176                    bias=False,
177                )
178                for _ in range(self.num_tasks)
179            ]
180        )
181        # each task type has an output
182        self.out = nn.ModuleList([PredictionLayer(task) for task in task_types])
183
184        # add final parameters to be regularized
185        regularization_modules = [
186            self.expert_dnn,
187            self.gate_dnn_final_layer,
188            self.tower_dnn_final_layer,
189        ]
190        for module in regularization_modules:
191            self.add_regularization_weight(
192                filter(
193                    lambda x: "weight" in x[0] and "bn" not in x[0],
194                    module.named_parameters(),
195                ),
196                l2=l2_reg_dnn,
197            )
198        self.to(device)
199
200    def forward(self, X):
201        # list of embedding and dense feature values
202        sparse_embedding_list, dense_value_list = self.input_from_feature_columns(
203            X, self.dnn_feature_columns, self.embedding_dict
204        )
205        # concat features into a single vector for each instance
206        dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)
207
208        # expert dnn: collect output from each expert
209        expert_outs = []
210        for i in range(self.num_experts):
211            expert_out = self.expert_dnn[i](dnn_input)
212            expert_outs.append(expert_out)
213        expert_outs = torch.stack(expert_outs, 1)  # (bs, num_experts, dim)
214
215        # gate dnn: each gate has a way to combine expert outputs
216        mmoe_outs = []
217        for i in range(self.num_tasks):
218            if (
219                len(self.gate_dnn_hidden_units) > 0
220            ):  # input => gate dnn => final gate layer
221                gate_dnn_out = self.gate_dnn[i](dnn_input)
222                gate_dnn_out = self.gate_dnn_final_layer[i](gate_dnn_out)
223            else:  # input => final gate layer
224                gate_dnn_out = self.gate_dnn_final_layer[i](dnn_input)
225            # performs matrix multiplication between post-softmax gate dnn output and expert outputs
226            gate_mul_expert = torch.matmul(
227                gate_dnn_out.softmax(1).unsqueeze(1), expert_outs
228            )  # (bs, 1, dim)
229            mmoe_outs.append(gate_mul_expert.squeeze())
230
231        # tower dnn: each tower generates output for a specific task
232        task_outs = []
233        for i in range(self.num_tasks):
234            if (
235                len(self.tower_dnn_hidden_units) > 0
236            ):  # input => tower dnn => final tower layer
237                tower_dnn_out = self.tower_dnn[i](mmoe_outs[i])
238                tower_dnn_logit = self.tower_dnn_final_layer[i](tower_dnn_out)
239            else:  # input => final tower layer
240                tower_dnn_logit = self.tower_dnn_final_layer[i](mmoe_outs[i])
241            output = self.out[i](tower_dnn_logit)
242            task_outs.append(output)
243        task_outs = torch.cat(task_outs, -1)  # output dimension: (bs, num_tasks)
244        return task_outs

When coding up deep learning models, I think of the constructor (def __init__) as the LEGO blocks: You specify what the pieces are and the properties of each piece, so that later you can use them at your disposal. The forward pass is the building process — you put the pieces together in a specific way, to allow the data (X) to flow through the network architecture and return the output. For each component, we’ll look at the LEGO pieces and how they are used in the forward method.

Model Inputs

MMOE inherits from BaseModel (code), which contains components and methods shared by various model architectures. It uses the input_from_feature_columns method from BaseModel and the combined_dnn_input function from the inputs module to process the raw input into a format that can be consumed by the DNN model.

1# list of embedding and dense feature values
2sparse_embedding_list, dense_value_list = self.input_from_feature_columns(
3    X, self.dnn_feature_columns, self.embedding_dict
4)
5# concat feature lists into input for the deep part
6dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)

The input_from_feature_columns method first identifies which features are sparse and which are dense. For sparse features, it performs a lookup. For each embedding_name (e.g., uid), it finds the embedding matrix, and then for a given feature value (a particular uid), it retrieves the corresponding embedding vector. The output is a list of embedding vectors (sparse_embedding_list: [emb1, emb2, emb3, ...]). For dense features, it returns a list of scalars (dense_value_list: [val1, val2, val3, ...]). These two lists are returned separately.

 1def input_from_feature_columns(self, X, feature_columns, embedding_dict, support_dense=True):
 2
 3    sparse_feature_columns = list(
 4        filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else []
 5    dense_feature_columns = list(
 6        filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if len(feature_columns) else []
 7
 8    varlen_sparse_feature_columns = list(
 9        filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []
10
11    if not support_dense and len(dense_feature_columns) > 0:
12        raise ValueError(
13            "DenseFeat is not supported in dnn_feature_columns")
14
15    sparse_embedding_list = [embedding_dict[feat.embedding_name](
16        X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
17        feat in sparse_feature_columns]
18
19    sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index,
20                                                  varlen_sparse_feature_columns)
21    varlen_sparse_embedding_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index,
22                                                           varlen_sparse_feature_columns, self.device)
23
24    dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
25                        dense_feature_columns]
26
27    return sparse_embedding_list + varlen_sparse_embedding_list, dense_value_list

Then the combined_dnn_input function takes the two feature lists as inputs and outputs a matrix ([batch_size, input_dim]), where each row is a feature vector with all embeddings and dense features concatenated together:

  • sparse_dnn_input: Concatenate sparse embeddings in sparse_embedding_list along the last dimension; flatten the result starting from the 2nd dimension.
  • dense_dnn_input: Concatenate dense values in dense_value_list along the last dimension; flatten the result starting from the 2nd dimension.
  • dnn_input: Concatenate sparse_dnn_input (if available) and dense_dnn_input (if available) along the feature dimension; one of them must be available.
 1def combined_dnn_input(sparse_embedding_list, dense_value_list):
 2    if len(sparse_embedding_list) > 0 and len(dense_value_list) > 0:
 3        sparse_dnn_input = torch.flatten(
 4            torch.cat(sparse_embedding_list, dim=-1), start_dim=1)
 5        dense_dnn_input = torch.flatten(
 6            torch.cat(dense_value_list, dim=-1), start_dim=1)
 7        return concat_fun([sparse_dnn_input, dense_dnn_input])
 8    elif len(sparse_embedding_list) > 0:
 9        return torch.flatten(torch.cat(sparse_embedding_list, dim=-1), start_dim=1)
10    elif len(dense_value_list) > 0:
11        return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1)
12    else:
13        raise NotImplementedError

Expert Networks

In this implementation, each expert network is a two-layer fully connected DNN. The first layer has 256 neurons, and the second has 128, as specified in expert_dnn_hidden_units=(256, 128); they are connected by ReLU activation. nn.ModuleList stores and manages a list of expert networks as a single module.

 1self.expert_dnn = nn.ModuleList(
 2    [
 3        DNN(
 4            self.input_dim,
 5            expert_dnn_hidden_units,
 6            activation=dnn_activation,
 7            l2_reg=l2_reg_dnn,
 8            dropout_rate=dnn_dropout,
 9            use_bn=dnn_use_bn,
10            init_std=init_std,
11            device=device,
12        )
13        for _ in range(self.num_experts)
14    ]
15)

Each expert network processes dnn_input independently. Outputs from all experts are returned in a matrix of dimension [batch_size, num_experts, output_dim].

1expert_outs = []
2for i in range(self.num_experts):
3    expert_out = self.expert_dnn[i](dnn_input)
4    expert_outs.append(expert_out)
5expert_outs = torch.stack(expert_outs, 1)

Gating Networks

The gating network has a simpler structure than the expert network — each gate is a one-layer DNN with 64 neurons, as specified in gate_dnn_hidden_units=(64,).

 1self.gate_dnn = nn.ModuleList(
 2    [
 3        DNN(
 4            self.input_dim,
 5            gate_dnn_hidden_units,
 6            activation=dnn_activation,
 7            l2_reg=l2_reg_dnn,
 8            dropout_rate=dnn_dropout,
 9            use_bn=dnn_use_bn,
10            init_std=init_std,
11            device=device,
12        )
13        for _ in range(self.num_tasks)
14    ]
15)

The output from each gating network is passed to a linear layer to generate the weight of each expert that will be used in to combine expert outputs in each task.

 1self.gate_dnn_final_layer = nn.ModuleList(
 2    [
 3        nn.Linear(
 4            gate_dnn_hidden_units[-1]
 5            if len(gate_dnn_hidden_units) > 0
 6            else self.input_dim,
 7            self.num_experts,
 8            bias=False,
 9        )
10        for _ in range(self.num_tasks)
11    ]
12)

Each gate takes dnn_input and outputs expert-specific weights, which are then multiplied with expert outputs to get the input for the corresponding tower.

 1mmoe_outs = []
 2for i in range(self.num_tasks):
 3    if (
 4        len(self.gate_dnn_hidden_units) > 0
 5    ):  # input => gate dnn => final gate layer
 6        gate_dnn_out = self.gate_dnn[i](dnn_input)
 7        gate_dnn_out = self.gate_dnn_final_layer[i](gate_dnn_out)
 8    else:  # input => final gate layer
 9        gate_dnn_out = self.gate_dnn_final_layer[i](dnn_input)
10    # performs matrix multiplication between post-softmax gate dnn output and expert outputs
11    gate_mul_expert = torch.matmul(
12        gate_dnn_out.softmax(1).unsqueeze(1), expert_outs
13    )  # (bs, 1, dim)
14    mmoe_outs.append(gate_mul_expert.squeeze())

Tower Networks

Similar to the gating networks above, each tower network is also a one-layer DNN with 64 neurons, as specified in gate_dnn_hidden_units=(64,).

 1self.tower_dnn = nn.ModuleList(
 2    [
 3        DNN(
 4            expert_dnn_hidden_units[-1],
 5            tower_dnn_hidden_units,
 6            activation=dnn_activation,
 7            l2_reg=l2_reg_dnn,
 8            dropout_rate=dnn_dropout,
 9            use_bn=dnn_use_bn,
10            init_std=init_std,
11            device=device,
12        )
13        for _ in range(self.num_tasks)
14    ]
15)

The output from each tower is passed to a linear layer to generate the final prediction for a given task. Results of all tasks are collected in a nn.ModuleList.

 1# a list of linear layers, one for each task
 2self.tower_dnn_final_layer = nn.ModuleList(
 3    [
 4        nn.Linear(
 5            tower_dnn_hidden_units[-1]
 6            if len(tower_dnn_hidden_units) > 0
 7            else expert_dnn_hidden_units[-1],
 8            1,
 9            bias=False,
10        )
11        for _ in range(self.num_tasks)
12    ]
13)
14# each task type has an output
15self.out = nn.ModuleList([PredictionLayer(task) for task in task_types])

During training (fit in BaseModel), the total loss from all tasks and regularization is used to compute gradients and update weights — total_loss.backward().

Read More

Personally, I think deep learning ranking expertise is harder to come by than deep learning NLP expertise. Companies of any size can support fine-tuning “classic” NLP models such as BERT or BART, but many tech companies are still using Gradient Boosted Decision Trees (GBDT) in their ranking stack and struggling with the transition to deep learning. I’ll be collecting sources as I learn more. Below are some resources I find particularly useful at this moment:

  1. Ranking papers: A repository curated by a Meta engineer, collecting industry papers published by Meta, Airbnb, Amazon, and many more.
  2. Gaurav Chakravorty’s repo: Toy models created by a Meta E7 for educational purposes.