Coverage for kwave/utils/checks.py: 19%

93 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import logging 

2import numbers 

3import platform 

4from copy import deepcopy 

5from typing import TYPE_CHECKING, Any, List 

6 

7import numpy as np 

8import scipy 

9import scipy.optimize 

10 

11if TYPE_CHECKING: 11 ↛ 13line 11 didn't jump to line 13 because the condition on line 11 was never true

12 # Found here: https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ 

13 from kwave.kgrid import kWaveGrid 

14 from kwave.kmedium import kWaveMedium 

15 

16from .conversion import db2neper 

17from .math import primefactors, sinc 

18 

19 

20def enforce_fields(dictionary, *fields): 

21 """ 

22 Ensures that the given dictionary contains the specified fields. 

23 

24 Args: 

25 dictionary: A dictionary to check. 

26 *fields: The fields that must be present in the dictionary. 

27 

28 Raises: 

29 AssertionError: If any of the specified fields are not in the dictionary. 

30 

31 """ 

32 

33 for f in fields: 

34 assert f in dictionary.keys(), [f"The field {f} must be defined in the given dictionary"] 

35 

36 

37def enforce_fields_obj(obj, *fields): 

38 """ 

39 Enforces that certain fields are not None in the given object. 

40 

41 Args: 

42 obj: Object to check the fields of. 

43 *fields: List of field names to check. 

44 

45 Raises: 

46 AssertionError: If any of the given fields are None in the given object. 

47 

48 """ 

49 

50 for f in fields: 

51 assert getattr(obj, f) is not None, f"The field {f} must be not None in the given object" 

52 

53 

54def check_field_names(dictionary, *fields): 

55 """ 

56 This method checks if the keys of the given dictionary are valid fields. 

57 

58 Args: 

59 dictionary: A dictionary where the keys will be checked for validity. 

60 *fields: A list of valid field names. 

61 

62 Raises: 

63 AssertionError: If any of the keys in the dictionary are not in the list of valid fields. 

64 

65 """ 

66 

67 for k in dictionary.keys(): 

68 assert k in fields, f"The field {k} is not a valid field for the given dictionary" 

69 

70 

71def check_str_eq(value, target: str) -> bool: 

72 """ 

73 This method checks whether the given value is a string and is equal to the target string. 

74 It is useful to avoid FutureWarnings when value is not a string. 

75 

76 Args: 

77 value: The value to check. 

78 target: The target string to compare with. 

79 

80 Returns: 

81 True if the value is a string and is equal to the target, False otherwise. 

82 

83 """ 

84 

85 return isinstance(value, str) and value == target 

86 

87 

88def check_str_in(value, target: List[str]) -> bool: 

89 """ 

90 Check if value is in the given list only if the value is string. 

91 Helps to avoid FutureWarnings when value is not a string. 

92 Added by @Farid 

93 

94 Args: 

95 value: The value to check for inclusion in `target` 

96 target: A list of strings to check for the presence of `value` 

97 

98 Returns: 

99 True if `value` is a string and is present in `target`, otherwise False 

100 

101 """ 

102 

103 return isinstance(value, str) and value in target 

104 

105 

106def is_number(value: Any) -> bool: 

107 """ 

108 Check if the given value is a numeric type. 

109 

110 Args: 

111 value: The value to check. 

112 

113 Returns: 

114 True if the value is numeric, False otherwise. 

115 

116 """ 

117 

118 if value is None: 

119 return False 

120 if isinstance(value, (int, float)): 

121 return True 

122 if isinstance(value, str): 

123 return False 

124 if isinstance(value, np.ndarray): 

125 return np.issubdtype(value.dtype, np.number) 

126 return np.issubdtype(np.array(value), np.number) 

127 

128 

129def is_unix() -> bool: 

130 """ 

131 Check whether the current platform is a Unix-like system. 

132 

133 Returns: 

134 True if the current platform is a Unix-like system, False otherwise. 

135 

136 """ 

137 return platform.system() in ["Linux", "Darwin"] 

138 

139 

140def _evaluate_absorbing_dt_stability_limit(kmax, c_ref, medium: "kWaveMedium", xtol=1e-12) -> float: 

141 # convert the absorption coefficient to nepers.(rad/s)^-y.m^-1 

142 alpha_coeff = db2neper(medium.alpha_coeff, medium.alpha_power) 

143 

144 # calculate the absorption constant 

145 # calculate the absorption constant 

146 if medium.alpha_mode != "no_absorption": 

147 absorb_tau = -2.0 * alpha_coeff * medium.sound_speed ** (medium.alpha_power - 1.0) 

148 else: 

149 absorb_tau = np.array([0]) 

150 

151 # calculate the dispersion constant 

152 if medium.alpha_mode != "no_dispersion": 

153 absorb_eta = 2.0 * alpha_coeff * medium.sound_speed ** (medium.alpha_power) * np.tan(np.pi * medium.alpha_power / 2.0) 

154 else: 

155 absorb_eta = np.array([0]) 

156 

157 # estimate the timestep required for stability in the absorbing case by 

158 # assuming the k-space correction factor, kappa = 1 (note that 

159 # absorb_tau and absorb_eta are negative quantities) 

160 medium.sound_speed = np.atleast_1d(medium.sound_speed) 

161 

162 temp1 = 1 - absorb_eta.min() * kmax ** (medium.alpha_power - 1) 

163 

164 def kappa(dt): 

165 return sinc(c_ref * kmax * dt / 2.0) 

166 

167 def temp2(dt): 

168 return medium.sound_speed.max() * absorb_tau.min() * kappa(dt) * kmax ** (medium.alpha_power - 1) 

169 

170 def func_to_solve(dt): 

171 return (temp2(dt) + np.sqrt((temp2(dt)) ** 2.0 + 4.0 * temp1)) / (temp1 * kmax * kappa(dt) * medium.sound_speed.max()) 

172 

173 dt_start = func_to_solve(0) 

174 

175 dt_stability_limit = scipy.optimize.fixed_point(func_to_solve, dt_start, xtol=xtol) 

176 return dt_stability_limit 

177 

178 

179def _evaluate_non_absorbing_dt_stability_limit(kmax, c_ref, medium: "kWaveMedium") -> float: 

180 if c_ref >= medium.sound_speed.max(): 

181 # set the timestep to Inf when the model is unconditionally stable 

182 dt_stability_limit = float("inf") 

183 else: 

184 # set the timestep required for stability when c_ref~=max(medium.sound_speed(:)) 

185 dt_stability_limit = 2.0 / (c_ref * kmax) * np.arcsin(c_ref / medium.sound_speed.max()) 

186 return dt_stability_limit 

187 

188 

189def check_stability(kgrid: "kWaveGrid", medium: "kWaveMedium") -> float: 

190 """ 

191 Calculates the maximum time step for which the k-space 

192 propagation models are stable. 

193 

194 These models are unconditionally 

195 stable when the reference sound speed is equal to or greater than the 

196 maximum sound speed in the medium and there is no absorption. 

197 However, when the reference sound speed is less than the maximum 

198 sound speed the model is only stable for sufficiently small time 

199 steps. The criterion is more stringent (the time step is smaller) in 

200 the absorbing case. 

201 

202 The time steps given are accurate when the medium properties are 

203 homogeneous. For a heterogeneous media they give a useful, but not 

204 exact, estimate. 

205 

206 Args: 

207 kgrid: simulation grid 

208 medium: medium properties 

209 

210 Returns: 

211 The maximum time step for which the models are stable. This is set to Inf when the model is unconditionally stable. 

212 

213 """ 

214 

215 # why? : this function was migrated from Matlab. 

216 # Matlab would treat the 'medium' as a "pass by value" argument. 

217 # In python argument is passed by reference and changes in this function will cause original data to be changed. 

218 # Instead of making significant changes to the function, we make a deep copy of the argument 

219 medium = deepcopy(medium) 

220 

221 # find the maximum wavenumber 

222 kmax = kgrid.k.max() 

223 

224 # calculate the reference sound speed for the fluid code, using the 

225 # maximum by default which ensures the model is unconditionally stable 

226 reductions = {"min": np.min, "max": np.max, "mean": np.mean} 

227 

228 # TODO: move this logic to medium 

229 if medium.sound_speed_ref is not None: 

230 ss_ref = medium.sound_speed_ref 

231 if isinstance(ss_ref, numbers.Number): 

232 c_ref = ss_ref 

233 else: 

234 try: 

235 c_ref = reductions[ss_ref](medium.sound_speed) 

236 except KeyError: 

237 raise NotImplementedError(f"Unknown value of {ss_ref} for medium.sound_speed_ref.") 

238 else: 

239 c_ref = reductions["max"](medium.sound_speed) 

240 

241 medium.sound_speed = np.atleast_1d(medium.sound_speed) 

242 # calculate the timesteps required for stability 

243 if medium.alpha_coeff is None or np.all(medium.alpha_coeff == 0): 

244 dt_stability_limit = _evaluate_non_absorbing_dt_stability_limit(kmax, c_ref, medium) 

245 else: 

246 dt_stability_limit = _evaluate_absorbing_dt_stability_limit(kmax, c_ref, medium) 

247 

248 return dt_stability_limit 

249 

250 

251def check_factors(min_number: int, max_number: int) -> None: 

252 """ 

253 Return the maximum prime factor for a range of numbers. 

254 

255 checkFactors loops through the given range of numbers and finds the 

256 numbers with the smallest maximum prime factors. This allows suitable 

257 grid sizes to be selected to maximise the speed of the FFT (this is 

258 fastest for FFT lengths with small prime factors). The output is 

259 printed to the command line. 

260 

261 Args: 

262 min_number: integer specifying the lower bound of values to test 

263 max_number: integer specifying the upper bound of values to test 

264 

265 """ 

266 

267 # compute the factors and maximum prime factors for each number in the range 

268 factors = {} 

269 for n in range(min_number, max_number): 

270 factors[n] = {"factors": primefactors(n), "max_prime_factor": max(primefactors(n))} 

271 

272 # print the numbers that match each maximum prime factor 

273 for factor in [2, 3, 5, 7]: 

274 logging.log(logging.INFO, f"Numbers with a maximum prime factor of {factor}:") 

275 for n in range(min_number, max_number): 

276 if factors[n]["max_prime_factor"] == factor: 

277 logging.log(logging.INFO, n) 

278 

279 

280def check_divisible(number: float, divider: float) -> bool: 

281 """ 

282 Checks whether number is divisible by divider without any remainder 

283 Why do we need such a function? -> Because due to floating point precision we 

284 experience rounding errors while using standard modulo operator with floating point numbers 

285 

286 Args: 

287 number: Number that's supposed to be divided 

288 divider: Divider that should divide the number 

289 

290 Returns: 

291 True if number is divisible by divider, False otherwise 

292 

293 """ 

294 

295 result = number / divider 

296 after_decimal = result % 1 

297 return after_decimal == 0