版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)
文檔簡介
1、Augmented Neural ODEs Emilien Dupont University of Oxford dupontstats.ox.ac.uk Arnaud Doucet University of Oxford doucetstats.ox.ac.uk Yee Whye Teh University of Oxford y.w.tehstats.ox.ac.uk Abstract We show that Neural Ordinary Differential Equations (ODEs) learn representa- tions that preserve the
2、 topology of the input space and prove that this implies the existence of functions Neural ODEs cannot represent. To address these limita- tions, we introduce Augmented Neural ODEs which, in addition to being more expressive models, are empirically more stable, generalize better and have a lower com
3、putational cost than Neural ODEs. 1Introduction The relationship between neural networks and differential equations has been studied in several recent works (Weinan, 2017; Lu et al., 2017; Haber Ruthotto Chen et al., 2018). In particular, it has been shown that Residual Networks (He et al., 2016) ca
4、n be interpreted as discretized ODEs. Taking the discretization step to zero gives rise to a family of models called Neural ODEs (Chen et al., 2018). These models can be effi ciently trained with backpropagation and have shown great promise on a number of tasks including modeling continuous time dat
5、a and building normalizing fl ows with low computational cost (Chen et al., 2018; Grathwohl et al., 2018). In this work, we explore some of the consequences of taking this continuous limit and the restrictions this might create compared with regular neural nets. In particular, we show that there are
6、 simple classes of functions Neural ODEs (NODEs) cannot represent. While it is often possible for NODEs to approximate these functions in practice, the resulting fl ows are complex and lead to ODE problems that are computationally expensive to solve. To overcome these limitations, we introduce Augme
7、nted Neural ODEs (ANODEs) which are a simple extension of NODEs. ANODEs augment the space on which the ODE is solved, allowing the model to use the additional dimensions to learn more complex functions using simpler fl ows (see Fig. 1). In addition to being more expressive models, ANODEs signifi can
8、tly reduce the computational cost of both forward and backward passes of the model compared with NODEs. Our experiments also show that ANODEs generalize better, achieve lower losses with fewer parameters and are more stable to train. Neural ODEAugmented Neural ODE Figure 1: Learned fl ows for a Neur
9、al ODE and an Augmented Neural ODE. The fl ows (shown as lines with arrows) map input points to linearly separable features for binary classifi cation. Augmented Neural ODEs learn simpler fl ows that are easier for the ODE solver to compute. 33rd Conference on Neural Information Processing Systems (
10、NeurIPS 2019), Vancouver, Canada. 2Neural ODEs NODEs are a family of deep neural network models that can be interpreted as a continuous equivalent of Residual Networks (ResNets). To see this, consider the transformation of a hidden state from a layer t to t + 1 in ResNets ht+1= ht+ ft(ht) whereht Rd
11、is the hidden state at layertandft: Rd Rdis some differentiable function which preserves the dimension ofht(typically a CNN). The differenceht+1 htcan be interpreted as a discretization of the derivative h0(t) with timestep t = 1. Letting t 0, we see that lim t0 ht+t ht t = dh(t) dt = f(h(t),t) so t
12、he hidden state can be parameterized by an ODE. We can then map a data pointxinto a set of features (x) by solving the Initial Value Problem (IVP) dh(t) dt = f(h(t),t),h(0) = x to some timeT. The hidden state at timeT, i.e.h(T), corresponds to the features learned by the model. The analogy with ResN
13、ets can then be made more explicit. In ResNets, we map an inputxto some outputyby a forward pass of the neural network. We then adjust the weights of the network to matchywith someytrue. In NODEs, we map an inputxto an outputyby solving an ODE starting fromx. We then adjust the dynamics of the syste
14、m (encoded byf) such that the ODE transformsx to a y which is close to ytrue. ODE Figure 2: Diagram of Neural ODE architecture. ODE fl ows. We also defi ne the fl ow associated to the vector fi eld f(h(t),t) of the ODE. The fl owt: Rd Rd is defi ned as the hidden state at timet, i.e.t(x) = h(t), whe
15、n solving the ODE from the initial conditionh(0) = x . The fl ow measures how the states of the ODE at a given timetdepend on the initial conditionsx. We defi ne the features of the ODE as(x) := T(x) , i.e. the fl ow at the fi nal time T to which we solve the ODE. NODEs for regression and classifi c
16、ation.We can use ODEs to map input datax Rdto a set of features or representations(x) Rd. However, we are often interested in learning functions fromRdtoR, e.g. for regression or classifi cation. To defi ne a model fromRdtoR, we follow the example given in Lin the trajectories mapping1to1and1to1must
17、 intersect each other (see Fig. 3). However, ODE trajectories cannot cross each other, so the fl ow of an ODE cannot representg1d(x). This simple observation is at the core of all the examples provided in this paper and forms the basis for many of the limitations of NODEs. Experiments. We verify thi
18、s behavior experimentally by training an ODE fl ow on the identity mapping and ong1d(x) . The resulting fl ows are shown in Fig. 3. As can be seen, the model easily learns the identity mapping but cannot representg1d(x). Indeed, since the trajectories cannot cross, the model maps all input points to
19、 zero to minimize the mean squared error. ResNets vs NODEs.NODEs can be interpreted as continuous equivalents of ResNets, so it is interesting to consider why ResNets can representg1d(x)but NODEs cannot. The reason for this is 2 Figure 3: (Top left) Continuous trajectories mapping1to1(red) and1to1(b
20、lue) must intersect each other, which is not possible for an ODE. (Top right) Solutions of the ODE are shown in solid lines and solutions using the Euler method (which corresponds to ResNets) are shown in dashed lines. As can be seen, the discretization error allows the trajectories to cross. (Botto
21、m) Resulting vector fi elds and trajectories from training on the identity function (left) and g1d(x) (right). exactly because ResNets are a discretization of the ODE, allowing the trajectories to make discrete jumps to cross each other (see Fig. 3). Indeed, the error arising when taking discrete st
22、eps allows the ResNet trajectories to cross. In this sense, ResNets can be interpreted as ODE solutions with large errors, with these errors allowing them to represent more functions. 4Functions Neural ODEs cannot represent r1 r2 r3 Figure 4: Diagram of g(x) for d = 2. We now introduce classes of fu
23、nctions in arbitrary dimensiondwhich NODEs cannot represent. Let0 r1 r2 r3and letg : Rd Rbe a function such that ?g(x) = 1 if kxk r1 g(x) = 1if r2 kxk r3, wherek kis the Euclidean norm. An illustration of this function ford = 2is shown in Fig. 4. The function maps all points inside the blue sphere t
24、o1and all points in the red annulus to 1. Proposition 2. Neural ODEs cannot represent g(x). A proof is given in the appendix. While the proof requires tools from ODE theory and topology, the intuition behind it is simple. In order for the linear layer to map the blue and red points to1and 1respectiv
25、ely, the features(x)for the blue and red points must be linearly separable. Since the blue region is enclosed by the red region, points in the blue region must cross over the red region to become linearly separable, requiring the trajectories to intersect, which is not possible. In fact, we can make
26、 more general statements about which features Neural ODEs can learn. Proposition 3.The feature mapping(x)is a homeomorphism, so the features of Neural ODEs preserve the topology of the input space. A proof is given in the appendix. This statement is a consequence of the fl ow of an ODE being a homeo
27、morphism, i.e. a continuous bijection whose inverse is also continuous; see, e.g., (Younes, 2010). This implies that NODEs can only continuously deform the input space and cannot for example tear a connected region apart. Discrete points and continuous regions.It is worthwhile to consider what these
28、 results mean in practice. Indeed, when optimizing NODEs we train on inputs which are sampled from the continuous regions of the annulus and the sphere (see Fig. 4). The fl ow could then squeeze through the gaps 3 (a) g(x) in d = 1(b) g(x) in d = 2(c) Separable function in d = 2 Figure 5: Comparison
29、 of training losses of NODEs and ResNets. Compared to ResNets, NODEs struggle to fi tg(x)both ind = 1andd = 2. The difference between ResNets and NODEs is less pronounced for the separable function. between sampled points making it possible for the NODE to learn a good approximation of the function.
30、 However, fl ows that need to stretch and squeeze the input space in such a way are likely to lead to ill-posed ODE problems that are numerically expensive to solve. In order to explore this, we run a number of experiments (the code to reproduce all experiments in this paper is available at 4.1Exper
31、iments We fi rst compare the performance of ResNets and NODEs on simple regression tasks. To provide a baseline, we not only train ong(x)but also on data which can be made linearly separable without altering the topology of the space (implying that Neural ODEs should be able to easily learn this fun
32、ction). To ensure a fair comparison, we run large hyperparameter searches for each model and repeat each experiment 20 times to ensure results are meaningful across initializations (see appendix for details). We show results for experiments withd = 1andd = 2in Fig. 5. Ford = 1, the ResNet easily fi
33、ts the function, while the NODE cannot approximateg(x). Ford = 2, the NODE eventually learns to approximateg(x), but struggles compared to ResNets. This problem is less severe for the separable function, presumably because the fl ow does not need to break apart any regions to linearly separate them.
34、 4.2Computational Cost and Number of Function Evaluations One of the known limitations of NODEs is that, as training progresses and the fl ow gets increasingly complex, the number of steps required to solve the ODE increases (Chen et al., 2018; Grathwohl et al., 2018). As the ODE solver evaluates th
35、e functionfat each step, this problem is often referred to as the increasing number of function evaluations (NFE). In Fig. 6, we visualize the evolution of the feature space during training and the corresponding NFEs. The NODE initially tries to move the inner sphere out of the annulus by pushing ag
36、ainst and stretching the barrier. Eventually, since we are mapping discrete points and not a continuous region, the fl ow is able to break apart the annulus to let the fl ow through. However, this results in a large increase in NFEs, implying that the ODE stretching the space to separate the two reg
37、ions becomes more diffi cult to solve, making the computation slower. Figure 6: Evolution of the feature space as training progresses and the corresponding number of function evaluations required to solve the ODE. As the ODE needs to break apart the annulus, the number of function evaluations increa
38、ses. 4 5Augmented Neural ODEs Motivated by our theory and experiments, we introduce Augmented Neural ODEs (ANODEs) which provide a simple solution to the problems we have discussed. We augment the space on which we learn and solve the ODE fromRdtoRd+p , allowing the ODE fl ow to lift points into the
39、 additional dimensions to avoid trajectories intersecting each other. Lettinga(t) Rpdenote a point in the augmented part of the space, we can formulate the augmented ODE problem as d dt ?h(t) a(t) ? = f( ?h(t) a(t) ? ,t), ?h(0) a(0) ? = ?x 0 ? i.e. we concatenate every data pointxwith a vector of ze
40、ros and solve the ODE on this augmented space. We hypothesize that this will also make the learned (augmented)fsmoother, giving rise to simpler fl ows that the ODE solver can compute in fewer steps. In the following sections, we verify this behavior experimentally and show both on toy and image data
41、sets that ANODEs achieve lower losses, better generalization and lower computational cost than regular NODEs. 5.1Experiments We fi rst compare the performance of NODEs and ANODEs on toy datasets. As in previous experi- ments, we run large hyperparameter searches to ensure a fair comparison. As can b
42、e seen on Fig. 7, when trained ong(x) in different dimensions, ANODEs are able to fi t the functions NODEs cannot and learn much faster than NODEs despite the increased dimension of the input. The corresponding fl ows learned by the model are shown in Fig. 7. As can be seen, ind = 1, the ANODE moves
43、 into a higher dimension to linearly separate the points, resulting in a simple, nearly linear fl ow. Similarly, ind = 2 , the NODE learns a complicated fl ow whereas ANODEs simply lift out the inner circle to separate the data. This effect can also be visualized as the features evolve during traini
44、ng (see Fig. 8). Computational cost and number of function evaluations. As ANODEs learn simpler fl ows, they would presumably require fewer iterations to compute. To test this, we measure the NFEs for NODEs and ANODEs when training ong(x). As can be seen in Fig. 8, the NFEs required by ANODEs hardly
45、 increases during training while it nearly doubles for NODEs. We obtain similar results when training NODEs and ANODEs on image datasets (see Section 5.2). Generalization. As ANODEs learn simpler fl ows, we also hypothesize that they generalize better to unseen data than NODEs. To test this, we fi r
46、st visualize to which value each point in the input space gets mapped by a NODE and an ANODE that have been optimized to approximately zero training loss. As can be seen in Fig. 9, since NODEs can only continuously deform the input space, the learned fl ow must squeeze the points in the inner circle
47、 through the annulus, leading to poor InputsFlowFeatures ANODE 1DNODE 2DANODE 2D Figure 7: (Left) Loss plots for NODEs and ANODEs trained ong(x)ind = 1(top) andd = 2 (bottom). ANODEs easily approximate the functions and are consistently faster than NODEs. (Right) Flows learned by NODEs and ANODEs. A
48、NODEs learn simple nearly linear fl ows while NODEs learn complex fl ows that are diffi cult for the ODE solver to compute. 5 Figure 8: (Left) Evolution of features during training for ANODEs. The top left tile shows the feature space for a randomly initialized ANODE and the bottom right tile shows
49、the features after training. (Right) Evolution of the NFEs during training for NODEs and ANODEs trained on g(x) in d = 1. generalization. ANODEs, in contrast, map all points in the input space to reasonable values. As a further test, we can also create a validation set by removing random slices of t
50、he input space (e.g. removing all points whose angle is in0, 5) from the training set. We train both NODEs and ANODEs on the training set and plot the evolution of the validation loss during training in Fig. 9. While there is a large generalization gap for NODEs, presumably because the fl ow moves t
51、hrough the gaps in the training set, ANODEs generalize much better and achieve near zero validation loss. As we have shown, experimentally we obtain lower losses, simpler fl ows, better generalization and ODEs requiring fewer NFEs to solve when using ANODEs. We now test this behavior on image data b
52、y training models on MNIST, CIFAR10, SVHN and 200 classes of 64 64 ImageNet. 5.2Image Experiments We perform experiments on image datasets using convolutional architectures forf(h(t),t). As the inputxis an image, the hidden stateh(t)is now inRchwwherecis the number of channels andhandware the height
53、 and width respectively. In the case whereh(t) Rdwe augmented the space ash(t) Rd+p. For images we augment the space asRchw R(c+p)hw, i.e. we addp channels of zeros to the input image. While there are other ways to augment the space, we found that increasing the number of channels works well in prac
54、tice and use this method for all experiments. Full training and architecture details can be found in the appendix. Results for models trained with and without augmentation are shown in Fig. 10. As can be seen, ANODEs train faster and obtain lower losses at a smaller computational cost than NODEs. On
55、 MNIST for example, ANODEs with 10 augmented channels achieve the same loss in roughly 10 times fewer iterations (for CIFAR10, ANODEs are roughly 5 times faster). Perhaps most interestingly, we can plot the NFEs against the loss to understand roughly how complex a fl ow (i.e. how many NFEs) are requ
56、ired to model a function that achieves a certain loss. For example, to compute a function which obtains a loss of 0.8 on CIFAR10, a NODE requires approximately 100 function evaluations whereas ANODEs only require 50. Similar observations can be made for other datasets, implying that ANODEs can model
57、 equally rich functions at half the computational cost of NODEs. Neural ODEAugmented Neural ODE Figure 9: (Left) Plots of how NODEs and ANODEs map points in the input space to different outputs (both models achieve approximately the same zero training loss). As can be seen, the ANODE generalizes bet
58、ter. (Middle) Training and validation losses for NODE. (Right) Training and validation losses for ANODE. 6 LossNFEsNFEs vs Loss Figure 10: Training losses, NFEs and NFEs vs Loss for various augmented models on MNIST (top row) and CIFAR10 (bottom row). Note thatpindicates the size of the augmented di
59、mension, so p = 0corresponds to a regular NODE model. Further results on SVHN and64 64ImageNet can be found in the appendix. Parameter effi ciency.As we augment the dimension of the ODEs, we also increase the number of parameters of the models, so it may be that the improved performance of ANODEs is due to the higher number of parameters. To test this, we train a NODE and an ANODE with the same number of parameters on MNIST (84k weights), CIFAR10 (172k weights), SVHN (172k weights) and64 64 ImageNet (366k weigh
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責。
- 6. 下載文件中如有侵權(quán)或不適當內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 2026河北省定向華中師范大學(xué)選調(diào)生招錄參考題庫必考題
- 2026江西贛州市安遠縣東江水務(wù)集團招聘10人參考題庫附答案
- 2025廣東中山市港口鎮(zhèn)下南村招聘合同制治安員招聘1人備考題庫附答案
- 2026河北保定市安國市招聘市民政局和市委宣傳部輔助人員5人備考題庫及答案詳解參考
- 老年人健康監(jiān)測技能演示
- 2026四川巴中市公安局招聘警務(wù)輔助人員47人備考題庫及參考答案詳解1套
- 2026年蕪湖市文化和旅游局所屬事業(yè)單位公開招聘編外聘用人員備考題庫及答案詳解(考點梳理)
- 行政日常培訓(xùn)
- 環(huán)保先鋒:2025年新型涂料研發(fā)中心建設(shè)項目市場前景研究報告
- 主題教育檢視問題清單不具體問題整改報告
- 國家電網(wǎng)公司招聘高校畢業(yè)生應(yīng)聘登記表
- 見證取樣手冊(智能建筑分部)
- DZ∕T 0353-2020 地球化學(xué)詳查規(guī)范(正式版)
- 脊柱與四肢檢查課件
- 2024年河北省供銷合作總社招聘筆試參考題庫附帶答案詳解
- 醫(yī)療衛(wèi)生輿情課件
- 2023-2024學(xué)年宜賓市高一數(shù)學(xué)上學(xué)期期末質(zhì)量監(jiān)測試卷附答案解析
- 數(shù)據(jù)安全保護與隱私保護
- 實用的標準氧化還原電位表
- 英語口語8000句(情景模式)
- GB/T 17640-2008土工合成材料長絲機織土工布
評論
0/150
提交評論