Coverage for kwave/ksource.py: 19%
125 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1import logging
2from dataclasses import dataclass
4import numpy as np
6from kwave.kgrid import kWaveGrid
7from kwave.utils.matrix import num_dim2
10@dataclass
11class kSource(object):
12 _p0 = None
13 #: time varying pressure at each of the source positions given by source.p_mask
14 p = None
15 #: binary matrix specifying the positions of the time varying pressure source distribution
16 p_mask = None
17 #: optional input to control whether the input pressure is injected as a mass source or enforced
18 # as a dirichlet boundary condition; valid inputs are 'additive' (the default) or 'dirichlet'
19 p_mode = None
20 #: Pressure reference frequency
21 p_frequency_ref = None
23 #: time varying particle velocity in the x-direction at each of the source positions given by source.u_mask
24 ux = None
25 #: time varying particle velocity in the y-direction at each of the source positions given by source.u_mask
26 uy = None
27 #: time varying particle velocity in the z-direction at each of the source positions given by source.u_mask
28 uz = None
29 #: binary matrix specifying the positions of the time varying particle velocity distribution
30 u_mask = None
31 #: optional input to control whether the input velocity is applied as a force source or enforced as a dirichlet
32 # boundary condition; valid inputs are 'additive' (the default) or 'dirichlet'
33 u_mode = None
34 #: Velocity reference frequency
35 u_frequency_ref = None
37 sxx = None #: Stress source in x -> x direction
38 syy = None #: Stress source in y -> y direction
39 szz = None #: Stress source in z -> z direction
40 sxy = None #: Stress source in x -> y direction
41 sxz = None #: Stress source in x -> z direction
42 syz = None #: Stress source in y -> z direction
43 s_mask = None #: Stress source mask
44 s_mode = None #: Stress source mode
46 def is_p0_empty(self) -> bool:
47 """
48 Check if the `p0` field is set and not empty
49 """
50 return self.p0 is None or len(self.p0) == 0 or (np.sum(self.p0 != 0) == 0)
52 @property
53 def p0(self):
54 """
55 Initial pressure within the acoustic medium
56 """
57 return self._p0
59 @p0.setter
60 def p0(self, val):
61 # check size and contents
62 if len(val) == 0:
63 # if the initial pressure is empty, remove field
64 self._p0 = None
65 else:
66 self._p0 = val
68 def validate(self, kgrid: kWaveGrid) -> None:
69 """
70 Validate the object fields for correctness
72 Args:
73 kgrid: Instance of `~kwave.kgrid.kWaveGrid` class
75 Returns:
76 None
77 """
78 if self.p0 is not None:
79 if self.p0.shape != kgrid.k.shape:
80 # throw an error if p0 is not the correct size
81 raise ValueError("source.p0 must be the same size as the computational grid.")
83 # if using the elastic code, reformulate source.p0 in terms of the
84 # stress source terms using the fact that source.p = [0.5 0.5] /
85 # (2*CFL) is the same as source.p0 = 1
86 # if self.elastic_code:
87 # raise NotImplementedError
89 # check for a time varying pressure source input
90 if self.p is not None:
91 # force p_mask to be given if p is given
92 assert self.p_mask is not None
94 # check mask is the correct size
95 # noinspection PyTypeChecker
96 if (num_dim2(self.p_mask) != kgrid.dim) or (self.p_mask.shape != kgrid.k.shape):
97 raise ValueError("source.p_mask must be the same size as the computational grid.")
99 # check mask is not empty
100 assert np.sum(self.p_mask) != 0, "source.p_mask must be a binary grid with at least one element set to 1."
102 # don't allow both source.p0 and source.p in the same simulation
103 # USERS: please contact us via http://www.k-wave.org/forum if this
104 # is a problem
105 assert self.p0 is None, "source.p0 and source.p can't be defined in the same simulation."
107 # check the source mode input is valid
108 if self.p_mode is not None:
109 assert self.p_mode in [
110 "additive",
111 "dirichlet",
112 "additive-no-correction",
113 ], "source.p_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''."
115 # check if a reference frequency is defined
116 if self.p_frequency_ref is not None:
117 # check frequency is a scalar, positive number
118 assert np.isscalar(self.p_frequency_ref) and self.p_frequency_ref > 0
120 # check frequency is within range
121 assert self.p_frequency_ref <= kgrid.k_max_all * np.min(
122 self.medium.sound_speed / 2 * np.pi
123 ), "source.p_frequency_ref is higher than the maximum frequency supported by the spatial grid."
125 # change source mode to no include k-space correction
126 self.p_mode = "additive-no-correction"
128 if len(self.p[0]) > kgrid.Nt:
129 logging.log(logging.WARN, " source.p has more time points than kgrid.Nt, remaining time points will not be used.")
131 # check if the mask is binary or labelled
132 p_unique = np.unique(self.p_mask)
134 # create a second indexing variable
135 if p_unique.size <= 2 and p_unique.sum() == 1:
136 # if more than one time series is given, check the number of time
137 # series given matches the number of source elements, or the number
138 # of labelled sources
139 if self.p.shape[0] > 1 and (len(self.p[:, 0]) != self.p_mask.sum()):
140 raise ValueError("The number of time series in source.p " "must match the number of source elements in source.p_mask.")
141 else:
142 # check the source labels are monotonic, and start from 1
143 if (sum(p_unique[1:] - p_unique[:-1]) != len(p_unique) - 1) or (not any(p_unique == 1)):
144 raise ValueError(
145 "If using a labelled source.p_mask, " "the source labels must be monotonically increasing and start from 1."
146 )
147 # make sure the correct number of input signals are given
148 if np.size(self.p, 1) != (np.size(p_unique) - 1):
149 raise ValueError(
150 "The number of time series in source.p " "must match the number of labelled source elements in source.p_mask."
151 )
153 # check for time varying velocity source input and set source flag
154 if any([(getattr(self, k) is not None) for k in ["ux", "uy", "uz", "u_mask"]]):
155 # force u_mask to be given
156 assert self.u_mask is not None
158 # check mask is the correct size
159 assert (
160 num_dim2(self.u_mask) == kgrid.dim and self.u_mask.shape == kgrid.k.shape
161 ), "source.u_mask must be the same size as the computational grid."
163 # check mask is not empty
164 assert np.array(self.u_mask).sum() != 0, "source.u_mask must be a binary grid with at least one element set to 1."
166 # check the source mode input is valid
167 if self.u_mode is not None:
168 assert self.u_mode in [
169 "additive",
170 "dirichlet",
171 "additive-no-correction",
172 ], "source.u_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''."
174 # check if a reference frequency is defined
175 if self.u_frequency_ref is not None:
176 # check frequency is a scalar, positive number
177 u_frequency_ref = self.u_frequency_ref
178 assert np.isscalar(u_frequency_ref) and u_frequency_ref > 0
180 # check frequency is within range
181 assert self.u_frequency_ref <= (
182 kgrid.k_max_all * np.min(self.medium.sound_speed) / 2 * np.pi
183 ), "source.u_frequency_ref is higher than the maximum frequency supported by the spatial grid."
185 # change source mode to no include k-space correction
186 self.u_mode = "additive-no-correction"
188 if self.ux is not None:
189 if self.flag_ux > kgrid.Nt:
190 logging.log(logging.WARN, " source.ux has more time points than kgrid.Nt, " "remaining time points will not be used.")
191 if self.uy is not None:
192 if self.flag_uy > kgrid.Nt:
193 logging.log(logging.WARN, " source.uy has more time points than kgrid.Nt, " "remaining time points will not be used.")
194 if self.uz is not None:
195 if self.flag_uz > kgrid.Nt:
196 logging.log(logging.WARN, " source.uz has more time points than kgrid.Nt, " "remaining time points will not be used.")
198 # check if the mask is binary or labelled
199 u_unique = np.unique(self.u_mask)
201 # create a second indexing variable
202 if u_unique.size <= 2 and u_unique.sum() == 1:
203 # if more than one time series is given, check the number of time
204 # series given matches the number of source elements
205 ux_size = self.ux[:, 0].size
206 uy_size = self.uy[:, 0].size if (self.uy is not None) else None
207 uz_size = self.uz[:, 0].size if (self.uz is not None) else None
208 u_sum = np.sum(self.u_mask)
209 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)):
210 if (
211 (self.flag_ux and (ux_size != u_sum))
212 and (self.flag_uy and (uy_size != u_sum))
213 or (self.flag_uz and (uz_size != u_sum))
214 ):
215 raise ValueError(
216 "The number of time series in source.ux (etc) " "must match the number of source elements in source.u_mask."
217 )
219 # if more than one time series is given, check the number of time
220 # series given matches the number of source elements
221 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)):
222 if (
223 (self.flag_ux and (ux_size != u_sum))
224 or (self.flag_uy and (uy_size != u_sum))
225 or (self.flag_uz and (uz_size != u_sum))
226 ):
227 raise ValueError(
228 "The number of time series in source.ux (etc) " "must match the number of source elements in source.u_mask."
229 )
230 else:
231 raise NotImplementedError
233 # check the source labels are monotonic, and start from 1
234 # if (sum(u_unique(2:end) - u_unique(1:end-1)) != (numel(u_unique) - 1)) or (~any(u_unique == 1))
235 if eng.eval("(sum(u_unique(2:end) - " "u_unique(1:end-1)) ~= " "(numel(u_unique) - 1)) " "|| " "(~any(u_unique == 1))"):
236 raise ValueError(
237 "If using a labelled source.u_mask, " "the source labels must be monotonically increasing and start from 1."
238 )
240 # if more than one time series is given, check the number of time
241 # series given matches the number of source elements
242 # if (flgs.source_ux and (size(source.ux, 1) != (numel(u_unique) - 1))) or
243 # (flgs.source_uy and (size(source.uy, 1) != (numel(u_unique) - 1))) or
244 # (flgs.source_uz and (size(source.uz, 1) != (numel(u_unique) - 1)))
245 if eng.eval(
246 "(flgs.source_ux && (size(source.ux, 1) ~= (numel(u_unique) - 1))) "
247 "|| (flgs.source_uy && (size(source.uy, 1) ~= (numel(u_unique) - 1))) "
248 "|| "
249 "(flgs.source_uz && (size(source.uz, 1) ~= (numel(u_unique) - 1)))"
250 ):
251 raise ValueError(
252 "The number of time series in source.ux (etc) "
253 "must match the number of labelled source elements in source.u_mask."
254 )
256 # check for time varying stress source input and set source flag
257 if any([(getattr(self, k) is not None) for k in ["sxx", "syy", "szz", "sxy", "sxz", "syz", "s_mask"]]):
258 # force s_mask to be given
259 enforce_fields(self, "s_mask")
261 # check mask is the correct size
262 # if (numDim(source.s_mask) != kgrid.dim) or (all(size(source.s_mask) != size(kgrid.k)))
263 if eng.eval("(numDim(source.s_mask) ~= kgrid.dim) || (all(size(source.s_mask) ~= size(kgrid.k)))"):
264 raise ValueError("source.s_mask must be the same size as the computational grid.")
266 # check mask is not empty
267 assert np.array(eng.getfield(source, "s_mask")) != 0, "source.s_mask must be a binary grid with at least one element set to 1."
269 # check the source mode input is valid
270 if eng.isfield(source, "s_mode"):
271 assert eng.getfield(source, "s_mode") in [
272 "additive",
273 "dirichlet",
274 ], "source.s_mode must be set to ''additive'' or ''dirichlet''."
275 else:
276 eng.setfield(source, "s_mode", self.SOURCE_S_MODE_DEF)
278 # set source flgs to the length of the sources, this allows the
279 # inputs to be defined independently and be of any length
280 if self.sxx is not None and self_sxx > k_Nt:
281 logging.log(logging.WARN, " source.sxx has more time points than kgrid.Nt," " remaining time points will not be used.")
282 if self.syy is not None and self_syy > k_Nt:
283 logging.log(logging.WARN, " source.syy has more time points than kgrid.Nt," " remaining time points will not be used.")
284 if self.szz is not None and self_szz > k_Nt:
285 logging.log(logging.WARN, " source.szz has more time points than kgrid.Nt," " remaining time points will not be used.")
286 if self.sxy is not None and self_sxy > k_Nt:
287 logging.log(logging.WARN, " source.sxy has more time points than kgrid.Nt," " remaining time points will not be used.")
288 if self.sxz is not None and self_sxz > k_Nt:
289 logging.log(logging.WARN, " source.sxz has more time points than kgrid.Nt," " remaining time points will not be used.")
290 if self.syz is not None and self_syz > k_Nt:
291 logging.log(logging.WARN, " source.syz has more time points than kgrid.Nt," " remaining time points will not be used.")
293 # create an indexing variable corresponding to the location of all
294 # the source elements
295 raise NotImplementedError
297 # check if the mask is binary or labelled
298 "s_unique = unique(source.s_mask);"
300 # create a second indexing variable
301 if eng.eval("numel(s_unique) <= 2 && sum(s_unique) == 1"):
302 s_mask = eng.getfield(source, "s_mask")
303 s_mask_sum = np.array(s_mask).sum()
305 # if more than one time series is given, check the number of time
306 # series given matches the number of source elements
307 if (
308 (self.source_sxx and (eng.eval("length(source.sxx(:,1)) > 1))")))
309 or (self.source_syy and (eng.eval("length(source.syy(:,1)) > 1))")))
310 or (self.source_szz and (eng.eval("length(source.szz(:,1)) > 1))")))
311 or (self.source_sxy and (eng.eval("length(source.sxy(:,1)) > 1))")))
312 or (self.source_sxz and (eng.eval("length(source.sxz(:,1)) > 1))")))
313 or (self.source_syz and (eng.eval("length(source.syz(:,1)) > 1))")))
314 ):
315 if (
316 (self.source_sxx and (eng.eval("length(source.sxx(:,1))") != s_mask_sum))
317 or (self.source_syy and (eng.eval("length(source.syy(:,1))") != s_mask_sum))
318 or (self.source_szz and (eng.eval("length(source.szz(:,1))") != s_mask_sum))
319 or (self.source_sxy and (eng.eval("length(source.sxy(:,1))") != s_mask_sum))
320 or (self.source_sxz and (eng.eval("length(source.sxz(:,1))") != s_mask_sum))
321 or (self.source_syz and (eng.eval("length(source.syz(:,1))") != s_mask_sum))
322 ):
323 raise ValueError(
324 "The number of time series in source.sxx (etc) " "must match the number of source elements in source.s_mask."
325 )
327 else:
328 # check the source labels are monotonic, and start from 1
329 # if (sum(s_unique(2:end) - s_unique(1:end-1)) != (numel(s_unique) - 1)) or (~any(s_unique == 1))
330 if eng.eval("(sum(s_unique(2:end) - s_unique(1:end-1)) ~= " "(numel(s_unique) - 1)) || (~any(s_unique == 1))"):
331 raise ValueError(
332 "If using a labelled source.s_mask, " "the source labels must be monotonically increasing and start from 1."
333 )
335 numel_s_unique = eng.eval("numel(s_unique) - 1;")
336 # if more than one time series is given, check the number of time
337 # series given matches the number of source elements
338 if (
339 (self.source_sxx and (eng.eval("size(source.sxx, 1)") != numel_s_unique))
340 or (self.source_syy and (eng.eval("size(source.syy, 1)") != numel_s_unique))
341 or (self.source_szz and (eng.eval("size(source.szz, 1)") != numel_s_unique))
342 or (self.source_sxy and (eng.eval("size(source.sxy, 1)") != numel_s_unique))
343 or (self.source_sxz and (eng.eval("size(source.sxz, 1)") != numel_s_unique))
344 or (self.source_syz and (eng.eval("size(source.syz, 1)") != numel_s_unique))
345 ):
346 raise ValueError(
347 "The number of time series in source.sxx (etc) "
348 "must match the number of labelled source elements in source.u_mask."
349 )
351 @property
352 def flag_ux(self):
353 """
354 Get the length of the sources in X-direction, this allows the
355 inputs to be defined independently and be of any length
357 Returns:
358 Length of the sources
359 """
360 return len(self.ux[0]) if self.ux is not None else 0
362 @property
363 def flag_uy(self):
364 """
365 Get the length of the sources in X-direction, this allows the
366 inputs to be defined independently and be of any length
368 Returns:
369 Length of the sources
370 """
371 return len(self.uy[0]) if self.uy is not None else 0
373 @property
374 def flag_uz(self):
375 """
376 Get the length of the sources in X-direction, this allows the
377 inputs to be defined independently and be of any length
379 Returns:
380 Length of the sources
381 """
382 return len(self.uz[0]) if self.uz is not None else 0