Fixing DPO but I have a dinner reservation …

Direct preference optimization (DPO; https://arxiv.org/abs/2305.18290) is all the rage, i heard. i also hear from my students that DPO, which minimizes the following loss, often results in weird behaviours, such as unreasonable preference toward lengthy responses (even when there is no statistical difference in lengths between desirable and undesirable responses.) i won’t go into details of these issues, but i feel like there’s a relatively simple reason behind these pathologies based on basic calculus.

\[
\mathcal{L}_{\mathrm{dpo}}(\theta) = -\log \left(1 + \exp \left(- \log \frac{p_{\theta}(y|x)}{p_{0}(y|x)}
+ \log \frac{p_\theta(y’|x)}{p_{0}(y’|x)}\right)\right),
\]

where $p_0$ is the so-called reference model from which $y$ and $y’$ were drawn independently given $x$.

Let’s consider DPO with a fixed set $D$ of query-responses triplets $(x, y, y’)$ drawn from $p_0$. that is, $y, y’ \sim p_0(\cdot | x)$. Without loss of generality, i will always say that $y$ is more preferred than $y’$. the overall loss is then:

\[
J_{0} (\theta) =
\mathbb{E}_{x} \mathbb{E}_{y, y’ \sim p_0(\cdots | x)} \left[
\mathcal{L}_{\mathrm{dpo}}(\theta)
\right].
\]

what’s the issue here? the issue is that updating $\theta$ by minimizing this loss does not necessarily lead to $p_{\theta}(\cdot |x)$ from which we draw a good response. that is, there is no reason why $p_{\theta}(y|x) \gg p_{\theta}(\tilde{y}|x)$, where $y \in D$ and $\tilde{y}$ is an arbitrary sequence.

instead, a proper loss would be the following:

\[
J_{\mathrm{PDPO}} (\theta) =
\mathbb{E}_{x} \mathbb{E}_{y, y’ \sim p_{\theta}(\cdots | x)} \left[
\mathcal{L}_{\mathrm{dpo}}(\theta)
\right] =
\sum_{x} p(x)
\sum_{y, y’} p_{\theta}(y|x) p_{\theta}(y’|x) \mathcal{L}_{\mathrm{dpo}}(\theta).
\]

the main difference is that we are not using a fixed set of triplets drawn from $p_0$ but we use the samples drawn from the latest model $p_{\theta}$. This makes perfect sense, since responses we care about are those that we are more likely to draw from the trained model $p_{\theta}$. let’s now look at the gradient of this proper loss $J_{\mathrm{proper}}$ with respect to $\theta$ here.

\[
\begin{array}{rl}
\nabla J_{\mathrm{PDPO}}
=&
\nabla \mathbb{E}_x
\mathbb{E}_{y, y’ \sim p_{\theta}(\cdot|x)} \left[ \mathcal{L}_{\mathrm{dpo}}(\theta)\right]
\\
=&
\mathbb{E}_x
\mathbb{E}_{y, y’ \sim p_{\theta}(\cdot|x)}
\left[
\mathcal{L}_{\mathrm{dpo}}(y, y’, x)
\nabla_{\theta} (\mathcal{L}_{\mathrm{NLL}}(y, x) + \mathcal{L}_{\mathrm{NLL}}(y’, x))
+
\nabla_{\theta} \mathcal{L}_{\mathrm{dpo}}(y, y’, x)
\right],
\end{array}
\]

where we use a couple of tricks for computing the derivative, such as $\nabla (a \cdot b) = (\nabla a) b + a (\nabla b)$ and the log-derivative trick ($\nabla a = a \nabla \log a$). we use $\mathcal{L}_{\mathrm{NLL}}(y, x)$ as a short-hand notation of $- \log p_{\theta}(y|x)$.

what is interesting is that we automatically end up with two types of loss functions. the first one is the usual DPO loss. the second one is the likelihood on both desirable and undesriable responses. the second one is extremely important, since this one ensures that we are more likely to sample responses for which the first one (DPO) was optimized, after training.

now, this proper DPO loss (perhaps i can call it PDPO, since i was told we must name every single math formula in an obscure way) is not easy to minimize, as we must be able to determine which of an arbitrary pair of responses $(y, y’)$ given the query $x$ is more desirable. if $y$ is a molecular description, we would need to synthesize them and experiment with them to tell which is better. in other words, this PDPO loss is more readily usable when we have a ready and cheap way to tell the preference.

we can instead use importance sampling with the fixed set $D$ of the preference triplets $(x, y, y’)$:

\[
\nabla J_{\mathrm{PDPO}}^{\mathrm{IS}}(\theta)
\approx
\sum_{(x, y, y’) \in D}
\frac{p_\theta(y|x)}{p_0(y|x)}
\frac{p_\theta(y’|x)}{p_0(y’|x)}
\left(
\mathcal{L}_{\mathrm{dpo}}(y, y’, x)
\nabla_{\theta} (\mathcal{L}_{\mathrm{NLL}}(y, x) + \mathcal{L}_{\mathrm{NLL}}(y’, x))
+
\nabla_{\theta} \mathcal{L}_{\mathrm{dpo}}(y, y’, x)
\right).
\]

the importance weights, $\frac{p_\theta(y|x)}{p_0(y|x)}\frac{p_\theta(y’|x)}{p_0(y’|x)}$, say that we would use the pre-collected preference triplets only if they are reasonably likely under the current model $p_{\theta}$. this makes sense, as we care about the examples that are more likely to be drawn from the current model. unfortunately, this approach is not ideal, since the quality of each preference triplet becomes worse as $\theta$ drifts away.

so, what should we do? we should (1) draw triplets from the current model $p_{\theta}$ as frequently as possible based on the available resources and constraints and (2) use importance sampling $\nabla J_{\mathrm{PDPO}}^{\mathrm{IS}}(\theta)$ to update the parameters rather than the original gradient. unfortunately (1) is very difficult because it’s often really costly to measure the preference over two responses. (2) is also difficult, because the variance of this estimator will blow up quite rapidly as $\theta$ quickly evolves.

do i have any empirical evidence? unfortunately i have a dinner reservation and need to leave for the restaurant shortly.

Acknowledgement: i’d like to thank Richard Pang, who’s a graduating PhD student at NYU, for spending a couple of hours later afternoon on Friday to hear my rant and (then-incorrect) derivation. Also, i thank Weizhe Yuan and Angie Chen for keeping me up-to-date on mysteries and magics people perform each day finetuning language models, which serves as the constant motivation for me to think about this problem.

Leave a Reply