The ICLR deadline is approaching, and of course, it’s time to write a short blog post that has absolutely nothing to do with any of my manuscripts in preparation. i’d like to thank Ed Grefenstette, Tim Rocktäschel and Phu Mon Htut for fruitful discussion.

Let’s consider the following meta-optimization objective function:

$$\mathcal{L}'(D’; \theta_0 – \eta \nabla_{\theta} \mathcal{L}(D; \theta_0))$$

which we want to minimize w.r.t. θ₀. it has become popular recently thanks to the success of MAML and its earlier and more recent variants to use gradient descent to minimize such a meta-optimization objective function. the gradient can be written down as*

$$\nabla_{\theta_0} \mathcal{L}'(D’; \theta_0 – \eta \nabla_\theta \mathcal{L}(D; \theta_0) = \nabla_{\theta’} \times (1-\eta \nabla_{\theta_0} \nabla_{\theta} \mathcal{L}(D; \theta_0),$$

where θ’ is the updated parameter set. in this derivation, what we see is that the gradient w.r.t. the original parameter set θ₀ is propagated from the outer objective function L’ via θ’ which was computed using the gradient of the inner objective function L w.r.t. θ evaluated at the original parameter set θ₀.

so far so good, but what if the inner optimization procedure was

*stochastic*?that is, what if the meta-optimization objective function was:

$$\mathcal{L}'(D’; \theta_0 – \eta \mathbb{E}_z \nabla_\theta \mathcal{L}_z(D; \theta_0),$$

where z is used to absorb any stochasticity present in this gradient descent procedure. for instance, z could be use to sample a subset from D to build a minibatch gradient. after all, this is often what we do in deep learning rather than full-batch, deterministic gradient descent as shown above.

in this case, the gradient of the meta-objective function w.r.t. θ₀ looks slightly different from above:*

$$\nabla_{\theta_0} \mathcal{L}'(D’; \theta_0 – \eta \mathbb{E}_z \nabla_z(D; \theta_0) = \nabla_{\theta’}\mathcal{L}'(D’; \theta_0 – \eta \mathbb{E}_z \nabla_\theta \mathcal{L}_z(D; \theta_0) \times (1-\eta \mathbb{E}_z \nabla_{\theta_0} \nabla_{\theta} \mathcal{L}_z(D; \theta_0)).$$

what’s really important to notice here is that there are suddenly **two** expectations rather than just one expectation in the meta-objective function. this makes a huge difference, because we now need two independent sets of samples from z to estimate the meta-objective gradient w.r.t. θ₀.

how would this be implemented in practice? we first draw one minibatch and update θ₀ up to θ’. we then draw another minibatch and update θ₀ up to θ” (notice the double prime here!) we draw a validation minibatch D’ to evaluate θ’ using the meta-objective function L’. then we backprop up until θ’ (using the same validation minibatch). we then suddenly switch to θ” and backprop through it until θ₀. in other words, we use two separate paths until θ’ for forward and backward passes, which is pretty different from a usual practice.

what does this imply? what it implies is that

**correct meta-objective optimization looks for****θ₀ that is robust to the optimization trajectory taken due to the inherent stochasticity in SGD**. in order to do so it must consider what would have happened had a different optimization trajectory been used, and this can be estimated well by using separate minibatches for forward and backward passes. i believe Ferenc Huszar made a similar argument in “What is missing? Stochasticity” section of his recent blog post.an interesting question here is what z is and what kind of distribution we should impose on z. for instance, can we fold the choice of optimization algorithm into z in addition to other stochastic behaviours such as data permutation, dropout and others? if so, can we extend MAML to find a more robust initialization that would not only be robust to the stochasticity behind a select optimization algorithm but robust to the choice of optimization algorithm itself?

(*) i’m being massively sloppy with scalars, vectors, matrices, gradient and jacobian, and my apologies in advance. you could simply think of scalars only and the whole argument still largely holds.