Coverage for kwave/utils/interp.py: 11%

130 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import logging 

2 

3import numpy as np 

4from beartype import beartype as typechecker 

5from beartype.typing import List, Optional, Tuple, Union 

6from numpy.fft import fft, fftshift 

7from scipy.interpolate import interpn 

8from scipy.signal import resample 

9 

10from .conversion import grid2cart 

11from .data import scale_time 

12from .matrix import sort_rows 

13from .tictoc import TicToc 

14 

15 

16def interpolate3d(grid_points: List[np.ndarray], grid_values: np.ndarray, interp_locs: List[np.ndarray]) -> np.ndarray: 

17 """ 

18 Interpolates input grid values at the given locations 

19 Added by Farid 

20 

21 Matlab version of this function assumes unstructured grid. Interpolating such grid in Python using 

22 SciPy is very expensive. Thankfully, working with structured grid is fine for our purposes. 

23 We still support 3D arguments for backward compatibility even though they are mapped to 1D grid. 

24 While mapping we assume that only one axis per 3D grid changes throughout the grid. 

25 

26 Args: 

27 grid_points: List of 1D or 3D Numpy arrays 

28 grid_values: A 3D Numpy array which holds values at grid_points 

29 interp_locs: List of 1D or 3D Numpy arrays 

30 

31 """ 

32 

33 assert len(grid_points) == 3, "interpolate3D supports only 3D interpolation" 

34 assert len(grid_points) == len(interp_locs) 

35 

36 def unpack_and_make_1D(pts): 

37 pts_x, pts_y, pts_z = pts 

38 if pts_x.ndim == 3: 

39 pts_x = pts_x[:, 0, 0] 

40 if pts_y.ndim == 3: 

41 pts_y = pts_y[0, :, 0] 

42 if pts_z.ndim == 3: 

43 pts_z = pts_z[0, 0, :] 

44 return pts_x, pts_y, pts_z 

45 

46 g_x, g_y, g_z = unpack_and_make_1D(grid_points) 

47 q_x, q_y, q_z = unpack_and_make_1D(interp_locs) 

48 

49 # 'ij' indexing is crucial for Matlab compatibility 

50 queries = np.array(np.meshgrid(q_x, q_y, q_z, indexing="ij")) 

51 # Queries are just a list of 3D points 

52 queries = queries.reshape(3, -1).T 

53 

54 # Out of bound points will get NaN values 

55 result = interpn((g_x, g_y, g_z), grid_values, queries, method="linear", bounds_error=False, fill_value=np.nan) 

56 # Go back from list of interpolated values to 3D volume 

57 result = result.reshape((q_x.size, q_y.size, q_z.size)) 

58 # set values outside of the interpolation range to original values 

59 result[np.isnan(result)] = grid_values[np.isnan(result)] 

60 return result 

61 

62 

63def interpolate2d( 

64 grid_points: List[np.ndarray], grid_values: np.ndarray, interp_locs: List[np.ndarray], method="linear", copy_nans=True 

65) -> np.ndarray: 

66 """ 

67 Interpolates input grid values at the given locations 

68 Added by Farid 

69 

70 Matlab version of this function assumes unstructured grid. Interpolating such grid in Python using 

71 SciPy is very expensive. Thankfully, working with structured grid is fine for our purposes. 

72 We still support 3D arguments for backward compatibility even though they are mapped to 1D grid. 

73 While mapping we assume that only one axis per 3D grid changes throughout the grid. 

74 

75 Args: 

76 copy_nans: 

77 grid_points: List of 1D or 3D Numpy arrays 

78 grid_values: A 3D Numpy array which holds values at grid_points 

79 interp_locs: List of 1D or 3D Numpy arrays 

80 

81 """ 

82 

83 assert len(grid_points) == 2, "interpolate2D supports only 2D interpolation" 

84 assert len(grid_points) == len(interp_locs) 

85 

86 def unpack_and_make_1D(pts): 

87 pts_x, pts_y = pts 

88 if pts_x.ndim == 2: 

89 pts_x = pts_x[:, 0] 

90 if pts_y.ndim == 2: 

91 pts_y = pts_y[0, :] 

92 return pts_x, pts_y 

93 

94 g_x, g_y = unpack_and_make_1D(grid_points) 

95 q_x, q_y = unpack_and_make_1D(interp_locs) 

96 

97 # 'ij' indexing is crucial for Matlab compatibility 

98 queries = np.array(np.meshgrid(q_x, q_y, indexing="ij")) 

99 # Queries are just a list of 3D points 

100 queries = queries.reshape(2, -1).T 

101 

102 # Out of bound points will get NaN values 

103 result = interpn((g_x, g_y), grid_values, queries, method=method, bounds_error=False, fill_value=np.nan) 

104 # Go back from list of interpolated values to 3D volume 

105 result = result.reshape((q_x.size, q_y.size)) 

106 if copy_nans: 

107 assert result.shape == grid_values.shape 

108 # set values outside of the interpolation range to original values 

109 result[np.isnan(result)] = grid_values[np.isnan(result)] 

110 return result 

111 

112 

113def interpolate2d_with_queries( 

114 grid_points: List[np.ndarray], grid_values: np.ndarray, queries: np.ndarray, method="linear", copy_nans=True 

115) -> np.ndarray: 

116 """ 

117 Interpolates input grid values at the given locations 

118 Added by Farid 

119 

120 Simplified version of `interpolate2D_coords`. 

121 Expects `interp_locs` to be [N, 2] coordinates of the interpolation locations. 

122 Does not create meshgrid on the `interp_locs` as `interpolate2D_coords`! 

123 WARNING: supposed to support only 2D interpolation! 

124 

125 Args: 

126 copy_nans: 

127 grid_points: List of 1D or 3D Numpy arrays 

128 grid_values: A 3D Numpy array which holds values at grid_points 

129 queries: Numpy array with shape [N, 2] 

130 

131 """ 

132 assert len(grid_points) == 2, "interpolate2D supports only 2D interpolation" 

133 

134 g_x, g_y = grid_points 

135 

136 assert g_x.ndim == 1 # is a list 

137 assert g_y.ndim == 1 # is a list 

138 assert queries.ndim == 2 and queries.shape[1] == 2 

139 

140 # Out of bound points will get NaN values 

141 result = interpn((g_x, g_y), grid_values, queries, method=method, bounds_error=False, fill_value=np.nan) 

142 if copy_nans: 

143 assert result.shape == grid_values.shape 

144 # set values outside the interpolation range to original values 

145 result[np.isnan(result)] = grid_values[np.isnan(result)] 

146 return result 

147 

148 

149def get_bli( 

150 func: np.ndarray, 

151 dx: Optional[float] = 1, 

152 up_sampling_factor: Optional[int] = 20, 

153 plot: Optional[bool] = False, 

154) -> Tuple[np.ndarray, np.ndarray]: 

155 """ 

156 Calculates the band-limited interpolant of a 1D input function. 

157 

158 Args: 

159 func: The 1D input function. 

160 dx: Spatial sampling in meters. Defaults to 1. 

161 up_sampling_factor: Up-sampling factor used to sample the underlying BLI. Defaults to 20. 

162 plot: Whether to plot the BLI. Defaults to False. 

163 

164 Returns: 

165 A tuple containing the BLI and the x-grid for the BLI. 

166 

167 """ 

168 

169 func = np.squeeze(func) 

170 assert len(func.shape) == 1, f"func not 1D but rather {len(func.shape)}D" 

171 nx = len(func) 

172 

173 dk = 2 * np.pi / (dx * nx) 

174 if nx % 2: 

175 # odd 

176 k_min = -np.pi / dx + dk / 2 

177 k_max = np.pi / dx - dk / 2 

178 else: 

179 # even 

180 k_min = -np.pi / dx 

181 k_max = np.pi / dx - dk 

182 

183 k = np.arange( 

184 start=k_min, 

185 stop=k_max + dk, 

186 step=dk, 

187 ) 

188 x_fine = np.arange(start=0, stop=((nx - 1) * dx) + dx / up_sampling_factor, step=dx / up_sampling_factor) 

189 

190 func_k = fftshift(fft(func)) / nx 

191 

192 bli = np.real(np.sum(np.matmul(func_k[np.newaxis], np.exp(1j * np.outer(k, x_fine))), axis=0)) 

193 if plot: 

194 raise NotImplementedError 

195 return bli, x_fine 

196 

197 

198def interp_cart_data(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask, interp="nearest"): 

199 """ 

200 Takes a matrix of time-series data recorded over a set 

201 of Cartesian sensor points given by cart_sensor_mask and computes the 

202 equivalent time-series at each sensor position on the binary sensor 

203 mask binary_sensor_mask using interpolation. The properties of 

204 binary_sensor_mask are defined by the k-Wave grid object kgrid. 

205 Two and three-dimensional data are supported. 

206 

207 Usage: 

208 binary_sensor_data = interp_cart_data(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask) 

209 binary_sensor_data = interp_cart_data(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask, interp) 

210 

211 Args: 

212 kgrid: k-Wave grid object returned by kWaveGrid 

213 cart_sensor_data: original sensor data measured over 

214 cart_sensor_mask indexed as 

215 cart_sensor_data(sensor position, time) 

216 cart_sensor_mask: Cartesian sensor mask over which 

217 cart_sensor_data is measured 

218 binary_sensor_mask: binary sensor mask at which equivalent 

219 time-series are computed via interpolation 

220 

221 interp: (optional) interpolation mode used to compute the 

222 time-series, both 'nearest' and 'linear' 

223 (two-point) modes are supported 

224 (default = 'nearest') 

225 

226 Returns: 

227 array of time-series corresponding to the sensor positions given by binary_sensor_mask 

228 

229 """ 

230 

231 # make timer 

232 timer = TicToc() 

233 # start the clock 

234 timer.tic() 

235 

236 # extract the number of data points 

237 num_cart_data_points, num_time_points = cart_sensor_data.shape 

238 num_binary_sensor_points = np.sum(binary_sensor_mask.flatten()) 

239 

240 # update command line status 

241 logging.log(logging.INFO, "Interpolating Cartesian sensor data...") 

242 logging.log(logging.INFO, f" interpolation mode: {interp}") 

243 logging.log(logging.INFO, f" number of Cartesian sensor points: {num_cart_data_points}") 

244 logging.log(logging.INFO, f" number of binary sensor points: {num_binary_sensor_points}") 

245 

246 binary_sensor_data = np.zeros((num_binary_sensor_points, num_time_points)) 

247 

248 # Check dimensionality of data passed 

249 if kgrid.dim not in [2, 3]: 

250 raise ValueError("Data must be two- or three-dimensional.") 

251 

252 cart_bsm, _ = grid2cart(kgrid, binary_sensor_mask) 

253 

254 # nearest neighbour interpolation of the data points 

255 for point_index in range(num_binary_sensor_points): 

256 # find the measured data point that is closest 

257 dist = np.linalg.norm(cart_bsm[:, point_index] - cart_sensor_mask[: kgrid.dim, :].T, ord=2, axis=1) 

258 if interp == "nearest": 

259 dist_min_index = np.argmin(dist) 

260 

261 # assign value 

262 binary_sensor_data[point_index, :] = cart_sensor_data[dist_min_index, :] 

263 

264 elif interp == "linear": 

265 # raise NotImplementedError 

266 # append the distance information onto the data set 

267 cart_sensor_data_ro = cart_sensor_data 

268 np.append(cart_sensor_data_ro, dist[:, None], axis=1) 

269 new_col_pos = -1 

270 

271 # reorder the data set based on distance information 

272 cart_sensor_data_ro = sort_rows(cart_sensor_data_ro, new_col_pos) 

273 

274 # linearly interpolate between the two closest points 

275 perc = cart_sensor_data_ro[2, new_col_pos] / (cart_sensor_data_ro[1, new_col_pos] + cart_sensor_data_ro[2, new_col_pos]) 

276 binary_sensor_data[point_index, :] = perc * cart_sensor_data_ro[1, :] + (1 - perc) * cart_sensor_data_ro[2, :] 

277 

278 else: 

279 raise ValueError("Unknown interpolation option.") 

280 

281 # elif interp == 'linear': 

282 # 

283 # # dist = np.sqrt((cart_bsm[0, point_index] - cart_sensor_mask[0, :])**2 + 

284 # (cart_bsm[1, point_index] - cart_sensor_mask[1, :])**2) 

285 # # dist = np.linalg.norm(cart_bsm[:, point_index] - cart_sensor_mask.T, axis=1) 

286 # # append the distance information onto the data set 

287 # new_col_pos = len(cart_sensor_data[1, :]) -1 

288 # cart_sensor_data_ro = cart_sensor_data 

289 # cart_sensor_data_ro[:, new_col_pos] = dist 

290 # 

291 # # reorder the data set based on distance information 

292 # cart_sensor_data_ro = sort_rows(cart_sensor_data_ro, new_col_pos) 

293 # 

294 # # linearly interpolate between the two closest points 

295 # perc = cart_sensor_data_ro[1, new_col_pos] / 

296 # (cart_sensor_data_ro[0, new_col_pos] + cart_sensor_data_ro[1, new_col_pos] ) 

297 # binary_sensor_data[point_index, :] = perc * cart_sensor_data_ro[1, :new_col_pos - 1] + 

298 # (1 - perc) * cart_sensor_data_ro[1, :new_col_pos - 1] 

299 # 

300 # else: 

301 # raise ValueError('Unknown interpolation option.') 

302 

303 # update command line status 

304 logging.log(logging.INFO, f" computation completed in {scale_time(timer.toc())}") 

305 return binary_sensor_data 

306 

307 

308def interpftn(x, sz: tuple, win=None): 

309 """ 

310 Resamples an N-D matrix to the size given in sz using Fourier interpolation. 

311 

312 

313 Args: 

314 x: matrix to interpolate 

315 sz: list or tuple of new size 

316 win: (optional) name of windowing function to use 

317 

318 Returns: 

319 Resampled matrix 

320 

321 Examples: 

322 >>> y = interpftn(x, sz) 

323 >>> y = interpftn(x, sz, win) 

324 

325 """ 

326 

327 # extract the size of the input matrix 

328 x_sz = x.shape 

329 

330 # check enough coefficients have been given 

331 if sum([x != 1 for x in x_sz]) != len(sz): 

332 raise ValueError("The number of scaling coefficients must equal the number of dimensions in x.") 

333 

334 # interpolate for each matrix dimension (dimensions with no interpolation required are skipped) 

335 y = x 

336 for p_idx, p in enumerate(sz): 

337 if p != x_sz[p_idx]: 

338 y = resample(y, p, axis=p_idx, window=win) 

339 

340 return y 

341 

342 

343@typechecker 

344def get_delta_bli(Nx: int, dx: float, x: np.ndarray, x0: Union[int, float], include_imag: bool = False) -> np.ndarray: 

345 """ 

346 Exact BLI of an arbitrarily positioned delta function. 

347 

348 Calculates the exact Band-Limited Interpolation (BLI) of an arbitrarily positioned delta function. 

349 For grid dimensions with an evenly-sampled periodicity, a small Nyquist frequency sinusoid is added. 

350 This sinusoid is invisible on grid samples and has zero amplitude when the delta function lies on a grid node. 

351 It is important when the evaluation points aren't grid nodes, and when the delta function is off-grid. 

352 It serves to ensure conjugate symmetry in the BLI's Fourier transform. 

353 

354 Args: 

355 Nx: Number of grid points in the relevant Cartesian direction. 

356 dx: Grid point spacing [m]. 

357 x: Coordinates at which the BLI is evaluated [m]. 

358 x0: Coordinate at which the BLI is centered [m]. 

359 include_imag: Whether to include the imaginary component of the off-grid delta function. 

360 Defaults to False. 

361 

362 Returns: 

363 f: Value of the BLI at the specified coordinates. 

364 

365 """ 

366 

367 # ignore imaginary component of even function by default 

368 if include_imag is None: 

369 include_imag = False 

370 

371 # check whether the grid has even or odd samples per period 

372 is_even = Nx % 2 == 0 

373 

374 # compute BLI 

375 if is_even: 

376 # compute periodic sinc function 

377 f = np.sin(np.pi * (x - x0) / dx) / (Nx * np.tan(np.pi * (x - x0) / (Nx * dx))) 

378 

379 # correct indeterminate points 

380 f[(x - x0) == 0] = 1 

381 

382 # add Nyquist sinusoid to ensure conjugate symmetry 

383 f = f - np.sin(np.pi * x0 / dx) / Nx * np.sin(np.pi * x / dx) 

384 if include_imag: 

385 f = f + 1j * np.cos(np.pi * x0 / dx) / Nx * np.sin(np.pi * x / dx) 

386 else: 

387 # compute periodic sinc function 

388 f = np.sin(np.pi * (x - x0) / dx) / (Nx * np.sin(np.pi * (x - x0) / (Nx * dx))) 

389 

390 # correct indeterminate points 

391 f[(x - x0) == 0] = 1 

392 

393 return f