Coverage for kwave/utils/matlab.py: 30%

47 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1from typing import List, Optional, Tuple, Union 

2 

3import numpy as np 

4from beartype import beartype as typechecker 

5 

6 

7def rem(x, y, rtol=1e-05, atol=1e-08): 

8 """ 

9 Returns the remainder after division of x by y, taking into account the floating point precision. 

10 x and y must be real and have compatible sizes. 

11 This function should be equivalent to the MATLAB rem function. 

12 

13 Args: 

14 x (float, list, or ndarray): The dividend(s). 

15 y (float, list, or ndarray): The divisor(s). 

16 rtol (float): The relative tolerance parameter (see numpy.isclose). 

17 atol (float): The absolute tolerance parameter (see numpy.isclose). 

18 

19 Returns: 

20 float or ndarray: The remainder after division. 

21 """ 

22 if np.any(y == 0): 22 ↛ 23line 22 didn't jump to line 23 because the condition on line 22 was never true

23 return np.nan 

24 

25 quotient = x / y 

26 closest_int = np.round(quotient) 

27 

28 # check if quotient is close to an integer value 

29 if np.isclose(quotient, closest_int, rtol=rtol, atol=atol).all(): 29 ↛ 30line 29 didn't jump to line 30 because the condition on line 29 was never true

30 return np.zeros_like(x) 

31 

32 remainder = x - np.fix(quotient) * y 

33 

34 return remainder 

35 

36 

37def matlab_assign(matrix: np.ndarray, indices: Union[int, np.ndarray], values: Union[int, float, np.ndarray]) -> np.ndarray: 

38 """ 

39 Assigns values to elements of a matrix using subscript indices. 

40 

41 Args: 

42 matrix: The matrix to which values will be assigned. 

43 indices: The subscript indices of the elements to be assigned. Can be a single integer or a NumPy array. 

44 values: The values to be assigned. Can be a single integer, float, or a NumPy array. 

45 

46 Returns: 

47 The modified matrix. 

48 

49 """ 

50 original_shape = matrix.shape 

51 matrix = np.ravel(matrix, order="F") 

52 matrix[indices] = values 

53 return matrix.reshape(original_shape, order="F") 

54 

55 

56def matlab_find(arr: Union[List[int], np.ndarray], val: int = 0, mode: str = "neq") -> np.ndarray: 

57 """ 

58 Finds the indices of elements in an array that satisfy a given condition. 

59 

60 Args: 

61 arr: The array to search. Can be a list or a NumPy array. 

62 val: The value to compare against. Default is 0. 

63 mode: The comparison mode. Can be either 'neq' (not equal) or 'eq' (equal). Default is 'neq'. 

64 

65 Returns: 

66 A NumPy array of indices. 

67 

68 """ 

69 

70 if not isinstance(arr, np.ndarray): 

71 arr = np.array(arr) 

72 if mode == "neq": 

73 arr = np.where(arr.flatten(order="F") != val)[0] + 1 # +1 due to matlab indexing 

74 else: # 'eq' 

75 arr = np.where(arr.flatten(order="F") == val)[0] + 1 # +1 due to matlab indexing 

76 return np.expand_dims(arr, -1) # compatibility, n => [n, 1] 

77 

78 

79@typechecker 

80def matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) -> np.ndarray: 

81 """ 

82 Applies a mask to an array and returns the masked elements. 

83 

84 Args: 

85 arr: The array to be masked. 

86 mask: The mask array, which must be of the same shape as arr. 

87 diff: An optional integer to add to the mask indices before applying the mask. 

88 

89 Returns: 

90 A NumPy array containing the masked elements. 

91 

92 """ 

93 

94 if mask.dtype == "uint8": 

95 mask = mask.astype("int8") 

96 

97 if diff is None: 

98 flat_mask = mask.ravel(order="F") 

99 else: 

100 flat_mask = mask.ravel(order="F") + diff 

101 return np.expand_dims(arr.ravel(order="F")[flat_mask], axis=-1) # compatibility, n => [n, 1] 

102 

103 

104def unflatten_matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) -> Tuple[Union[int, np.ndarray], ...]: 

105 """ 

106 Converts a mask array to a tuple of subscript indices for an n-dimensional array. 

107 

108 Args: 

109 arr: The n-dimensional array for which the mask was created. 

110 mask: The mask array, which can be of any dimensions. 

111 diff: An optional integer to add to the mask indices before converting them to subscript indices. 

112 

113 Returns: 

114 A tuple of integers or NumPy arrays representing the corresponding subscript indices. 

115 

116 """ 

117 

118 if diff is None: 

119 return np.unravel_index(mask.ravel(order="F"), arr.shape, order="F") 

120 else: 

121 return np.unravel_index(mask.ravel(order="F") + diff, arr.shape, order="F") 

122 

123 

124def ind2sub(array_shape: Tuple[int, ...], ind: int) -> Tuple[int, ...]: 

125 """ 

126 Converts a linear index to a tuple of subscript indices for an n-dimensional array. 

127 

128 Args: 

129 array_shape: A tuple of integers representing the shape of the array. 

130 ind: The linear index to be converted. 

131 

132 Returns: 

133 A tuple of integers representing the corresponding subscript indices. 

134 

135 """ 

136 

137 indices = np.unravel_index(ind - 1, array_shape, order="F") 

138 indices = (np.squeeze(index) + 1 for index in indices) 

139 return indices 

140 

141 

142def sub2ind(array_shape: Tuple[int, int, int], x: np.ndarray, y: np.ndarray, z: np.ndarray) -> np.ndarray: 

143 """ 

144 Convert 3D subscript indices to a linear index. 

145 

146 This function converts 3D subscript indices to a linear index in a way that is consistent with the way 

147 that MATLAB handles indexing. The output is a 1D numpy array containing the linear indices. 

148 

149 Args: 

150 array_shape: A tuple containing the shape of the array. 

151 x: A 1D numpy array of subscript indices for the x-dimension. 

152 y: A 1D numpy array of subscript indices for the y-dimension. 

153 z: A 1D numpy array of subscript indices for the z-dimension. 

154 

155 Returns: 

156 A 1D numpy array containing the linear indices. 

157 

158 """ 

159 

160 results = [] 

161 x, y, z = np.squeeze(x), np.squeeze(y), np.squeeze(z) 

162 for x_i, y_i, z_i in zip(x, y, z): 

163 index = np.ravel_multi_index((x_i, y_i, z_i), dims=array_shape, order="F") 

164 results.append(index) 

165 return np.array(results)