Coverage for kwave/utils/matlab.py: 30%
47 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1from typing import List, Optional, Tuple, Union
3import numpy as np
4from beartype import beartype as typechecker
7def rem(x, y, rtol=1e-05, atol=1e-08):
8 """
9 Returns the remainder after division of x by y, taking into account the floating point precision.
10 x and y must be real and have compatible sizes.
11 This function should be equivalent to the MATLAB rem function.
13 Args:
14 x (float, list, or ndarray): The dividend(s).
15 y (float, list, or ndarray): The divisor(s).
16 rtol (float): The relative tolerance parameter (see numpy.isclose).
17 atol (float): The absolute tolerance parameter (see numpy.isclose).
19 Returns:
20 float or ndarray: The remainder after division.
21 """
22 if np.any(y == 0): 22 ↛ 23line 22 didn't jump to line 23 because the condition on line 22 was never true
23 return np.nan
25 quotient = x / y
26 closest_int = np.round(quotient)
28 # check if quotient is close to an integer value
29 if np.isclose(quotient, closest_int, rtol=rtol, atol=atol).all(): 29 ↛ 30line 29 didn't jump to line 30 because the condition on line 29 was never true
30 return np.zeros_like(x)
32 remainder = x - np.fix(quotient) * y
34 return remainder
37def matlab_assign(matrix: np.ndarray, indices: Union[int, np.ndarray], values: Union[int, float, np.ndarray]) -> np.ndarray:
38 """
39 Assigns values to elements of a matrix using subscript indices.
41 Args:
42 matrix: The matrix to which values will be assigned.
43 indices: The subscript indices of the elements to be assigned. Can be a single integer or a NumPy array.
44 values: The values to be assigned. Can be a single integer, float, or a NumPy array.
46 Returns:
47 The modified matrix.
49 """
50 original_shape = matrix.shape
51 matrix = np.ravel(matrix, order="F")
52 matrix[indices] = values
53 return matrix.reshape(original_shape, order="F")
56def matlab_find(arr: Union[List[int], np.ndarray], val: int = 0, mode: str = "neq") -> np.ndarray:
57 """
58 Finds the indices of elements in an array that satisfy a given condition.
60 Args:
61 arr: The array to search. Can be a list or a NumPy array.
62 val: The value to compare against. Default is 0.
63 mode: The comparison mode. Can be either 'neq' (not equal) or 'eq' (equal). Default is 'neq'.
65 Returns:
66 A NumPy array of indices.
68 """
70 if not isinstance(arr, np.ndarray):
71 arr = np.array(arr)
72 if mode == "neq":
73 arr = np.where(arr.flatten(order="F") != val)[0] + 1 # +1 due to matlab indexing
74 else: # 'eq'
75 arr = np.where(arr.flatten(order="F") == val)[0] + 1 # +1 due to matlab indexing
76 return np.expand_dims(arr, -1) # compatibility, n => [n, 1]
79@typechecker
80def matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) -> np.ndarray:
81 """
82 Applies a mask to an array and returns the masked elements.
84 Args:
85 arr: The array to be masked.
86 mask: The mask array, which must be of the same shape as arr.
87 diff: An optional integer to add to the mask indices before applying the mask.
89 Returns:
90 A NumPy array containing the masked elements.
92 """
94 if mask.dtype == "uint8":
95 mask = mask.astype("int8")
97 if diff is None:
98 flat_mask = mask.ravel(order="F")
99 else:
100 flat_mask = mask.ravel(order="F") + diff
101 return np.expand_dims(arr.ravel(order="F")[flat_mask], axis=-1) # compatibility, n => [n, 1]
104def unflatten_matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) -> Tuple[Union[int, np.ndarray], ...]:
105 """
106 Converts a mask array to a tuple of subscript indices for an n-dimensional array.
108 Args:
109 arr: The n-dimensional array for which the mask was created.
110 mask: The mask array, which can be of any dimensions.
111 diff: An optional integer to add to the mask indices before converting them to subscript indices.
113 Returns:
114 A tuple of integers or NumPy arrays representing the corresponding subscript indices.
116 """
118 if diff is None:
119 return np.unravel_index(mask.ravel(order="F"), arr.shape, order="F")
120 else:
121 return np.unravel_index(mask.ravel(order="F") + diff, arr.shape, order="F")
124def ind2sub(array_shape: Tuple[int, ...], ind: int) -> Tuple[int, ...]:
125 """
126 Converts a linear index to a tuple of subscript indices for an n-dimensional array.
128 Args:
129 array_shape: A tuple of integers representing the shape of the array.
130 ind: The linear index to be converted.
132 Returns:
133 A tuple of integers representing the corresponding subscript indices.
135 """
137 indices = np.unravel_index(ind - 1, array_shape, order="F")
138 indices = (np.squeeze(index) + 1 for index in indices)
139 return indices
142def sub2ind(array_shape: Tuple[int, int, int], x: np.ndarray, y: np.ndarray, z: np.ndarray) -> np.ndarray:
143 """
144 Convert 3D subscript indices to a linear index.
146 This function converts 3D subscript indices to a linear index in a way that is consistent with the way
147 that MATLAB handles indexing. The output is a 1D numpy array containing the linear indices.
149 Args:
150 array_shape: A tuple containing the shape of the array.
151 x: A 1D numpy array of subscript indices for the x-dimension.
152 y: A 1D numpy array of subscript indices for the y-dimension.
153 z: A 1D numpy array of subscript indices for the z-dimension.
155 Returns:
156 A 1D numpy array containing the linear indices.
158 """
160 results = []
161 x, y, z = np.squeeze(x), np.squeeze(y), np.squeeze(z)
162 for x_i, y_i, z_i in zip(x, y, z):
163 index = np.ravel_multi_index((x_i, y_i, z_i), dims=array_shape, order="F")
164 results.append(index)
165 return np.array(results)