From f449160b610b70fd5da04a705b7d640a9b13accb Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sat, 3 Aug 2024 21:20:55 -0700 Subject: [PATCH] diff_traj: use random traj --- notebooks/compute_dream_pos.ipynb | 59 +++++++++++++++++++++++++++++-- torchdrive/tasks/diff_traj.py | 43 +++++++++++++++------- 2 files changed, 88 insertions(+), 14 deletions(-) diff --git a/notebooks/compute_dream_pos.ipynb b/notebooks/compute_dream_pos.ipynb index cfdc521..2525f64 100644 --- a/notebooks/compute_dream_pos.ipynb +++ b/notebooks/compute_dream_pos.ipynb @@ -2,7 +2,17 @@ "cells": [ { "cell_type": "code", - "execution_count": 69, + "execution_count": 2, + "id": "196d0aa7-d2ce-44d2-a3ce-164063830b57", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "id": "990bef42-c459-4a9a-b42e-514466fdc658", "metadata": { "scrolled": true @@ -46,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 4, "id": "d534dcf9-d88c-4a9d-a42f-8d75817134d3", "metadata": {}, "outputs": [ @@ -77,6 +87,51 @@ "plt.gca().set_aspect(\"equal\")" ] }, + { + "cell_type": "code", + "execution_count": 61, + "id": "80a42d2e-56f8-4489-87e0-a43b95c37e85", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMEAAAPdCAYAAAAzvgzhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxAUlEQVR4nO3df2wc533n8c/M/p4hd0nOUD9ormZWdtCDgVYppFgnn/9IAsN24AQ1UPh6QJFKhiHUuDgoYCGthAZV/2iiXi3kgrpG4gKtrw16gJEr2uL6w2lqG0hycWGcHbdwDoprVBIpUZREkdql+GOX++P+WO5KNH8tuTPzPM88nxdAoJYp6kmqd3b2O88+Y7RarRaINGaKXgCRaIyAtMcISHuMgLTHCEh7jIC0xwhIe0nRC+hXs9nE1NQUBgcHYRiG6OWQJFqtFubn5zE2NgbT3Pp/65WPYGpqCsViUfQySFKTk5MYHx/f8nuUj2BwcBBA+z9sPp8XvBqSRaVSQbFY7P792IryEXQugfL5PCOgdXq5ROYbY9IeIyDtMQLSHiMg7TEC0h4jIO0xAtIeIyDtMQLSHiMg7TEC0h4jIO0xAtIeIyDtMQLSHiMg7TEC0h4jIO0xAtIeIyDtMQLSHiMg7TEC0h4jIO0xAtIeIyDtMQLSHiMg7TEC0h4jIO0xAtIeIyDtMQLSHiMg7TEC0h4jIO0xAtKeFBG8/PLL8H0f2WwWR48exTvvvCN6SaQR4RG89tpreOGFF3D27Fm89957OHToEB5//HHcuHFD9NJIE8Ij+MY3voGTJ0/imWeewYMPPohvf/vbsCwLf/qnf7rh91erVVQqlTVfG/kf/+ci/tvrF8JcOsWE0AhqtRreffddPProo91fM00Tjz76KN5+++0Nf8+5c+dQKBS6X8ViccPvuzy7iO/9dDqUdVO8CI1gZmYGjUYDe/fuXfPre/fuxfT0xn+Bz5w5g3K53P2anJzc8PtKro3J2UU0mq3A103xkhS9gJ3KZDLIZDLbfp/n2FhptDB1ewnFESuClZGqhL4SuK6LRCKB69evr/n169evY9++fX39bN9p/8W/dGuhr59D8Sc0gnQ6jcOHD+ONN97o/lqz2cQbb7yBY8eO9fWz7xvKIWkauHRrsd9lUswJvxx64YUXcPz4cRw5cgQPPfQQvvnNb2JhYQHPPPNMXz83mTBRHLFwaYavBLQ14RH8yq/8Cm7evInf+Z3fwfT0ND75yU/i9ddfX/dmeTc8x8JlXg7RNoRHAADPP/88nn/++cB/ru/Y+NFHM4H/XIoX4TfLwuQ7FiZucUxKW4t1BJ5ro9Zo4lp5SfRSSGKxjqDk2ACAy5wQ0RZiHcF9wzkkTIP3CmhLsY4glTAxPpzjmJS2FOsIgPaEiDfMaCsaRMB7BbS12EfgOTYu31pEk2NS2kTsIyi5Nqr1JqYry6KXQpKKfQQed5PSNmIfwfiwBdMALs3wzTFtLPYRpJMmxof55pg2F/sIgPYlES+HaDNaROA7Ni+HaFN6RODauDy7wDEpbUiPCBwLyytN3Jivil4KSUiLCLzV3aQXuYeINqBFBMWRHEwDnBDRhrSIIJNMYGwox410tCEtIgA6EyK+EtB62kTAewW0GW0iKLnt3aStFsektJY2EXiOjaWVBsektI42EXTPJuX7AvoYbSIojlgwDJ48QetpE0E2lcBYIcc3x7SONhEAnBDRxrSKwHe5m5TW0yuC1ZMnOCale2kVgefYWKg1cPMOx6R0l1YRlFyeTUrraRXBgRHeK6D1tIogm0pgfyHLCRGtoVUEAM8mpfX0i8Dl8Su0lnYReKsnT3BMSh3aReA7Nu5U67i1UBO9FJKEfhG47QkRL4moQ7sIvJHOyRN8c0xt2kWQSyewL5/lKwF1aRcB0NlNylcCatMyAp48QffSMwK3HQHHpAToGoFjYb5ax9ziiuilkAS0jIBnk9K9tIyA9wroXlpGYKWT2DOY4YSIAGgaAcAJEd2lbQQen3RPq7SNwHf5uQJq0zcCx0Z5aQVz3E2qPW0j4JPuqUPbCHyePEGrtI1gIJOEO5DhDTPSNwLg7ol0pDe9I+CEiKB7BHwlIGgegefYmFtcQZm7SbWmdQSds0k5JtWb1hEc4L0CguYR5LMpOHaaD+7QnNYRAO0JEd8c6037CPgcM9I+Ap5STYzAtTG7UEN5iWNSXTGC1QnRBF8NtKV9BN2TJ/i+QFvaR1DIpTBip3GZu0m1pX0EAM8m1R0jQGdCxFcCXTECtCPgDTN9MQK0T6SbuVPD/DLHpDpiBLg7IeLnjfXECACUHG6p1hkjAFCwUhiyUnwl0BQjWOU5Nk+e0BQjWFXi5421xQhWedxNqi1GsMp3Ldycr+JOtS56KRQxRrDq7piUl0S6YQSrSrxXoC1GsGrISiGfTXJCpCFGsMowDH7oXlOM4B78vLGeGME9fMfiw/w0xAju4Tk2bsxXsVjjmFQnjOAefHqNnhjBPTonT/DNsV4YwT1G7DQGM0lc5NmkWmEE9+CYVE+M4GN4Nql+GMHH+I7No9o1wwg+xndtTFeWsVRriF4KRYQRfEx3QjTLSyJdMIKP6Wyp5iWRPhjBx7gDaQxkkpwQaYQRfIxhGDybVDOMYAPtCRFfCXTBCDbguzx5QieMYAOeY2OqvIzlFY5JdcAINuCvTogmZvm+QAeMYAO+u/qke74v0AIj2MDoQAZWOsHPFWiCEWygPSa1+TA/TTCCTZQ4IdIGI9iEx92k2mAEm/AdC1PlJY5JNcAINuE7Nlot4MocXw3ijhFsonPyBC+J4o8RbGLPYAbZlMmPWmqAEWzCMAw+5FsTjGAL7Yd883Io7hjBFjzX4lHtGmAEW/AdG1O3l1Ctc0waZ4xgC75jo9kCrswtiV4KhYgRbIG7SfXACLawdzCLTNLk541jjhFswTSN1QkRXwnijBFsgydPxB8j2Ibv8uSJuGME2/AdG1fmFlGrN0UvhULCCLbhOxaaLeDqbY5J44oRbMPr7iblJVFcMYJt7M9nkU5yN2mcMYJtmKYBb8TiRroYYwQ98BybG+lijBH0gCdPxBsj6IHn2Lgyt4SVBsekccQIeuA7NurNFq5yN2ksMYIedHeT8pIolhhBD/YXckgnTE6IYooR9CBhGiiO5DghiilG0KOSyy3VccUIeuTx5InYYgQ98h0LE7OLqHNMGjuMoEe+2x6TTt1eFr0UChgj6FHnOWYck8YPI+jR/kIWqYTBCGKIEfQomTBRHLF4SnUMMYId4MkT8cQIdsBzLD7ML4YYwQ6UXBuTs4toNFuil0IBYgQ74Dk2VhotTPFD97HCCHbAd7ibNI6ERXDp0iU8++yzKJVKyOVyuP/++3H27FnUajVRS9rWfUM5JE2DJ9LFTFLUH3zhwgU0m0288soreOCBB/DBBx/g5MmTWFhYwPnz50Uta0udMell7iaNFWERPPHEE3jiiSe6/3zw4EH87Gc/w7e+9a0tI6hWq6hWq91/rlQqoa7z49pnkzKCOJHqPUG5XMbIyMiW33Pu3DkUCoXuV7FYjGh1be2H+fFyKE6kieCjjz7CSy+9hF//9V/f8vvOnDmDcrnc/ZqcnIxohW2+Y2HiFsekcRJ4BKdPn4ZhGFt+XbhwYc3vuXr1Kp544gk8/fTTOHny5JY/P5PJIJ/Pr/mKkufaqDWamK5wN2lcBP6e4NSpUzhx4sSW33Pw4MHu/z01NYXPfOYzePjhh/HHf/zHQS8ncN3dpDMLuG8oJ3g1FITAIxgdHcXo6GhP33v16lV85jOfweHDh/Hqq6/CNKW5OtvU+HAOCbO9m/Q/PeCKXg4FQNh06OrVq/j0pz8Nz/Nw/vx53Lx5s/vv9u3bJ2pZ20olTIwP5/hRyxgRFsH3v/99fPTRR/joo48wPj6+5t+1WnK/6eTZpPEi7PrjxIkTaLVaG37JruTwbNI4kf8iXEKdkyeaHJPGAiPYBd+1UK1zTBoXjGAX+KH7eGEEuzA+bME0wAlRTDCCXUgnTdw3nOPD/GKCEexSeyMdI4gDRrBLPs8mjQ1GsEudzxVwTKo+RrBLJdfG8koTN+ar238zSY0R7JLHMWlsMIJdKo7kYBrghCgGGMEuZZIJjA3l+FHLGGAEfeDZpPHACPrgORa3VMcAI+hD516BCtu/aXOMoA++a2NppYGbHJMqjRH0oXM2KS+J1MYI+lAcsWBwN6nyGEEfsqkExgo53jBTHCPok+dYfCVQHCPoE0+eUB8j6FPJbZ88wTGpuhhBnzzHxkKtgZk78j5chLbGCPrED92rjxH0yes8x4zvC5TFCPqUTSWwv5DlhEhhjCAAfMi32hhBAEout1SrjBEEwHNsXJ7hblJVMYIA+I6F+WodtxY4JlURIwiA77bHpLwkUhMjCMCBkc6YlBMiFTGCAFjpJPbmM7xhpihGEBA+5FtdjCAgPHlCXYwgIJ7bPnmCY1L1MIKAlBwb88t1zC2uiF4K7RAjCAjPJlUXIwgId5OqixEExM4kMTqYYQQKYgQBKjk2Ls9yTKoaRhAgz7H4SqAgRhAg3+UNMxUxggD5jo3y0gpuL3I3qUoYQYA8nk2qJEYQoLtbqnlJpBJGEKCBTBLuAHeTqoYRBMzn2aTKYQQB49mk6mEEAeucTUrqYAQB8xwbc4srKHM3qTIYQcB4Nql6GEHAPHd1NykjUAYjCFg+m4JjpzkhUggjCAE30qmFEYSgvZGOEaiCEYSg86R7UgMjCIHnWLi1UENlmWNSFTCCEJQ6G+l4LKMSGEEIvBHeK1AJIwhBwUph2EpxQqQIRhASftRSHYwgJDybVB2MICSeY/E9gSIYQUhKro2ZOzXMc0wqPUYQks7ZpLxpJj9GEBLf4W5SVTCCkAxZaQxZKb4SKIARhMhzbN4rUAAjCJHPCZESGEGIPD7MTwmMIEQl18LN+SoWqnXRS6EtMIIQ8RFOamAEIfJ5r0AJjCBEw1YK+WySrwSSYwQhMgyjvZuUY1KpMYKQcUIkP0YQspLDs0llxwhC5jk2rleqWKxxTCorRhAyf/VYRk6I5MUIQnZ3TMpLIlkxgpCN2GkMZpJ8cywxRhAywzDguTybVGaMIAK+w7NJZcYIIsCzSeXGCCLgORaulZexVGuIXgptgBFEoHM26cQsXw1kxAgiwC3VcmMEEXAH0rDTCU6IJMUIItDdTcpXAikxgohwQiQvRhARPsxPXowgIr5rY6q8jOUVjkllwwgi0tlIN8kxqXQYQUQ6Z5Ne5CWRdBhBREYHM7DSCb45lhAjiIhhGKufN+YrgWwYQYR4NqmcGEGE2sev8HJINowgQr5jYaq8hGqdY1KZMIIIeY6NVotjUtkwggh17hXwkkgujCBCe/MZZFMm3xxLhhFEyDAMft5YQowgYp5j8YaZZBhBxPi5Avkwgoj5jo2rc0uo1Zuil0KrGEHEPMdCswVMzvGSSBaMIGKdkyd4Nqk8GEHE9g5mkUmavFcgEUYQMdM02h+15CuBNBiBAD4f4SQVRiCA79p8TyARRiCA51i4MreElQbHpDJgBAKUHBuNZgtX5pZEL4XACITwXJ5NKhNGIMD+fBbppMnDuCTBCAQwTQPeCDfSyYIRCMKTJ+TBCATxeTapNBiBIL5rc0wqCUYgiO/YqDdbmLrNMalojEAQj2eTSoMRCDI2lEM6YXJCJAFGIEjCNFAcyXFCJAFGIJDv2JwQSYARCOTxOWZSYAQClVwLk3OLqHNMKhQjEMhzbKw0Wpi6vSx6KVpjBAL5fNK9FBiBQGNDWaQSBj9lJhgjECiZMFEctnCRJ08IxQgEa59NylcCkRiBYDybVDxGIJjv2JicXUKj2RK9FG0xAsE8x0Kt0eRuUoEYgWB3zyblm2NRGIFg9w3lkDQNvi8QiBEIlkyYGB/OcSOdQIxAAu0JES+HRGEEEuDD/MRiBBLwHAsTs4tockwqBCOQgO/aqNWbuFbhblIRGIEEOrtJL/PNsRBSRFCtVvHJT34ShmHg/fffF72cyI0P55AwDVzk+wIhpIjgN3/zNzE2NiZ6GcKkVsekvGEmhvAI/uEf/gH/+I//iPPnz/f0/dVqFZVKZc1XHHj80L0wQiO4fv06Tp48ie985zuwLKun33Pu3DkUCoXuV7FYDHmV0fD5MD9hhEXQarVw4sQJPPfcczhy5EjPv+/MmTMol8vdr8nJyRBXGR1/9eQJjkmjF3gEp0+fhmEYW35duHABL730Eubn53HmzJkd/fxMJoN8Pr/mKw5810K13sT1eY5Jo5YM+geeOnUKJ06c2PJ7Dh48iDfffBNvv/02MpnMmn935MgR/Oqv/ir+7M/+LOilSc1bHZNenFnA/kJO8Gr0EngEo6OjGB0d3fb7/vAP/xC/93u/1/3nqakpPP7443jttddw9OjRoJclveKwBdNob6l++H7Rq9FL4BH06sCBA2v+eWBgAABw//33Y3x8XMSShEonTdw3zLNJRRA+IqW7eDapGMJeCT7O9320WnpPRjzHwv+9NCd6GdrhK4FEOluqdf8fg6gxAon4jo3llSauV6qil6IVRiAR323fNeeb42gxAokURywYBngiXcQYgUQyyQTGCjmeTRoxRiAZ3+XZpFFjBJJpT4j4ShAlRiCZ9m5SjkmjxAgk4zkWFmsN3JznmDQqjEAynbNJeUkUHUYgmc6YlPcKosMIJJNNJbA/n+VGuggxAgn5Lh/yHSVGICGPZ5NGihFIyHcsXJrhmDQqjEBCvmtjodbAzJ2a6KVogRFIqHs2KS+JIsEIJHRgpL2l+iInRJFgBBLKpRPYX8hyQhQRRiApj8cyRoYRSIqPcIoOI5CU79q4PLPIMWkEGIGkfMfCfLWO2QWOScPGCCTVOZuUl0ThYwSS8pzVkyf4eePQMQJJWekk9uYzvGEWAUYgMc+xcZH3CkLHCCRWWv28MYWLEUjMcy1c5G7S0DECifmOjfnlOuYWV0QvJdYYgcR8jkkjwQgk1hmT8n1BuBiBxOxMEqODGZ5NGjJGIDnf4dmkYWMEkuPZpOFjBJLzXT7ML2yMQHKeY6G8tII57iYNDSOQXPdD97O8JAoLI5Dc3d2kvCQKCyOQ3GA2BXcgzRtmIWIECmg/uIOXQ2FhBArg2aThYgQK6JxNSuFgBArwXRtziysoczdpKBiBAu6OSflqEAZGoADP5dmkYWIECshnU3DsNCdEIWEEiuDZpOFhBIrwHW6kCwsjUAQf5hceRqAIz7Fwa6GGyjLHpEFjBIrojkn5UcvAMQJF8OSJ8DACRRSsFIatFD9vHAJGoBDPsXnyRAgYgUJKLs8mDQMjUEj7hhlfCYLGCBTiOzZm7lQxzzFpoBiBQny386R7vhoEiREoxO+eTcoIgsQIFDJkpVHIpXivIGCMQDE8kS54jEAx7QN6eTkUJEagmPbD/PhKECRGoBjfsXBzvoqFal30UmKDESiGY9LgMQLFcDdp8BiBYoatFAazSUYQIEagGMMw2hvpuJs0MIxAQZwQBYsRKIgP8wsWI1CQ79i4XqliscYxaRAYgYL81WMZJ/gIp0AwAgV5nTEp9xAFghEoyLHTGMwk+SmzgDACBRmGAc/lm+OgMAJFtU+eYARBYASKKvFhfoFhBIryHAvXystYXmmIXoryGIGiuJs0OIxAUdxNGhxGoCh3IA07neCEKACMQFGGYfBs0oAwAoXxbNJgMAKFeTx5IhCMQGG+Y2OqvMQxaZ8YgcJ810arxd2k/WIECuPZpMFgBAobHczASie4pbpPjEBhnTEpb5j1hxEojmeT9o8RKI5bqvvHCBRXci1MlZdQrXNMuluMQHGe0x6TTs4uiV6KshiB4nx+6L5vjEBxewYzyKZMToj6wAgUZ5oGfH7Usi+MIAbaD/nmK8FuMYIY8HnDrC+MIAZ818bVuSXU6k3RS1ESI4gBz7HQbAFX5vi+YDcYQQzwQ/f9YQQxsC+fRSZp4hI/b7wrjCAGTNNY/aglXwl2gxHERPsRTnwl2A1GEBM8eWL3GEFMeI6FK3NLWGlwTLpTjCAmfMdGo9nClTnuJt0pRhATnQN6OSbdOUYQE/vzWaSTJi5zS/WOMYKYME0DB0YsPsdsFxhBjHAj3e4wghjhyRO7wwhixHNtTM4uos4x6Y4wghgpOTbqzRau3uaYdCcYQYx4q2eT8s3xzjCCGBkbyiGdMHnyxA4xghhJmAaKIzlOiHaIEcQMT57YOUYQM55j83JohxhBzPiuhck5jkl3ghHEjO/YWGm0cK28LHopymAEMdP50D2Pa+8dI4iZsaEskqbBT5ntACOImWTC5G7SHWIEMcSTJ3aGEcQQH+G0M4wghkqujcnZJTSaLdFLUQIjiCHPsVBrNHGtzN2kvWAEMXT3EU58c9wLRhBD48M5JE2DG+l6xAhiKJkwMT6c44SoR4wgptoTIl4O9YIRxBTPJu0dI4gpz7FweXYRTY5Jt8UIYsp3bNTqTVyrcDfpdhhBTHXOJuWxjNtjBDE1PpxDwjRwke8LtsUIYiqVMHHfUI6fN+4BI4gx3+XnjXvBCGKMZ5P2hhHEmLd6SjXHpFtjBDFWci1U601cn+eYdCuMIMY87ibtifAI/u7v/g5Hjx5FLpfD8PAwnnrqKdFLio3isAXT4HPMtpMU+Yf/5V/+JU6ePImvf/3r+OxnP4t6vY4PPvhA5JJiJZ00cd8wzybdjrAI6vU6fuM3fgMvvvginn322e6vP/jgg1v+vmq1imq12v3nSqUS2hrjwHdsXObl0JaEXQ699957uHr1KkzTxC/+4i9i//79+NznPrftK8G5c+dQKBS6X8ViMaIVq8lzLL4SbENYBP/+7/8OAPjd3/1dfPWrX8Xf/u3fYnh4GJ/+9KcxOzu76e87c+YMyuVy92tycjKqJSup8zC/Votj0s0EHsHp06dhGMaWXxcuXECz2T4w9rd/+7fxy7/8yzh8+DBeffVVGIaB7373u5v+/Ewmg3w+v+aLNuc7NpZXmrgxX93+mzUV+HuCU6dO4cSJE1t+z8GDB3Ht2jUAa98DZDIZHDx4EBMTE0EvS1u+236E08WZBezNZwWvRk6BRzA6OorR0dFtv+/w4cPIZDL42c9+hkceeQQAsLKygkuXLsHzvKCXpa3xYQuGAVy+tYD/eNARvRwpCZsO5fN5PPfcczh79iyKxSI8z8OLL74IAHj66adFLSt2sqkExgo5nk26BaH3CV588UUkk0l88YtfxNLSEo4ePYo333wTw8PDIpcVO75rcTfpFoRGkEqlcP78eZw/f17kMmLPc2z8ZOK26GVIS/i2CQpfyWmfPMEx6cYYgQY8x8JirYGbdzgm3Qgj0EDnQ/fcTboxRqCBAyPtMSm3T2yMEWggm0pgfz7LE+k2wQg00X7INy+HNsIINOG7Ni+HNsEINNE5eYJj0vUYgSY8x8adah0zd2qilyIdRqCJUudsUl4SrcMINHFgpL2lmhvp1mMEmsilE9iXz3Ij3QYYgUZ8l5833ggj0Ijv2DybdAOMQCPtG2bcTfpxjEAjJdfCfLWO2QWOSe/FCDTSPZuUl0RrMAKNeM7qmJQTojUYgUasdBJ78xneMPsYRqCZ9oM7eDl0L0agGZ9nk67DCDTjuzYucky6BiPQjO/YmF+u4/biiuilSIMRaKYzIeJDvu9iBJrp3CvghOguRqCZgUwSo4MZft74HoxAQ5wQrcUINMR7BWsxAg2VXJvvCe7BCDTkORZuL67g9iJ3kwKMQEs+d5OuwQg01LlXwEuiNkagocFsCu5AGhe5pRoAI9CWx88bdzECTXUe8k2MQFuds0mJEWjLc23MLtRQXuJuUkagqRI30nUxAk0dcHg2aQcj0FQhl8KInebJE2AEWuNu0jZGoDGeTdrGCDTWOZtUd4xAY75r4dZCDZVlvcekjEBjnd2kE5pfEjECjXUi0H0jHSPQWMFKYdhKaX/DjBFojp83ZgTa8x1L+wkRI9Cc7/KVgBFozndszNyp4k61LnopwjACzfHpNYxAe353S7W+l0SMQHPDdhqFXErrjXSMgLSfEDEC0v7kCUZAq2NSvhKQxnzHwo35KhZreo5JGQHdfdK9pg/uYASEkqv3yROMgDBspTCYTWq7fYIREAzDaB/LqOmYlBEQAL0nRIyAAOh9NikjIADtCdF0ZRlLtYbopUSOERAAoOSuPr1mVr9LIkZAAPS+V8AICADg2GkMZJJavjlmBARgdUzqWlreMGME1NU+lpGXQ6QxXU+pZgTU5Ts2rpWXsbyi15iUEVCXv7qRbmJWr0siRkBdnZMndDublBFQ1+hABnY6od2EiBFQl2EYWp5NyghoDd/V7+QJRkBr6HjyBCOgNUqOjanyklZjUkZAa3iOhVYLmNRoTMoIaI3OvQKd3hwzAlpjz2AGuZReY1JGQGu0x6R67SFiBLSOr9luUkZA6+h28gQjoHV8x8LU7SVU63qMSRkBreM5NpotYHJ2SfRSIsEIaB3dziZlBLTOnsEMsilTm3sFjIDWMU0D3og+Z5MyAtqQ7+pzr4AR0IZ8jXaTMgLakOfYuDK3iFq9KXopoWMEtCHftdBsAVfm4v9qwAhoQzo96Z4R0Ib25bPIJE0tTp5gBLQh02zvJtXhhhkjoE3pcvIEI6BN6XI2KSOgTfmujStzS1hpxHtMyghoU75jo9Fs4epcvHeTMgLaVPds0phfEjEC2tRYIYd00sTlmI9JGQFtyjQNHBixYj8hYgS0JR0mRIyAtqTDblJGQFvyXBuTs4uox3hMyghoS75jod5s4ert+I5JGQFtqbObNM5vjhkBbWlsKIdUwoj1RjpGQFtKmAaKI1ast1QzAtpW3CdEjIC25TvxPpuUEdC2fNeK9ZiUEdC2PMfGSqOFa+Vl0UsJBSOgbZW6Y9J4XhIxAtrW2FAWSdOI7b0CRkDbSiZMFEfi+5BvRkA98WN88gQjoJ7E+eQJRkA98R0LE7cW0Wi2RC8lcIyAeuK7NmqNJq6V47eblBFQT+J8NikjoJ7cN5xDwjRiuZGOEVBPUgkTxeFcLCdEjIB6FtcJESOgnvlOPG+YMQLqme/auDy7iGbMxqSMgHrmOzZq9SamK/HaTcoIqGeds0njdknECKhn48MWEjHcTcoIqGfppIn7huI3JmUEtCOeE7+TJxgB7UjJjd/JE4yAdsRbPXkiTmNSRkA74jsWqvUmrs/HZ0zKCGhHfHf1Q/cz8bkkYgS0I+PDOZgGYjUhYgS0I5lkAmNDuVg9zI8R0I75jo3LvBwinfluvJ5jxghoxzqnVLda8RiTMgLaMc+xsbTSwI35quilBIIR0I6V3HjtJhUawYcffohf+qVfguu6yOfzeOSRR/DWW2+JXBL1YHzYgmHE54BeoRF8/vOfR71ex5tvvol3330Xhw4dwuc//3lMT0+LXBZtI5tKYKyQi82WamERzMzM4N/+7d9w+vRp/MIv/AI+8YlP4Pd///exuLiIDz74QNSyqEe+G5+zSYVF4DgOfu7nfg5//ud/joWFBdTrdbzyyivYs2cPDh8+vOnvq1arqFQqa74oep5jx2brhLAIDMPAP/3TP+EnP/kJBgcHkc1m8Y1vfAOvv/46hoeHN/19586dQ6FQ6H4Vi8UIV00dvtO+VxCHMWngEZw+fRqGYWz5deHCBbRaLXzpS1/Cnj178MMf/hDvvPMOnnrqKXzhC1/AtWvXNv35Z86cQblc7n5NTk4G/R+BeuA7NhZrDdy8o/6Y1GgFnPLNmzdx69atLb/n4MGD+OEPf4jHHnsMc3NzyOfz3X/3iU98As8++yxOnz7d059XqVRQKBRQLpfX/BwK14fX5/HYf/8BvvvcMXzKHxG9nHV28vciGfQfPjo6itHR0W2/b3GxfT1pmmtfjEzTRLMZz6ckxsmBkfa9goszC1JGsBPC3hMcO3YMw8PDOH78OP7lX/4FH374Ib7yla/g4sWLePLJJ0Uti3rUHpNmYzEhEhaB67p4/fXXcefOHXz2s5/FkSNH8KMf/Qh/8zd/g0OHDolaFu1AXM4mDfxyaCeOHDmC733veyKXQH3wXQv/eqUsehl9494h2rW47CZlBLRrnmPjTrWOWws10UvpCyOgXfNjspuUEdCueSOrJ08o/uaYEdCu5dIJ7MurPyZlBNSXOJxNygioL3E4m5QRUF86Z5OqPCZlBNQX37Ewv1zHrMJjUkZAfemeTarwJREjoL50nmOm8oSIEVBfrHQSewYzSt8wYwTUN99VezcpI6C++Y7aJ08wAuqb59i4OKPumJQRUN98x0ZluY7biyuil7IrjID61t1NquglESOgvnlO514BIyBNDWSScAcyyp5IxwgoECWFzyZlBBQIz7FxUdF7BYyAAqHyvQJGQIHwXRu3F1dwe1G93aSMgALhO+ruJmUEFAiVd5MyAgrEYDYFdyCt5JiUEVBgPMfmKwHpzXMsXGQEpLOSo+bJE4yAAuO5NmYXaigvqbWblBFQYHxFJ0SMgALjKXqvgBFQYAq5FEbsNC4r9qF7RkCBUnFCxAgoUCpOiBgBBUrFG2aMgALluxZm7tQwv6zOmJQRUKA6u0lVuiRiBBQoX8EP3TMCClTBSmHISil1NikjoMD5ij3pnhFQ4FT7vDEjoMC1zyblKwFprOTamLlTxZ1qXfRSesIIKHCqfd6YEVDgumNSRS6JGAEFbshKIZ9NKnOvgBFQ4AzDWH3INyMgjXmOzcsh0pvvWLwcIr35ro0b81Us1uQfkzICCoWn0G5SRkCh6Jw8ocJGOkZAoRix0xjMJpXYSMcIKBSGYcBX5KOWjIBC4zkWLvJyiHTWvmHGyyHSmOfYmK4sY6nWEL2ULTECCk33bNJZuS+JGAGFxnfV2E3KCCg0jp3GQCYp/YSIEVBoDMOAp8AeIkZAofJd+XeTMgIKlQonTzACCpXn2JgqL2N5Rd4xKSOgUJVWJ0QTs/JeEjECCpWnwG5SRkChGh3IwE4npJ4QMQIKVXtMKvfZpIyAQue7ck+IGAGFTvaTJxgBhc53LEyVl6QdkzICCp3v2Gi1gCtzcr4aMAIKXWc3qazHtTMCCt2ewQxyqYS0b44ZAYVO9t2kjIAi4Uv8pHtGQJHwXHlPnmAEFImSY2Pq9hKqdfnGpIyAIuE5Npot4MrckuilrMMIKBK+K+9uUkZAkdg7mEU2ZUq5kY4RUCRM04A3IufZpIyAIiPr2aSMgCIj69mkjIAi4zk2rswtolZvil7KGoyAIuM71uqYVK5XA0ZAkensJpXtkogRUGT25bNIJ03pNtIxAopMe0xqSXfDjBFQpHxXvpMnGAFFSsazSRkBRcpzbEzOLWGlIc+YlBFQpEqujUazhasS7SZlBBSp7tmkEl0SMQKK1P5CDumEKdWEiBFQpBKmgQOOJdWEiBFQ5GSbEDECipxsp1QzAoqc71iYnF1EXZIxKSOgyPmujXqzhanby6KXAoARkAC+s3o2qSTvCxgBRW5/IYtUwpDmzTEjoMglEyaKI5Y0D+5gBCSE79jS3DVmBCSETKdUMwISouTamJxdRKPZEr0URkBieI6NlUYLU7fF7yZlBCSEL9FuUkZAQtw3lEPSNKTYPsEISIjOmPSyBFuqGQEJI8uEiBGQML4ku0kZAQnjOxYmbokfkzICEsZzbdQaTVwrix2TMgISpuTIcTYpIyBh7hvOIWEawt8cMwISJpUwMT6cE37yBCMgoWSYEDECEkqGkycYAQnlOe3nmDUFjkkZAQlVcm1U601MV8R96J4RkFAynE3KCEio8WELpgGhnzdmBCRUOmlifFjsm2NGQMKJ3k3KCEg437F5OUR6810bl2cXhI1JGQEJ5zsWlleauDFfFfLnMwISzuucTSpoDxEjIOGKI7n2mFTQm2NGQMJlkgmMDeWEfa6AEZAU2hMivhKQxkTeK2AEJIWS295N2mpFPyZlBCQFz7GxtNIQMiZlBCSF7tmkAt4XMAKSQnHEgmGIOXmCEZAUsqkExgo5IW+OGQFJQ9SEiBGQNHxXzG5SRkDS6Jw8EfWYlBGQNDzHxkKtgZt3oh2TMgKSRskVczZpaBF87Wtfw8MPPwzLsjA0NLTh90xMTODJJ5+EZVnYs2cPvvKVr6Ber4e1JJLcgREx9wqSYf3gWq2Gp59+GseOHcOf/MmfrPv3jUYDTz75JPbt24cf//jHuHbtGn7t134NqVQKX//618NaFkksm0pgfyEb/YSoFbJXX321VSgU1v363//937dM02xNT093f+1b3/pWK5/Pt6rVas8/v1wutwC0yuVyEMslwf7LK2+3/utfvNv3z9nJ3wth7wnefvtt/PzP/zz27t3b/bXHH38clUoFP/3pTzf9fdVqFZVKZc0XxYfvRn/8irAIpqen1wQAoPvP09PTm/6+c+fOoVAodL+KxWKo66RoeasnT7QiHJPuKILTp0/DMIwtvy5cuBDWWgEAZ86cQblc7n5NTk6G+udRtHzHxp1qHbcWapH9mTt6Y3zq1CmcOHFiy+85ePBgTz9r3759eOedd9b82vXr17v/bjOZTAaZTKanP4PU47vtCdHlWwtwB6L5//OOIhgdHcXo6Gggf/CxY8fwta99DTdu3MCePXsAAN///veRz+fx4IMPBvJnkHq8kc7JE4s47I1E8meGNiKdmJjA7OwsJiYm0Gg08P777wMAHnjgAQwMDOCxxx7Dgw8+iC9+8Yv4gz/4A0xPT+OrX/0qvvSlL/F/6TWWSyewL5+N9s1x37OoTRw/frwFYN3XW2+91f2eS5cutT73uc+1crlcy3Xd1qlTp1orKys7+nM4Io2f//ztH7ee/5/v9fUzdvL3wmi1BHyoM0CVSgWFQgHlchn5fF70cigAv/W//hX/71oF//vLj+z6Z+zk7wX3DpF0fNfGpQh3kzICko7vWJhfrmNucSWSP48RkHSiPpuUEZB07r1XEAVGQNKx0knsGcxE9pBvRkBSivJsUkZAUopyNykjICl5js3LIdKb79goL63g9mL4u0kZAUmp86T7KMakjICk5Ed48gQjICkNZJJwBzKRfOieEZC0fMeKZEzKCEha7Y10vBwijXXOJg0bIyBpeY6NucUVlEPeTcoISFqds0nDfnPMCEhaBzrPMWMEpKt8NgXHTof+4A5GQFLzXTv0N8eMgKQWxXPMGAFJzY9gNykjIKn5ro3ZhRrKS+GNSRkBSa3zpPuJEF8NGAFJrXvyRIjvCxgBSa2QS2HETuNyiBvpGAFJz3MsvhKQ3nzHDvXDNYyApNeOgK8EpDHftTBzp4b55XDGpIyApNeZEIV1ScQISHolJ9wt1YyApFewUhiyUnwlIL15jh3aGUSMgJRQCvHzxoyAlBDm2aSMgJTguxZuzlexUK0H/rMZASnBD3FCxAhICX6I9woYASlhyEohn03ylYD0ZRhG+1jGEMakjICUEdbnjRkBKSOss0kZASnDc2xcr1SxWAt2TMoISBlhPb2GEZAyOidPBH1JxAhIGSN2GoOZJC4GfDYpIyBldMakfCUgrYVxNikjIKX4jh34Ue2MgJTiuzamK8tYqjUC+5mMgJTSPZt0NrhXA0ZASumeTRrgHiJGQEpxB9IYyCQDnRAxAlKKYRirEyJeDpHG2hMivhKQxnw32N2kjICU4zk2psrLWF4JZkzKCEg5nc8bBzUmZQSkHN9dfdJ9QO8LGAEpZ3QgAyudCOxzBYyAlNMek9qBPcKJEZCSSgFOiBgBKckLcDcpIyAl+Y6FqfJSIGNSRkBK8hwbrRZwZa7/VwNGQEoqrZ48EcQlESMgJe0ZzCCbMgP5qCUjICUZhrF6LCMjII0F9aR7RkDK8txgTp5gBKQs37FxdW4JtXqzr5/DCEhZvmOj2QIm+xyTMgJSVmc3ab/bJxgBKWvvYBaZpNn32aSMgJRlmsbqhIivBKSxIE6eYASkNN+1UV6s9fUzkgGthUiI33riPyBhGn39DL4SkNL6DQBgBESMgIgRkPYYAWmPEZD2GAFpjxGQ9hgBaY8RkPYYAWmPEZD2GAFpjxGQ9hgBaY8RkPYYAWmPEZD2GAFpjxGQ9hgBaY8RkPYYAWmPEZD2GAFpjxGQ9hgBaY8RkPYYAWmPEZD2GAFpjxGQ9hgBaY8RkPYYAWmPEZD2GAFpjxGQ9hgBaY8RkPaUf5h3q9UCAFQqFcErIZl0/j50/n5sRfkI5ufnAQDFYlHwSkhG8/PzKBQKW36P0eolFYk1m01MTU1hcHAQhnH36eaVSgXFYhGTk5PI5/MCV6iWuPz31mq1MD8/j7GxMZjm1lf9yr8SmKaJ8fHxTf99Pp9X+v+ZosThv7ftXgE6+MaYtMcISHuxjSCTyeDs2bPIZDKil6IUHf97U/6NMVG/YvtKQNQrRkDaYwSkPUZA2mMEpL3YRvDyyy/D931ks1kcPXoU77zzjuglSe0HP/gBvvCFL2BsbAyGYeCv//qvRS8pMrGM4LXXXsMLL7yAs2fP4r333sOhQ4fw+OOP48aNG6KXJq2FhQUcOnQIL7/8suilRC6W9wmOHj2KT33qU/ijP/ojAO1NdsViEV/+8pdx+vRpwauTn2EY+Ku/+is89dRTopcSidi9EtRqNbz77rt49NFHu79mmiYeffRRvP322wJXRrKKXQQzMzNoNBrYu3fvml/fu3cvpqenBa2KZBa7CIh2KnYRuK6LRCKB69evr/n169evY9++fYJWRTKLXQTpdBqHDx/GG2+80f21ZrOJN954A8eOHRO4MpKV8p8s28gLL7yA48eP48iRI3jooYfwzW9+EwsLC3jmmWdEL01ad+7cwUcffdT954sXL+L999/HyMgIDhw4IHBlEWjF1EsvvdQ6cOBAK51Otx566KHWP//zP4tektTeeuutFoB1X8ePHxe9tNDF8j4B0U7E7j0B0U4xAtIeIyDtMQLSHiMg7TEC0h4jIO0xAtIeIyDtMQLSHiMg7f1/tbrB7NIqjC4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import math\n", + "\n", + "def random_traj(BS = 1, device = \"cpu\", vel = 1):\n", + " # scale from 0.5 to 1.5\n", + " speed = (torch.rand(BS, device=device) + 0.5) * vel\n", + " \n", + " angle = torch.rand(BS, device=device)* math.pi\n", + " x = torch.sin(angle) * torch.arange(seq_len, device=device) / 2 * speed\n", + " y = torch.cos(angle) * torch.arange(seq_len, device=device) / 2 * speed\n", + " \n", + " traj = torch.stack([x, y], dim=-1)\n", + " return traj\n", + "\n", + "traj = random_traj()\n", + "\n", + "plt.plot(traj[..., 0], traj[..., 1], label=\"a\")\n", + "plt.gca().set_aspect(\"equal\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55b0f6be-d6b6-4d15-8f32-0d65fd98c780", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/torchdrive/tasks/diff_traj.py b/torchdrive/tasks/diff_traj.py index b722d7c..3e76c3e 100644 --- a/torchdrive/tasks/diff_traj.py +++ b/torchdrive/tasks/diff_traj.py @@ -504,7 +504,9 @@ def nll_loss_gmm_direct( log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2) ) # (batch_size, num_timestamps) reg_gmm_exp = (0.5 * 1 / (1 - rho**2)) * ( - (dx**2) / (std1**2) + (dy**2) / (std2**2) - 2 * rho * dx * dy / (std1 * std2) + (dx**2) / (std1**2) + + (dy**2) / (std2**2) + - 2 * rho * dx * dy / (std1 * std2) ) # (batch_size, num_timestamps) reg_loss = ((reg_gmm_log_coefficient + reg_gmm_exp) * gt_valid_mask).sum(dim=-1) @@ -717,6 +719,22 @@ def forward( return losses, nearest_trajs[..., :2], pred_traj +def random_traj( + BS: int, seq_len: int, device: object, vel: torch.Tensor +) -> torch.Tensor: + """Generates a random trajectory at the specified velocity.""" + + # scale from 0.5 to 1.5 + speed = (torch.rand(BS, device=device) + 0.5) * vel + + angle = torch.rand(BS, device=device) * math.pi + x = torch.sin(angle) * torch.arange(seq_len, device=device) / 2 * speed + y = torch.cos(angle) * torch.arange(seq_len, device=device) / 2 * speed + + traj = torch.stack([x, y], dim=-1) + return traj + + class DiffTraj(nn.Module, Van): """ A diffusion model for trajectory detection. @@ -812,14 +830,16 @@ def prepare_inputs( # calculate velocity between first two frames to allow model to understand current speed # TODO: convert this to a categorical embedding - velocity = positions[:, 1] - positions[:, 0] - assert positions.size(-1) == 2 - velocity = torch.linalg.vector_norm(velocity, dim=-1, keepdim=True) # approximately 0.5 fps since video is 12hz positions = positions[:, ::6] mask = mask[:, ::6] + # at 2 hz multiply by 2 to get true velocity + velocity = positions[:, 1] - positions[:, 0] + assert positions.size(-1) == 2 + velocity = torch.linalg.vector_norm(velocity, dim=-1, keepdim=True) * 2 + return positions, mask, velocity def forward( @@ -887,11 +907,14 @@ def forward( losses.update(pred_losses) pred_len = min(pred_traj.size(1), mask[0].sum().item()) + pred_traj_len = min(positions.size(1), pred_traj.size(1)) + + rand_traj = random_traj(BS, pred_traj_len, device=device, vel=velocity) dreamed_imgs = [] for i in range(BS): cond_img = batch.color[cam][i : i + 1, 0] - cond_traj = pred_traj[i : i + 1] + cond_traj = rand_traj[i : i + 1] dreamed_img = self.vista.generate(cond_img, cond_traj) # add last img (frame 10 == 1s) @@ -906,11 +929,10 @@ def forward( normalize_img(dream_img[0, 0]), ) - pred_traj_len = min(positions.size(1), pred_traj.size(1)) dream_target, dream_mask, dream_positions, dream_pred = compute_dream_pos( positions[:, :pred_traj_len], mask[:, :pred_traj_len], - pred_traj[:, :pred_traj_len], + rand_traj[:, :pred_traj_len], step=self.dream_steps, ) @@ -921,7 +943,6 @@ def forward( losses[f"dream-{k}"] = v if writer and log_text: - size = min(pred_traj.size(1), positions.size(1)) ctx.add_scalar( @@ -976,13 +997,11 @@ def forward( pred_positions = dream_pred[0, :pred_len].cpu() plt.plot( - pred_positions[..., 0], pred_positions[..., 1], label="og_pred" + pred_positions[..., 0], pred_positions[..., 1], label="rand_traj" ) pred_positions = dream_traj[0, :pred_len].cpu() - plt.plot( - pred_positions[..., 0], pred_positions[..., 1], label="new_pred" - ) + plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="pred") fig.legend() plt.gca().set_aspect("equal")