Coverage for kwave/utils/math.py: 23%

139 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import math 

2import warnings 

3from functools import wraps 

4from itertools import compress 

5from typing import List, Optional, Tuple, Union 

6 

7import numpy as np 

8from deprecated import deprecated 

9from scipy import ndimage 

10from scipy.spatial.transform import Rotation 

11 

12from kwave import __version__ 

13from kwave.data import Vector 

14 

15 

16@deprecated( 

17 version="0.4.1", 

18 reason="Use scipy.spatial.transform.Rotation.from_euler('x', angle, degrees=True).as_matrix() instead", 

19) 

20def Rx(theta: float) -> np.ndarray: 

21 """Create a rotation matrix for rotation about the x-axis. 

22 

23 Args: 

24 theta: Rotation angle in degrees 

25 

26 Returns: 

27 3x3 rotation matrix 

28 """ 

29 return Rotation.from_euler("x", theta, degrees=True).as_matrix() 

30 

31 

32@deprecated( 

33 version="0.4.1", 

34 reason="Use scipy.spatial.transform.Rotation.from_euler('y', angle, degrees=True).as_matrix() instead", 

35) 

36def Ry(theta: float) -> np.ndarray: 

37 """Create a rotation matrix for rotation about the y-axis. 

38 

39 Args: 

40 theta: Rotation angle in degrees 

41 

42 Returns: 

43 3x3 rotation matrix 

44 """ 

45 return Rotation.from_euler("y", theta, degrees=True).as_matrix() 

46 

47 

48@deprecated( 

49 version="0.4.1", 

50 reason="Use scipy.spatial.transform.Rotation.from_euler('z', angle, degrees=True).as_matrix() instead", 

51) 

52def Rz(theta: float) -> np.ndarray: 

53 """Create a rotation matrix for rotation about the z-axis. 

54 

55 Args: 

56 theta: Rotation angle in degrees 

57 

58 Returns: 

59 3x3 rotation matrix 

60 """ 

61 return Rotation.from_euler("z", theta, degrees=True).as_matrix() 

62 

63 

64@deprecated( 

65 version="0.4.1", 

66 reason="Use make_affine() instead. It provides the same functionality with a clearer name and better documentation.", 

67) 

68def get_affine_matrix(translation: Vector, rotation: Union[float, List[float]], seq: str = "xyz") -> np.ndarray: 

69 return make_affine(translation, rotation, seq) 

70 

71 

72def make_affine(translation: Vector, rotation: Union[float, List[float]], seq: str = "xyz") -> np.ndarray: 

73 """ 

74 Create an affine transformation matrix combining rotation and translation. 

75 Uses scipy.spatial.transform.Rotation internally. 

76 

77 Args: 

78 translation: [dx, dy] or [dx, dy, dz] 

79 rotation: Single angle (degrees) for 2D or list of angles for 3D 

80 seq: Rotation sequence for 3D (default: 'xyz') 

81 

82 Returns: 

83 3x3 (2D) or 4x4 (3D) affine transformation matrix 

84 

85 Examples: 

86 # 2D transform (rotation around z-axis) 

87 T1 = make_affine([1, 2], 45) 

88 

89 # 3D transform with xyz Euler angles 

90 T2 = make_affine([1, 2, 3], [45, 30, 60]) 

91 

92 # 3D transform with custom sequence 

93 T3 = make_affine([1, 2, 3], [45, 30], 'xy') 

94 """ 

95 if len(translation) == 2: 

96 # 2D transformation 

97 R = Rotation.from_euler("z", rotation, degrees=True).as_matrix()[:2, :2] 

98 T = np.eye(3) 

99 T[:2, :2] = R 

100 T[:2, 2] = translation 

101 return T 

102 else: 

103 # 3D transformation 

104 R = Rotation.from_euler(seq, rotation, degrees=True) 

105 T = np.eye(4) 

106 T[:3, :3] = R.as_matrix() 

107 T[:3, 3] = translation 

108 return T 

109 

110 

111def cosd(angle_in_degrees): 

112 """Compute cosine of angle in degrees.""" 

113 return np.cos(np.radians(angle_in_degrees)) 

114 

115 

116def sind(angle_in_degrees): 

117 """Compute sine of angle in degrees.""" 

118 return np.sin(np.radians(angle_in_degrees)) 

119 

120 

121def largest_prime_factor(n: int) -> int: 

122 """ 

123 Finds the largest prime factor of a positive integer. 

124 

125 Args: 

126 n: The positive integer to be factored. 

127 

128 Returns: 

129 The largest prime factor of n. 

130 

131 """ 

132 i = 2 

133 while i * i <= n: 

134 if n % i: 

135 i += 1 

136 else: 

137 n //= i 

138 return n 

139 

140 

141def rwh_primes(n: int) -> List[int]: 

142 """ 

143 Generates a list of prime numbers less than a given integer. 

144 

145 Args: 

146 n: The upper bound for the list of primes. 

147 

148 Returns: 

149 A list of prime numbers less than n. 

150 

151 """ 

152 sieve = bytearray([True]) * (n // 2 + 1) 

153 for i in range(1, int(n**0.5) // 2 + 1): 

154 if sieve[i]: 

155 sieve[2 * i * (i + 1) :: 2 * i + 1] = bytearray((n // 2 - 2 * i * (i + 1)) // (2 * i + 1) + 1) 

156 return [2, *compress(range(3, n, 2), sieve[1:])] 

157 

158 

159def primefactors(n: int) -> List[int]: 

160 """ 

161 Finds the prime factors of a given integer. 

162 

163 Args: 

164 n: The integer to factor. 

165 

166 Returns: 

167 A list of prime factors of n. 

168 

169 """ 

170 factors = [] 

171 while n % 2 == 0: 

172 (factors.append(2),) 

173 n = n / 2 

174 

175 # n became odd 

176 for i in range(3, int(math.sqrt(n)) + 1, 2): 

177 while n % i == 0: 

178 factors.append(i) 

179 n = n / i 

180 

181 if n > 2: 

182 factors.append(n) 

183 

184 return factors 

185 

186 

187def next_pow2(n: int) -> int: 

188 """ 

189 Calculate the next power of 2 that is greater than or equal to `n`. 

190 

191 This function takes a positive integer `n` and returns the smallest power of 2 that is greater 

192 than or equal to `n`. 

193 

194 Args: 

195 n: The number to find the next power of 2 for. 

196 

197 Returns: 

198 The smallest power of 2 that is greater than or equal to `n`. 

199 

200 """ 

201 # decrement `n` (to handle cases when `n` itself is a power of 2) 

202 n = n - 1 

203 

204 # set all bits after the last set bit 

205 n |= n >> 1 

206 n |= n >> 2 

207 n |= n >> 4 

208 n |= n >> 8 

209 n |= n >> 16 

210 

211 # increment `n` and return 

212 return np.log2(n + 1) 

213 

214 

215def phase_shift_interpolate(data: np.ndarray, shift: float, shift_dim: Optional[int] = None) -> np.ndarray: 

216 """ 

217 Interpolates array data using phase shifts in the Fourier domain. 

218 

219 This function resamples the input data along the specified dimension using a 

220 regular grid that is offset by the non-dimensional distance shift. 

221 The resampling is performed using a Fourier interpolant. 

222 

223 This can be used to shift the acoustic particle velocity recorded by the 

224 first-order simulation functions to the regular (non-staggered) temporal 

225 grid by setting shift to 1/2. 

226 

227 Example: 

228 # Move velocity data from staggered to regular grid points 

229 v_regular = phase_shift_interpolate(v_staggered, shift=0.5) 

230 

231 Args: 

232 data: The input array to be interpolated. 

233 shift: Non-dimensional shift amount, where: 

234 0 = no shift 

235 1/2 = shift for staggered grid 

236 1 = full grid point 

237 shift_dim: The dimension along which to apply the phase shift. 

238 Default is highest non-singleton dimension. 

239 

240 Returns: 

241 The interpolated array after applying the phase shift. 

242 """ 

243 # Handle dimension selection (matching MATLAB behavior) 

244 if shift_dim is None: 

245 # Find highest non-singleton dimension 

246 shift_dim = data.ndim - 1 

247 if data.ndim == 2 and data.shape[1] == 1: 

248 shift_dim = 0 

249 else: 

250 shift_dim = shift_dim - 1 

251 if not (0 <= shift_dim <= 3): 

252 raise ValueError("Input dim must be 0, 1, 2 or 3.") 

253 elif shift_dim >= data.ndim: 

254 warnings.warn(f"Shift dimension {shift_dim} is greater than the number of dimensions in the input array {data.ndim}.") 

255 shift_dim = data.ndim - 1 

256 

257 # Create shift array with zeros except for the shift dimension 

258 shifts = np.zeros(data.ndim) 

259 shifts[shift_dim] = shift 

260 

261 # Take FFT of input data 

262 fft_data = np.fft.fft(data, axis=shift_dim) 

263 

264 # Apply fourier shift (scipy expects input in Fourier domain) 

265 # Note: scipy.ndimage.fourier_shift applies the shift in the opposite direction 

266 # compared to MATLAB, so we negate the shift 

267 shifted_data = ndimage.fourier_shift(fft_data, -shifts) 

268 

269 # Return to spatial domain, ensuring real output 

270 return np.real(np.fft.ifft(shifted_data, axis=shift_dim)) 

271 

272 

273@deprecated( 

274 version="0.4.1", 

275 reason="This function has been renamed to phase_shift_interpolate() to better reflect its functionality.", 

276) 

277def fourier_shift(data: np.ndarray, shift: float, shift_dim: Optional[int] = None) -> np.ndarray: 

278 """Wrapper for phase_shift_interpolate. See its documentation for details.""" 

279 return phase_shift_interpolate(data, shift, shift_dim) 

280 

281 

282def round_even(x): 

283 """ 

284 Rounds to the nearest even integer. 

285 

286 Args: 

287 x: Input value 

288 

289 Returns: 

290 Nearest even integer. 

291 

292 """ 

293 return 2 * round(x / 2) 

294 

295 

296def round_odd(x): 

297 """ 

298 Rounds to the nearest odd integer. 

299 

300 Args: 

301 x: Input value 

302 

303 Returns: 

304 Nearest odd integer. 

305 

306 """ 

307 return 2 * round((x + 1) / 2) - 1 

308 

309 

310def find_closest(A: np.ndarray, a: Union[float, int]) -> Tuple[Union[float, int], Tuple[int, ...]]: 

311 """ 

312 Returns the value and index of the item in A that is closest to the value a. 

313 

314 This function finds the value and index of the item in the input array A that is closest to the given value a. 

315 For vectors, the value and index correspond to the closest element in A. For matrices, value and index are row 

316 vectors corresponding to the closest element from each column. For N-D arrays, the function finds the closest 

317 value along the first matrix dimension (singleton dimensions are removed before the search). If there is more 

318 than one element with the closest value, the index of the first one is returned. 

319 

320 Args: 

321 A: The array to search. 

322 a: The value to find. 

323 

324 Returns: 

325 A tuple containing the value and index of the closest element in A to a. 

326 

327 """ 

328 assert isinstance(A, np.ndarray), "A must be an np.array" 

329 

330 idx = np.unravel_index(np.argmin(abs(A - a)), A.shape) 

331 return A[idx], idx 

332 

333 

334def sinc(x: Union[int, float, np.ndarray]) -> Union[int, float, np.ndarray]: 

335 """ 

336 Calculates the sinc function of a given value or array of values. 

337 

338 Args: 

339 x: The value or array of values for which to calculate the sinc function. 

340 

341 Returns: 

342 The sinc function of x. 

343 

344 """ 

345 return np.sinc(x / np.pi) 

346 

347 

348def gaussian( 

349 x: Union[int, float, np.ndarray], 

350 magnitude: Optional[Union[int, float]] = None, 

351 mean: Optional[float] = 0, 

352 variance: Optional[float] = 1, 

353) -> Union[int, float, np.ndarray]: 

354 """ 

355 Returns a Gaussian distribution f(x) with the specified magnitude, mean, and variance. If these values are not specified, 

356 the magnitude is normalised and values of variance = 1 and mean = 0 are used. For example running: 

357 

358 import matplotlib.pyplot as plt 

359 x = np.arange(-3, 0.05, 3) 

360 plt.plot(x, gaussian(x)) 

361 

362 will plot a normalised Gaussian distribution. 

363 

364 Note, the full width at half maximum of the resulting distribution can be calculated by FWHM = 2 * sqrt(2 * log(2) * variance). 

365 

366 Args: 

367 x: The input values. 

368 magnitude: Bell height. Defaults to normalised. 

369 mean: Mean or expected value. Defaults to 0. 

370 variance: Variance, or bell width. Defaults to 1. 

371 

372 Returns: 

373 A Gaussian distribution. 

374 

375 """ 

376 if magnitude is None: 

377 magnitude = (2 * math.pi * variance) ** -0.5 

378 

379 gauss_distr = magnitude * np.exp(-((x - mean) ** 2) / (2 * variance)) 

380 

381 return gauss_distr 

382 

383 

384def _compute_direction(start_pos: np.ndarray, end_pos: np.ndarray) -> Tuple[np.ndarray, float]: 

385 """Compute normalized direction vector and magnitude between two points.""" 

386 direction = end_pos - start_pos 

387 magnitude = np.linalg.norm(direction) 

388 direction = direction / magnitude 

389 return direction, magnitude 

390 

391 

392def _compute_rotation_axis(reference: np.ndarray, direction: np.ndarray) -> Tuple[np.ndarray, float]: 

393 """Compute normalized rotation axis and its magnitude.""" 

394 axis = np.cross(reference, direction) 

395 axis_norm = np.linalg.norm(axis) 

396 return axis, axis_norm 

397 

398 

399def _create_rotation_matrix(axis: np.ndarray, angle: float) -> np.ndarray: 

400 """Create rotation matrix using Rodrigues' formula.""" 

401 cos_theta = np.cos(angle) 

402 sin_theta = np.sin(angle) 

403 

404 # Skew-symmetric matrix of axis 

405 skew = np.array([[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]]) 

406 

407 # Outer product 

408 outer = np.outer(axis, axis) 

409 

410 return cos_theta * np.eye(3) + sin_theta * skew + (1 - cos_theta) * outer 

411 

412 

413def compute_rotation_between_vectors(start_pos: np.ndarray, end_pos: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 

414 """Compute rotation matrix between two 3D points. 

415 

416 Args: 

417 start_pos: Starting position vector 

418 end_pos: Ending position vector 

419 

420 Returns: 

421 Tuple containing: 

422 - 3x3 rotation matrix 

423 - Normalized direction vector 

424 """ 

425 direction, magnitude = _compute_direction(start_pos, end_pos) 

426 

427 if np.isclose(magnitude, 0): 

428 return np.eye(3), np.zeros(3) 

429 

430 reference = np.array([0.0, 0.0, -1.0]) 

431 

432 axis, axis_norm = _compute_rotation_axis(reference, direction) 

433 

434 if axis_norm > np.finfo(float).eps: 

435 axis = axis / axis_norm 

436 angle = np.arccos(np.clip(np.dot(reference, direction), -1.0, 1.0)) 

437 rot_mat = _create_rotation_matrix(axis, angle) 

438 else: 

439 # Vectors are parallel or anti-parallel 

440 rot_mat = np.eye(3) if np.dot(reference, direction) > 0 else -np.eye(3) 

441 

442 return rot_mat, direction 

443 

444 

445def compute_linear_transform(pos1, pos2, offset=None): 

446 """ 

447 Compute linear transformation between two 3D points. 

448 

449 This function computes the linear transformation that maps a vector pointing from 

450 pos1 to pos2 into the canonical direction [0, 0, -1]. 

451 

452 Args: 

453 pos1: Starting position (3D point) 

454 pos2: Ending position (3D point) 

455 offset: Offset vector (3D point) 

456 

457 Returns: 

458 Tuple containing: 

459 - 3x3 rotation matrix 

460 

461 """ 

462 rot_mat, direction = compute_rotation_between_vectors(pos1, pos2) 

463 if offset is not None: 

464 offset_pos = pos1 + offset * direction 

465 else: 

466 offset_pos = 0 

467 return rot_mat, offset_pos