Stochastic variational inference for low-rank stochastic block models, or how i re-discovered SBM unnecessarily

Prologue

a few weeks ago, i listened to Sebastian Seung’s mini-lecture at Flatiron Institute (CCM) about the recently completed fruit fly brain connectome. near the end of the mini-lecture, sebastian talked about the necessity of graph node clustering based on the type-level connectivity patterns instead of node-level connectivity patterns. i thought that would be obviously easy to solve with latent variable modeling and ChatGPT. i was so wrong, because ChatGPT misled me into every possible wrong corner of the solution space over the next two weeks or so.

eventually, i implemented a simple variational inference approach to latent variable clustering, which took only 3-4 hours in total (ugh ChatGPT!) after a couple of weeks, re-reading the whole post, i realized that i re-discovered the stochastic block model 🤦.

anyhow, out of spite, i’ve now canceled the ChatGPT subscription in favour of Google Gemini (2.0 family look very nice,) and turned the original markdown at https://hackmd.io/r55wOtjtTbCdsmE5M_cc9w into the wordpress markdown using gemini-exp-1206. it works beautifully, but of course, i must warn you that there may be some wrong conversions here and there, especially in math expressions. i mean … gemini capitalized every sentence for me below …

my apologies in advance.

Motivation: inter-node connectivity based clustering

Many conventional approaches to unsupervised vertex (node) clustering in a graph can be summarized as grouping nodes based on the overlaps in their neighbours. As a typical and representative example, we can use \(k\)-means clustering for vertex clustering of a graph \(\mathcal{G}=(V, E)\) by representing each node using the following feature vector \(v_i \in \mathbb{R}^{2|V|}\):

\[
v_i^j = \begin{cases}
1, & \text{if } (i,j) \in E \\
0, & \text{otherwise}
\end{cases}
\]

for \(j=1, \ldots, |V|\), and

\[
v_i^{|V|+j} = \begin{cases}
1, & \text{if } (j,i) \in E \\
0, & \text{otherwise}
\end{cases}
\]

By running \(k\)-means clustering on this feature representation, two nodes that share a similar set of incoming and outgoing \(1\)-hop neighbours will likely belong to the same cluster. It is often a usual practice to perform dimensionality reduction followed by \(k\)-means clustering or its variant.

Such an approach is desirable especially when the goal is to coarsen a given graph by merging a set of vertices belonging to each cluster into a single vertex in a newly coarsened (or compressed) graph. In other words, this approach looks for nodes that are interchangeable.

We are however often interested in clustering of nodes according to inter-cluster connectivity patterns rather than inter-node connectivity patterns. For instance, we can have two clusters that are bipartite (no intra-cluster edges) and one-to-one (one node in one cluster is connected to exactly one node in the other cluster.) Because none of the vertices from one cluster shares the neighbour nodes from the other cluster, the conventional approach would not cluster them into a single cluster.

In this note, we are interested in a simple approach to vertex clustering that almost entirely focuses on the inter-cluster connectivity rather than the inter-node connectivity. The proposed approach can be thought of as a special case of a stochastic block model (SBM) for a directed graph, such that it is readily applicable to a large graph. In the next section, I will describe this algorithm and its relationship to SBM.

Latent variable based vertex clustering

The proposed approach can be described at a high level as a latent-variable generative model of a directed edge. By repeatedly sampling the existence \(e_{i,j} \in \left\{0, 1\right\}\) of each directed edge at a time for all vertex pair \(V \times V\), we end up with a graph instance \(\mathcal{G}\). The probability of the directed edge \(i\to j\) for a node pair \((i,j)\) depends not on the nodes themselves but their latent cluster assignments, \(z_i \in \left\{1, \ldots, K\right\}\) and \(z_j \in \left\{1, \ldots, K\right\}\), respectively.

In other words, for each node pair \((i,j)\), we first determine their latent clusters by drawing \(z_i\) and \(z_j\) from the prior distributions (could simply be uniform over \(\left\{1, \ldots, K\right\}\)). We then stochastically decide whether there is an edge \(i\to j\) by drawing a sample from \(p(e_{i,j} | z_i, z_j)\) and also decide whether there is \(j \to i\) by drawing a sample from \(p(e_{j,i} | z_j, z_i)\). This edge probability is parameterized with a small number of parameters.

The goal is then to simultaneously infer the posterior distribution over \(z_1, \ldots, z_{|V|}\) and estimate the edge-probability parameters given an observed graph \(\mathcal{G}\). In order to avoid the (potential) intractability of doing so, we can rely on stochastic gradient based variational inference to estimate the factorized approximate posterior and assign each node to the most likely cluster according to its own approximate posterior.

In this approach, the chance of having an edge from one node to another is entirely up to their latent clusters. You can almost think of this as inferring the type of each node and considering the connectivity at this type level.

For instance, consider natural language text. Instead of considering how surface-level tokens are connected (that is, are placed next to each other), we can infer its part-of-speech tag (e.g., verbs, nouns, etc.) and consider their connectivity. This would allow us more easily to realize that for instance in English verbs often follow nouns, even when noun-verb pairs (subject-verb pairs) are often extremely sparsely observed.

In the next section, we consider and describe one particular implementation of this approach in detail.

Details

We will detail the algorithm and learning procedure here. We start by defining \(\alpha_i \in \Lambda^{K-1}\) be a distribution over \(K\) clusters for the \(i\)-th node in a graph \(\mathcal{G}=(V, E)\). \(\Lambda^{K-1}\) is the \((K-1)\)-dimensional simplex such that

\[
\begin{aligned}
&\alpha_i \geq 0,\text{ for all } i, \\
&\sum_i \alpha_i = 1.
\end{aligned}
\]

In other words,

\[
q(z_i = k) = \alpha_i^k,
\]

where \(z_i \in \left\{1, \ldots, K\right\}\) is a cluster assignment variable. This distribution \(q(z_i)\) will be the approximate posterior for each node \(i\).

Given a pair of \(z_i\) and \(z_j\), we can compute the directed edge probability as

\[
p(e_{i,j} =1 | z_i, z_j) =
\sigma(
u_s^{z_i} \cdot u_t^{z_j} + b
),
\]

where

  • \(e_{i,j} \in \left\{0, 1\right\}\) indicates the existence of the directed edge from \(i\) to \(j\).
  • \(u_s^k \in \mathbb{R}^d\) is the source feature vector of the \(k\)-th cluster.
  • \(u_t^k \in \mathbb{R}^d\) is the target feature vector of the \(k\)-th cluster.
  • \(b \in \mathbb{R}\) is a bias capturing the overall sparsity of the graph \(\mathcal{G}\).
  • \(\sigma(a) = \frac{1}{1+\exp(-a)}\) is a sigmoid function.

We can then write down the objective function (the variational lowerbound to the log-likelihood) as

\[
\begin{aligned}
\mathcal{J}(u_s, u_v, b, \alpha) =&
\sum_{i=1}^{|V|}
\sum_{j=1}^{|V|}
\mathbb{E}_{q(k)q(k’)}
\left[
e_{i,j} \log p(e_{i,j}=1|k, k’)
+\right.
\\
&\qquad\qquad\qquad\quad\quad\left.
(1-e_{i,j}) \log (1- p(e_{i,j}=1|k, k’))
\right]
\\
&\qquad\qquad\qquad\quad\quad
+\sum_{i=1}^{|V|} \mathcal{H}(q(z_i))
\\
=&
\sum_{i=1}^{|V|}
\sum_{j=1}^{|V|}
\sum_{k=1}^K
\sum_{k’=1}^K
\alpha_i^k
\alpha_j^{k’}
\left(
e_{i,j} \log p(e_{i,j}=1|k, k’) \right. \\
&\qquad\qquad\qquad\qquad\quad\quad+
\left.
(1-e_{i,j}) \log (1- p(e_{i,j}=1|k, k’))
\right) \\
&-
\sum_{i=1}^{|V|} \sum_{k=1}^K \alpha_i^k \log \alpha_i^k,
\end{aligned}
\]

where \(\mathcal{H}\) is the entropy functional. \(e_{ij}\) is set to an observed existence of the directed edge \(i\to j\).

We realize that \(p(e_{i,j} = 1 | k, k’)\) does not depend on \(i\) nor \(j\). It only depends on \(k\) and \(k’\). Let us use \(s_{k,k’}\) to denote this quantity. Then, we can simplify the first term into

\[
\begin{aligned}
&\sum_{k=1}^K \sum_{k’=1}^K
\log s_{k, k’}
\sum_{i=1}^{|V|} \sum_{j=1}^{|V|}
\alpha_i^k \alpha_j^{k’}
e_{i,j}
+
\sum_{k=1}^K \sum_{k’=1}^K
\log (1-s_{k, k’})
\sum_{i=1}^{|V|} \sum_{j=1}^{|V|}
\alpha_i^k \alpha_j^{k’}
(1-e_{i,j}) \\
&=
\sum_{k=1}^K \sum_{k’=1}^K
\log s_{k, k’}
\sum_{i=1}^{|V|} \alpha_i^k
\sum_{j=1}^{|V|} \alpha_j^{k’} e_{i,j} +
\log (1-s_{k, k’})
\sum_{i=1}^{|V|}
\alpha_i^k
\sum_{j=1}^{|V|}
\alpha_j^{k’}
(1-e_{i,j}),
\end{aligned}
\]

which enables us to create an efficient implementation based on matrix-vector operations.

For a large graph \(\mathcal{G}\) with many nodes (\(|V| \gg 0\)), we can use minibatch stochastic gradient descent to maximize this objective w.r.t. \(\alpha\), \(u_s\), \(u_v\) and \(b\), all together. In this case, we sample a minibatch by sampling a random subgraph \(\tilde{\mathcal{G}} \subset \mathcal{G}\). Because the approximate posterior is factorized over the nodes, minibatching works without any sophisticated implementation techniques.

Instead of \(\alpha_i\), which requires us to take into account the constraints (non-negativity and sum-to-1), we can reparameterize it as

\[
\alpha_i^k = \frac{\exp(\beta_i^k)}{\sum_{i’} \exp(\beta_{i’}^k)},
\]

and perform optimization w.r.t. the unconstrained \(\beta_i\). This last trick dramatically simplifies the implementation and also allows us to rely almost entirely on the standard components of any standard deep learning library, such as PyTorch.

Once optimiztaion is over, we can either use the estimated approximate posterior \(\alpha_i\) to determine the cluster assignment of each vertex \(i\), e.g. by \(\hat{z}_i = \arg\max_k \alpha_i^k\), or perform exact posterior inference. Though, the latter could be extremely costly.

Relationship to a stochastic block model (SBM)

The proposed approach can be thought of as a stochastic block model (SBM) for a directed graph with the following modifications:

  1. Low-rank approximation to the block connectivity matrix \(B \in \left[0,1\right]^{K \times K}\): the block connectivity matrix \(B_{k, k’}\) in an SBM is approximated by \(\sigma(

    u_s^{k} \cdot u_t^{k’} + b)\).
  2. Stochastic variational inference with Adam optimizer: by reparameterizing both \(B\) and the approximate posterior to admit gradient-based optimization, we use the latest stochastic optimization technique from deep learning to seamlessly scale inference to a very large graph.
  3. Potential for future extensions: because we parametrize the block-block connectivity, there are many opportunities to extend the proposed approach in the future. For instance, we can contextualize the whole process by making \(u_s\) and \(u_t\) vectors conditioned on the context.

Simulation: clustering neurons

Task and Data: FlyWire

As a demonstration, we use the fruit fly connectome data, called FlyWire. In particular, we use the connection data. From this connection data, we create a connectivity matrix of size 134,181×134,181 with 2,700,513 non-zero entries. If there are multiple synapses between two neurons, we simply suppress them to become \(1\). In other words, we end up with a sparse, binary matrix of size 134,181×134,181.

Here, we take visual neurons as neurons of our interest. After clustering of these 134,181 neurons is done, we compare the clustering of a subset, that corresponds to the visual neurons, to the annotated visual neurons types from FlyWire. There are 729 types in total for 46,479 neurons.

Evaluation metric

Because we have the ground-truth clustering assignments, we use Hungarian algorithm to maximize the following objective function and use the achived objective as the quality metric of clustering:

\[
\sum_{X \in \left\{0, 1 \right\}^{K \times K’}}
\sum_{k=1}^K \sum_{k’=1}^{K’}
C_{k,k’} X_{k,k’},
\]

where \(C_{k,k’}\) is the number of instances that belonged to the \(k\)-th cluster and the \(k’\)-th cluster, respectively, according to two clustering algorithms in comparison. When one of the algorithms is the ground-truth one, the higher objective implies the better alignment between the ground-truth and predicted clustering assignments.

In addition to the score we get, we also get the actual alignment between the clusters from \(X\). This allows us to visually inspect the relationship between two clustering assignments.

Comparison: PCA+\(k\)-means clustering

In order to highlight how inter-cluster connectivity focused node clustering is different from a conventional approach, we implement one simple iterative algorithm for vertex clustering. We first train a linear autoencoder using stochastic gradient descent to minimize the following loss function:

\[
\min_{U \in \mathbb{R}^{N \times d}}
\|
(W U) U^\top + b – W
\|^2_2,
\]

where \(W\) is the observed sparse, binary connectivity matrix, and \(b\) is a scalar bias. We optionally orthogonalize \(U\), although this may not be necessary. We then run \(k\)-means clustering on the \(N\) rows of estimated \(U\).

Random assignment baseline

In order to establish that clustering does something meaningful, we use two random cluster assignment strategies as minimal baselines. If the scores of these baselines are comparable to those by either the proposed latent variable approach or the PCA+\(k\)-means clustering approach, we would easily conclude that comparison between the latter two is not meaningful.

The first random assignment is optimistic, in that the cluster marginal distribution follows that of the ground-truth assignment. We create this assignment by randomly permuting the ground-truth cluster assignment. The second one is pessimistic, in that we assign each vertex uniformly to one of the 729 clusters at random. This is still not the most pessimistic version in that we assume we know the true number of clusters.

Implementation

Some very naive and potentially buggy implementation is available at

https://github.com/kyunghyuncho/flywrite-cell-types

The code relies heavily on pytorch, scikit-learn and scipy. The code is released under MIT License.

Result

We first use 1,024 clusters for both algorithms; PCA+\(k\)-means clustering and latent variable based clustering. In the case of the former, we update \(W\) and \(b\) 10,000 times using stochastic gradient descent with minibatches of size 512. We run stochastic gradient descent until the clustering assignment converges in the case of the proposed approach. For both, we use Adam as an optimizer.

PCA+\(k\)LVRand(Opt)Rand(Pess)
GT155030191235985
PCA+\(k\)143310291238
LV13571162
Rand(Opt)985

From the result table above, we make a few important observations. First, we see that random assignment generally gets the score of \(\approx\) 1000. The optimistic assignment receives 1235 against the ground-truth assignment, while the pessimistic one 985. Both PCA+\(k\) and LV receive significantly higher scores than either of these random assignments, suggesting that the proposed quality metric is sensible.

Second, the agreement between PCA+\(k\) and LV is significantly higher than 1000. This implies that there is some shared structures behind the graph captured by both PCA+\(k\) and LV. Despite the distinct goals behind these algorithms, inter-node and inter-cluster connectivity patterns are not independent of each other. We leave for the future how to design an algorithm that can exploit both in a flexible and controllable manner.

Finally, and perhaps most importantly, the proposed approach (LV) receives a substantially higher score of 3019 than PCA+\(k\) which received 1550. 1550 is significantly higher than any score by random assignment, but the gap is even greater when compared against LV. This suggests that the neuron types of visual neurons do correlate better with inter-cluster connectivity, and that the proposed approach can capture this inter-cluster connectivity patterns.

Limitations and future directions

A major limitation of the proposed approach is that it only takes into account the first-order neighbours in the cluster space. As it has become quite obvious from the advances and successes of language models in recent years, there are a lot of signals that can be squeezed out by considering an increasingly larger context. Instead of considering just an immediate neighbouring vertex’s cluster assignment, it will be important in the future to consider a larger neighbourhood.

In order to take into account multi-hop neighbours, a graph neural net can be an interesting alternative. It will however make learning and inference much more challenging, as the first term in the objective function above will not be exactly computable and require sampling-based approximation. It is to be seen whether such a sampling-based approach would have low enough variance to be useful.

Another limitation is not about the method but about evaluation. We chose visual neuron type classification, because these visual neurons are classified according to which other types of neurons they are connected to rather than which other neurons they are connected to. Even then, it is unclear whether the proposed approach would be well fit for this problem. It will be necessary to test the proposed approach on a more diverse set of benchmark problems in the future.

Finally, there is yet another limitation on the evaluation protocol we used. For evaluation, I used all the neurons for clustering and then selected a subset of visual neurons to check whether their assigned clusters make sense. This may not be an ideal choice, since some of the clusters may be dedicated to non-visual neurons only, although a large portion of fly neurons are visual neurons. It would potentially improve clustering dramatically by considering only the visual neurons and their immediate neighbours. We leave this for the future.

Leave a Reply