Coverage for kwave/utils/checks.py: 19%
93 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1import logging
2import numbers
3import platform
4from copy import deepcopy
5from typing import TYPE_CHECKING, Any, List
7import numpy as np
8import scipy
9import scipy.optimize
11if TYPE_CHECKING: 11 ↛ 13line 11 didn't jump to line 13 because the condition on line 11 was never true
12 # Found here: https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
13 from kwave.kgrid import kWaveGrid
14 from kwave.kmedium import kWaveMedium
16from .conversion import db2neper
17from .math import primefactors, sinc
20def enforce_fields(dictionary, *fields):
21 """
22 Ensures that the given dictionary contains the specified fields.
24 Args:
25 dictionary: A dictionary to check.
26 *fields: The fields that must be present in the dictionary.
28 Raises:
29 AssertionError: If any of the specified fields are not in the dictionary.
31 """
33 for f in fields:
34 assert f in dictionary.keys(), [f"The field {f} must be defined in the given dictionary"]
37def enforce_fields_obj(obj, *fields):
38 """
39 Enforces that certain fields are not None in the given object.
41 Args:
42 obj: Object to check the fields of.
43 *fields: List of field names to check.
45 Raises:
46 AssertionError: If any of the given fields are None in the given object.
48 """
50 for f in fields:
51 assert getattr(obj, f) is not None, f"The field {f} must be not None in the given object"
54def check_field_names(dictionary, *fields):
55 """
56 This method checks if the keys of the given dictionary are valid fields.
58 Args:
59 dictionary: A dictionary where the keys will be checked for validity.
60 *fields: A list of valid field names.
62 Raises:
63 AssertionError: If any of the keys in the dictionary are not in the list of valid fields.
65 """
67 for k in dictionary.keys():
68 assert k in fields, f"The field {k} is not a valid field for the given dictionary"
71def check_str_eq(value, target: str) -> bool:
72 """
73 This method checks whether the given value is a string and is equal to the target string.
74 It is useful to avoid FutureWarnings when value is not a string.
76 Args:
77 value: The value to check.
78 target: The target string to compare with.
80 Returns:
81 True if the value is a string and is equal to the target, False otherwise.
83 """
85 return isinstance(value, str) and value == target
88def check_str_in(value, target: List[str]) -> bool:
89 """
90 Check if value is in the given list only if the value is string.
91 Helps to avoid FutureWarnings when value is not a string.
92 Added by @Farid
94 Args:
95 value: The value to check for inclusion in `target`
96 target: A list of strings to check for the presence of `value`
98 Returns:
99 True if `value` is a string and is present in `target`, otherwise False
101 """
103 return isinstance(value, str) and value in target
106def is_number(value: Any) -> bool:
107 """
108 Check if the given value is a numeric type.
110 Args:
111 value: The value to check.
113 Returns:
114 True if the value is numeric, False otherwise.
116 """
118 if value is None:
119 return False
120 if isinstance(value, (int, float)):
121 return True
122 if isinstance(value, str):
123 return False
124 if isinstance(value, np.ndarray):
125 return np.issubdtype(value.dtype, np.number)
126 return np.issubdtype(np.array(value), np.number)
129def is_unix() -> bool:
130 """
131 Check whether the current platform is a Unix-like system.
133 Returns:
134 True if the current platform is a Unix-like system, False otherwise.
136 """
137 return platform.system() in ["Linux", "Darwin"]
140def _evaluate_absorbing_dt_stability_limit(kmax, c_ref, medium: "kWaveMedium", xtol=1e-12) -> float:
141 # convert the absorption coefficient to nepers.(rad/s)^-y.m^-1
142 alpha_coeff = db2neper(medium.alpha_coeff, medium.alpha_power)
144 # calculate the absorption constant
145 # calculate the absorption constant
146 if medium.alpha_mode != "no_absorption":
147 absorb_tau = -2.0 * alpha_coeff * medium.sound_speed ** (medium.alpha_power - 1.0)
148 else:
149 absorb_tau = np.array([0])
151 # calculate the dispersion constant
152 if medium.alpha_mode != "no_dispersion":
153 absorb_eta = 2.0 * alpha_coeff * medium.sound_speed ** (medium.alpha_power) * np.tan(np.pi * medium.alpha_power / 2.0)
154 else:
155 absorb_eta = np.array([0])
157 # estimate the timestep required for stability in the absorbing case by
158 # assuming the k-space correction factor, kappa = 1 (note that
159 # absorb_tau and absorb_eta are negative quantities)
160 medium.sound_speed = np.atleast_1d(medium.sound_speed)
162 temp1 = 1 - absorb_eta.min() * kmax ** (medium.alpha_power - 1)
164 def kappa(dt):
165 return sinc(c_ref * kmax * dt / 2.0)
167 def temp2(dt):
168 return medium.sound_speed.max() * absorb_tau.min() * kappa(dt) * kmax ** (medium.alpha_power - 1)
170 def func_to_solve(dt):
171 return (temp2(dt) + np.sqrt((temp2(dt)) ** 2.0 + 4.0 * temp1)) / (temp1 * kmax * kappa(dt) * medium.sound_speed.max())
173 dt_start = func_to_solve(0)
175 dt_stability_limit = scipy.optimize.fixed_point(func_to_solve, dt_start, xtol=xtol)
176 return dt_stability_limit
179def _evaluate_non_absorbing_dt_stability_limit(kmax, c_ref, medium: "kWaveMedium") -> float:
180 if c_ref >= medium.sound_speed.max():
181 # set the timestep to Inf when the model is unconditionally stable
182 dt_stability_limit = float("inf")
183 else:
184 # set the timestep required for stability when c_ref~=max(medium.sound_speed(:))
185 dt_stability_limit = 2.0 / (c_ref * kmax) * np.arcsin(c_ref / medium.sound_speed.max())
186 return dt_stability_limit
189def check_stability(kgrid: "kWaveGrid", medium: "kWaveMedium") -> float:
190 """
191 Calculates the maximum time step for which the k-space
192 propagation models are stable.
194 These models are unconditionally
195 stable when the reference sound speed is equal to or greater than the
196 maximum sound speed in the medium and there is no absorption.
197 However, when the reference sound speed is less than the maximum
198 sound speed the model is only stable for sufficiently small time
199 steps. The criterion is more stringent (the time step is smaller) in
200 the absorbing case.
202 The time steps given are accurate when the medium properties are
203 homogeneous. For a heterogeneous media they give a useful, but not
204 exact, estimate.
206 Args:
207 kgrid: simulation grid
208 medium: medium properties
210 Returns:
211 The maximum time step for which the models are stable. This is set to Inf when the model is unconditionally stable.
213 """
215 # why? : this function was migrated from Matlab.
216 # Matlab would treat the 'medium' as a "pass by value" argument.
217 # In python argument is passed by reference and changes in this function will cause original data to be changed.
218 # Instead of making significant changes to the function, we make a deep copy of the argument
219 medium = deepcopy(medium)
221 # find the maximum wavenumber
222 kmax = kgrid.k.max()
224 # calculate the reference sound speed for the fluid code, using the
225 # maximum by default which ensures the model is unconditionally stable
226 reductions = {"min": np.min, "max": np.max, "mean": np.mean}
228 # TODO: move this logic to medium
229 if medium.sound_speed_ref is not None:
230 ss_ref = medium.sound_speed_ref
231 if isinstance(ss_ref, numbers.Number):
232 c_ref = ss_ref
233 else:
234 try:
235 c_ref = reductions[ss_ref](medium.sound_speed)
236 except KeyError:
237 raise NotImplementedError(f"Unknown value of {ss_ref} for medium.sound_speed_ref.")
238 else:
239 c_ref = reductions["max"](medium.sound_speed)
241 medium.sound_speed = np.atleast_1d(medium.sound_speed)
242 # calculate the timesteps required for stability
243 if medium.alpha_coeff is None or np.all(medium.alpha_coeff == 0):
244 dt_stability_limit = _evaluate_non_absorbing_dt_stability_limit(kmax, c_ref, medium)
245 else:
246 dt_stability_limit = _evaluate_absorbing_dt_stability_limit(kmax, c_ref, medium)
248 return dt_stability_limit
251def check_factors(min_number: int, max_number: int) -> None:
252 """
253 Return the maximum prime factor for a range of numbers.
255 checkFactors loops through the given range of numbers and finds the
256 numbers with the smallest maximum prime factors. This allows suitable
257 grid sizes to be selected to maximise the speed of the FFT (this is
258 fastest for FFT lengths with small prime factors). The output is
259 printed to the command line.
261 Args:
262 min_number: integer specifying the lower bound of values to test
263 max_number: integer specifying the upper bound of values to test
265 """
267 # compute the factors and maximum prime factors for each number in the range
268 factors = {}
269 for n in range(min_number, max_number):
270 factors[n] = {"factors": primefactors(n), "max_prime_factor": max(primefactors(n))}
272 # print the numbers that match each maximum prime factor
273 for factor in [2, 3, 5, 7]:
274 logging.log(logging.INFO, f"Numbers with a maximum prime factor of {factor}:")
275 for n in range(min_number, max_number):
276 if factors[n]["max_prime_factor"] == factor:
277 logging.log(logging.INFO, n)
280def check_divisible(number: float, divider: float) -> bool:
281 """
282 Checks whether number is divisible by divider without any remainder
283 Why do we need such a function? -> Because due to floating point precision we
284 experience rounding errors while using standard modulo operator with floating point numbers
286 Args:
287 number: Number that's supposed to be divided
288 divider: Divider that should divide the number
290 Returns:
291 True if number is divisible by divider, False otherwise
293 """
295 result = number / divider
296 after_decimal = result % 1
297 return after_decimal == 0