From 32abc8da301352a0cb417f08f95d4477b178663e Mon Sep 17 00:00:00 2001 From: Igor Sugak Date: Thu, 17 Oct 2024 20:57:00 -0700 Subject: [PATCH] [Codemod][PSS] Upgrade fbcode/pytorch to Python Scientific Stack 2 (#3845) Summary: X-link: https://github.com/pytorch/opacus/pull/680 X-link: https://github.com/pytorch/captum/pull/1387 X-link: https://github.com/pytorch/botorch/pull/2584 Differential Revision: D64008689 --- test/torchaudio_unittest/prototype/functional/dsp_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/torchaudio_unittest/prototype/functional/dsp_utils.py b/test/torchaudio_unittest/prototype/functional/dsp_utils.py index fb0300a9d6..44c0cac3c3 100644 --- a/test/torchaudio_unittest/prototype/functional/dsp_utils.py +++ b/test/torchaudio_unittest/prototype/functional/dsp_utils.py @@ -1,4 +1,5 @@ import numpy as np +import numpy.typing as npt def oscillator_bank( @@ -43,8 +44,8 @@ def freq_ir(magnitudes): def exp_sigmoid( - input: np.ndarray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7 -) -> np.ndarray: + input: npt.NDArray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7 +) -> npt.NDArray: """Exponential Sigmoid pointwise nonlinearity (Numpy version). Implements the equation: ``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``