Update (January 22, 2020): After several discussions with Matthew Wiesner, I have added some content to this post (e.g. deriving the derivatives for MMI) and rewritten some parts to make the explanations clearer.
Recently, I came across this paper which compares several sequence discriminative training criteria based on the popular lattice-free MMI (LF-MMI) objective, and concludes that “boosted” LF-MMI outperforms others consistently. Since I couldn’t find the code publicly available, I set out to implement it myself in Kaldi. The idea was that even if the claim turned out to be false, this would give me a hands-on experience with C++ level implementations in Kaldi.
On first look, the implementation seems trivial if you already have a LF-MMI (also called the “chain” model in Kaldi) implementation available. However, there are several tricks used in Kaldi which are worth pointing out. In this article, I start with giving an overview of LF-MMI and its implementation in the chain models, and then talk about how I implemented boosted LF-MMI. The majority of the theory here is based on this paper which introduced LF-MMI and this doc on chain model.
MMI – a background
Maximum mutual information, or MMI, is a sequence discriminative training criteria popular in ASR. “Sequence” means that the objective takes into account the utterance as a whole instead of “frame-level” objectives like cross-entropy. “Discriminative” loosely means using an objective function which supposedly optimizes some criteria associated with the task, and then minimizing that objective directly using gradient-based methods. Discriminative training for LVCSR was made popular in Dan Povey’s thesis. Formally, the MMI objective for ASR is written as
\[F_{MMI}(\theta) = \sum_{r=1}^R \log \frac{P_{\theta}(O_r|M_{w_r})P(w_r)}{\sum_{\hat{w}}P_{\theta}(O_r|M_{\hat{w}})P(\hat{w})},\]where $M_w$ is the HMM corresponding to the transcription $w$. As you can see, the objective function considers the log-probability of the whole utterance in the numerator, and normalizes it by dividing with the log-probability of all possible utterances in the denominator. Here, the distributions with subscript $\theta$ are the parametrized distributions that are trained.
Deriving gradients for MMI
First-order gradient based methods such as stochastic gradient descent (SGD) are most commonly used for optimization. To use such approaches, we need to first know the gradient of the MMI objective in terms of the parameter $\theta$. In this section, we derive this gradient fairly explicitly. This is mostly taken from these slides with added steps and explanations.
We start with our above formulation of the MMI objective and break the $\log$ into the smaller terms.
\[\begin{align*} F_{MMI}(\theta) &= \sum_{r=1}^R \log \frac{P_{\theta}(O_r|W_r)P(W_r)}{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \\ &= \sum_{r=1}^R \left[ \log P_{\theta}(O_r|W_r) + \log P(W_r) - \log \sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W}) \right] \end{align*}\]Now we take the gradient w.r.t parameter $\theta$:
\[\begin{align*} \nabla_{\theta} F_{MMI}(\theta) &= \nabla_{\theta} \sum_{r=1}^R \left[ \log P_{\theta}(O_r|W_r) + \log P(W_r) - \log \sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W}) \right] \\ &= \sum_{r=1}^R \left[ \nabla_{\theta} \log P_{\theta}(O_r|W_r) + \nabla_{\theta} \log P(W_r) - \nabla_{\theta} \log \sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W}) \right] \\ &= \sum_{r=1}^R \left[ \nabla_{\theta} \log P_{\theta}(O_r|W_r) - \nabla_{\theta} \log \sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W}) \right] \\ \end{align*}\]Here we have used $\nabla_{\theta} \log P(W_r) = 0$ since $P(W_r)$ is independent of $\theta$. Now we simplify the second term inside the sum.
\[\begin{align*} \nabla_{\theta} F_{MMI}(\theta) &= \sum_{r=1}^R \left[ \nabla_{\theta} \log P_{\theta}(O_r|W_r) - \frac{\nabla_{\theta} \sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})}{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \right] \\ &= \sum_{r=1}^R \left[ \nabla_{\theta} \log P_{\theta}(O_r|W_r) - \frac{\sum_{\hat{W}} \nabla_{\theta} \left( P_{\theta}(O_r|\hat{W})P(\hat{W}) \right) }{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \right] \\ &= \sum_{r=1}^R \left[ \nabla_{\theta} \log P_{\theta}(O_r|W_r) - \frac{\sum_{\hat{W}} P(\hat{W}) \nabla_{\theta} P_{\theta}(O_r|\hat{W}) }{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \right] \end{align*}\]Here we have used the fact that $P(\hat{W})$ is independent of $\theta$ so it becomes a constant for the gradient. Now, we know that $\nabla \log x = \frac{\nabla x}{x}$, and so $\nabla x = x \nabla \log x$. Using this, we can substitute $\nabla_{\theta} P_{\theta}(O_r \mid \hat{W})$ in the above to get
\[\nabla_{\theta} F_{MMI}(\theta) = \sum_{r=1}^R \left[ \nabla_{\theta} \log P_{\theta}(O_r|W_r) - \frac{\sum_{\hat{W}} P(\hat{W}) P_{\theta}(O_r|\hat{W}) \nabla_{\theta} \log P_{\theta}(O_r|\hat{W}) }{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \right].\]Note that we have done this because our model computes log-probabilities, and so it is easier to compute gradients in the same log domain. When we train an acoustic model, we are essentially maximizing the log-likelihood $\log_{\theta} P(O_r \mid W)$ for the training data. Let us now focus on the gradient for this term. We can see that
\[\begin{align*} \nabla_{\theta} \log P(O_r,W) &= \nabla_{\theta} \log P(O_r|W) P(W) \\ &= \nabla_{\theta} \left( \log P(O_r|W) + \log P(W) \right) \\ &= \nabla_{\theta} \log P(O_r|W) + \nabla_{\theta} \log P(W) \\ &= \nabla_{\theta} \log P(O_r|W). \end{align*}\]So we can work with the joint probability, since it is easier to model with an HMM. Any word sequence $W$ can have several possible HMM state sequences, where the variability arises from the following places:
- Different pronunciations implies $W$ can be written as various sequence of phonemes.
- Even for a particular phone sequence, we can have different alignments between states and frames. For example, if we have a monophone model, and $W$ consists of the phone sequence
a-b-c
, if we have 4 acoustic frames, then we can have any of the following alignments:a-a-b-c
,a-b-b-c
, ora-b-c-c
.
Suppose $W$ has $R$ such possible state sequences. For each sequence $r$, the probability of the sequence can be written as $\prod_{t=1}^{T_r} p_{\theta}(o_t \mid s_t)p(s_t \mid s_{t-1})$, where $p_{\theta}(o_t \mid s_t)$ is the emission probability, which we train, and $p(s_t \mid s_{t-1})$ is the transition probability, which is usually kept fixed. So we can write the overall gradient of our acoustic model as
\[\begin{align*} \nabla_{\theta} \log P(O_r,W) &= \nabla_{\theta} \log \sum_r \prod_{t=1}^{T_r} p_{\theta}(o_t | s_t)p(s_t |s_{t-1}) \\ &= \frac{\sum_r \nabla_{\theta} \prod_{t=1}^{T_r} p_{\theta}(o_t | s_t)p(s_t |s_{t-1})}{\sum_{\sigma} \prod_{t^{\prime}=1}^{T_{\sigma}} p_{\theta}(o_{t^{\prime}} | s_{t^{\prime}}) p(s_{t^{\prime}} | s_{t^{\prime}-1})} \end{align*}.\]Let $x = \prod_{t=1}^{T_r} p_{\theta}(o_t \mid s_t)p(s_t \mid s_{t-1})$. Then
\[\begin{align*} \log x &= \sum_{t=1}^{T_r} \log p_{\theta}(o_t \mid s_t)p(s_t \mid s_{t-1}) \\ \nabla_{\theta} \log x &= \nabla_{\theta} \sum_{t=1}^{T_r} \log p_{\theta}(o_t \mid s_t)p(s_t \mid s_{t-1}) \\ \frac{\nabla_{\theta} x}{x} &= \sum_{t=1}^{T_r} \nabla_{\theta} \left( \log p_{\theta}(o_t \mid s_t) + \log p(s_t \mid s_{t-1}) \right) \\ \nabla_{\theta} x &= x \sum_{t=1}^{T_r} \nabla_{\theta} \log p_{\theta}(o_t \mid s_t). \end{align*}\]Substituting this in above, we have
\[\begin{align*} \nabla_{\theta} \log P(O_r,W) &= \frac{\sum_r \left( \prod_{t=1}^{T_r} p_{\theta}(o_t | s_t)p(s_t |s_{t-1}) \right) \left( \sum_{t=1}^{T_r} \nabla_{\theta} \log p_{\theta}(o_t \mid s_t) \right)}{\sum_{\sigma} \prod_{t^{\prime}=1}^{T_{\sigma}} p_{\theta}(o_{t^{\prime}} | s_{t^{\prime}}) p(s_{t^{\prime}} | s_{t^{\prime}-1})} \end{align*}.\]Consider the numerator: the entire product term is a constant for a given state sequence $r$. So we can exhcange the summations as $\sum_r w_r \sum_t a_{rt} = \sum_r \sum_t w_r a_{rt} = \sum_t \sum_r w_r a_{rt}$. Using this, we get:
\[\begin{align*} \nabla_{\theta} \log P(O_r,W) &= \frac{\sum_{t=1}^T \left( \sum_r \nabla_{\theta} \log p_{\theta}(o_t | s_t) \right) \left( \prod_{t^{\prime}=1}^{T_r} p_{\theta}(o_{t^{\prime}} | s_{t^{\prime}}) p(s_{t^{\prime}} | s_{t^{\prime}-1}) \right)}{\sum_{\sigma} \prod_{t^{\prime}=1}^{T_{\sigma}} p_{\theta}(o_{t^{\prime}} | s_{t^{\prime}}) p(s_{t^{\prime}} | s_{t^{\prime}-1})} \\ \end{align*}.\]Here in the numerator inside the outermost sum, we iterate over all state sequences in $W$ and compute their log probabilities. Alternatively, we can iterate over all possible state sequences and sum their probabilities of being in $W$ as follows:
\[\begin{align*} \nabla_{\theta} \log P(O_r,W) &= \sum_{t=1}^T \sum_S \frac{\nabla_{\theta} \log p_{\theta}(o_t|s_t) \sum_{s^{\prime} \in s}p_{\theta}(O_r,s^{\prime}|W)}{p_{\theta}(O_r|W)}. \\ \end{align*}.\]Here, note that
\[\gamma_{rt}(s|W) = \frac{\sum_{s^{\prime} \in s}p_{\theta}(O_r,s^{\prime}|W)}{p_{\theta}(O_r|W)} = p_{\theta,t}(s|O_r,W)\]is the word sequence conditioned state posterior (also called occupancy). Finally, we write the gradient of the acoustic model as
\[\nabla_{\theta} \log P(O_r,W) = \sum_{t=1}^T \sum_s \gamma_{rt}(s|W) \nabla_{\theta} \log p_{\theta}(o_t|s_t).\]We now plug this back into our gradient of the MMI objective to get
\[\begin{align*} \nabla_{\theta} F_{MMI}(\theta) &= \sum_{r=1}^R \left[ \nabla_{\theta} \log P_{\theta}(O_r|W_r) - \frac{\sum_{\hat{W}} P(\hat{W}) P_{\theta}(O_r|\hat{W}) \nabla_{\theta} \log P_{\theta}(O_r|\hat{W}) }{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \right]. \\ &= \sum_{r=1}^R \left[ \sum_{t=1}^T \sum_s \gamma_{rt}(s|W_r) \nabla_{\theta} \log p_{\theta}(o_t|s_t) - \frac{\sum_{\hat{W}} P(\hat{W}) P_{\theta}(O_r|\hat{W}) \sum_{t=1}^T \sum_s \gamma_{rt}(s|\hat{W}) \nabla_{\theta} \log p_{\theta}(o_t|s_t) }{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \right] \\ &= \sum_{r=1}^R \sum_{t=1}^T \sum_s \nabla_{\theta} \log p_{\theta}(o_t|s_t) \left( \gamma_{rt}(s|W_r) - \frac{\sum_{\hat{W}} P(\hat{W}) P_{\theta}(O_r|\hat{W}) \gamma_{rt}(s|\hat{W}) }{\sum_{\hat{W}}P_{\theta}(O_r|\hat{W})P(\hat{W})} \right) \\ &= \sum_{r=1}^R \sum_{t=1}^T \sum_s \nabla_{\theta} \log p_{\theta}(o_t|s_t) \left( \gamma_{rt}(s|W_r) - \gamma_{rt}(s) \right), \\ \end{align*}\]where $\gamma_{rt}(s)$ is the general state posterior.
This gives us the overall gradient required to maximize the MMI objective. Here, the term $\log p_{\theta}(o_t \mid s_t)$ is score that is usually output by a neural network, so the corresponding gradient is simply done during backpropagation. For the gradient of the overall objective, we multiply this with the term $\gamma_{rt}(s \mid W_r) - \gamma_{rt}(s)$. The key then is to compute the state occupancies for the numerator term and the denominator term.
Why is this difficult?
Computing the sum in the denominator means summing over an exponentially large number of word sequences, which is not practically feasible. To remedy this, we approximate the sum with either of two methods:
-
N-best list: This is computed once and used for all utterances. However, this approximation is less used since it is too crude.
-
Lattice structure: This may be word/phone based. A path through the lattice represents a possible word/phone sequence. One limitation with using a lattice is that it requires initialization with a trained model, and usually cross-entropy trained systems are used for this purpose. The older nnet setups in Kaldi used this approach.
With the advent of end-to-end models, such a requirement of a trained system to initialize the lattice comes across as a major drawback of lattice-based MMI. How can we avoid using a lattice?
Lattice-free MMI
First proposed in this paper from Dan Povey, lattice-free MMI is “purely sequence trained” in the sense that no cross-entropy training is required to initialize, since it does not use a lattice. So how does it approximate the sum in the denominator? Simply put, it does not “approximate” it — it computes this sum exactly.
The key idea is that if we represent the denominator as a graph and somehow manage to fit this graph in the GPU, then computation can be performed efficiently. In the manner that it is formalized, the denominator graph cannot be fit into the GPU. To fix this, two major modifications are applied:
-
A phone LM is used instead of a word LM. The number of possible phones is much smaller than the number of possible words, which makes the size of graph for phone LM significantly smaller.
-
DNN outputs are computed at one-third the standard frame rate, which means that we now have 3 times fewer outputs to compute for any utterance. This is achieved by setting the frame shift to 30 ms instead of the traditional 10 ms.
This reduced frame rate also means that now we cannot use the standard 3-state left-to-right HMM topology that is common in ASR, since we want to traverse the entire HMM in a single frame. Instead, we use an HMM which can emit symbols in the set ab*
.
To train such a system according to the MMI objective, we need a way to efficiently compute the objective itself and its derivative. In Kaldi, the numerator and denominator are represented as FSTs (corresponding to the HMMs) and the overall objective function is simply the difference of these in log-space. For the derivative, as derived earlier, we need to compute the state occupancies in the numerator FST and the denominator FST.
The denominator and numerator FSTs
Let us start with the denominator FST since it is much more expensive. The process of creating the denominator FST is very similar to the decoding graph creation. The key idea, as in traditional ASR using WFSTs (see Mohri’s well-known paper), is to have separate FSTs for H
(HMM state graph), C
(context-dependency), L
(the lexicon), and G
(the language model), and use WFST composition algorithms to get the final graph. If we compose the FST over all possible word sequences, then the graph would become too big to fit into GPU memory. To solve this problem, the chain model uses a phone-level LM (which has way fewer possible sequences). Since we are using phones instead of words, we don’t need the L
graph. So our final graph is actually an HCP
instead of an HCLG
, where P
denotes the phone LM.
At this point, I would like to point out some Kaldi specifics. The phone LM P
is created in stage -6
by calling the function create_phone_lm()
. The denominator FST is created in the stage -5
within the train.py
script, which internally makes a call to the binary chain-make-den-fst
. The denominator graph is specificied in chain-den-graph.cc
. It uses the files $dir/tree
(the tree) and $dir/0.trans_mdl
(the transition model), which correspond to the C
and H
components, and the phone LM that was created in the previous stage.
The phone LM P
is constructed so that the overall size of the graph is minimized. It is a 4-gram with no backoff lower than 3-gram so that triphones not seen in training cannot be generated. The number of states is limited by completely removing low-count 4-gram states.
Once we have the composed graph HCP
, a different kind of minimization technique is used, which consists of performing the following operations thrice in a row.
- Push the weights
- Minimize the graph
- Reverse the arcs and swap initial and final states.
A trick used to reduce the size of the denominator FST for training on the GPU is to train on chunks of 1-1.5 seconds, instead of the entire utterance. However, to do this, we would also need to break up the transcript, and 1-second chunks may not coincide with word boundaries. How do we solve this?
Recall that the numerator FST is defined to be utterance-specific, and encodes alternative pronunciations of the transcript of the original utterance. This lattice is turned into an FST that constrains at what time the phones can appear, with an error window of 0.05s from their position in the lattice. This is then processed into an FST whose labels are pdf-ids (neural net outputs). We extract fixed size chunks from this FST for training chunks in the denominator FST.
Another issue associated with chunk-level FSTs is that the initial probabilities are now different. We approximate this by running the HMM for a few iterations and then averaging the probabilities to use as the initial probability of any state. This is a crude approximation but it seems to work. We then call this the normalization FST.
The numerator FST is much easier since it just contains the lattice for one utterance, broken into chunks of fixed length. It is the composition of H
, C
, and L
(no G
since the utterance is known). The only point worth mentioning here (and this will be important when we talk about boosted LF-MMI later) is that the numerator FST is composed with the normalization FST. This is done for two reasons.
- It ensures that the objective function value is always negative, which makes it easier to interpret.
- It also ensures that the numerator FST does not contain sequences that are not allowed by the denominator (or normalization) FST. This happens since the sum of the overall path weights for such sequences will be dominated by the normalization FST part.
Note that to compose the numerator lattice with HCP
, we can just take the phones at the output of the numerator lattice and project everything on the input. Since the numerator lattice changes for every utterance, we have to perform this composition for each utterance. But since the numerator lattice is small and composition is analogous to intersection, almost all paths in HCP
get removed and the final FST is very small.
What does the neural network do?
For the chain model, the neural network is just a box that does scoring for $p(o_t \mid s_t)$. Each HMM state has a pdf-id associated with it, and for each frame in the output, the neural network needs to output a score for each pdf-id. Therefore, for each chunk of input frames of width $w$, the nnet output is a matrix of dimensions $N x w$, where $N$ is the total number of pdf-ids. This matrix itself can be visualized as an FST where the nodes are the frames, the arcs denote pdf-ids, and the weights on the arcs are the scores computed by the neural network. This FST is called a sausage lattice in Kaldi.
The sausage lattice is composed with HCP
to get the total score. Here, the sausage lattice provides the acoustic score and the graph provides the graph score (read this doc for more details).
We train the neural network with lattice posteriors generated by forced alignment using a trained GMM-HMM model.
Forward-backward computations
Forward-backward computations are required in 2 situations:
-
At training time, we need to compute the state occupancies for the numerator FST $\gamma_{rt}(s \mid W_r)$ and the denominator FST $\gamma_{rt}(s)$. These are nothing but the product $\alpha \beta$, where $\alpha$ and $\beta$ for any state can be obtained using forward and backward algorithm, respectively.
-
At test-time Viterbi decoding, while pruning the FST, we might want some kind of look-ahead to avoid pruning paths which can have higher probability later. The $\beta$ values can be used as a proxy look-ahead.
Since the numerator FST is much smaller, its forward and backward computations are performed on CPU (the process is outlined in chain-numerator.h
), while those for the denominator FST (outlined in chain-denominator.h
) are performed on the GPU.
The basic forward and backward algorithm are the same as well known in literature, and a pseudocode is also given in the extended comments in chain-denominator.h
. However, this algorithm is susceptible to numeric overflow and underflow. To avoid this, we multiply the emission probability of the frame with a normalizing factor $\frac{1}{alpha(t)}$ where $alpha(t) = \sum_{i} \alpha_i (t)$. This is also called an “arbitrary scale” since in principle it can be allowed to be any value and doesn’t affect the posterior. However, we do need to add a quantity $\sum_{t=0}^{T-1} \log alpha(t)$ to the final log probability obtained to make it equal to the actual log probability. This “arbitrary scaling” is used in both the forward and backward computations.
The actual objective function computation is implemented in ComputeChainObjfAndDeriv()
defined in chain-training.cc
. There are two Kaldi-specific things I must point out here.
-
The forward-backward computation for the denominator FST in the GPU is not done in the log domain, since computing log several times makes things slower. However, this also means that the objective function values can occasionally become “bad”. To fix this, the
PenalizeOutOfRange()
function is used to encourage the objective to be within the [-30,30] range. -
The denominator computation is performed before the numerator, so as to reduce the maximum memory usage. I am not sure how this is, but it is important to remember this detail as we move to the implementation of boosted LF-MMI.
Implementing boosted LF-MMI
First, what is boosted LF-MMI? It is the same as LF-MMI, except that now we optimize the following objective function.
\[F_{bMMI}(\lambda) = \sum_{r=1}^R \log \frac{P_{\lambda}(O_r|M_{W_r})P(W_r)}{\sum_{\hat{w}}P_{\lambda}(O_r|M_{\hat{w}})P(\hat{w})e^{-bA(M_{w_r},M_{\hat{w}})}},\]where $b$ is the boosting factor and $A(M_{w_r},M_{\hat{w}})$ is the accuracy function which measures the number of matching labels between the reference and hypothesis sequences. My Kaldi implementation for LF-bMMI can be found in this branch. You may note that most of the changes are cosmetic and only serve to pass the new argument $b$ from the training script to the actual implementation, which is in the function ComputeBoostedChainObjfAndDeriv()
.
In our implementation, the only change is that in the computation for num_logprob_weighted
, we subtract from numerator.forward()
by a term b * num_seq * frames_per_seq
. This might seem weird at first, since in the expression of the objective function, we actually subtract the denominator by this term. However, recall that the numerator FST is composed with the normalization FST, so that this modification will result in the same result as the objective function above.
On trying out LF-bMMI for mini-Librispeech, I found it to be slightly worse than regular LF-MMI (11.86 vs 11.74 WER), and consultation with Vimal Manohar revealed that he had tried LF-bMMI and LF-SMBR along with Hossein Hadian last year to similar results.