#implicitfunction #jax #autodiff




Many problems in Machine Learning involve loops of inner and outer optimization. Finding update steps for the outer loop is usually difficult, because of the.need to differentiate through the inner loop's procedure over multiple steps. Such loop unrolling is very limited and constrained to very few steps. Other papers have found solutions around unrolling in very specific, individual problems. This paper proposes a unified framework for implicit differentiation of inner optimization procedures without unrolling and provides implementations that integrate seamlessly into JAX.




OUTLINE:


0:00 - Intro & Overview


2:05 - Automatic Differentiation of Inner Optimizations


4:30 - Example: Meta-Learning


7:45 - Unrolling Optimization


13:00 - Unified Framework Overview & Pseudocode


21:10 - Implicit Function Theorem


25:45 - More Technicalities


28:45 - Experiments




ERRATA:


- Dataset Distillation is done with respect to the training set, not the validation or test set.




Paper: https://arxiv.org/abs/2105.15183


Code coming soon




Abstract:


Automatic differentiation (autodiff) has revolutionized machine learning. It allows expressing complex computations by composing elementary ones in creative ways and removes the burden of computing their derivatives by hand. More recently, differentiation of optimization problem solutions has attracted widespread attention with applications such as optimization as a layer, and in bi-level problems such as hyper-parameter optimization and meta-learning. However, the formulas for these derivatives often involve case-by-case tedious mathematical derivations. In this paper, we propose a unified, efficient and modular approach for implicit differentiation of optimization problems. In our approach, the user defines (in Python in the case of our implementation) a function F capturing the optimality conditions of the problem to be differentiated. Once this is done, we leverage autodiff of F and implicit differentiation to automatically differentiate the optimization problem. Our approach thus combines the benefits of implicit differentiation and autodiff. It is efficient as it can be added on top of any state-of-the-art solver and modular as the optimality condition specification is decoupled from the implicit differentiation mechanism. We show that seemingly simple principles allow to recover many recently proposed implicit differentiation methods and create new ones easily. We demonstrate the ease of formulating and solving bi-level optimization problems using our framework. We also showcase an application to the sensitivity analysis of molecular dynamics.




Authors: Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert




Links:


TabNine Code Completion (Referral): http://bit.ly/tabnine-yannick


YouTube: https://www.youtube.com/c/yannickilcher


Twitter: https://twitter.com/ykilcher


Discord: https://discord.gg/4H8xxDF


BitChute: https://www.bitchute.com/channel/yann...


Minds: https://www.minds.com/ykilcher


Parler: https://parler.com/profile/YannicKilcher


LinkedIn: https://www.linkedin.com/in/yannic-ki...


BiliBili: https://space.bilibili.com/1824646584




If you want to support me, the best thing to do is to share out the content :)




If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this):


SubscribeStar: https://www.subscribestar.com/yannick...


Patreon: https://www.patreon.com/yannickilcher


Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq

Twitter Mentions