On the causality view of “Context-Aware Learning for Neural Machine Translation”

[Notice: what an unfortunate timing! This post is definitely NOT an april fool’s joke.]

Sebastien Jean and I had a paper titled <context-aware learning for neural machine translation> rejected from NAACL’19, perhaps understandable because we did not report any substantial gain in the BLEU score. As I finally found some time to read Pearl’s <Book of Why> due to a personal reason  (yes, personal reasons sometimes can help), I thought I wrote a short note on how the idea in this paper was originally motivated. As I was never educated in causal inference or learning, I was scared of using a term “causal” in any of my papers so far, and this paper was not an exception. I feel like this intentional avoidance of the term may have made the paper more obscure, and perhaps it’s not a bad idea to use a blog post (and time from that personal reason) to write out my original motivation without worrying about academic scrutiny. 

 
Let me focus on building a translation model that takes as input both the current source sentence X and the previous source sentence C, and outputs the translation Y of the current source sentence X, although there is no reason to restricted C to be only a single immediate previous sentence. Let’s introduce a variable Z that represents all that we do not observe directly, such as the world state, the author’s intention and the actual meaning behind the text. You can think of Z also to include both benign and detrimental common sense, such as “bananas are always yellow” (when I was in Rwanda just a few weeks ago, I learned this to be false. see the picture on the left I took in Kimironko Market, Kigali), “presidents are often male; e.g., Monsieur President vs. Madame President”, …
 
If I were to draw a causal diagram, following Pearl, one version would look like below, where I used a dashed circle to explicitly indicate that Z is not observed:
The document, a part of which is represented by C and X, is created from (caused by) Z. The current sentence X is also caused by Z but not necessarily by its preceding sentence C. This is one assumption that I am not comfortable with, but it can be understood generously if we consider that in many cases we can more easily reorder sentences in a paragraph than reordering words in a sentence. Once we know both X and C, the translation Y of the source sentence X is determined (caused) by the source sentence X and the previous sentence C. Why is there an arrow from C to Y that bypasses X? This is due to the difference between the source and target languages. Consider an example of translating from a language without gendered pronouns to one with gendered pronouns. 
 
Based on this diagram, what we next want to know in this “context-aware neural machine translation” is the effect of the previous sentence C on the translation Y of the source sentence X. 
 
Now, a fair warning before we proceed: because I only just gave a quick read of <Book of Why>, I may be completely off here.
 
Let’s consider two paths from C to Y in the diagram above: C->Y and C<-Z->X->Y. The first path corresponds to the direct effect of C on Y, and the second path could be thought of as a path with a mediator X. The effect of the cause C on Y will be some function of these two, and if all the relationships are linear, the sum of the effects from these two paths will be the total effect of C on Y. 
 
Now, obviously, it’ll be best if we could estimate the coefficient (a set of neural net parameters) associated with each arrow in the diagram above somehow. Then, we can compute the total effect exactly, and that would be the end story of causal inference. Unfortunately, other than the coefficients of the two arrows (C->Y and X->Y), which can be estimated from data by fitting a neural machine translation system, it looks pretty unrealistic for us to estimate the parameters of Z->C and Z->X.
 
This is where we move away from causal inference and toward machine learning (in particular, machine translation). Instead of trying to estimate those coefficients and infer the causal effect of C on Y, our goal is now to train a neural machine translation system to maximally exploit the effect of C on Y. That is, we train a context-aware neural machine translation such that the context C maximally influences (causes) Y in addition to the source sentence X, according to the causal diagram above
Under this goal, the second path C<-Z->X->Y (the path colored blue above) is of our interest, as this is path contains two arrows of which we don’t know how to estimate the coefficients. We notice that this path is blocked by the confounder Z which we don’t observe nor control for (though, this could be an interesting exercise in the future to control Z by finely partitioning a corpus.) One classical technique in this case is to run a randomized trial on C, which effectively cuts the arrow from Z to C. 

This cut indicates that the choice of C is not dependent on Z. In the case of training a context-aware neural machine translation system, this can be thought of as replacing the previous sentence with any randomly drawn sentence from a large corpus (though, it is not at all clear what the distribution should be, and we discuss a few alternatives in Sec. 4.3.) Then, by contrasting the effect of C and X on Y and that of randomly drawn C and X on Y, we can measure the effect of C and Y. This can be expressed in an equation:

$$\delta(Y|X,C) = s(Y|X, C) – s(Y|X, r(C))$$

r(C) is a randomized context, and we use the conditional log-probability of Y given X and C (or r(C)) as the causal effect (score) s(Y|X,C). This formulation naturally lends itself to a new regularization term that encourages the context-aware neural machine translation system to maximize the effect of the context C on Y. We use the margin loss together with this causal effect on three different levels (minibatch, sentence and token). Here let me write out the sentence-level regularization term:
$$\mathcal{R}(\theta; \mathcal{D}) = \alpha_s \sum_{n=1}^N \left[ T_n \delta_s – s^{\text{sent}}(Y_n|X_n, C_n) + s^{\text{sent}}(Y_n|X_n) \right]_+$$
 
Minimizing this term literally maximizes the causal effect of C on Y until it is at least as large at some predefined threshold (δ) multiplied by the length of each sentence.
 
We call this regularization technique “context-aware learning” (or context-aware regularization), as I was actively avoiding a term “causal” anywhere. Indeed, this technique helps in a sense that the final, trained neural machine translation system actually degrades when a wrong context is provided, as opposed to a usual context-aware translation system which is often trained without considering this causal effect. Compare (c) and (d) below while contrasting the columns “Normal” and “Context-Marginalized”. We did also observe some improvement even when the correct context was given (Normal), but the reviewers were not impressed.  
As you may have noticed, this approach is agnostic to an underlying machine translation system. As long as you can train the underlying system with the proposed regularization term, this framework carries over very naturally. It is furthermore decoupled with the actual problem of machine translation. The proposed approach can be applied to any other problems where we have a set of input modalities, of which some are only weakly correlated with the output but are known to cause the output. 
 
Phew, there you go! I’m glad that I found some time today to fulfill my deep desire to say “causality” out loud. 
 
PS1. I had another ill-fated attempt to apply this framework to generic supervised (unsupervised) learning and explain it without mentioning anything about causality or randomized trials: https://openreview.net/forum?id=SJlh2jR9FX&noteId=rJxOVW7714. Though, I cannot tell whether Adji Dieng noticed this 🙂
 
PS2. the diagram above is slightly less satisfying, as there is no arrow from C to X. A natural next step would be the following:

We probably want to randomize both C and X. Though, I am pretty sure there must be better ways to do so.

 
PS3. While this paper was under review at NAACL’19, I saw a talk by Natasha Jaques who visited NYU. Her work nicely incorporated counterfactual analysis (now at the individual-level causal inference) to learning a set of coordinating neural net agents in a similar manner as my paper. Definitely worth a read: Social Influence as Intrinsic Motivation for Multi-Agent Deep Reinforcement Learning.

Leave a Reply