Coverage for kwave/utils/math.py: 23%
139 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1import math
2import warnings
3from functools import wraps
4from itertools import compress
5from typing import List, Optional, Tuple, Union
7import numpy as np
8from deprecated import deprecated
9from scipy import ndimage
10from scipy.spatial.transform import Rotation
12from kwave import __version__
13from kwave.data import Vector
16@deprecated(
17 version="0.4.1",
18 reason="Use scipy.spatial.transform.Rotation.from_euler('x', angle, degrees=True).as_matrix() instead",
19)
20def Rx(theta: float) -> np.ndarray:
21 """Create a rotation matrix for rotation about the x-axis.
23 Args:
24 theta: Rotation angle in degrees
26 Returns:
27 3x3 rotation matrix
28 """
29 return Rotation.from_euler("x", theta, degrees=True).as_matrix()
32@deprecated(
33 version="0.4.1",
34 reason="Use scipy.spatial.transform.Rotation.from_euler('y', angle, degrees=True).as_matrix() instead",
35)
36def Ry(theta: float) -> np.ndarray:
37 """Create a rotation matrix for rotation about the y-axis.
39 Args:
40 theta: Rotation angle in degrees
42 Returns:
43 3x3 rotation matrix
44 """
45 return Rotation.from_euler("y", theta, degrees=True).as_matrix()
48@deprecated(
49 version="0.4.1",
50 reason="Use scipy.spatial.transform.Rotation.from_euler('z', angle, degrees=True).as_matrix() instead",
51)
52def Rz(theta: float) -> np.ndarray:
53 """Create a rotation matrix for rotation about the z-axis.
55 Args:
56 theta: Rotation angle in degrees
58 Returns:
59 3x3 rotation matrix
60 """
61 return Rotation.from_euler("z", theta, degrees=True).as_matrix()
64@deprecated(
65 version="0.4.1",
66 reason="Use make_affine() instead. It provides the same functionality with a clearer name and better documentation.",
67)
68def get_affine_matrix(translation: Vector, rotation: Union[float, List[float]], seq: str = "xyz") -> np.ndarray:
69 return make_affine(translation, rotation, seq)
72def make_affine(translation: Vector, rotation: Union[float, List[float]], seq: str = "xyz") -> np.ndarray:
73 """
74 Create an affine transformation matrix combining rotation and translation.
75 Uses scipy.spatial.transform.Rotation internally.
77 Args:
78 translation: [dx, dy] or [dx, dy, dz]
79 rotation: Single angle (degrees) for 2D or list of angles for 3D
80 seq: Rotation sequence for 3D (default: 'xyz')
82 Returns:
83 3x3 (2D) or 4x4 (3D) affine transformation matrix
85 Examples:
86 # 2D transform (rotation around z-axis)
87 T1 = make_affine([1, 2], 45)
89 # 3D transform with xyz Euler angles
90 T2 = make_affine([1, 2, 3], [45, 30, 60])
92 # 3D transform with custom sequence
93 T3 = make_affine([1, 2, 3], [45, 30], 'xy')
94 """
95 if len(translation) == 2:
96 # 2D transformation
97 R = Rotation.from_euler("z", rotation, degrees=True).as_matrix()[:2, :2]
98 T = np.eye(3)
99 T[:2, :2] = R
100 T[:2, 2] = translation
101 return T
102 else:
103 # 3D transformation
104 R = Rotation.from_euler(seq, rotation, degrees=True)
105 T = np.eye(4)
106 T[:3, :3] = R.as_matrix()
107 T[:3, 3] = translation
108 return T
111def cosd(angle_in_degrees):
112 """Compute cosine of angle in degrees."""
113 return np.cos(np.radians(angle_in_degrees))
116def sind(angle_in_degrees):
117 """Compute sine of angle in degrees."""
118 return np.sin(np.radians(angle_in_degrees))
121def largest_prime_factor(n: int) -> int:
122 """
123 Finds the largest prime factor of a positive integer.
125 Args:
126 n: The positive integer to be factored.
128 Returns:
129 The largest prime factor of n.
131 """
132 i = 2
133 while i * i <= n:
134 if n % i:
135 i += 1
136 else:
137 n //= i
138 return n
141def rwh_primes(n: int) -> List[int]:
142 """
143 Generates a list of prime numbers less than a given integer.
145 Args:
146 n: The upper bound for the list of primes.
148 Returns:
149 A list of prime numbers less than n.
151 """
152 sieve = bytearray([True]) * (n // 2 + 1)
153 for i in range(1, int(n**0.5) // 2 + 1):
154 if sieve[i]:
155 sieve[2 * i * (i + 1) :: 2 * i + 1] = bytearray((n // 2 - 2 * i * (i + 1)) // (2 * i + 1) + 1)
156 return [2, *compress(range(3, n, 2), sieve[1:])]
159def primefactors(n: int) -> List[int]:
160 """
161 Finds the prime factors of a given integer.
163 Args:
164 n: The integer to factor.
166 Returns:
167 A list of prime factors of n.
169 """
170 factors = []
171 while n % 2 == 0:
172 (factors.append(2),)
173 n = n / 2
175 # n became odd
176 for i in range(3, int(math.sqrt(n)) + 1, 2):
177 while n % i == 0:
178 factors.append(i)
179 n = n / i
181 if n > 2:
182 factors.append(n)
184 return factors
187def next_pow2(n: int) -> int:
188 """
189 Calculate the next power of 2 that is greater than or equal to `n`.
191 This function takes a positive integer `n` and returns the smallest power of 2 that is greater
192 than or equal to `n`.
194 Args:
195 n: The number to find the next power of 2 for.
197 Returns:
198 The smallest power of 2 that is greater than or equal to `n`.
200 """
201 # decrement `n` (to handle cases when `n` itself is a power of 2)
202 n = n - 1
204 # set all bits after the last set bit
205 n |= n >> 1
206 n |= n >> 2
207 n |= n >> 4
208 n |= n >> 8
209 n |= n >> 16
211 # increment `n` and return
212 return np.log2(n + 1)
215def phase_shift_interpolate(data: np.ndarray, shift: float, shift_dim: Optional[int] = None) -> np.ndarray:
216 """
217 Interpolates array data using phase shifts in the Fourier domain.
219 This function resamples the input data along the specified dimension using a
220 regular grid that is offset by the non-dimensional distance shift.
221 The resampling is performed using a Fourier interpolant.
223 This can be used to shift the acoustic particle velocity recorded by the
224 first-order simulation functions to the regular (non-staggered) temporal
225 grid by setting shift to 1/2.
227 Example:
228 # Move velocity data from staggered to regular grid points
229 v_regular = phase_shift_interpolate(v_staggered, shift=0.5)
231 Args:
232 data: The input array to be interpolated.
233 shift: Non-dimensional shift amount, where:
234 0 = no shift
235 1/2 = shift for staggered grid
236 1 = full grid point
237 shift_dim: The dimension along which to apply the phase shift.
238 Default is highest non-singleton dimension.
240 Returns:
241 The interpolated array after applying the phase shift.
242 """
243 # Handle dimension selection (matching MATLAB behavior)
244 if shift_dim is None:
245 # Find highest non-singleton dimension
246 shift_dim = data.ndim - 1
247 if data.ndim == 2 and data.shape[1] == 1:
248 shift_dim = 0
249 else:
250 shift_dim = shift_dim - 1
251 if not (0 <= shift_dim <= 3):
252 raise ValueError("Input dim must be 0, 1, 2 or 3.")
253 elif shift_dim >= data.ndim:
254 warnings.warn(f"Shift dimension {shift_dim} is greater than the number of dimensions in the input array {data.ndim}.")
255 shift_dim = data.ndim - 1
257 # Create shift array with zeros except for the shift dimension
258 shifts = np.zeros(data.ndim)
259 shifts[shift_dim] = shift
261 # Take FFT of input data
262 fft_data = np.fft.fft(data, axis=shift_dim)
264 # Apply fourier shift (scipy expects input in Fourier domain)
265 # Note: scipy.ndimage.fourier_shift applies the shift in the opposite direction
266 # compared to MATLAB, so we negate the shift
267 shifted_data = ndimage.fourier_shift(fft_data, -shifts)
269 # Return to spatial domain, ensuring real output
270 return np.real(np.fft.ifft(shifted_data, axis=shift_dim))
273@deprecated(
274 version="0.4.1",
275 reason="This function has been renamed to phase_shift_interpolate() to better reflect its functionality.",
276)
277def fourier_shift(data: np.ndarray, shift: float, shift_dim: Optional[int] = None) -> np.ndarray:
278 """Wrapper for phase_shift_interpolate. See its documentation for details."""
279 return phase_shift_interpolate(data, shift, shift_dim)
282def round_even(x):
283 """
284 Rounds to the nearest even integer.
286 Args:
287 x: Input value
289 Returns:
290 Nearest even integer.
292 """
293 return 2 * round(x / 2)
296def round_odd(x):
297 """
298 Rounds to the nearest odd integer.
300 Args:
301 x: Input value
303 Returns:
304 Nearest odd integer.
306 """
307 return 2 * round((x + 1) / 2) - 1
310def find_closest(A: np.ndarray, a: Union[float, int]) -> Tuple[Union[float, int], Tuple[int, ...]]:
311 """
312 Returns the value and index of the item in A that is closest to the value a.
314 This function finds the value and index of the item in the input array A that is closest to the given value a.
315 For vectors, the value and index correspond to the closest element in A. For matrices, value and index are row
316 vectors corresponding to the closest element from each column. For N-D arrays, the function finds the closest
317 value along the first matrix dimension (singleton dimensions are removed before the search). If there is more
318 than one element with the closest value, the index of the first one is returned.
320 Args:
321 A: The array to search.
322 a: The value to find.
324 Returns:
325 A tuple containing the value and index of the closest element in A to a.
327 """
328 assert isinstance(A, np.ndarray), "A must be an np.array"
330 idx = np.unravel_index(np.argmin(abs(A - a)), A.shape)
331 return A[idx], idx
334def sinc(x: Union[int, float, np.ndarray]) -> Union[int, float, np.ndarray]:
335 """
336 Calculates the sinc function of a given value or array of values.
338 Args:
339 x: The value or array of values for which to calculate the sinc function.
341 Returns:
342 The sinc function of x.
344 """
345 return np.sinc(x / np.pi)
348def gaussian(
349 x: Union[int, float, np.ndarray],
350 magnitude: Optional[Union[int, float]] = None,
351 mean: Optional[float] = 0,
352 variance: Optional[float] = 1,
353) -> Union[int, float, np.ndarray]:
354 """
355 Returns a Gaussian distribution f(x) with the specified magnitude, mean, and variance. If these values are not specified,
356 the magnitude is normalised and values of variance = 1 and mean = 0 are used. For example running:
358 import matplotlib.pyplot as plt
359 x = np.arange(-3, 0.05, 3)
360 plt.plot(x, gaussian(x))
362 will plot a normalised Gaussian distribution.
364 Note, the full width at half maximum of the resulting distribution can be calculated by FWHM = 2 * sqrt(2 * log(2) * variance).
366 Args:
367 x: The input values.
368 magnitude: Bell height. Defaults to normalised.
369 mean: Mean or expected value. Defaults to 0.
370 variance: Variance, or bell width. Defaults to 1.
372 Returns:
373 A Gaussian distribution.
375 """
376 if magnitude is None:
377 magnitude = (2 * math.pi * variance) ** -0.5
379 gauss_distr = magnitude * np.exp(-((x - mean) ** 2) / (2 * variance))
381 return gauss_distr
384def _compute_direction(start_pos: np.ndarray, end_pos: np.ndarray) -> Tuple[np.ndarray, float]:
385 """Compute normalized direction vector and magnitude between two points."""
386 direction = end_pos - start_pos
387 magnitude = np.linalg.norm(direction)
388 direction = direction / magnitude
389 return direction, magnitude
392def _compute_rotation_axis(reference: np.ndarray, direction: np.ndarray) -> Tuple[np.ndarray, float]:
393 """Compute normalized rotation axis and its magnitude."""
394 axis = np.cross(reference, direction)
395 axis_norm = np.linalg.norm(axis)
396 return axis, axis_norm
399def _create_rotation_matrix(axis: np.ndarray, angle: float) -> np.ndarray:
400 """Create rotation matrix using Rodrigues' formula."""
401 cos_theta = np.cos(angle)
402 sin_theta = np.sin(angle)
404 # Skew-symmetric matrix of axis
405 skew = np.array([[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]])
407 # Outer product
408 outer = np.outer(axis, axis)
410 return cos_theta * np.eye(3) + sin_theta * skew + (1 - cos_theta) * outer
413def compute_rotation_between_vectors(start_pos: np.ndarray, end_pos: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
414 """Compute rotation matrix between two 3D points.
416 Args:
417 start_pos: Starting position vector
418 end_pos: Ending position vector
420 Returns:
421 Tuple containing:
422 - 3x3 rotation matrix
423 - Normalized direction vector
424 """
425 direction, magnitude = _compute_direction(start_pos, end_pos)
427 if np.isclose(magnitude, 0):
428 return np.eye(3), np.zeros(3)
430 reference = np.array([0.0, 0.0, -1.0])
432 axis, axis_norm = _compute_rotation_axis(reference, direction)
434 if axis_norm > np.finfo(float).eps:
435 axis = axis / axis_norm
436 angle = np.arccos(np.clip(np.dot(reference, direction), -1.0, 1.0))
437 rot_mat = _create_rotation_matrix(axis, angle)
438 else:
439 # Vectors are parallel or anti-parallel
440 rot_mat = np.eye(3) if np.dot(reference, direction) > 0 else -np.eye(3)
442 return rot_mat, direction
445def compute_linear_transform(pos1, pos2, offset=None):
446 """
447 Compute linear transformation between two 3D points.
449 This function computes the linear transformation that maps a vector pointing from
450 pos1 to pos2 into the canonical direction [0, 0, -1].
452 Args:
453 pos1: Starting position (3D point)
454 pos2: Ending position (3D point)
455 offset: Offset vector (3D point)
457 Returns:
458 Tuple containing:
459 - 3x3 rotation matrix
461 """
462 rot_mat, direction = compute_rotation_between_vectors(pos1, pos2)
463 if offset is not None:
464 offset_pos = pos1 + offset * direction
465 else:
466 offset_pos = 0
467 return rot_mat, offset_pos