Coverage for kwave/ksource.py: 19%

125 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import logging 

2from dataclasses import dataclass 

3 

4import numpy as np 

5 

6from kwave.kgrid import kWaveGrid 

7from kwave.utils.matrix import num_dim2 

8 

9 

10@dataclass 

11class kSource(object): 

12 _p0 = None 

13 #: time varying pressure at each of the source positions given by source.p_mask 

14 p = None 

15 #: binary matrix specifying the positions of the time varying pressure source distribution 

16 p_mask = None 

17 #: optional input to control whether the input pressure is injected as a mass source or enforced 

18 # as a dirichlet boundary condition; valid inputs are 'additive' (the default) or 'dirichlet' 

19 p_mode = None 

20 #: Pressure reference frequency 

21 p_frequency_ref = None 

22 

23 #: time varying particle velocity in the x-direction at each of the source positions given by source.u_mask 

24 ux = None 

25 #: time varying particle velocity in the y-direction at each of the source positions given by source.u_mask 

26 uy = None 

27 #: time varying particle velocity in the z-direction at each of the source positions given by source.u_mask 

28 uz = None 

29 #: binary matrix specifying the positions of the time varying particle velocity distribution 

30 u_mask = None 

31 #: optional input to control whether the input velocity is applied as a force source or enforced as a dirichlet 

32 # boundary condition; valid inputs are 'additive' (the default) or 'dirichlet' 

33 u_mode = None 

34 #: Velocity reference frequency 

35 u_frequency_ref = None 

36 

37 sxx = None #: Stress source in x -> x direction 

38 syy = None #: Stress source in y -> y direction 

39 szz = None #: Stress source in z -> z direction 

40 sxy = None #: Stress source in x -> y direction 

41 sxz = None #: Stress source in x -> z direction 

42 syz = None #: Stress source in y -> z direction 

43 s_mask = None #: Stress source mask 

44 s_mode = None #: Stress source mode 

45 

46 def is_p0_empty(self) -> bool: 

47 """ 

48 Check if the `p0` field is set and not empty 

49 """ 

50 return self.p0 is None or len(self.p0) == 0 or (np.sum(self.p0 != 0) == 0) 

51 

52 @property 

53 def p0(self): 

54 """ 

55 Initial pressure within the acoustic medium 

56 """ 

57 return self._p0 

58 

59 @p0.setter 

60 def p0(self, val): 

61 # check size and contents 

62 if len(val) == 0: 

63 # if the initial pressure is empty, remove field 

64 self._p0 = None 

65 else: 

66 self._p0 = val 

67 

68 def validate(self, kgrid: kWaveGrid) -> None: 

69 """ 

70 Validate the object fields for correctness 

71 

72 Args: 

73 kgrid: Instance of `~kwave.kgrid.kWaveGrid` class 

74 

75 Returns: 

76 None 

77 """ 

78 if self.p0 is not None: 

79 if self.p0.shape != kgrid.k.shape: 

80 # throw an error if p0 is not the correct size 

81 raise ValueError("source.p0 must be the same size as the computational grid.") 

82 

83 # if using the elastic code, reformulate source.p0 in terms of the 

84 # stress source terms using the fact that source.p = [0.5 0.5] / 

85 # (2*CFL) is the same as source.p0 = 1 

86 # if self.elastic_code: 

87 # raise NotImplementedError 

88 

89 # check for a time varying pressure source input 

90 if self.p is not None: 

91 # force p_mask to be given if p is given 

92 assert self.p_mask is not None 

93 

94 # check mask is the correct size 

95 # noinspection PyTypeChecker 

96 if (num_dim2(self.p_mask) != kgrid.dim) or (self.p_mask.shape != kgrid.k.shape): 

97 raise ValueError("source.p_mask must be the same size as the computational grid.") 

98 

99 # check mask is not empty 

100 assert np.sum(self.p_mask) != 0, "source.p_mask must be a binary grid with at least one element set to 1." 

101 

102 # don't allow both source.p0 and source.p in the same simulation 

103 # USERS: please contact us via http://www.k-wave.org/forum if this 

104 # is a problem 

105 assert self.p0 is None, "source.p0 and source.p can't be defined in the same simulation." 

106 

107 # check the source mode input is valid 

108 if self.p_mode is not None: 

109 assert self.p_mode in [ 

110 "additive", 

111 "dirichlet", 

112 "additive-no-correction", 

113 ], "source.p_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''." 

114 

115 # check if a reference frequency is defined 

116 if self.p_frequency_ref is not None: 

117 # check frequency is a scalar, positive number 

118 assert np.isscalar(self.p_frequency_ref) and self.p_frequency_ref > 0 

119 

120 # check frequency is within range 

121 assert self.p_frequency_ref <= kgrid.k_max_all * np.min( 

122 self.medium.sound_speed / 2 * np.pi 

123 ), "source.p_frequency_ref is higher than the maximum frequency supported by the spatial grid." 

124 

125 # change source mode to no include k-space correction 

126 self.p_mode = "additive-no-correction" 

127 

128 if len(self.p[0]) > kgrid.Nt: 

129 logging.log(logging.WARN, " source.p has more time points than kgrid.Nt, remaining time points will not be used.") 

130 

131 # check if the mask is binary or labelled 

132 p_unique = np.unique(self.p_mask) 

133 

134 # create a second indexing variable 

135 if p_unique.size <= 2 and p_unique.sum() == 1: 

136 # if more than one time series is given, check the number of time 

137 # series given matches the number of source elements, or the number 

138 # of labelled sources 

139 if self.p.shape[0] > 1 and (len(self.p[:, 0]) != self.p_mask.sum()): 

140 raise ValueError("The number of time series in source.p " "must match the number of source elements in source.p_mask.") 

141 else: 

142 # check the source labels are monotonic, and start from 1 

143 if (sum(p_unique[1:] - p_unique[:-1]) != len(p_unique) - 1) or (not any(p_unique == 1)): 

144 raise ValueError( 

145 "If using a labelled source.p_mask, " "the source labels must be monotonically increasing and start from 1." 

146 ) 

147 # make sure the correct number of input signals are given 

148 if np.size(self.p, 1) != (np.size(p_unique) - 1): 

149 raise ValueError( 

150 "The number of time series in source.p " "must match the number of labelled source elements in source.p_mask." 

151 ) 

152 

153 # check for time varying velocity source input and set source flag 

154 if any([(getattr(self, k) is not None) for k in ["ux", "uy", "uz", "u_mask"]]): 

155 # force u_mask to be given 

156 assert self.u_mask is not None 

157 

158 # check mask is the correct size 

159 assert ( 

160 num_dim2(self.u_mask) == kgrid.dim and self.u_mask.shape == kgrid.k.shape 

161 ), "source.u_mask must be the same size as the computational grid." 

162 

163 # check mask is not empty 

164 assert np.array(self.u_mask).sum() != 0, "source.u_mask must be a binary grid with at least one element set to 1." 

165 

166 # check the source mode input is valid 

167 if self.u_mode is not None: 

168 assert self.u_mode in [ 

169 "additive", 

170 "dirichlet", 

171 "additive-no-correction", 

172 ], "source.u_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''." 

173 

174 # check if a reference frequency is defined 

175 if self.u_frequency_ref is not None: 

176 # check frequency is a scalar, positive number 

177 u_frequency_ref = self.u_frequency_ref 

178 assert np.isscalar(u_frequency_ref) and u_frequency_ref > 0 

179 

180 # check frequency is within range 

181 assert self.u_frequency_ref <= ( 

182 kgrid.k_max_all * np.min(self.medium.sound_speed) / 2 * np.pi 

183 ), "source.u_frequency_ref is higher than the maximum frequency supported by the spatial grid." 

184 

185 # change source mode to no include k-space correction 

186 self.u_mode = "additive-no-correction" 

187 

188 if self.ux is not None: 

189 if self.flag_ux > kgrid.Nt: 

190 logging.log(logging.WARN, " source.ux has more time points than kgrid.Nt, " "remaining time points will not be used.") 

191 if self.uy is not None: 

192 if self.flag_uy > kgrid.Nt: 

193 logging.log(logging.WARN, " source.uy has more time points than kgrid.Nt, " "remaining time points will not be used.") 

194 if self.uz is not None: 

195 if self.flag_uz > kgrid.Nt: 

196 logging.log(logging.WARN, " source.uz has more time points than kgrid.Nt, " "remaining time points will not be used.") 

197 

198 # check if the mask is binary or labelled 

199 u_unique = np.unique(self.u_mask) 

200 

201 # create a second indexing variable 

202 if u_unique.size <= 2 and u_unique.sum() == 1: 

203 # if more than one time series is given, check the number of time 

204 # series given matches the number of source elements 

205 ux_size = self.ux[:, 0].size 

206 uy_size = self.uy[:, 0].size if (self.uy is not None) else None 

207 uz_size = self.uz[:, 0].size if (self.uz is not None) else None 

208 u_sum = np.sum(self.u_mask) 

209 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)): 

210 if ( 

211 (self.flag_ux and (ux_size != u_sum)) 

212 and (self.flag_uy and (uy_size != u_sum)) 

213 or (self.flag_uz and (uz_size != u_sum)) 

214 ): 

215 raise ValueError( 

216 "The number of time series in source.ux (etc) " "must match the number of source elements in source.u_mask." 

217 ) 

218 

219 # if more than one time series is given, check the number of time 

220 # series given matches the number of source elements 

221 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)): 

222 if ( 

223 (self.flag_ux and (ux_size != u_sum)) 

224 or (self.flag_uy and (uy_size != u_sum)) 

225 or (self.flag_uz and (uz_size != u_sum)) 

226 ): 

227 raise ValueError( 

228 "The number of time series in source.ux (etc) " "must match the number of source elements in source.u_mask." 

229 ) 

230 else: 

231 raise NotImplementedError 

232 

233 # check the source labels are monotonic, and start from 1 

234 # if (sum(u_unique(2:end) - u_unique(1:end-1)) != (numel(u_unique) - 1)) or (~any(u_unique == 1)) 

235 if eng.eval("(sum(u_unique(2:end) - " "u_unique(1:end-1)) ~= " "(numel(u_unique) - 1)) " "|| " "(~any(u_unique == 1))"): 

236 raise ValueError( 

237 "If using a labelled source.u_mask, " "the source labels must be monotonically increasing and start from 1." 

238 ) 

239 

240 # if more than one time series is given, check the number of time 

241 # series given matches the number of source elements 

242 # if (flgs.source_ux and (size(source.ux, 1) != (numel(u_unique) - 1))) or 

243 # (flgs.source_uy and (size(source.uy, 1) != (numel(u_unique) - 1))) or 

244 # (flgs.source_uz and (size(source.uz, 1) != (numel(u_unique) - 1))) 

245 if eng.eval( 

246 "(flgs.source_ux && (size(source.ux, 1) ~= (numel(u_unique) - 1))) " 

247 "|| (flgs.source_uy && (size(source.uy, 1) ~= (numel(u_unique) - 1))) " 

248 "|| " 

249 "(flgs.source_uz && (size(source.uz, 1) ~= (numel(u_unique) - 1)))" 

250 ): 

251 raise ValueError( 

252 "The number of time series in source.ux (etc) " 

253 "must match the number of labelled source elements in source.u_mask." 

254 ) 

255 

256 # check for time varying stress source input and set source flag 

257 if any([(getattr(self, k) is not None) for k in ["sxx", "syy", "szz", "sxy", "sxz", "syz", "s_mask"]]): 

258 # force s_mask to be given 

259 enforce_fields(self, "s_mask") 

260 

261 # check mask is the correct size 

262 # if (numDim(source.s_mask) != kgrid.dim) or (all(size(source.s_mask) != size(kgrid.k))) 

263 if eng.eval("(numDim(source.s_mask) ~= kgrid.dim) || (all(size(source.s_mask) ~= size(kgrid.k)))"): 

264 raise ValueError("source.s_mask must be the same size as the computational grid.") 

265 

266 # check mask is not empty 

267 assert np.array(eng.getfield(source, "s_mask")) != 0, "source.s_mask must be a binary grid with at least one element set to 1." 

268 

269 # check the source mode input is valid 

270 if eng.isfield(source, "s_mode"): 

271 assert eng.getfield(source, "s_mode") in [ 

272 "additive", 

273 "dirichlet", 

274 ], "source.s_mode must be set to ''additive'' or ''dirichlet''." 

275 else: 

276 eng.setfield(source, "s_mode", self.SOURCE_S_MODE_DEF) 

277 

278 # set source flgs to the length of the sources, this allows the 

279 # inputs to be defined independently and be of any length 

280 if self.sxx is not None and self_sxx > k_Nt: 

281 logging.log(logging.WARN, " source.sxx has more time points than kgrid.Nt," " remaining time points will not be used.") 

282 if self.syy is not None and self_syy > k_Nt: 

283 logging.log(logging.WARN, " source.syy has more time points than kgrid.Nt," " remaining time points will not be used.") 

284 if self.szz is not None and self_szz > k_Nt: 

285 logging.log(logging.WARN, " source.szz has more time points than kgrid.Nt," " remaining time points will not be used.") 

286 if self.sxy is not None and self_sxy > k_Nt: 

287 logging.log(logging.WARN, " source.sxy has more time points than kgrid.Nt," " remaining time points will not be used.") 

288 if self.sxz is not None and self_sxz > k_Nt: 

289 logging.log(logging.WARN, " source.sxz has more time points than kgrid.Nt," " remaining time points will not be used.") 

290 if self.syz is not None and self_syz > k_Nt: 

291 logging.log(logging.WARN, " source.syz has more time points than kgrid.Nt," " remaining time points will not be used.") 

292 

293 # create an indexing variable corresponding to the location of all 

294 # the source elements 

295 raise NotImplementedError 

296 

297 # check if the mask is binary or labelled 

298 "s_unique = unique(source.s_mask);" 

299 

300 # create a second indexing variable 

301 if eng.eval("numel(s_unique) <= 2 && sum(s_unique) == 1"): 

302 s_mask = eng.getfield(source, "s_mask") 

303 s_mask_sum = np.array(s_mask).sum() 

304 

305 # if more than one time series is given, check the number of time 

306 # series given matches the number of source elements 

307 if ( 

308 (self.source_sxx and (eng.eval("length(source.sxx(:,1)) > 1))"))) 

309 or (self.source_syy and (eng.eval("length(source.syy(:,1)) > 1))"))) 

310 or (self.source_szz and (eng.eval("length(source.szz(:,1)) > 1))"))) 

311 or (self.source_sxy and (eng.eval("length(source.sxy(:,1)) > 1))"))) 

312 or (self.source_sxz and (eng.eval("length(source.sxz(:,1)) > 1))"))) 

313 or (self.source_syz and (eng.eval("length(source.syz(:,1)) > 1))"))) 

314 ): 

315 if ( 

316 (self.source_sxx and (eng.eval("length(source.sxx(:,1))") != s_mask_sum)) 

317 or (self.source_syy and (eng.eval("length(source.syy(:,1))") != s_mask_sum)) 

318 or (self.source_szz and (eng.eval("length(source.szz(:,1))") != s_mask_sum)) 

319 or (self.source_sxy and (eng.eval("length(source.sxy(:,1))") != s_mask_sum)) 

320 or (self.source_sxz and (eng.eval("length(source.sxz(:,1))") != s_mask_sum)) 

321 or (self.source_syz and (eng.eval("length(source.syz(:,1))") != s_mask_sum)) 

322 ): 

323 raise ValueError( 

324 "The number of time series in source.sxx (etc) " "must match the number of source elements in source.s_mask." 

325 ) 

326 

327 else: 

328 # check the source labels are monotonic, and start from 1 

329 # if (sum(s_unique(2:end) - s_unique(1:end-1)) != (numel(s_unique) - 1)) or (~any(s_unique == 1)) 

330 if eng.eval("(sum(s_unique(2:end) - s_unique(1:end-1)) ~= " "(numel(s_unique) - 1)) || (~any(s_unique == 1))"): 

331 raise ValueError( 

332 "If using a labelled source.s_mask, " "the source labels must be monotonically increasing and start from 1." 

333 ) 

334 

335 numel_s_unique = eng.eval("numel(s_unique) - 1;") 

336 # if more than one time series is given, check the number of time 

337 # series given matches the number of source elements 

338 if ( 

339 (self.source_sxx and (eng.eval("size(source.sxx, 1)") != numel_s_unique)) 

340 or (self.source_syy and (eng.eval("size(source.syy, 1)") != numel_s_unique)) 

341 or (self.source_szz and (eng.eval("size(source.szz, 1)") != numel_s_unique)) 

342 or (self.source_sxy and (eng.eval("size(source.sxy, 1)") != numel_s_unique)) 

343 or (self.source_sxz and (eng.eval("size(source.sxz, 1)") != numel_s_unique)) 

344 or (self.source_syz and (eng.eval("size(source.syz, 1)") != numel_s_unique)) 

345 ): 

346 raise ValueError( 

347 "The number of time series in source.sxx (etc) " 

348 "must match the number of labelled source elements in source.u_mask." 

349 ) 

350 

351 @property 

352 def flag_ux(self): 

353 """ 

354 Get the length of the sources in X-direction, this allows the 

355 inputs to be defined independently and be of any length 

356 

357 Returns: 

358 Length of the sources 

359 """ 

360 return len(self.ux[0]) if self.ux is not None else 0 

361 

362 @property 

363 def flag_uy(self): 

364 """ 

365 Get the length of the sources in X-direction, this allows the 

366 inputs to be defined independently and be of any length 

367 

368 Returns: 

369 Length of the sources 

370 """ 

371 return len(self.uy[0]) if self.uy is not None else 0 

372 

373 @property 

374 def flag_uz(self): 

375 """ 

376 Get the length of the sources in X-direction, this allows the 

377 inputs to be defined independently and be of any length 

378 

379 Returns: 

380 Length of the sources 

381 """ 

382 return len(self.uz[0]) if self.uz is not None else 0