Coverage for kwave/utils/interp.py: 11%
130 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1import logging
3import numpy as np
4from beartype import beartype as typechecker
5from beartype.typing import List, Optional, Tuple, Union
6from numpy.fft import fft, fftshift
7from scipy.interpolate import interpn
8from scipy.signal import resample
10from .conversion import grid2cart
11from .data import scale_time
12from .matrix import sort_rows
13from .tictoc import TicToc
16def interpolate3d(grid_points: List[np.ndarray], grid_values: np.ndarray, interp_locs: List[np.ndarray]) -> np.ndarray:
17 """
18 Interpolates input grid values at the given locations
19 Added by Farid
21 Matlab version of this function assumes unstructured grid. Interpolating such grid in Python using
22 SciPy is very expensive. Thankfully, working with structured grid is fine for our purposes.
23 We still support 3D arguments for backward compatibility even though they are mapped to 1D grid.
24 While mapping we assume that only one axis per 3D grid changes throughout the grid.
26 Args:
27 grid_points: List of 1D or 3D Numpy arrays
28 grid_values: A 3D Numpy array which holds values at grid_points
29 interp_locs: List of 1D or 3D Numpy arrays
31 """
33 assert len(grid_points) == 3, "interpolate3D supports only 3D interpolation"
34 assert len(grid_points) == len(interp_locs)
36 def unpack_and_make_1D(pts):
37 pts_x, pts_y, pts_z = pts
38 if pts_x.ndim == 3:
39 pts_x = pts_x[:, 0, 0]
40 if pts_y.ndim == 3:
41 pts_y = pts_y[0, :, 0]
42 if pts_z.ndim == 3:
43 pts_z = pts_z[0, 0, :]
44 return pts_x, pts_y, pts_z
46 g_x, g_y, g_z = unpack_and_make_1D(grid_points)
47 q_x, q_y, q_z = unpack_and_make_1D(interp_locs)
49 # 'ij' indexing is crucial for Matlab compatibility
50 queries = np.array(np.meshgrid(q_x, q_y, q_z, indexing="ij"))
51 # Queries are just a list of 3D points
52 queries = queries.reshape(3, -1).T
54 # Out of bound points will get NaN values
55 result = interpn((g_x, g_y, g_z), grid_values, queries, method="linear", bounds_error=False, fill_value=np.nan)
56 # Go back from list of interpolated values to 3D volume
57 result = result.reshape((q_x.size, q_y.size, q_z.size))
58 # set values outside of the interpolation range to original values
59 result[np.isnan(result)] = grid_values[np.isnan(result)]
60 return result
63def interpolate2d(
64 grid_points: List[np.ndarray], grid_values: np.ndarray, interp_locs: List[np.ndarray], method="linear", copy_nans=True
65) -> np.ndarray:
66 """
67 Interpolates input grid values at the given locations
68 Added by Farid
70 Matlab version of this function assumes unstructured grid. Interpolating such grid in Python using
71 SciPy is very expensive. Thankfully, working with structured grid is fine for our purposes.
72 We still support 3D arguments for backward compatibility even though they are mapped to 1D grid.
73 While mapping we assume that only one axis per 3D grid changes throughout the grid.
75 Args:
76 copy_nans:
77 grid_points: List of 1D or 3D Numpy arrays
78 grid_values: A 3D Numpy array which holds values at grid_points
79 interp_locs: List of 1D or 3D Numpy arrays
81 """
83 assert len(grid_points) == 2, "interpolate2D supports only 2D interpolation"
84 assert len(grid_points) == len(interp_locs)
86 def unpack_and_make_1D(pts):
87 pts_x, pts_y = pts
88 if pts_x.ndim == 2:
89 pts_x = pts_x[:, 0]
90 if pts_y.ndim == 2:
91 pts_y = pts_y[0, :]
92 return pts_x, pts_y
94 g_x, g_y = unpack_and_make_1D(grid_points)
95 q_x, q_y = unpack_and_make_1D(interp_locs)
97 # 'ij' indexing is crucial for Matlab compatibility
98 queries = np.array(np.meshgrid(q_x, q_y, indexing="ij"))
99 # Queries are just a list of 3D points
100 queries = queries.reshape(2, -1).T
102 # Out of bound points will get NaN values
103 result = interpn((g_x, g_y), grid_values, queries, method=method, bounds_error=False, fill_value=np.nan)
104 # Go back from list of interpolated values to 3D volume
105 result = result.reshape((q_x.size, q_y.size))
106 if copy_nans:
107 assert result.shape == grid_values.shape
108 # set values outside of the interpolation range to original values
109 result[np.isnan(result)] = grid_values[np.isnan(result)]
110 return result
113def interpolate2d_with_queries(
114 grid_points: List[np.ndarray], grid_values: np.ndarray, queries: np.ndarray, method="linear", copy_nans=True
115) -> np.ndarray:
116 """
117 Interpolates input grid values at the given locations
118 Added by Farid
120 Simplified version of `interpolate2D_coords`.
121 Expects `interp_locs` to be [N, 2] coordinates of the interpolation locations.
122 Does not create meshgrid on the `interp_locs` as `interpolate2D_coords`!
123 WARNING: supposed to support only 2D interpolation!
125 Args:
126 copy_nans:
127 grid_points: List of 1D or 3D Numpy arrays
128 grid_values: A 3D Numpy array which holds values at grid_points
129 queries: Numpy array with shape [N, 2]
131 """
132 assert len(grid_points) == 2, "interpolate2D supports only 2D interpolation"
134 g_x, g_y = grid_points
136 assert g_x.ndim == 1 # is a list
137 assert g_y.ndim == 1 # is a list
138 assert queries.ndim == 2 and queries.shape[1] == 2
140 # Out of bound points will get NaN values
141 result = interpn((g_x, g_y), grid_values, queries, method=method, bounds_error=False, fill_value=np.nan)
142 if copy_nans:
143 assert result.shape == grid_values.shape
144 # set values outside the interpolation range to original values
145 result[np.isnan(result)] = grid_values[np.isnan(result)]
146 return result
149def get_bli(
150 func: np.ndarray,
151 dx: Optional[float] = 1,
152 up_sampling_factor: Optional[int] = 20,
153 plot: Optional[bool] = False,
154) -> Tuple[np.ndarray, np.ndarray]:
155 """
156 Calculates the band-limited interpolant of a 1D input function.
158 Args:
159 func: The 1D input function.
160 dx: Spatial sampling in meters. Defaults to 1.
161 up_sampling_factor: Up-sampling factor used to sample the underlying BLI. Defaults to 20.
162 plot: Whether to plot the BLI. Defaults to False.
164 Returns:
165 A tuple containing the BLI and the x-grid for the BLI.
167 """
169 func = np.squeeze(func)
170 assert len(func.shape) == 1, f"func not 1D but rather {len(func.shape)}D"
171 nx = len(func)
173 dk = 2 * np.pi / (dx * nx)
174 if nx % 2:
175 # odd
176 k_min = -np.pi / dx + dk / 2
177 k_max = np.pi / dx - dk / 2
178 else:
179 # even
180 k_min = -np.pi / dx
181 k_max = np.pi / dx - dk
183 k = np.arange(
184 start=k_min,
185 stop=k_max + dk,
186 step=dk,
187 )
188 x_fine = np.arange(start=0, stop=((nx - 1) * dx) + dx / up_sampling_factor, step=dx / up_sampling_factor)
190 func_k = fftshift(fft(func)) / nx
192 bli = np.real(np.sum(np.matmul(func_k[np.newaxis], np.exp(1j * np.outer(k, x_fine))), axis=0))
193 if plot:
194 raise NotImplementedError
195 return bli, x_fine
198def interp_cart_data(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask, interp="nearest"):
199 """
200 Takes a matrix of time-series data recorded over a set
201 of Cartesian sensor points given by cart_sensor_mask and computes the
202 equivalent time-series at each sensor position on the binary sensor
203 mask binary_sensor_mask using interpolation. The properties of
204 binary_sensor_mask are defined by the k-Wave grid object kgrid.
205 Two and three-dimensional data are supported.
207 Usage:
208 binary_sensor_data = interp_cart_data(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask)
209 binary_sensor_data = interp_cart_data(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask, interp)
211 Args:
212 kgrid: k-Wave grid object returned by kWaveGrid
213 cart_sensor_data: original sensor data measured over
214 cart_sensor_mask indexed as
215 cart_sensor_data(sensor position, time)
216 cart_sensor_mask: Cartesian sensor mask over which
217 cart_sensor_data is measured
218 binary_sensor_mask: binary sensor mask at which equivalent
219 time-series are computed via interpolation
221 interp: (optional) interpolation mode used to compute the
222 time-series, both 'nearest' and 'linear'
223 (two-point) modes are supported
224 (default = 'nearest')
226 Returns:
227 array of time-series corresponding to the sensor positions given by binary_sensor_mask
229 """
231 # make timer
232 timer = TicToc()
233 # start the clock
234 timer.tic()
236 # extract the number of data points
237 num_cart_data_points, num_time_points = cart_sensor_data.shape
238 num_binary_sensor_points = np.sum(binary_sensor_mask.flatten())
240 # update command line status
241 logging.log(logging.INFO, "Interpolating Cartesian sensor data...")
242 logging.log(logging.INFO, f" interpolation mode: {interp}")
243 logging.log(logging.INFO, f" number of Cartesian sensor points: {num_cart_data_points}")
244 logging.log(logging.INFO, f" number of binary sensor points: {num_binary_sensor_points}")
246 binary_sensor_data = np.zeros((num_binary_sensor_points, num_time_points))
248 # Check dimensionality of data passed
249 if kgrid.dim not in [2, 3]:
250 raise ValueError("Data must be two- or three-dimensional.")
252 cart_bsm, _ = grid2cart(kgrid, binary_sensor_mask)
254 # nearest neighbour interpolation of the data points
255 for point_index in range(num_binary_sensor_points):
256 # find the measured data point that is closest
257 dist = np.linalg.norm(cart_bsm[:, point_index] - cart_sensor_mask[: kgrid.dim, :].T, ord=2, axis=1)
258 if interp == "nearest":
259 dist_min_index = np.argmin(dist)
261 # assign value
262 binary_sensor_data[point_index, :] = cart_sensor_data[dist_min_index, :]
264 elif interp == "linear":
265 # raise NotImplementedError
266 # append the distance information onto the data set
267 cart_sensor_data_ro = cart_sensor_data
268 np.append(cart_sensor_data_ro, dist[:, None], axis=1)
269 new_col_pos = -1
271 # reorder the data set based on distance information
272 cart_sensor_data_ro = sort_rows(cart_sensor_data_ro, new_col_pos)
274 # linearly interpolate between the two closest points
275 perc = cart_sensor_data_ro[2, new_col_pos] / (cart_sensor_data_ro[1, new_col_pos] + cart_sensor_data_ro[2, new_col_pos])
276 binary_sensor_data[point_index, :] = perc * cart_sensor_data_ro[1, :] + (1 - perc) * cart_sensor_data_ro[2, :]
278 else:
279 raise ValueError("Unknown interpolation option.")
281 # elif interp == 'linear':
282 #
283 # # dist = np.sqrt((cart_bsm[0, point_index] - cart_sensor_mask[0, :])**2 +
284 # (cart_bsm[1, point_index] - cart_sensor_mask[1, :])**2)
285 # # dist = np.linalg.norm(cart_bsm[:, point_index] - cart_sensor_mask.T, axis=1)
286 # # append the distance information onto the data set
287 # new_col_pos = len(cart_sensor_data[1, :]) -1
288 # cart_sensor_data_ro = cart_sensor_data
289 # cart_sensor_data_ro[:, new_col_pos] = dist
290 #
291 # # reorder the data set based on distance information
292 # cart_sensor_data_ro = sort_rows(cart_sensor_data_ro, new_col_pos)
293 #
294 # # linearly interpolate between the two closest points
295 # perc = cart_sensor_data_ro[1, new_col_pos] /
296 # (cart_sensor_data_ro[0, new_col_pos] + cart_sensor_data_ro[1, new_col_pos] )
297 # binary_sensor_data[point_index, :] = perc * cart_sensor_data_ro[1, :new_col_pos - 1] +
298 # (1 - perc) * cart_sensor_data_ro[1, :new_col_pos - 1]
299 #
300 # else:
301 # raise ValueError('Unknown interpolation option.')
303 # update command line status
304 logging.log(logging.INFO, f" computation completed in {scale_time(timer.toc())}")
305 return binary_sensor_data
308def interpftn(x, sz: tuple, win=None):
309 """
310 Resamples an N-D matrix to the size given in sz using Fourier interpolation.
313 Args:
314 x: matrix to interpolate
315 sz: list or tuple of new size
316 win: (optional) name of windowing function to use
318 Returns:
319 Resampled matrix
321 Examples:
322 >>> y = interpftn(x, sz)
323 >>> y = interpftn(x, sz, win)
325 """
327 # extract the size of the input matrix
328 x_sz = x.shape
330 # check enough coefficients have been given
331 if sum([x != 1 for x in x_sz]) != len(sz):
332 raise ValueError("The number of scaling coefficients must equal the number of dimensions in x.")
334 # interpolate for each matrix dimension (dimensions with no interpolation required are skipped)
335 y = x
336 for p_idx, p in enumerate(sz):
337 if p != x_sz[p_idx]:
338 y = resample(y, p, axis=p_idx, window=win)
340 return y
343@typechecker
344def get_delta_bli(Nx: int, dx: float, x: np.ndarray, x0: Union[int, float], include_imag: bool = False) -> np.ndarray:
345 """
346 Exact BLI of an arbitrarily positioned delta function.
348 Calculates the exact Band-Limited Interpolation (BLI) of an arbitrarily positioned delta function.
349 For grid dimensions with an evenly-sampled periodicity, a small Nyquist frequency sinusoid is added.
350 This sinusoid is invisible on grid samples and has zero amplitude when the delta function lies on a grid node.
351 It is important when the evaluation points aren't grid nodes, and when the delta function is off-grid.
352 It serves to ensure conjugate symmetry in the BLI's Fourier transform.
354 Args:
355 Nx: Number of grid points in the relevant Cartesian direction.
356 dx: Grid point spacing [m].
357 x: Coordinates at which the BLI is evaluated [m].
358 x0: Coordinate at which the BLI is centered [m].
359 include_imag: Whether to include the imaginary component of the off-grid delta function.
360 Defaults to False.
362 Returns:
363 f: Value of the BLI at the specified coordinates.
365 """
367 # ignore imaginary component of even function by default
368 if include_imag is None:
369 include_imag = False
371 # check whether the grid has even or odd samples per period
372 is_even = Nx % 2 == 0
374 # compute BLI
375 if is_even:
376 # compute periodic sinc function
377 f = np.sin(np.pi * (x - x0) / dx) / (Nx * np.tan(np.pi * (x - x0) / (Nx * dx)))
379 # correct indeterminate points
380 f[(x - x0) == 0] = 1
382 # add Nyquist sinusoid to ensure conjugate symmetry
383 f = f - np.sin(np.pi * x0 / dx) / Nx * np.sin(np.pi * x / dx)
384 if include_imag:
385 f = f + 1j * np.cos(np.pi * x0 / dx) / Nx * np.sin(np.pi * x / dx)
386 else:
387 # compute periodic sinc function
388 f = np.sin(np.pi * (x - x0) / dx) / (Nx * np.sin(np.pi * (x - x0) / (Nx * dx)))
390 # correct indeterminate points
391 f[(x - x0) == 0] = 1
393 return f