Coverage for kwave/utils/filters.py: 8%
191 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1import logging
2from typing import List, Optional, Tuple, Union
4import numpy as np
5import scipy
6from scipy.fftpack import fft, fftshift, ifft, ifftshift
7from scipy.signal import lfilter
9from kwave.utils.conversion import create_index_at_dim
11from .checks import is_number
12from .data import scale_SI
13from .math import find_closest, gaussian, next_pow2, sinc
14from .matrix import num_dim, num_dim2
15from .signals import get_win
18def single_sided_correction(func_fft: np.ndarray, fft_len: int, dim: int) -> np.ndarray:
19 """Correct the single-sided magnitude by multiplying the symmetric points by 2.
21 The DC and Nyquist components are unique and are not multiplied by 2.
22 The Nyquist component only exists for even numbered FFT lengths.
24 Args:
25 func_fft: The FFT of the function to be corrected.
26 fft_len: The length of the FFT.
27 dim: The dimension along which to apply the correction.
29 Returns:
30 None, modifies the input array in place to have the corrected FFT of the function.
31 """
32 # Create a slice object for each dimension
33 slices = [slice(None)] * func_fft.ndim
35 if fft_len % 2: # odd FFT length
36 # Set slice for the specified dimension to select all elements except the first
37 slices[dim] = slice(1, None)
38 else: # even FFT length
39 # Set slice for the specified dimension to select all elements except first and last
40 slices[dim] = slice(1, -1)
42 # Apply the slicing and multiply by 2
43 func_fft[tuple(slices)] *= 2
46def spect(
47 func: np.ndarray,
48 fs: float,
49 dim: Optional[Union[int, str]] = "auto",
50 fft_len: Optional[int] = 0,
51 power_two: Optional[bool] = False,
52 unwrap_phase: Optional[bool] = False,
53 window: Optional[str] = "Rectangular",
54) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
55 """
56 Calculates the spectrum of a signal.
58 Args:
59 func: The signal to analyse.
60 fs: The sampling frequency in Hz.
61 dim: The dimension over which the spectrum is calculated. Defaults to 'auto'.
62 fft_len: The length of the FFT. If the set length is smaller than the signal length, the default value is used
63 instead (default = signal length).
64 power_two: Whether the FFT length is forced to be the next highest power of 2 (default = False).
65 unwrap_phase: Whether to unwrap the phase spectrum (default = False).
66 window: (str) The window type used to filter the signal before the FFT is taken (default = 'Rectangular'). Any valid
67 input types for get_win may be used.
69 Returns:
70 f: Frequency array
71 func_as: Single-sided amplitude spectrum
72 func_ps: Single-sided phase spectrum
74 Raises:
75 ValueError: If the input signal is scalar or has more than 4 dimensions.
76 """
78 # check the size of the input
79 sz = func.shape
81 # check input isn't scalar
82 if np.size(func) == 1:
83 raise ValueError("Input signal cannot be scalar.")
85 # check input doesn't have more than 4 dimensions
86 if len(sz) > 4:
87 raise ValueError("Input signal must have 1, 2, 3, or 4 dimensions.")
89 # automatically set dimension to first non - singleton dimension
90 if dim == "auto":
91 dim = np.argmax(np.array(sz) > 1)
92 if sz[dim] <= 1:
93 raise ValueError("All dimensions are singleton; unable to determine valid dimension.")
95 # assign the number of points being analysed
96 func_length = sz[dim]
98 # set the length of the FFT
99 if fft_len <= 0 or fft_len < func_length:
100 if power_two:
101 # find an appropriate FFT length of the form 2 ^ N that is equal to or
102 # larger than the length of the input signal
103 fft_len = 2 ** (next_pow2(func_length))
104 else:
105 # set the FFT length to the function length
106 fft_len = func_length
108 # window the signal, reshaping the window to be in the correct direction
109 win, coherent_gain = get_win(func_length, type_=window, symmetric=False)
110 win_shape = [1] * len(sz)
111 win_shape[dim] = func_length
112 win = np.reshape(win, tuple(win_shape))
113 func = win * func
115 # compute the fft using the defined FFT length, if fft_len >
116 # func_length, the input signal is padded with zeros
117 func_fft = np.fft.fft(func, n=fft_len, axis=dim)
119 # correct for the magnitude scaling of the FFT and the coherent gain of the
120 # window(note that the correction is equal to func_length NOT fft_len)
121 epsilon = 1e-10 # Small value to prevent division by zero
122 func_fft = func_fft / (func_length * coherent_gain + epsilon)
124 # reduce to a single sided spectrum where the number of unique points for
125 # even numbered FFT lengths is given by N / 2 + 1, and for odd(N + 1) / 2
126 num_unique_pts = int(np.ceil((fft_len + 1) / 2))
127 slicing = [slice(None)] * len(sz)
128 slicing[dim] = slice(0, num_unique_pts)
129 func_fft = func_fft[tuple(slicing)]
131 single_sided_correction(func_fft, fft_len, dim)
133 # create the frequency axis variable
134 f = np.arange(0, num_unique_pts) * fs / fft_len
136 # calculate the amplitude spectrum
137 func_as = np.abs(func_fft)
139 # calculate the phase spectrum
140 func_ps = np.angle(func_fft)
142 # unwrap the phase spectrum if required
143 if unwrap_phase:
144 func_ps = np.unwrap(func_ps, axis=dim)
146 return f, func_as, func_ps
149def extract_amp_phase(
150 data: np.ndarray, fs: float, source_freq: float, dim: Tuple[str, int] = "auto", fft_padding: int = 3, window: str = "Hanning"
151) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
152 """
153 Extract the amplitude and phase information at a specified frequency from a vector or matrix of time series data.
155 The amplitude and phase are extracted from the frequency spectrum, which is calculated using a windowed and zero
156 padded FFT. The values are extracted at the frequency closest to source_freq. By default, the time dimension is set
157 to the highest non-singleton dimension.
159 Args:
160 data: Matrix of time signals [s]
161 fs: Sampling frequency [Hz]
162 source_freq: Frequency at which the amplitude and phase should be extracted [Hz]
163 dim: The time dimension of the input data. If 'auto', the highest non-singleton dimension is used.
164 fft_padding: The amount of zero padding to apply to the FFT.
165 window: The windowing function to use for the FFT.
167 Returns:
168 A tuple of the amplitude, phase and frequency of the extracted signal.
170 """
172 # check for the dim input
173 if dim == "auto":
174 dim = num_dim(data)
175 if dim == 2 and data.shape[1] == 1:
176 dim = 1
178 # create 1D window and reshape to be oriented in the time dimension of the
179 # input data
180 win, coherent_gain = get_win(data.shape[dim], window)
181 # this list magic in Python comes from the use of ones in MATLAB
182 # TODO: simplify this
183 win = np.reshape(win, [1] * (dim - 1) + [len(win)])
185 # apply window to time dimension of input data
186 data = win * data
188 # compute amplitude and phase spectra
189 f, func_as, func_ps = spect(data, fs, fft_len=fft_padding * data.shape[dim], dim=dim)
191 # correct for coherent gain
192 func_as = func_as / coherent_gain
194 # find the index of the frequency component closest to source_freq
195 _, f_index = find_closest(f, source_freq)
197 # get size of output variable, collapsing the time dimension
198 sz = list(data.shape)
199 sz[dim - 1] = 1
201 # extract amplitude and relative phase at freq_index
202 # Create a tuple of slice objects with the frequency index at the correct dimension
203 idx = create_index_at_dim(func_as.ndim, dim, f_index)
204 amp = func_as[idx]
205 phase = func_ps[idx]
207 return amp.squeeze(), phase.squeeze(), f[f_index]
210def fwhm(f, x):
211 """
212 fwhm calculates the Full Width at Half Maximum (FWHM) of a positive
213 1D input function f(x) with spacing given by x.
216 Args:
217 f: f(x)
218 x: x
220 Returns:
221 FWHM of f(x) along with the position of the leading and trailing edges as a tuple
223 """
225 # ensure f is numpy array
226 f = np.array(f)
227 if len(f.squeeze().shape) != 1:
228 raise ValueError("Input function must be 1-dimensional.")
230 def lin_interp(x, y, i, half):
231 return x[i] + (x[i + 1] - x[i]) * ((half - y[i]) / (y[i + 1] - y[i]))
233 def half_max_x(x, y):
234 half = max(y) / 2.0
235 signs = np.sign(np.add(y, -half))
236 zero_crossings = signs[0:-2] != signs[1:-1]
237 zero_crossings_i = np.where(zero_crossings)[0]
238 return [lin_interp(x, y, zero_crossings_i[0], half), lin_interp(x, y, zero_crossings_i[1], half)]
240 hmx = half_max_x(x, f)
241 fwhm_val = hmx[1] - hmx[0]
243 return fwhm_val, tuple(hmx)
246def gaussian_filter(
247 signal: Union[np.ndarray, List[float]], fs: float, frequency: float, bandwidth: float
248) -> Union[np.ndarray, List[float]]:
249 """
250 Applies a frequency domain Gaussian filter with the
251 specified center frequency and percentage bandwidth to the input
252 signal. If the input signal is given as a matrix, the filter is
253 applied to each matrix row.
255 Args:
256 signal: Signal to filter [channel, samples]
257 fs: Sampling frequency [Hz]
258 frequency: Center frequency of filter [Hz]
259 bandwidth: Bandwidth of filter in percentage
261 Returns:
262 The filtered signal
264 """
266 N = signal.shape[-1]
267 if N % 2 == 0:
268 f = np.arange(-N / 2, N / 2) * fs / N
269 else:
270 f = np.arange(-(N - 1) / 2, (N - 1) / 2 + 1) * fs / N
272 mean = frequency
273 variance = (bandwidth / 100 * frequency / (2 * np.sqrt(2 * np.log(2)))) ** 2
274 magnitude = 1
276 # create double-sided Gaussain filter
277 gfilter = np.fmax(gaussian(f, magnitude, mean, variance), gaussian(f, magnitude, -mean, variance))
279 # add dimensions to filter to be broadcastable to signal shape
280 if len(signal.shape) == 2:
281 gfilter = gfilter[np.newaxis, :]
283 # apply filter
284 signal = np.real(ifft(ifftshift(gfilter * fftshift(fft(signal)))))
286 return signal
289def filter_time_series(
290 kgrid: "kWaveGrid",
291 medium: "kWaveMedium",
292 signal: np.ndarray,
293 ppw: Optional[int] = 3,
294 rppw: Optional[int] = 0,
295 stop_band_atten: Optional[int] = 60,
296 transition_width: Optional[float] = 0.1,
297 zerophase: Optional[bool] = False,
298 plot_spectrums: Optional[bool] = False,
299 plot_signals: Optional[bool] = False,
300) -> np.ndarray:
301 """
302 Filters a time-domain signal using the Kaiser windowing method.
304 The filter is designed to attenuate high-frequency noise in the signal while preserving
305 the signal's important features. The filter design parameters can be adjusted to trade off
306 between the amount of noise reduction and the amount of signal distortion.
308 Args:
309 kgrid: The kWaveGrid grid.
310 medium: The kWavemedium.
311 signal: The time-domain signal to filter.
312 ppw: The minimum number of points per wavelength in the signal. This determines the
313 minimum frequency that will be passed through the filter. Higher values of ppw
314 result in a lower cut-off frequency and more noise reduction, but may also result
315 in more signal distortion. Defaults to 3.
316 rppw: The number of points per wavelength in the smoothing ramp applied to the beginning
317 of the signal. This can be used to reduce ringing artifacts caused by the sudden
318 transition from the filtered signal to the unfiltered signal. Defaults to 0.
319 stop_band_atten: The stop-band attenuation in dB. This determines the steepness of the
320 filter's transition from the pass-band to the stop-band. Higher values result in a
321 steeper transition and more noise reduction, but may also result in more signal
322 distortion. Defaults to 60.
323 transition_width: The transition width as a proportion of the sampling frequency. This
324 determines the width of the transition region between the pass-band and the stop-band.
325 Smaller values result in a narrower transition and more noise reduction, but may also
326 result in more signal distortion. Defaults to 0.1.
327 zerophase: Whether to implement the filter as a zero-phase filter. Zero-phase filtering
328 can be used to preserve the phase information in the signal, which can be important
329 for some applications. However, it may also result in more signal distortion.
330 Defaults to False.
331 plot_spectrums: Whether to plot the spectrums of the input and filtered signals.
332 Defaults to False.
333 plot_signals: Whether to plot the input and filtered signals. Defaults to False.
335 Raises:
336 ValueError: Checks correctness of passed arguments.
337 NotImplementedError: Cannot currently plot anything.
339 Returns:
340 The filtered signal.
342 """
344 # check the input is a row vector
345 if num_dim2(signal) == 1:
346 m, n = signal.shape
347 if n == 1:
348 signal = signal.T
349 rotate_signal = True
350 else:
351 rotate_signal = False
352 else:
353 raise TypeError("Input signal must be a vector.")
355 # update the command line status
356 logging.log(logging.INFO, "Filtering input signal...")
358 # extract the time step
359 assert not isinstance(kgrid.t_array, str) or kgrid.t_array != "auto", "kgrid.t_array must be explicitly defined."
361 # compute the sampling frequency
362 fs = 1 / kgrid.dt
364 # extract the minimum sound speed
365 if medium.sound_speed is not None:
366 # for the fluid code, use medium.sound_speed
367 c0 = medium.sound_speed.min()
369 elif all(medium.is_defined("sound_speed_compression", "sound_speed_shear")): # pragma: no cover
370 # for the elastic code, combine the shear and compression sound speeds and remove zeros values
371 ss = np.hstack([medium.sound_speed_compression, medium.sound_speed_shear])
372 ss[ss == 0] = np.nan
373 c0 = np.nanmin(ss)
375 # cleanup unused variables
376 del ss
378 else:
379 raise ValueError(
380 "The input fields medium.sound_speed or medium.sound_speed_compression and medium.sound_speed_shear must " "be defined."
381 )
383 # extract the maximum supported frequency (two points per wavelength)
384 f_max = kgrid.k_max_all * c0 / (2 * np.pi)
386 # calculate the filter cut-off frequency
387 filter_cutoff_f = 2 * f_max / ppw
389 # calculate the wavelength of the filter cut-off frequency as a number of time steps
390 filter_wavelength = (2 * np.pi / filter_cutoff_f) / kgrid.dt
392 # filter the signal if required
393 if ppw != 0:
394 filtered_signal = apply_filter(
395 signal,
396 fs,
397 float(filter_cutoff_f),
398 "LowPass",
399 zero_phase=zerophase,
400 stop_band_atten=float(stop_band_atten),
401 transition_width=transition_width,
402 )
404 # add a start-up ramp if required
405 if rppw != 0:
406 # calculate the length of the ramp in time steps
407 ramp_length = round(rppw * filter_wavelength / (2 * ppw))
409 # create the ramp
410 ramp = (-np.cos(np.arange(0, ramp_length - 1 + 1) * np.pi / ramp_length) + 1) / 2
412 # apply the ramp
413 filtered_signal[1:ramp_length] = filtered_signal[1:ramp_length] * ramp
415 # restore the original vector orientation if modified
416 if rotate_signal:
417 filtered_signal = filtered_signal.T
419 # update the command line status
420 logging.log(logging.INFO, f" maximum frequency supported by kgrid: {scale_SI(f_max)}Hz (2 PPW)")
421 if ppw != 0:
422 logging.log(logging.INFO, f" filter cutoff frequency: {scale_SI(filter_cutoff_f)}Hz ({ppw} PPW)")
423 if rppw != 0:
424 logging.log(
425 logging.INFO, f" ramp frequency: {scale_SI(2 * np.pi / (2 * ramp_length * kgrid.dt))}Hz (ramp_points_per_wavelength PPW)"
426 )
427 logging.log(logging.INFO, " computation complete.")
429 # plot signals if required
430 if plot_signals or plot_spectrums:
431 raise NotImplementedError
433 return filtered_signal
436def apply_filter(
437 signal: np.ndarray,
438 fs: float,
439 cutoff_f: float,
440 filter_type: str,
441 zero_phase: Optional[bool] = False,
442 transition_width: Optional[float] = 0.1,
443 stop_band_atten: Optional[int] = 60,
444) -> np.ndarray:
445 """
446 Filters an input signal using a FIR filter with Kaiser window coefficients based on the specified cut-off frequency and filter type.
447 Both causal and zero phase filters can be applied.
449 Args:
450 signal: The input signal.
451 fs: The sampling frequency of the signal.
452 cutoff_f: The cut-off frequency of the filter.
453 filter_type: The type of filter to apply, either 'HighPass', 'LowPass' or 'BandPass'.
454 zero_phase: Whether to apply a zero-phase filter. Defaults to False.
455 transition_width: The transition width of the filter, as a proportion of the sampling frequency. Defaults to 0.1.
456 stop_band_atten: The stop-band attenuation of the filter in dB. Defaults to 60.
458 Returns:
459 The filtered signal.
461 """
463 # for a bandpass filter, use applyFilter recursively
464 if filter_type == "BandPass":
465 assert isinstance(cutoff_f, list), "List of two frequencies required as for filter type 'BandPass'"
466 assert len(cutoff_f) == 2, "List of two frequencies required as for filter type 'BandPass'"
468 # apply the low pass filter
469 func_filt_lp = apply_filter(
470 signal, fs, cutoff_f[1], "LowPass", stop_band_atten=stop_band_atten, transition_width=transition_width, zero_phase=zero_phase
471 )
473 # apply the high pass filter
474 filtered_signal = apply_filter(
475 func_filt_lp,
476 fs,
477 cutoff_f[0],
478 "HighPass",
479 stop_band_atten=stop_band_atten,
480 transition_width=transition_width,
481 zero_phase=zero_phase,
482 )
484 else:
485 # check filter type
486 if filter_type == "LowPass":
487 high_pass = False
488 elif filter_type == "HighPass":
489 high_pass = True
490 cutoff_f = fs / 2 - cutoff_f
491 else:
492 raise ValueError(f'Unknown filter type {filter_type}. Options are "LowPass, HighPass, BandPass"')
494 # make sure input is the correct way around
495 m, n = signal.shape
496 if m > n:
497 signal = signal.T
499 # correct the stopband attenuation if a zero phase filter is being used
500 if zero_phase:
501 stop_band_atten = stop_band_atten / 2
503 # decide the filter order
504 N = np.ceil((stop_band_atten - 7.95) / (2.285 * (transition_width * np.pi)))
505 N = int(N)
507 # construct impulse response of ideal bandpass filter h(n), a sinc function
508 fc = cutoff_f / fs # normalised cut-off
509 n = np.arange(-N / 2, N / 2)
510 h = 2 * fc * sinc(2 * np.pi * fc * n)
512 # if no window is given, use a Kaiser window
513 # TODO: there is no window argument
514 if "w" not in locals():
515 # compute Kaiser window parameter beta
516 if stop_band_atten > 50:
517 beta = 0.1102 * (stop_band_atten - 8.7)
518 elif stop_band_atten >= 21:
519 beta = 0.5842 * (stop_band_atten - 21) ** 0.4 + 0.07886 * (stop_band_atten - 21)
520 else:
521 beta = 0
523 # construct the Kaiser smoothing window w(n)
524 m = np.arange(0, N)
525 w = np.real(scipy.special.iv(0, np.pi * beta * np.sqrt(1 - (2 * m / N - 1) ** 2))) / np.real(scipy.special.iv(0, np.pi * beta))
527 # window the ideal impulse response with Kaiser window to obtain the FIR filter coefficients hw(n)
528 hw = w * h
530 # modify to make a high_pass filter
531 if high_pass:
532 hw = (-1 * np.ones((1, len(hw))) ** (np.arange(1, len(hw) + 1))) * hw
534 # add some zeros to allow the reverse (zero phase) filtering room to work
535 L = signal.size # length of original input signal
536 filtered_signal = np.hstack([np.zeros((1, N)), signal]).squeeze()
538 # apply the filter
539 filtered_signal = lfilter(hw.squeeze(), 1, filtered_signal)
540 if zero_phase:
541 filtered_signal = np.fliplr(lfilter(hw.squeeze(), 1, filtered_signal[np.arange(L + N, 1, -1)]))
543 # remove the part of the signal corresponding to the added zeros
544 filtered_signal = filtered_signal[N:]
546 return filtered_signal[np.newaxis]
549def smooth(a: np.ndarray, restore_max: Optional[bool] = False, window_type: Optional[str] = "Blackman") -> np.ndarray:
550 """
551 Smooths a matrix.
553 Args:
554 a: The spatial distribution to smooth.
555 restore_max: Boolean controlling whether the maximum value is restored after smoothing. Defaults to False.
556 window_type: Shape of the smoothing window. Any valid inputs to get_win are supported. Defaults to 'Blackman'.
558 Returns:
559 a_sm: The smoothed matrix.
561 """
563 DEF_USE_ROTATION = True
565 if a.dtype == bool:
566 a = a.astype(int)
568 assert is_number(a) and np.all(~np.isinf(a))
569 assert isinstance(restore_max, bool)
570 assert isinstance(window_type, str)
572 # get the grid size
573 grid_size = a.shape
575 # remove singleton dimensions
576 if num_dim2(a) != len(grid_size):
577 grid_size = np.squeeze(grid_size)
579 # use a symmetric filter for odd grid sizes, and a non-symmetric filter for
580 # even grid sizes to ensure the DC component of the window has a value of
581 # unity
582 window_symmetry = (np.array(grid_size) % 2).astype(bool)
584 # get the window, taking the absolute value to discard machine precision
585 # negative values
586 from .signals import get_win
588 win, _ = get_win(grid_size, type_=window_type, rotation=DEF_USE_ROTATION, symmetric=window_symmetry)
589 win = np.abs(win)
591 # rotate window if input mat is (1, N)
592 if a.shape[0] == 1: # is row?
593 win = win.T
595 # apply the filter
596 a_sm = np.real(np.fft.ifftn(np.fft.fftn(a) * np.fft.ifftshift(win)))
598 # restore magnitude if required
599 if restore_max:
600 a_sm = (np.abs(a).max() / np.abs(a_sm).max()) * a_sm
601 return a_sm