Coverage for kwave/utils/filters.py: 8%

191 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import logging 

2from typing import List, Optional, Tuple, Union 

3 

4import numpy as np 

5import scipy 

6from scipy.fftpack import fft, fftshift, ifft, ifftshift 

7from scipy.signal import lfilter 

8 

9from kwave.utils.conversion import create_index_at_dim 

10 

11from .checks import is_number 

12from .data import scale_SI 

13from .math import find_closest, gaussian, next_pow2, sinc 

14from .matrix import num_dim, num_dim2 

15from .signals import get_win 

16 

17 

18def single_sided_correction(func_fft: np.ndarray, fft_len: int, dim: int) -> np.ndarray: 

19 """Correct the single-sided magnitude by multiplying the symmetric points by 2. 

20 

21 The DC and Nyquist components are unique and are not multiplied by 2. 

22 The Nyquist component only exists for even numbered FFT lengths. 

23 

24 Args: 

25 func_fft: The FFT of the function to be corrected. 

26 fft_len: The length of the FFT. 

27 dim: The dimension along which to apply the correction. 

28 

29 Returns: 

30 None, modifies the input array in place to have the corrected FFT of the function. 

31 """ 

32 # Create a slice object for each dimension 

33 slices = [slice(None)] * func_fft.ndim 

34 

35 if fft_len % 2: # odd FFT length 

36 # Set slice for the specified dimension to select all elements except the first 

37 slices[dim] = slice(1, None) 

38 else: # even FFT length 

39 # Set slice for the specified dimension to select all elements except first and last 

40 slices[dim] = slice(1, -1) 

41 

42 # Apply the slicing and multiply by 2 

43 func_fft[tuple(slices)] *= 2 

44 

45 

46def spect( 

47 func: np.ndarray, 

48 fs: float, 

49 dim: Optional[Union[int, str]] = "auto", 

50 fft_len: Optional[int] = 0, 

51 power_two: Optional[bool] = False, 

52 unwrap_phase: Optional[bool] = False, 

53 window: Optional[str] = "Rectangular", 

54) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

55 """ 

56 Calculates the spectrum of a signal. 

57 

58 Args: 

59 func: The signal to analyse. 

60 fs: The sampling frequency in Hz. 

61 dim: The dimension over which the spectrum is calculated. Defaults to 'auto'. 

62 fft_len: The length of the FFT. If the set length is smaller than the signal length, the default value is used 

63 instead (default = signal length). 

64 power_two: Whether the FFT length is forced to be the next highest power of 2 (default = False). 

65 unwrap_phase: Whether to unwrap the phase spectrum (default = False). 

66 window: (str) The window type used to filter the signal before the FFT is taken (default = 'Rectangular'). Any valid 

67 input types for get_win may be used. 

68 

69 Returns: 

70 f: Frequency array 

71 func_as: Single-sided amplitude spectrum 

72 func_ps: Single-sided phase spectrum 

73 

74 Raises: 

75 ValueError: If the input signal is scalar or has more than 4 dimensions. 

76 """ 

77 

78 # check the size of the input 

79 sz = func.shape 

80 

81 # check input isn't scalar 

82 if np.size(func) == 1: 

83 raise ValueError("Input signal cannot be scalar.") 

84 

85 # check input doesn't have more than 4 dimensions 

86 if len(sz) > 4: 

87 raise ValueError("Input signal must have 1, 2, 3, or 4 dimensions.") 

88 

89 # automatically set dimension to first non - singleton dimension 

90 if dim == "auto": 

91 dim = np.argmax(np.array(sz) > 1) 

92 if sz[dim] <= 1: 

93 raise ValueError("All dimensions are singleton; unable to determine valid dimension.") 

94 

95 # assign the number of points being analysed 

96 func_length = sz[dim] 

97 

98 # set the length of the FFT 

99 if fft_len <= 0 or fft_len < func_length: 

100 if power_two: 

101 # find an appropriate FFT length of the form 2 ^ N that is equal to or 

102 # larger than the length of the input signal 

103 fft_len = 2 ** (next_pow2(func_length)) 

104 else: 

105 # set the FFT length to the function length 

106 fft_len = func_length 

107 

108 # window the signal, reshaping the window to be in the correct direction 

109 win, coherent_gain = get_win(func_length, type_=window, symmetric=False) 

110 win_shape = [1] * len(sz) 

111 win_shape[dim] = func_length 

112 win = np.reshape(win, tuple(win_shape)) 

113 func = win * func 

114 

115 # compute the fft using the defined FFT length, if fft_len > 

116 # func_length, the input signal is padded with zeros 

117 func_fft = np.fft.fft(func, n=fft_len, axis=dim) 

118 

119 # correct for the magnitude scaling of the FFT and the coherent gain of the 

120 # window(note that the correction is equal to func_length NOT fft_len) 

121 epsilon = 1e-10 # Small value to prevent division by zero 

122 func_fft = func_fft / (func_length * coherent_gain + epsilon) 

123 

124 # reduce to a single sided spectrum where the number of unique points for 

125 # even numbered FFT lengths is given by N / 2 + 1, and for odd(N + 1) / 2 

126 num_unique_pts = int(np.ceil((fft_len + 1) / 2)) 

127 slicing = [slice(None)] * len(sz) 

128 slicing[dim] = slice(0, num_unique_pts) 

129 func_fft = func_fft[tuple(slicing)] 

130 

131 single_sided_correction(func_fft, fft_len, dim) 

132 

133 # create the frequency axis variable 

134 f = np.arange(0, num_unique_pts) * fs / fft_len 

135 

136 # calculate the amplitude spectrum 

137 func_as = np.abs(func_fft) 

138 

139 # calculate the phase spectrum 

140 func_ps = np.angle(func_fft) 

141 

142 # unwrap the phase spectrum if required 

143 if unwrap_phase: 

144 func_ps = np.unwrap(func_ps, axis=dim) 

145 

146 return f, func_as, func_ps 

147 

148 

149def extract_amp_phase( 

150 data: np.ndarray, fs: float, source_freq: float, dim: Tuple[str, int] = "auto", fft_padding: int = 3, window: str = "Hanning" 

151) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

152 """ 

153 Extract the amplitude and phase information at a specified frequency from a vector or matrix of time series data. 

154 

155 The amplitude and phase are extracted from the frequency spectrum, which is calculated using a windowed and zero 

156 padded FFT. The values are extracted at the frequency closest to source_freq. By default, the time dimension is set 

157 to the highest non-singleton dimension. 

158 

159 Args: 

160 data: Matrix of time signals [s] 

161 fs: Sampling frequency [Hz] 

162 source_freq: Frequency at which the amplitude and phase should be extracted [Hz] 

163 dim: The time dimension of the input data. If 'auto', the highest non-singleton dimension is used. 

164 fft_padding: The amount of zero padding to apply to the FFT. 

165 window: The windowing function to use for the FFT. 

166 

167 Returns: 

168 A tuple of the amplitude, phase and frequency of the extracted signal. 

169 

170 """ 

171 

172 # check for the dim input 

173 if dim == "auto": 

174 dim = num_dim(data) 

175 if dim == 2 and data.shape[1] == 1: 

176 dim = 1 

177 

178 # create 1D window and reshape to be oriented in the time dimension of the 

179 # input data 

180 win, coherent_gain = get_win(data.shape[dim], window) 

181 # this list magic in Python comes from the use of ones in MATLAB 

182 # TODO: simplify this 

183 win = np.reshape(win, [1] * (dim - 1) + [len(win)]) 

184 

185 # apply window to time dimension of input data 

186 data = win * data 

187 

188 # compute amplitude and phase spectra 

189 f, func_as, func_ps = spect(data, fs, fft_len=fft_padding * data.shape[dim], dim=dim) 

190 

191 # correct for coherent gain 

192 func_as = func_as / coherent_gain 

193 

194 # find the index of the frequency component closest to source_freq 

195 _, f_index = find_closest(f, source_freq) 

196 

197 # get size of output variable, collapsing the time dimension 

198 sz = list(data.shape) 

199 sz[dim - 1] = 1 

200 

201 # extract amplitude and relative phase at freq_index 

202 # Create a tuple of slice objects with the frequency index at the correct dimension 

203 idx = create_index_at_dim(func_as.ndim, dim, f_index) 

204 amp = func_as[idx] 

205 phase = func_ps[idx] 

206 

207 return amp.squeeze(), phase.squeeze(), f[f_index] 

208 

209 

210def fwhm(f, x): 

211 """ 

212 fwhm calculates the Full Width at Half Maximum (FWHM) of a positive 

213 1D input function f(x) with spacing given by x. 

214 

215 

216 Args: 

217 f: f(x) 

218 x: x 

219 

220 Returns: 

221 FWHM of f(x) along with the position of the leading and trailing edges as a tuple 

222 

223 """ 

224 

225 # ensure f is numpy array 

226 f = np.array(f) 

227 if len(f.squeeze().shape) != 1: 

228 raise ValueError("Input function must be 1-dimensional.") 

229 

230 def lin_interp(x, y, i, half): 

231 return x[i] + (x[i + 1] - x[i]) * ((half - y[i]) / (y[i + 1] - y[i])) 

232 

233 def half_max_x(x, y): 

234 half = max(y) / 2.0 

235 signs = np.sign(np.add(y, -half)) 

236 zero_crossings = signs[0:-2] != signs[1:-1] 

237 zero_crossings_i = np.where(zero_crossings)[0] 

238 return [lin_interp(x, y, zero_crossings_i[0], half), lin_interp(x, y, zero_crossings_i[1], half)] 

239 

240 hmx = half_max_x(x, f) 

241 fwhm_val = hmx[1] - hmx[0] 

242 

243 return fwhm_val, tuple(hmx) 

244 

245 

246def gaussian_filter( 

247 signal: Union[np.ndarray, List[float]], fs: float, frequency: float, bandwidth: float 

248) -> Union[np.ndarray, List[float]]: 

249 """ 

250 Applies a frequency domain Gaussian filter with the 

251 specified center frequency and percentage bandwidth to the input 

252 signal. If the input signal is given as a matrix, the filter is 

253 applied to each matrix row. 

254 

255 Args: 

256 signal: Signal to filter [channel, samples] 

257 fs: Sampling frequency [Hz] 

258 frequency: Center frequency of filter [Hz] 

259 bandwidth: Bandwidth of filter in percentage 

260 

261 Returns: 

262 The filtered signal 

263 

264 """ 

265 

266 N = signal.shape[-1] 

267 if N % 2 == 0: 

268 f = np.arange(-N / 2, N / 2) * fs / N 

269 else: 

270 f = np.arange(-(N - 1) / 2, (N - 1) / 2 + 1) * fs / N 

271 

272 mean = frequency 

273 variance = (bandwidth / 100 * frequency / (2 * np.sqrt(2 * np.log(2)))) ** 2 

274 magnitude = 1 

275 

276 # create double-sided Gaussain filter 

277 gfilter = np.fmax(gaussian(f, magnitude, mean, variance), gaussian(f, magnitude, -mean, variance)) 

278 

279 # add dimensions to filter to be broadcastable to signal shape 

280 if len(signal.shape) == 2: 

281 gfilter = gfilter[np.newaxis, :] 

282 

283 # apply filter 

284 signal = np.real(ifft(ifftshift(gfilter * fftshift(fft(signal))))) 

285 

286 return signal 

287 

288 

289def filter_time_series( 

290 kgrid: "kWaveGrid", 

291 medium: "kWaveMedium", 

292 signal: np.ndarray, 

293 ppw: Optional[int] = 3, 

294 rppw: Optional[int] = 0, 

295 stop_band_atten: Optional[int] = 60, 

296 transition_width: Optional[float] = 0.1, 

297 zerophase: Optional[bool] = False, 

298 plot_spectrums: Optional[bool] = False, 

299 plot_signals: Optional[bool] = False, 

300) -> np.ndarray: 

301 """ 

302 Filters a time-domain signal using the Kaiser windowing method. 

303 

304 The filter is designed to attenuate high-frequency noise in the signal while preserving 

305 the signal's important features. The filter design parameters can be adjusted to trade off 

306 between the amount of noise reduction and the amount of signal distortion. 

307 

308 Args: 

309 kgrid: The kWaveGrid grid. 

310 medium: The kWavemedium. 

311 signal: The time-domain signal to filter. 

312 ppw: The minimum number of points per wavelength in the signal. This determines the 

313 minimum frequency that will be passed through the filter. Higher values of ppw 

314 result in a lower cut-off frequency and more noise reduction, but may also result 

315 in more signal distortion. Defaults to 3. 

316 rppw: The number of points per wavelength in the smoothing ramp applied to the beginning 

317 of the signal. This can be used to reduce ringing artifacts caused by the sudden 

318 transition from the filtered signal to the unfiltered signal. Defaults to 0. 

319 stop_band_atten: The stop-band attenuation in dB. This determines the steepness of the 

320 filter's transition from the pass-band to the stop-band. Higher values result in a 

321 steeper transition and more noise reduction, but may also result in more signal 

322 distortion. Defaults to 60. 

323 transition_width: The transition width as a proportion of the sampling frequency. This 

324 determines the width of the transition region between the pass-band and the stop-band. 

325 Smaller values result in a narrower transition and more noise reduction, but may also 

326 result in more signal distortion. Defaults to 0.1. 

327 zerophase: Whether to implement the filter as a zero-phase filter. Zero-phase filtering 

328 can be used to preserve the phase information in the signal, which can be important 

329 for some applications. However, it may also result in more signal distortion. 

330 Defaults to False. 

331 plot_spectrums: Whether to plot the spectrums of the input and filtered signals. 

332 Defaults to False. 

333 plot_signals: Whether to plot the input and filtered signals. Defaults to False. 

334 

335 Raises: 

336 ValueError: Checks correctness of passed arguments. 

337 NotImplementedError: Cannot currently plot anything. 

338 

339 Returns: 

340 The filtered signal. 

341 

342 """ 

343 

344 # check the input is a row vector 

345 if num_dim2(signal) == 1: 

346 m, n = signal.shape 

347 if n == 1: 

348 signal = signal.T 

349 rotate_signal = True 

350 else: 

351 rotate_signal = False 

352 else: 

353 raise TypeError("Input signal must be a vector.") 

354 

355 # update the command line status 

356 logging.log(logging.INFO, "Filtering input signal...") 

357 

358 # extract the time step 

359 assert not isinstance(kgrid.t_array, str) or kgrid.t_array != "auto", "kgrid.t_array must be explicitly defined." 

360 

361 # compute the sampling frequency 

362 fs = 1 / kgrid.dt 

363 

364 # extract the minimum sound speed 

365 if medium.sound_speed is not None: 

366 # for the fluid code, use medium.sound_speed 

367 c0 = medium.sound_speed.min() 

368 

369 elif all(medium.is_defined("sound_speed_compression", "sound_speed_shear")): # pragma: no cover 

370 # for the elastic code, combine the shear and compression sound speeds and remove zeros values 

371 ss = np.hstack([medium.sound_speed_compression, medium.sound_speed_shear]) 

372 ss[ss == 0] = np.nan 

373 c0 = np.nanmin(ss) 

374 

375 # cleanup unused variables 

376 del ss 

377 

378 else: 

379 raise ValueError( 

380 "The input fields medium.sound_speed or medium.sound_speed_compression and medium.sound_speed_shear must " "be defined." 

381 ) 

382 

383 # extract the maximum supported frequency (two points per wavelength) 

384 f_max = kgrid.k_max_all * c0 / (2 * np.pi) 

385 

386 # calculate the filter cut-off frequency 

387 filter_cutoff_f = 2 * f_max / ppw 

388 

389 # calculate the wavelength of the filter cut-off frequency as a number of time steps 

390 filter_wavelength = (2 * np.pi / filter_cutoff_f) / kgrid.dt 

391 

392 # filter the signal if required 

393 if ppw != 0: 

394 filtered_signal = apply_filter( 

395 signal, 

396 fs, 

397 float(filter_cutoff_f), 

398 "LowPass", 

399 zero_phase=zerophase, 

400 stop_band_atten=float(stop_band_atten), 

401 transition_width=transition_width, 

402 ) 

403 

404 # add a start-up ramp if required 

405 if rppw != 0: 

406 # calculate the length of the ramp in time steps 

407 ramp_length = round(rppw * filter_wavelength / (2 * ppw)) 

408 

409 # create the ramp 

410 ramp = (-np.cos(np.arange(0, ramp_length - 1 + 1) * np.pi / ramp_length) + 1) / 2 

411 

412 # apply the ramp 

413 filtered_signal[1:ramp_length] = filtered_signal[1:ramp_length] * ramp 

414 

415 # restore the original vector orientation if modified 

416 if rotate_signal: 

417 filtered_signal = filtered_signal.T 

418 

419 # update the command line status 

420 logging.log(logging.INFO, f" maximum frequency supported by kgrid: {scale_SI(f_max)}Hz (2 PPW)") 

421 if ppw != 0: 

422 logging.log(logging.INFO, f" filter cutoff frequency: {scale_SI(filter_cutoff_f)}Hz ({ppw} PPW)") 

423 if rppw != 0: 

424 logging.log( 

425 logging.INFO, f" ramp frequency: {scale_SI(2 * np.pi / (2 * ramp_length * kgrid.dt))}Hz (ramp_points_per_wavelength PPW)" 

426 ) 

427 logging.log(logging.INFO, " computation complete.") 

428 

429 # plot signals if required 

430 if plot_signals or plot_spectrums: 

431 raise NotImplementedError 

432 

433 return filtered_signal 

434 

435 

436def apply_filter( 

437 signal: np.ndarray, 

438 fs: float, 

439 cutoff_f: float, 

440 filter_type: str, 

441 zero_phase: Optional[bool] = False, 

442 transition_width: Optional[float] = 0.1, 

443 stop_band_atten: Optional[int] = 60, 

444) -> np.ndarray: 

445 """ 

446 Filters an input signal using a FIR filter with Kaiser window coefficients based on the specified cut-off frequency and filter type. 

447 Both causal and zero phase filters can be applied. 

448 

449 Args: 

450 signal: The input signal. 

451 fs: The sampling frequency of the signal. 

452 cutoff_f: The cut-off frequency of the filter. 

453 filter_type: The type of filter to apply, either 'HighPass', 'LowPass' or 'BandPass'. 

454 zero_phase: Whether to apply a zero-phase filter. Defaults to False. 

455 transition_width: The transition width of the filter, as a proportion of the sampling frequency. Defaults to 0.1. 

456 stop_band_atten: The stop-band attenuation of the filter in dB. Defaults to 60. 

457 

458 Returns: 

459 The filtered signal. 

460 

461 """ 

462 

463 # for a bandpass filter, use applyFilter recursively 

464 if filter_type == "BandPass": 

465 assert isinstance(cutoff_f, list), "List of two frequencies required as for filter type 'BandPass'" 

466 assert len(cutoff_f) == 2, "List of two frequencies required as for filter type 'BandPass'" 

467 

468 # apply the low pass filter 

469 func_filt_lp = apply_filter( 

470 signal, fs, cutoff_f[1], "LowPass", stop_band_atten=stop_band_atten, transition_width=transition_width, zero_phase=zero_phase 

471 ) 

472 

473 # apply the high pass filter 

474 filtered_signal = apply_filter( 

475 func_filt_lp, 

476 fs, 

477 cutoff_f[0], 

478 "HighPass", 

479 stop_band_atten=stop_band_atten, 

480 transition_width=transition_width, 

481 zero_phase=zero_phase, 

482 ) 

483 

484 else: 

485 # check filter type 

486 if filter_type == "LowPass": 

487 high_pass = False 

488 elif filter_type == "HighPass": 

489 high_pass = True 

490 cutoff_f = fs / 2 - cutoff_f 

491 else: 

492 raise ValueError(f'Unknown filter type {filter_type}. Options are "LowPass, HighPass, BandPass"') 

493 

494 # make sure input is the correct way around 

495 m, n = signal.shape 

496 if m > n: 

497 signal = signal.T 

498 

499 # correct the stopband attenuation if a zero phase filter is being used 

500 if zero_phase: 

501 stop_band_atten = stop_band_atten / 2 

502 

503 # decide the filter order 

504 N = np.ceil((stop_band_atten - 7.95) / (2.285 * (transition_width * np.pi))) 

505 N = int(N) 

506 

507 # construct impulse response of ideal bandpass filter h(n), a sinc function 

508 fc = cutoff_f / fs # normalised cut-off 

509 n = np.arange(-N / 2, N / 2) 

510 h = 2 * fc * sinc(2 * np.pi * fc * n) 

511 

512 # if no window is given, use a Kaiser window 

513 # TODO: there is no window argument 

514 if "w" not in locals(): 

515 # compute Kaiser window parameter beta 

516 if stop_band_atten > 50: 

517 beta = 0.1102 * (stop_band_atten - 8.7) 

518 elif stop_band_atten >= 21: 

519 beta = 0.5842 * (stop_band_atten - 21) ** 0.4 + 0.07886 * (stop_band_atten - 21) 

520 else: 

521 beta = 0 

522 

523 # construct the Kaiser smoothing window w(n) 

524 m = np.arange(0, N) 

525 w = np.real(scipy.special.iv(0, np.pi * beta * np.sqrt(1 - (2 * m / N - 1) ** 2))) / np.real(scipy.special.iv(0, np.pi * beta)) 

526 

527 # window the ideal impulse response with Kaiser window to obtain the FIR filter coefficients hw(n) 

528 hw = w * h 

529 

530 # modify to make a high_pass filter 

531 if high_pass: 

532 hw = (-1 * np.ones((1, len(hw))) ** (np.arange(1, len(hw) + 1))) * hw 

533 

534 # add some zeros to allow the reverse (zero phase) filtering room to work 

535 L = signal.size # length of original input signal 

536 filtered_signal = np.hstack([np.zeros((1, N)), signal]).squeeze() 

537 

538 # apply the filter 

539 filtered_signal = lfilter(hw.squeeze(), 1, filtered_signal) 

540 if zero_phase: 

541 filtered_signal = np.fliplr(lfilter(hw.squeeze(), 1, filtered_signal[np.arange(L + N, 1, -1)])) 

542 

543 # remove the part of the signal corresponding to the added zeros 

544 filtered_signal = filtered_signal[N:] 

545 

546 return filtered_signal[np.newaxis] 

547 

548 

549def smooth(a: np.ndarray, restore_max: Optional[bool] = False, window_type: Optional[str] = "Blackman") -> np.ndarray: 

550 """ 

551 Smooths a matrix. 

552 

553 Args: 

554 a: The spatial distribution to smooth. 

555 restore_max: Boolean controlling whether the maximum value is restored after smoothing. Defaults to False. 

556 window_type: Shape of the smoothing window. Any valid inputs to get_win are supported. Defaults to 'Blackman'. 

557 

558 Returns: 

559 a_sm: The smoothed matrix. 

560 

561 """ 

562 

563 DEF_USE_ROTATION = True 

564 

565 if a.dtype == bool: 

566 a = a.astype(int) 

567 

568 assert is_number(a) and np.all(~np.isinf(a)) 

569 assert isinstance(restore_max, bool) 

570 assert isinstance(window_type, str) 

571 

572 # get the grid size 

573 grid_size = a.shape 

574 

575 # remove singleton dimensions 

576 if num_dim2(a) != len(grid_size): 

577 grid_size = np.squeeze(grid_size) 

578 

579 # use a symmetric filter for odd grid sizes, and a non-symmetric filter for 

580 # even grid sizes to ensure the DC component of the window has a value of 

581 # unity 

582 window_symmetry = (np.array(grid_size) % 2).astype(bool) 

583 

584 # get the window, taking the absolute value to discard machine precision 

585 # negative values 

586 from .signals import get_win 

587 

588 win, _ = get_win(grid_size, type_=window_type, rotation=DEF_USE_ROTATION, symmetric=window_symmetry) 

589 win = np.abs(win) 

590 

591 # rotate window if input mat is (1, N) 

592 if a.shape[0] == 1: # is row? 

593 win = win.T 

594 

595 # apply the filter 

596 a_sm = np.real(np.fft.ifftn(np.fft.fftn(a) * np.fft.ifftshift(win))) 

597 

598 # restore magnitude if required 

599 if restore_max: 

600 a_sm = (np.abs(a).max() / np.abs(a_sm).max()) * a_sm 

601 return a_sm