Coverage for kwave/utils/matrix.py: 12%

138 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import logging 

2 

3import numpy as np 

4from beartype import beartype as typechecker 

5from beartype.typing import List, Optional, Tuple, Union 

6from jaxtyping import Bool, Int, Num, Real, Shaped 

7from scipy.interpolate import interp1d, interpn 

8 

9import kwave.utils.typing as kt 

10 

11from .data import scale_time 

12from .tictoc import TicToc 

13 

14 

15@typechecker 

16def trim_zeros(data: Num[np.ndarray, "..."]) -> Tuple[Num[np.ndarray, "..."], List[Tuple[Int[kt.ScalarLike, ""], Int[kt.ScalarLike, ""]]]]: 

17 """ 

18 Create a tight bounding box by removing zeros. 

19 

20 Args: 

21 data: Matrix to trim. 

22 

23 Returns: 

24 Tuple containing the trimmed matrix and indices used to trim the matrix. 

25 

26 Raises: 

27 ValueError: If the input data is not 1D, 2D, or 3D. 

28 

29 Example: 

30 data = np.array([[0, 0, 0, 0, 0, 0], 

31 [0, 0, 0, 3, 0, 0], 

32 [0, 0, 1, 3, 4, 0], 

33 [0, 0, 1, 3, 4, 0], 

34 [0, 0, 1, 3, 0, 0], 

35 [0, 0, 0, 0, 0, 0]]) 

36 

37 trimmed_data, indices = trim_zeros(data) 

38 

39 # Output: 

40 # trimmed_data: 

41 # [[0 3 0] 

42 # [1 3 4] 

43 # [1 3 4] 

44 # [1 3 0]] 

45 # 

46 # indices: 

47 # [(1, 4), (2, 5), (3, 5)] 

48 

49 """ 

50 data = np.squeeze(data) 

51 

52 # only allow 1D, 2D, and 3D 

53 if data.ndim > 3: 

54 raise ValueError("Input data must be 1D, 2D, or 3D.") 

55 

56 # set collapse directions for each dimension 

57 collapse = {2: [1, 0], 3: [(1, 2), (0, 2), (0, 1)]} 

58 

59 # preallocate output to store indices 

60 ind = [] 

61 

62 # loop through dimensions 

63 for dim_index in range(data.ndim): 

64 # collapse to 1D vector 

65 if data.ndim == 1: 

66 summed_values = data 

67 else: 

68 summed_values = np.sum(np.abs(data), axis=collapse[data.ndim][dim_index]) 

69 

70 # find the first and last non-empty values 

71 non_zeros = np.where(summed_values > 0)[0] 

72 ind_first = non_zeros[0] 

73 ind_last = non_zeros[-1] + 1 

74 

75 # trim data 

76 if data.ndim == 1: 

77 data = data[ind_first:ind_last] 

78 ind.append((ind_first, ind_last)) 

79 else: 

80 if dim_index == 0: 

81 data = data[ind_first:ind_last, ...] 

82 ind.append((ind_first, ind_last)) 

83 elif dim_index == 1: 

84 data = data[:, ind_first:ind_last, ...] 

85 ind.append((ind_first, ind_last)) 

86 elif dim_index == 2: 

87 data = data[..., ind_first:ind_last] 

88 ind.append((ind_first, ind_last)) 

89 

90 return data, ind 

91 

92 

93@typechecker 

94def expand_matrix( 

95 matrix: Union[Num[np.ndarray, "..."], Bool[np.ndarray, "..."]], 

96 exp_coeff: Union[Shaped[kt.ArrayLike, "dim"], List], 

97 edge_val: Optional[Real[kt.ScalarLike, ""]] = None, 

98): 

99 """ 

100 Enlarge a matrix by extending the edge values. 

101 

102 expandMatrix enlarges an input matrix by extension of the values at 

103 the outer faces of the matrix (endpoints in 1D, outer edges in 2D, 

104 outer surfaces in 3D). Alternatively, if an input for edge_val is 

105 given, all expanded matrix elements will have this value. The values 

106 for exp_coeff are forced to be real positive integers (or zero). 

107 

108 Note, indexing is done inline with other k-Wave functions using 

109 mat(x) in 1D, mat(x, y) in 2D, and mat(x, y, z) in 3D. 

110 

111 Args: 

112 matrix: the matrix to enlarge 

113 exp_coeff: the number of elements to add in each dimension 

114 in 1D: [a] or [x_start, x_end] 

115 in 2D: [a] or [x, y] or 

116 [x_start, x_end, y_start, y_end] 

117 in 3D: [a] or [x, y, z] or 

118 [x_start, x_end, y_start, y_end, z_start, z_end] 

119 (here 'a' is applied to all dimensions) 

120 edge_val: value to use in the matrix expansion 

121 

122 Returns: 

123 expanded matrix 

124 

125 """ 

126 

127 opts = {} 

128 matrix = np.squeeze(matrix) 

129 

130 if edge_val is None: 

131 opts["mode"] = "edge" 

132 else: 

133 opts["mode"] = "constant" 

134 opts["constant_values"] = edge_val 

135 

136 exp_coeff = np.array(exp_coeff).astype(int).squeeze() 

137 n_coeff = exp_coeff.size 

138 assert n_coeff > 0 

139 

140 if n_coeff == 1: 

141 opts["pad_width"] = exp_coeff 

142 elif len(matrix.shape) == 1: 

143 assert n_coeff <= 2 

144 opts["pad_width"] = exp_coeff 

145 elif len(matrix.shape) == 2: 

146 if n_coeff == 2: 

147 opts["pad_width"] = [(exp_coeff[0],), (exp_coeff[1],)] 

148 if n_coeff == 4: 

149 opts["pad_width"] = [(exp_coeff[0], exp_coeff[1]), (exp_coeff[2], exp_coeff[3])] 

150 elif len(matrix.shape) == 3: 

151 if n_coeff == 3: 

152 opts["pad_width"] = np.tile(np.expand_dims(exp_coeff, axis=-1), [1, 2]) 

153 if n_coeff == 6: 

154 opts["pad_width"] = [(exp_coeff[0], exp_coeff[1]), (exp_coeff[2], exp_coeff[3]), (exp_coeff[4], exp_coeff[5])] 

155 

156 return np.pad(matrix, **opts) 

157 

158 

159def resize(mat: np.ndarray, new_size: Union[int, List[int]], interp_mode: str = "linear") -> np.ndarray: 

160 """ 

161 Resizes a matrix of spatial samples to a desired resolution or spatial sampling frequency 

162 via interpolation. 

163 

164 Parameters: 

165 mat: Matrix to be resized (i.e., resampled). 

166 new_size: Desired output resolution. 

167 interp_mode: Interpolation method. 

168 

169 Returns: 

170 Resized matrix. 

171 """ 

172 # start the timer 

173 TicToc.tic() 

174 

175 # update command line status 

176 logging.log(logging.INFO, "Resizing matrix...") 

177 # check inputs 

178 assert num_dim2(mat) == len(new_size), "Resolution input must have the same number of elements as data dimensions." 

179 

180 mat = mat.squeeze() 

181 

182 axis = [] 

183 for dim in range(len(mat.shape)): 

184 dim_size = mat.shape[dim] 

185 axis.append(np.linspace(0, 1, dim_size)) 

186 

187 new_axis = [] 

188 for dim in range(len(new_size)): 

189 dim_size = new_size[dim] 

190 new_axis.append(np.linspace(0, 1, dim_size)) 

191 

192 points = tuple(p for p in axis) 

193 xi = np.meshgrid(*new_axis) 

194 xi = np.array([x.flatten() for x in xi]).T 

195 new_points = xi 

196 mat_rs = np.squeeze(interpn(points, mat, new_points, method=interp_mode)) 

197 # TODO: fix this hack. 

198 if dim + 1 == 3: 

199 mat_rs = mat_rs.reshape([new_size[1], new_size[0], new_size[2]]) 

200 mat_rs = np.transpose(mat_rs, (1, 0, 2)) 

201 else: 

202 mat_rs = mat_rs.reshape(new_size, order="F") 

203 # update command line status 

204 logging.log(logging.INFO, f" completed in {scale_time(TicToc.toc())}") 

205 assert mat_rs.shape == tuple(new_size), "Resized matrix does not match requested size." 

206 return mat_rs 

207 

208 

209def gradient_fd(f, dx=None, dim=None, deriv_order=None, accuracy_order=None) -> List[np.ndarray]: 

210 """ 

211 Calculate the gradient of an n-dimensional input matrix using the finite-difference method. 

212 

213 This function is a wrapper of the numpy gradient method for use in the k-wave library. 

214 For one-dimensional inputs, the gradient is always computed along the non-singleton dimension. 

215 For higher dimensional inputs, the gradient for singleton dimensions is returned as 0. 

216 For elements in the center of the grid, the gradient is computed using centered finite-differences. 

217 For elements on the edge of the grid, the gradient is computed using forward or backward finite-differences. 

218 The order of accuracy of the finite-difference approximation is controlled by `accuracy_order` (default = 2). 

219 The calculations are done using sparse multiplication, so the input matrix is always cast to double precision. 

220 

221 Args: 

222 f: Input matrix. 

223 dx: Array of values for the grid point spacing in each dimension. 

224 If a value for `dim` is given, `dn` is the spacing in dimension `dim`. 

225 dim: Optional input to specify a single dimension over which to compute the gradient for 

226 deriv_order: Order of the derivative to compute, 

227 e.g., use 1 to compute df/dx, 2 to compute df^2/dx^2, etc. (default = 1). 

228 accuracy_order: Order of accuracy for the finite difference coefficients. 

229 Because centered differences are used, this must be set to an integer 

230 multiple of 2 (default = 2). 

231 

232 Returns: 

233 A list of ndarrays (or a single ndarray if there is only one dimension) 

234 corresponding to the derivatives of f with respect to each dimension. 

235 Each derivative has the same shape as f. 

236 

237 """ 

238 

239 if deriv_order: 

240 logging.log(logging.WARN, f"{DeprecationWarning.__name__}: deriv_order is no longer a supported argument.") 

241 if accuracy_order: 

242 logging.log(logging.WARN, f"{DeprecationWarning.__name__}: accuracy_order is no longer a supported argument.") 

243 

244 if dim is not None and dx is not None: 

245 return np.gradient(f, dx, axis=dim) 

246 elif dim is not None: 

247 return np.gradient(f, axis=dim) 

248 elif dx is not None: 

249 return np.gradient(f, dx) 

250 else: 

251 return np.gradient(f) 

252 

253 

254def min_nd(matrix: np.ndarray) -> Tuple[float, Tuple]: 

255 """ 

256 Find the minimum value and its indices in a numpy array. 

257 

258 Args: 

259 matrix: A numpy array of any value type. 

260 

261 Returns: 

262 A tuple containing the minimum value and a tuple of indices in the form (row, column, ...). 

263 Indices are 1-based, following the convention used in MATLAB. 

264 

265 Examples: 

266 >>> matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 

267 >>> min_nd(matrix) 

268 (1, (1, 1)) 

269 

270 """ 

271 

272 min_val, linear_index = np.min(matrix), matrix.argmin() 

273 numpy_index = np.unravel_index(linear_index, matrix.shape) 

274 matlab_index = tuple(idx + 1 for idx in numpy_index) 

275 return min_val, matlab_index 

276 

277 

278def max_nd(matrix: np.ndarray) -> Tuple[float, Tuple]: 

279 """ 

280 Returns the maximum value in a n-dimensional array and its index. 

281 

282 Args: 

283 matrix: n-dimensional array of values. 

284 

285 Returns: 

286 A tuple containing the maximum value in the array, and a tuple containing the index of the 

287 maximum value. The index is given in the MATLAB convention, where indexing starts at 1. 

288 

289 """ 

290 

291 # Get the maximum value and its linear index 

292 max_val, linear_index = np.max(matrix), matrix.argmax() 

293 

294 # Convert the linear index to a tuple of indices in the original matrix 

295 numpy_index = np.unravel_index(linear_index, matrix.shape) 

296 

297 # Convert the tuple of indices to 1-based indices (as used in Matlab) 

298 matlab_index = tuple(idx + 1 for idx in numpy_index) 

299 

300 # Return the maximum value and the 1-based index 

301 return max_val, matlab_index 

302 

303 

304def broadcast_axis(data: np.ndarray, ndims: int, axis: int) -> np.ndarray: 

305 """ 

306 Broadcast the given axis of the data to the specified number of dimensions. 

307 

308 Args: 

309 data: The data to broadcast. 

310 ndims: The number of dimensions to broadcast the axis to. 

311 axis: The axis to broadcast. 

312 

313 Returns: 

314 The broadcasted data. 

315 

316 """ 

317 

318 newshape = [1] * ndims 

319 newshape[axis] = -1 

320 return data.reshape(*newshape) 

321 

322 

323def revolve2d(mat2d: np.ndarray) -> np.ndarray: 

324 """ 

325 Revolve a 2D numpy array in a clockwise direction to form a 3D numpy array. 

326 

327 Args: 

328 mat2d: A 2D numpy array of any value type. 

329 

330 Returns: 

331 A 3D numpy array formed by revolving the input array in a clockwise direction. 

332 

333 Examples: 

334 >>> mat2d = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 

335 >>> revolve2d(mat2d) 

336 array([[[1, 2, 3], 

337 [4, 5, 6], 

338 [7, 8, 9]], 

339 [[7, 4, 1], 

340 [8, 5, 2], 

341 [9, 6, 3]], 

342 [[9, 8, 7], 

343 [6, 5, 4], 

344 [3, 2, 1]]]) 

345 

346 """ 

347 

348 # Start timer 

349 TicToc.tic() 

350 

351 # Update command line status 

352 logging.log(logging.INFO, "Revolving 2D matrix to form a 3D matrix...") 

353 

354 # Get size of matrix 

355 m, n = mat2d.shape 

356 

357 # Create the reference axis for the 2D image 

358 r_axis_one_sided = np.arange(0, n) 

359 r_axis_two_sided = np.arange(-(n - 1), n) 

360 

361 # Compute the distance from every pixel in the z-y cross-section of the 3D 

362 # matrix to the rotation axis 

363 z, y = np.meshgrid(r_axis_two_sided, r_axis_two_sided) 

364 r = np.sqrt(y**2 + z**2) 

365 

366 # Create empty image matrix 

367 mat3D = np.zeros((m, 2 * n - 1, 2 * n - 1)) 

368 

369 # Loop through each cross-section and create 3D matrix 

370 for x_index in range(m): 

371 interp = interp1d(x=r_axis_one_sided, y=mat2d[x_index, :], kind="linear", bounds_error=False, fill_value=0) 

372 mat3D[x_index, :, :] = interp(r) 

373 

374 # Update command line status 

375 logging.log(logging.INFO, f" completed in {scale_time(TicToc.toc())}s") 

376 return mat3D 

377 

378 

379def sort_rows(arr: np.ndarray, index: int) -> np.ndarray: 

380 """ 

381 Sort the rows of a 2D numpy array by the values in a specific column. 

382 

383 Args: 

384 arr: A 2D numpy array. 

385 index: The index of the column to sort by. 

386 

387 Returns: 

388 A copy of the input array with the rows sorted by the values in the specified column. 

389 

390 Raises: 

391 AssertionError: If `arr` is not a 2D numpy array. 

392 

393 Examples: 

394 >>> arr = np.array([[3, 2, 1], [1, 3, 2], [2, 1, 3]]) 

395 >>> sort_rows(arr, 0) 

396 array([[1, 3, 2], 

397 [2, 1, 3], 

398 [3, 2, 1]]) 

399 

400 """ 

401 

402 assert arr.ndim == 2, "'sort_rows' currently supports only 2-dimensional matrices" 

403 return arr[arr[:, index].argsort()] 

404 

405 

406def num_dim(x: np.ndarray) -> int: 

407 """ 

408 Returns the number of dimensions in x, after collapsing any singleton dimensions. 

409 

410 Args: 

411 x: The input array. 

412 

413 Returns: 

414 The number of dimensions in x. 

415 

416 """ 

417 

418 return len(x.squeeze().shape) 

419 

420 

421def num_dim2(x: np.ndarray) -> int: 

422 """ 

423 Get the number of dimensions of an array after collapsing singleton dimensions. 

424 

425 Args: 

426 x: The input array. 

427 

428 Returns: 

429 The number of dimensions of the array after collapsing singleton dimensions. 

430 

431 """ 

432 

433 sz = np.squeeze(x).shape 

434 

435 if len(sz) > 2: 

436 return len(sz) 

437 else: 

438 return np.sum(np.array(sz) > 1)