In the previous post on the MVDR beamformer, we saw how speaker-specific “masks” can be used in conjunction with a multi-channel input signal to extract the speaker-specific signal in the presence of background noise or interfering speakers. Even earlier, we saw how such masks can be estimated by modeling the mixture STFT using complex angular central GMMs. Although this approach is useful since it provides an unsupervised way of learning the masks, it has its limitations:

- Iterative mask estimation (using the EM algorithm) is slow and may not be feasible to use in real-time applications.
- The model makes the assumption that STFT bins are linear combinations of speaker models, and so it may not work for highly non-linear mixtures.

## Estimating masks using neural networks

Neural networks are the most popular methods for approximating non-linear functions. As such, we can naively think about replacing CACGMM-based mask estimation with a neural network. However, there are some points that need to be kept in mind:

- NNs are discriminative models and require paired training data, i.e., they must be trained with simulated mixtures.
- An NN-based mask estimator trained to predict 2 masks cannot predict 3 masks at test time, i.e., number of speakers is fixed.
- Decoding long recordings can be intractable from a memory perspective, so decoding must be performed in chunks.
- Chunk-based decoding can cause permutation problem between chunks of the same recording.

**Continuous speech separation** provides an elegant solution to all of the above concerns.
The core idea is that even though the overall recording may contain an arbitrary number of
speakers, there are usually at most a small number (say 2 or 3) speakers if we consider a
small segment of the recording. So mask estimation can be performed on these small chunks
to estimate a fixed number of masks. This solves the first 3 concerns, i.e, we have fixed
number of outputs and decoding smaller segments is feasible on the constrained GPU memory.
However, we still need to solve the permutation problem across chunks.

For this, CSS proposes a “stitching” mechanism. Chunks are created in a strided window manner with hop less than the chunk size, so that adjacent chunks share a small portion. When computing masks for a new chunk, the masks are reordered so that the cross-entropy with the masks in the preceding chunk is minimized. This process lends some “continuity” to the chunks.

In practice, the stitching process can be implemented using an Overlap-Add with the mask estimation network as the transform function.

## Generalizing with graph-PIT

While the above formulation of CSS has shown strong performance, it is still constrained
in the sense that it assumes each segment to have at most 2 speakers, which may not always
be the case. Recently, it was shown that this constraint can be softened: we just
need to assume that at most 2 speakers are active *at any instant of time*. This is possible
by generalizing the permutation invariant training (PIT) objective that is often used for
training the mask estimation networks.

To generalize PIT, we basically assign utterances to the 2 output channels so as to avoid having overlapping utterances in the same channel. This can be formulated as a graph coloring problem, hence the name graph-PIT. The idea is that if we model utterances as nodes in a graph such that overlapping utterances are connected with an edge, the problem of assigning utterances to channels is equivalent to the problem of coloring nodes with 2 colors such that adjacent nodes have different colors.

Nevertheless, with $U$ utterances and $N$ channels, there can be $N^U$ possible solutions, since graph coloring is an NP-hard problem.

## Speeding up uPIT and graph-PIT

In a follow-up paper, the authors of graph-PIT proposed methods to mitigate the exponential complexity issue that graph-PIT faces. It is a short and elegant paper with some nice matrix math, so I will spend a bit of time on it in this post.

To understand how the idea works, let us first look at the task definition and 2 types of loss functions. We have a mixture containing $U$ utterances $\mathbf{S} = [\mathbf{s}_1,\ldots,\mathbf{s}_U] \in \mathbb{R}^{T\times U}$. Suppose we estimate the outputs on $C$ channels $\hat{\mathbf{S}} = [\hat{\mathbf{s}}_1,\ldots,\hat{\mathbf{s}}_U] \in \mathbb{R}^{T\times C}$. Assume $U=C$ for simplicity. We need to compute some loss $\mathcal{L}(\mathbf{S},\hat{\mathbf{S}})$.

The loss function $\mathcal{L}$ can be one of 2 types. In the first type, it can be decomposed into an aggregate over pairwise losses $\mathcal{L}(\mathbf{s},\hat{\mathbf{s}})$ — let us call these “aggregated” loss; an example is averaged SDR (a-SDR), which is an average over pairwise SDR:

\[\begin{aligned} \mathcal{L}^{\mathrm{a-SDR}}(\mathbf{S},\hat{\mathbf{S}}) &= -\frac{1}{C}\sum_{c=1}^C 10 \log_{10} \frac{\Vert \mathbf{s}_c\Vert^2}{\Vert \mathbf{s}_c - \hat{\mathbf{s}}_c\Vert^2} \\ &= \frac{1}{C}\sum_{c=1}^C \left( - 10 \log_{10} \frac{\Vert \mathbf{s}_c\Vert^2}{\Vert \mathbf{s}_c - \hat{\mathbf{s}}_c\Vert^2} \right) \\ &= \frac{1}{C}\sum_{c=1}^C \mathcal{L}^{\mathrm{SDR}} (\mathbf{s}_c,\hat{\mathbf{s}}_c). \end{aligned}\]The second kind of losses are defined over the whole group of utterances and cannot be decomposed into pairwise losses — let us call these “group” loss; an example is source-aggregated SDR:

\[\mathcal{L}^{\mathrm{sa-SDR}}(\mathbf{S},\hat{\mathbf{S}}) = - 10 \log_{10} \frac{\sum_{c=1}^C \Vert \mathbf{s}_c\Vert^2}{\sum_{c=1}^C \Vert \mathbf{s}_c - \hat{\mathbf{s}}_c\Vert^2}.\]In utterance-level PIT (uPIT), we define a “meta”-loss function over an underlying loss (which can be an aggregated loss or a group loss) such that we minimize the loss over all permutations of the source utterances, i.e.,

\[\mathcal{J}^{\mathrm{uPIT}} (\hat{\mathbf{S}},\mathbf{S}) = \min_{\mathbf{P} \in \mathcal{P}_C} \mathcal{L}(\hat{\mathbf{S}}, \mathbf{SP}),\]where $\mathcal{P}_C$ is the set of all permutation matrices. To solve the above problem, there can be 2 ways depending on whether the underlying $\mathcal{L}$ is an aggregated loss or a group loss.

**Case 1 ($\mathcal{L}$ is an aggregated loss):** Let $\mathbf{M} \in \mathbb{R}^{C\times C}$ be
the matrix of pairwise losses. Then,

which can be solved in $\mathcal{O}(C^3)$ time using the Hungarian algorithm.

**Case 2 ($\mathcal{L}$ is a group loss):** We cannot directly use the Hungarian algorithm since
group losses cannot be decomposed into pair-wise losses. However, we can still use it if it is
possible to write the loss as

\(\mathcal{J}^{\mathrm{uPIT}} (\hat{\mathbf{S}},\mathbf{S}) = f(\min_{\mathbf{P} \in \mathcal{P}_C} \mathrm{Tr}(\mathbf{MP},\hat{\mathbf{S}},\mathbf{S}))\),

where $\mathbf{M} \in \mathbb{R}^{C\times C}$, and $f$ is a strictly monotonously increasing function. There can be multiple ways of doing this decomposition. For the SA-SDR loss, one such way is:

\[f(x, \hat{\mathbf{S}},\mathbf{S}) = -10 \log_{10} \frac{\mathrm{Tr}(\mathbf{S}^T\mathbf{S})}{\mathrm{Tr}(\mathbf{S}^T\mathbf{S}) + \mathrm{Tr}(\hat{\mathbf{S}}^T\hat{\mathbf{S}}) + 2x},\]with $\mathbf{M}$ defined as $\mathbf{M} = -\hat{mathbf{S}}^T \mathbf{S}.$

For graph-PIT, the loss function can be written as:

\[\begin{aligned} \mathcal{J}^{\mathrm{Graph-PIT}} (\hat{\mathbf{S}},\mathbf{S}) &= \min_{\mathbf{P}\in \mathcal{B}_{G,C}} \mathcal{L}(\hat{\mathbf{S}},\mathbf{SP}) \\ &= f(\min_{\mathbf{P}\in \mathcal{B}_{G,C}} \mathrm{Tr}(\mathbf{MP}, \hat{\mathbf{S}},\mathbf{S})), \end{aligned}\]where $\mathbf{P}$ is no longer a square matrix. If we use SA-SDR as the underlying loss, we can still use the definitions of $f$ and $\mathbf{M}$ as defined above, since they work on the whole group of sources and outputs.