diff --git a/kilosort/gui/sorter.py b/kilosort/gui/sorter.py index 19e5440c..813c94ad 100644 --- a/kilosort/gui/sorter.py +++ b/kilosort/gui/sorter.py @@ -13,8 +13,8 @@ setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction, detect_spikes, cluster_spikes, save_sorting ) - from kilosort.io import save_preprocessing +from kilosort.utils import log_performance, log_cuda_details #logger = setup_logger(__name__) @@ -128,7 +128,11 @@ def run(self): ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0) - except: + except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + logger.exception('Out of memory error, printing performance...') + log_performance(logger, level='info') + log_cuda_details(logger) # This makes sure the full traceback is written to log file. logger.exception('Encountered error in `run_kilosort`:') # Annoyingly, this will print the error message twice for console diff --git a/kilosort/hierarchical.py b/kilosort/hierarchical.py index 79a6aba1..22340dc9 100644 --- a/kilosort/hierarchical.py +++ b/kilosort/hierarchical.py @@ -1,7 +1,5 @@ from scipy.sparse import csr_matrix import numpy as np -import faiss -from sklearn.cluster import KMeans def cluster_qr(M, iclust, iclust0): diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index efde7502..08c6db75 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -230,7 +230,12 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, save_extra_vars=save_extra_vars, save_preprocessed_copy=save_preprocessed_copy ) - except: + except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + logger.exception('Out of memory error, printing performance...') + log_performance(logger, level='info') + log_cuda_details(logger) + # This makes sure the full traceback is written to log file. logger.exception('Encountered error in `run_kilosort`:') # Annoyingly, this will print the error message twice for console, but diff --git a/kilosort/spikedetect.py b/kilosort/spikedetect.py index 40109eb3..b946977c 100644 --- a/kilosort/spikedetect.py +++ b/kilosort/spikedetect.py @@ -1,5 +1,7 @@ from io import StringIO +import os import logging +import warnings logger = logging.getLogger(__name__) from torch.nn.functional import max_pool2d, avg_pool2d, conv1d, max_pool1d @@ -11,8 +13,6 @@ from kilosort.utils import template_path, log_performance -device = torch.device('cuda') - def my_max2d(X, dt): Xmax = max_pool2d( @@ -72,9 +72,16 @@ def extract_wPCA_wTEMP(ops, bfile, nt=61, twav_min=20, Th_single_ch=6, nskip=25, model = TruncatedSVD(n_components=ops['settings']['n_pcs']).fit(clips) wPCA = torch.from_numpy(model.components_).to(device).float() - model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips) - wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float() - wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="") + # Prevents memory leak for KMeans when using MKL on Windows + msg = 'KMeans is known to have a memory leak on Windows with MKL' + nthread = os.environ.get('OMP_NUM_THREADS', msg) + os.environ['OMP_NUM_THREADS'] = '7' + model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips) + wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float() + wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5 + os.environ['OMP_NUM_THREADS'] = nthread return wPCA, wTEMP