Skip to content

Commit

Permalink
added fix for OMP-related warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Oct 4, 2024
1 parent aaac099 commit 3f12ff4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
8 changes: 6 additions & 2 deletions kilosort/gui/sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
detect_spikes, cluster_spikes, save_sorting
)

from kilosort.io import save_preprocessing
from kilosort.utils import log_performance, log_cuda_details

#logger = setup_logger(__name__)

Expand Down Expand Up @@ -128,7 +128,11 @@ def run(self):
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)

except:
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
logger.exception('Out of memory error, printing performance...')
log_performance(logger, level='info')
log_cuda_details(logger)
# This makes sure the full traceback is written to log file.
logger.exception('Encountered error in `run_kilosort`:')
# Annoyingly, this will print the error message twice for console
Expand Down
2 changes: 0 additions & 2 deletions kilosort/hierarchical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from scipy.sparse import csr_matrix
import numpy as np
import faiss
from sklearn.cluster import KMeans


def cluster_qr(M, iclust, iclust0):
Expand Down
7 changes: 6 additions & 1 deletion kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,12 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
save_extra_vars=save_extra_vars,
save_preprocessed_copy=save_preprocessed_copy
)
except:
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
logger.exception('Out of memory error, printing performance...')
log_performance(logger, level='info')
log_cuda_details(logger)

# This makes sure the full traceback is written to log file.
logger.exception('Encountered error in `run_kilosort`:')
# Annoyingly, this will print the error message twice for console, but
Expand Down
17 changes: 12 additions & 5 deletions kilosort/spikedetect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from io import StringIO
import os
import logging
import warnings
logger = logging.getLogger(__name__)

from torch.nn.functional import max_pool2d, avg_pool2d, conv1d, max_pool1d
Expand All @@ -11,8 +13,6 @@

from kilosort.utils import template_path, log_performance

device = torch.device('cuda')


def my_max2d(X, dt):
Xmax = max_pool2d(
Expand Down Expand Up @@ -72,9 +72,16 @@ def extract_wPCA_wTEMP(ops, bfile, nt=61, twav_min=20, Th_single_ch=6, nskip=25,
model = TruncatedSVD(n_components=ops['settings']['n_pcs']).fit(clips)
wPCA = torch.from_numpy(model.components_).to(device).float()

model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float()
wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="")
# Prevents memory leak for KMeans when using MKL on Windows
msg = 'KMeans is known to have a memory leak on Windows with MKL'
nthread = os.environ.get('OMP_NUM_THREADS', msg)
os.environ['OMP_NUM_THREADS'] = '7'
model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float()
wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
os.environ['OMP_NUM_THREADS'] = nthread

return wPCA, wTEMP

Expand Down

0 comments on commit 3f12ff4

Please sign in to comment.