-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #151 from decargroup/bugfix/150-fix-performance-is…
…sues-in-predict_trajectory Fix performance issues in `pykoop.predict_trajectory()`
- Loading branch information
Showing
13 changed files
with
437 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -288,6 +288,7 @@ dmypy.json | |
|
||
# profiling data | ||
.prof | ||
*.prof | ||
|
||
### Vim ### | ||
# Swap | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""Benchmark :func:`pykoop.predict_trajectory()`. | ||
Outputs a ``.prof`` file that can be visualized using ``snakeviz``. | ||
""" | ||
|
||
import cProfile | ||
|
||
import pykoop | ||
|
||
|
||
def main(): | ||
"""Benchmark :func:`pykoop.predict_trajectory()`.""" | ||
pykoop.set_config(skip_validation=True) | ||
|
||
# Get example mass-spring-damper data | ||
eg = pykoop.example_data_pendulum() | ||
# Create pipeline | ||
kp = pykoop.KoopmanPipeline( | ||
lifting_functions=[ | ||
('pl', pykoop.PolynomialLiftingFn(order=2)), | ||
('dl', pykoop.DelayLiftingFn(n_delays_state=2, n_delays_input=2)), | ||
], | ||
regressor=pykoop.Edmd(alpha=1), | ||
) | ||
# Fit the pipeline | ||
kp.fit( | ||
eg['X_train'], | ||
n_inputs=eg['n_inputs'], | ||
episode_feature=eg['episode_feature'], | ||
) | ||
# Predict using the pipeline | ||
with cProfile.Profile() as pr: | ||
X_pred = kp.predict_trajectory(eg['X_train']) | ||
pr.dump_stats('benchmark_predict_trajectory.prof') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""Benchmark :func:`pykoop.koopman_pipeline._unique_episodes()`. | ||
It's very hard to do better than :func:`pandas.unique`, so I will stop messing | ||
with it. Another approach could be to store the unique episodes somewhere for | ||
reuse, but that could be convoluted. | ||
""" | ||
|
||
import timeit | ||
|
||
import numpy as np | ||
|
||
import pykoop | ||
|
||
|
||
def main(): | ||
"""Benchmark :func:`pykoop.koopman_pipeline._unique_episodes()`.""" | ||
pykoop.set_config(skip_validation=True) | ||
"""Benchmark :func:`pykoop.unique_episodes()`.""" | ||
X_ep = np.array([0] * 100 + [1] * 1000 + [2] * 500 + [10] * 1000) | ||
n_loop = 100_000 | ||
time = timeit.timeit(lambda: pykoop.unique_episodes(X_ep), number=n_loop) | ||
print(f' Total time: {time} s') | ||
print(f'Time per loop: {time / n_loop} s') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
BSD 3-Clause License | ||
|
||
Copyright (c) 2007-2022 The scikit-learn developers. | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
* Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
* Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
* Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
"""Global configuration for ``pykoop``. | ||
Based on code from the ``scikit-learn`` project. Original author of the file | ||
is Joel Nothman. Specifically, the original file is | ||
``scikit-learn/sklearn/_config.py`` from commit ``894b335``. | ||
Distributed under the BSD-3-Clause License. See ``LICENSE`` in this directory | ||
for the full license. | ||
""" | ||
|
||
import contextlib | ||
import os | ||
import threading | ||
from typing import Any, Dict, Optional | ||
|
||
_global_config = { | ||
'skip_validation': False, | ||
} | ||
_threadlocal = threading.local() | ||
|
||
|
||
def _get_threadlocal_config() -> Dict[str, Any]: | ||
"""Get a threadlocal mutable configuration. | ||
If the configuration does not exist, copy the default global configuration. | ||
""" | ||
if not hasattr(_threadlocal, 'global_config'): | ||
_threadlocal.global_config = _global_config.copy() | ||
return _threadlocal.global_config | ||
|
||
|
||
def get_config() -> Dict[str, Any]: | ||
"""Retrieve current values for configuration set by :func:`set_config`. | ||
Returns | ||
------- | ||
config : dict | ||
Keys are parameter names that can be passed to :func:`set_config`. | ||
Examples | ||
-------- | ||
Get configuation | ||
>>> pykoop.get_config() | ||
{'skip_validation': False} | ||
""" | ||
# Return a copy of the threadlocal configuration so that users will | ||
# not be able to modify the configuration with the returned dict. | ||
return _get_threadlocal_config().copy() | ||
|
||
|
||
def set_config(skip_validation: Optional[bool] = None) -> None: | ||
"""Set global configuration. | ||
Parameters | ||
---------- | ||
skip_validation : Optional[bool] | ||
Set to ``True`` to skip all parameter validation. Can save significant | ||
time, especially in func:`pykoop.predict_trajectory()` but risks | ||
crashes. | ||
Examples | ||
-------- | ||
Set configuation | ||
>>> pykoop.set_config(skip_validation=False) | ||
""" | ||
local_config = _get_threadlocal_config() | ||
# Set parameters | ||
if skip_validation is not None: | ||
local_config['skip_validation'] = skip_validation | ||
|
||
|
||
@contextlib.contextmanager | ||
def config_context(*, skip_validation=None): | ||
"""Context manager for global configuration. | ||
Parameters | ||
---------- | ||
skip_validation : Optional[bool] | ||
Set to ``True`` to skip all parameter validation. Can save significant | ||
time, especially in func:`pykoop.predict_trajectory()` but risks | ||
crashes. | ||
Examples | ||
-------- | ||
Use config context manager | ||
>>> with pykoop.config_context(skip_validation=False): | ||
... pykoop.KoopmanPipeline() | ||
KoopmanPipeline() | ||
""" | ||
old_config = get_config() | ||
set_config(skip_validation=skip_validation) | ||
|
||
try: | ||
yield | ||
finally: | ||
set_config(**old_config) |
Oops, something went wrong.