retrieval-augmented generation (RAG) is all the rage in the world of LLM’s (i heard.) RAG confuses me quite a bit, since it’s unclear to me how RAG should work. in particular, i have a major confusion in how language models should be trained to be good at retrieval augmented generation. it’s a simple confusion, and let me describe it here.
let $D$ be an entire training corpus i have prepared to train a language model. a naive way to train a language model is to
\[
\max_{\theta} \sum_{x \in D} \log p_{\theta}(x).
\]
this whole process of learning can be thought of as compressing $D$ into $\theta$ so that the language model approximately goes through $D$ given a new instance in the future to check how similar this new instance to the instances within $D$ and returns a similarity score, i.e., $\log p_{\theta} (x’)$, where $x’$ was a new instance.
let’s say i want to ensure that my language model is good at retrieval augmented generation. then, i would change the objective above into
\[
\max_{\theta} \sum_{x \in D} \mathbb{E}_{i \sim \mathrm{uniform}(1,|x|)} \log p_{\theta}(x_{i+1:|x|}|x_{1:i}, \mathrm{retrieval}(x_{1:i}, D)),
\]
where $|x|$ is the length of $x$ and $\mathrm{retrieval}(x_{1:i}, D)$ is the retrieval function that retrieves passages similar from $D$ according to $x_{1:i}$. sounds good so far, right?
… not really, because there is a genuine ambiguity in whether the language model should use information from $D$ relevant to predicting $x_{i+1:|x|}$ given $x_{1:i+1}$ via its parameters $\theta$ or via $\mathrm{retrieval}(x_{1:i}, D)$. after all, $\theta$ is a compressed version of $D$ in a way that facilitates retrieval of relevant information by the language model. what would make the language model prefer to rely on the passages retrieved by $\mathrm{retrieval}$ rather than its own (unknown) internal mechanism?
of course, we can fix this explicitly by splitting $D$ into a training set $D_{\mathrm{train}}$ and a retreival set $D_{\mathrm{retrieval}}$, such that $D_{\mathrm{train}} \cup D_{\mathrm{retrieval}} = D$ and $D_{\mathrm{train}} \cap D_{\mathrm{retreival}} = \emptyset$. we can then change the training objective above into
\[
\max_{\theta} \sum_{x \in D_{\mathrm{train}}} \mathbb{E}_{i \sim \mathrm{uniform}(1,|x|)} \log p_{\theta}(x_{i+1:|x|}|x_{1:i}, \mathrm{retrieval}(x_{1:i}, D_{\mathrm{retrieval}})),
\]
in this case, the language model must learn to rely on the retrieved passages, since the retrieved passages are not included in the training set and thereby not in $\theta$ (at least their verbatim copies.)
of course, this approach brings up perhaps even more difficult questions. first, how big should $D_{\mathrm{retrieval}}$ be? we might think this should be big enough, but of course, that implies that $D_{\mathrm{train}}$ is smaller. we know the importance of using a large corpus to train a language model, and it’s unclear how much compromise we can make on the size of the training set.
second, what if the retrieval function misses relevant information from $D_{\mathrm{retrieval}}$? when that happens, such relevant information is totally missed during training, effectively leading to the loss of information from the overall corpus $D$.
so, are you training your language models to excel at RAG? if so, how are you doing it?
p.s. this question came to my mind and was crystalized over Andrew Drozdov‘s dissertation defense that just finished (and yes, he successfully defended his dissertation!)