Coverage for kwave/kWaveSimulation_helper/save_to_disk_func.py: 10%

210 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import logging 

2import os 

3 

4import numpy as np 

5from scipy.io import savemat 

6 

7from kwave.kgrid import kWaveGrid 

8from kwave.kmedium import kWaveMedium 

9from kwave.options.simulation_options import SimulationOptions 

10from kwave.utils.data import scale_time 

11from kwave.utils.dotdictionary import dotdict 

12from kwave.utils.io import write_attributes, write_matrix 

13from kwave.utils.matrix import num_dim2 

14from kwave.utils.tictoc import TicToc 

15 

16 

17def save_to_disk_func( 

18 kgrid: kWaveGrid, medium: kWaveMedium, source, opt: SimulationOptions, auto_chunk: bool, values: dotdict, flags: dotdict 

19): 

20 # update command line status 

21 logging.log(logging.INFO, " precomputation completed in ", scale_time(TicToc.toc())) 

22 TicToc.tic() 

23 logging.log(logging.INFO, " saving input files to disk...") 

24 

25 # check for a binary sensor mask or cuboid corners 

26 # modified by Farid | disabled temporarily! 

27 # assert self.binary_sensor_mask or self.cuboid_corners, \ 

28 # "Optional input ''save_to_disk'' only supported for sensor masks defined as a binary matrix 

29 # or the opposing corners of a rectangle (2D) or cuboid (3D)." 

30 

31 # ========================================================================= 

32 # VARIABLE LIST 

33 # ========================================================================= 

34 integer_variables = dotdict() 

35 float_variables = dotdict() 

36 

37 grab_integer_variables(integer_variables, kgrid, flags, medium) 

38 grab_pml_size(integer_variables, opt) 

39 grab_float_variables(float_variables, kgrid, opt, values, flags.elastic_code, flags.axisymmetric) 

40 

41 # overwrite z-values for 2D simulations 

42 if kgrid.dim == 2: 

43 integer_variables.Nz = 1 

44 integer_variables.pml_z_size = 0 

45 

46 grab_medium_props(integer_variables, float_variables, medium, flags.elastic_code) 

47 grab_source_props( 

48 integer_variables, 

49 float_variables, 

50 source, 

51 values.u_source_pos_index, 

52 values.s_source_pos_index, 

53 values.p_source_pos_index, 

54 values.transducer_input_signal, 

55 values.delay_mask, 

56 ) 

57 

58 grab_sensor_props(integer_variables, kgrid.dim, values.sensor_mask_index, values.record.cuboid_corners_list) 

59 grab_nonuniform_grid_props(float_variables, kgrid, flags.nonuniform_grid) 

60 

61 # ========================================================================= 

62 # DATACAST AND SAVING 

63 # ========================================================================= 

64 

65 remove_z_dimension(float_variables, kgrid.dim) 

66 save_file(opt.input_filename, integer_variables, float_variables, opt.hdf_compression_level, auto_chunk=auto_chunk) 

67 

68 # update command line status 

69 logging.log(logging.INFO, " completed in ", scale_time(TicToc.toc())) 

70 

71 

72def grab_integer_variables(integer_variables, kgrid, flags, medium): 

73 # integer variables used within the time loop for all codes 

74 

75 variables = dotdict( 

76 { 

77 "Nx": kgrid.Nx, 

78 "Ny": kgrid.Ny, 

79 "Nz": kgrid.Nz, 

80 "Nt": kgrid.Nt, 

81 "p_source_flag": flags.source_p, 

82 "p0_source_flag": flags.source_p0, 

83 "ux_source_flag": flags.source_ux, 

84 "uy_source_flag": flags.source_uy, 

85 "uz_source_flag": flags.source_uz, 

86 "sxx_source_flag": flags.source_sxx, 

87 "syy_source_flag": flags.source_syy, 

88 "szz_source_flag": flags.source_szz, 

89 "sxy_source_flag": flags.source_sxy, 

90 "sxz_source_flag": flags.source_sxz, 

91 "syz_source_flag": flags.source_syz, 

92 "transducer_source_flag": flags.transducer_source, 

93 "nonuniform_grid_flag": flags.nonuniform_grid, 

94 "nonlinear_flag": medium.is_nonlinear(), 

95 "absorbing_flag": None, 

96 "elastic_flag": flags.elastic_code, 

97 "axisymmetric_flag": flags.axisymmetric, 

98 # create pseudonyms for the sensor flgs 

99 # 0: binary mask indices 

100 # 1: cuboid corners 

101 "sensor_mask_type": flags.cuboid_corners, 

102 } 

103 ) 

104 integer_variables.update(variables) 

105 

106 

107def grab_pml_size(integer_variables, opt): 

108 # additional integer variables not used within time loop but stored directly to output file 

109 integer_variables["pml_x_size"] = opt.pml_x_size 

110 integer_variables["pml_y_size"] = opt.pml_y_size 

111 integer_variables["pml_z_size"] = opt.pml_z_size 

112 

113 

114def grab_float_variables(float_variables: dotdict, kgrid, opt, values, is_elastic_code, is_axisymmetric): 

115 # single precision variables not used within time loop but stored directly 

116 # to the output file for all files 

117 variables = dotdict( 

118 { 

119 "dx": kgrid.dx, 

120 "dy": kgrid.dy, 

121 "dz": kgrid.dz, 

122 "pml_x_alpha": opt.pml_x_alpha, 

123 "pml_y_alpha": opt.pml_y_alpha, 

124 "pml_z_alpha": opt.pml_z_alpha, 

125 } 

126 ) 

127 float_variables.update(variables) 

128 

129 if is_elastic_code: # pragma: no cover 

130 grab_elastic_code_variables(float_variables, kgrid, values) 

131 elif is_axisymmetric: 

132 grab_axisymmetric_variables(float_variables, values) 

133 else: 

134 # single precision variables used within the time loop 

135 float_variables["dt"] = values.dt 

136 float_variables["c0"] = values.c0 

137 float_variables["c_ref"] = values.c_ref 

138 float_variables["rho0"] = values.rho0 

139 float_variables["rho0_sgx"] = values.rho0_sgx 

140 float_variables["rho0_sgy"] = values.rho0_sgy 

141 float_variables["rho0_sgz"] = values.rho0_sgz 

142 

143 

144def grab_elastic_code_variables(float_variables, kgrid, values): # pragma: no cover 

145 # single precision variables used within the time loop 

146 float_variables["dt"] = None 

147 float_variables["c_ref"] = None 

148 float_variables["lambda"] = None 

149 float_variables["mu"] = None 

150 

151 float_variables["rho0_sgx"] = None 

152 float_variables["rho0_sgy"] = None 

153 float_variables["rho0_sgz"] = None 

154 

155 float_variables["mu_sgxy"] = None 

156 float_variables["mu_sgxz"] = None 

157 float_variables["mu_sgyz"] = None 

158 

159 # create shift variables used for calculating u_non_staggered and I outputs 

160 x_shift_neg = np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.x * kgrid.dx / 2)) 

161 y_shift_neg = np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.y * kgrid.dy / 2)).T 

162 z_shift_neg = np.transpose(np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.z * kgrid.dz / 2)), (1, 2, 0)) 

163 

164 # create reduced variables for use with real-to-complex FFT 

165 Nz = kgrid.Nz if kgrid.dim != 2 else 1 

166 Nx_r = kgrid.Nx // 2 + 1 

167 Ny_r = kgrid.Ny // 2 + 1 

168 Nz_r = Nz // 2 + 1 

169 

170 ddx_k_shift_pos = values.ddx_k_shift_pos 

171 ddx_k_shift_neg = values.ddx_k_shift_neg 

172 

173 float_variables["ddx_k_shift_pos_r"] = ddx_k_shift_pos[:Nx_r] 

174 float_variables["ddy_k_shift_pos"] = None 

175 float_variables["ddz_k_shift_pos"] = None 

176 

177 float_variables["ddx_k_shift_neg_r"] = ddx_k_shift_neg[:Nx_r] 

178 float_variables["ddy_k_shift_neg"] = None 

179 float_variables["ddz_k_shift_neg"] = None 

180 

181 float_variables["x_shift_neg_r"] = x_shift_neg[:Nx_r] 

182 float_variables["y_shift_neg_r"] = y_shift_neg[:Ny_r] 

183 float_variables["z_shift_neg_r"] = z_shift_neg[:Nz_r] 

184 

185 del x_shift_neg 

186 

187 float_variables["pml_x"] = None 

188 float_variables["pml_y"] = None 

189 float_variables["pml_z"] = None 

190 

191 float_variables["pml_x_sgx"] = None 

192 float_variables["pml_y_sgy"] = None 

193 float_variables["pml_z_sgz"] = None 

194 

195 float_variables["mpml_x_sgx"] = None 

196 float_variables["mpml_y_sgy"] = None 

197 float_variables["mpml_z_sgz"] = None 

198 

199 float_variables["mpml_x"] = None 

200 float_variables["mpml_y"] = None 

201 float_variables["mpml_z"] = None 

202 

203 

204def grab_axisymmetric_variables(float_variables, values): 

205 # single precision variables used within the time loop 

206 float_variables["dt"] = values.dt 

207 float_variables["c0"] = values.c0 

208 float_variables["c_ref"] = values.c_ref 

209 float_variables["rho0"] = values.rho0 

210 float_variables["rho0_sgx"] = values.rho0_sgx 

211 float_variables["rho0_sgy"] = values.rho0_sgy 

212 

213 

214def grab_medium_props(integer_variables, float_variables, medium, is_elastic_code): 

215 # ========================================================================= 

216 # VARIABLES USED IN NONLINEAR SIMULATIONS 

217 # ========================================================================= 

218 if medium.is_nonlinear(): 

219 float_variables["BonA"] = medium.BonA 

220 

221 # ========================================================================= 

222 # VARIABLES USED IN ABSORBING SIMULATIONS 

223 # ========================================================================= 

224 

225 # set absorbing flag 

226 if medium.absorbing: 

227 integer_variables.absorbing_flag = 2 if medium.stokes else 1 

228 else: 

229 integer_variables.absorbing_flag = 0 

230 

231 if medium.absorbing: 

232 if is_elastic_code: # pragma: no cover 

233 # add to the variable list 

234 float_variables["chi"] = None 

235 float_variables["eta"] = None 

236 float_variables["eta_sgxy"] = None 

237 float_variables["eta_sgxz"] = None 

238 float_variables["eta_sgyz"] = None 

239 else: 

240 float_variables["alpha_coeff"] = medium.alpha_coeff 

241 float_variables["alpha_power"] = medium.alpha_power 

242 

243 

244def grab_source_props( 

245 integer_variables, 

246 float_variables, 

247 source, 

248 u_source_pos_index, 

249 s_source_pos_index, 

250 p_source_pos_index, 

251 transducer_input_signal, 

252 delay_mask, 

253): 

254 # ========================================================================= 

255 # SOURCE VARIABLES 

256 # ========================================================================= 

257 # source modes and indices 

258 # - these are only defined if the source flgs are > 0 

259 # - the source mode describes whether the source will be added or replaced 

260 # - the source indices describe which grid points act as the source 

261 # - the u_source_index is reused for any of the u sources and the transducer source 

262 

263 grab_velocity_source_props(integer_variables, source, u_source_pos_index) 

264 grab_stress_source_props(integer_variables, source, s_source_pos_index) 

265 grab_pressure_source_props(integer_variables, source, p_source_pos_index, u_source_pos_index) 

266 grab_time_varying_source_props(integer_variables, float_variables, source, transducer_input_signal, delay_mask) 

267 

268 

269def grab_velocity_source_props(integer_variables, source, u_source_pos_index): 

270 # velocity source 

271 if any(integer_variables.get(k) for k in ["ux_source_flag", "uy_source_flag", "uz_source_flag"]): 

272 integer_variables["u_source_mode"] = { 

273 "dirichlet": 0, 

274 "additive-no-correction": 1, 

275 "additive": 2, 

276 }[source.u_mode] 

277 

278 if integer_variables.ux_source_flag: 

279 u_source_many = num_dim2(source.ux) > 1 

280 elif integer_variables.uy_source_flag: 

281 u_source_many = num_dim2(source.uy) > 1 

282 elif integer_variables.uz_source_flag: 

283 u_source_many = num_dim2(source.uz) > 1 

284 integer_variables["u_source_many"] = u_source_many 

285 

286 integer_variables.u_source_index = u_source_pos_index 

287 

288 

289def grab_stress_source_props(integer_variables, source, s_source_pos_index): 

290 # stress source 

291 if ( 

292 integer_variables.sxx_source_flag 

293 or integer_variables.syy_source_flag 

294 or integer_variables.szz_source_flag 

295 or integer_variables.sxy_source_flag 

296 or integer_variables.sxz_source_flag 

297 or integer_variables.syz_source_flag 

298 ): 

299 integer_variables.s_source_mode = source.s_mode != "dirichlet" 

300 if integer_variables.sxx_source_flag: 

301 s_source_many = num_dim2(source.sxx) > 1 

302 elif integer_variables.syy_source_flag: 

303 s_source_many = num_dim2(source.syy) > 1 

304 elif integer_variables.szz_source_flag: 

305 s_source_many = num_dim2(source.szz) > 1 

306 elif integer_variables.sxy_source_flag: 

307 s_source_many = num_dim2(source.sxy) > 1 

308 elif integer_variables.sxz_source_flag: 

309 s_source_many = num_dim2(source.sxz) > 1 

310 elif integer_variables.syz_source_flag: 

311 s_source_many = num_dim2(source.syz) > 1 

312 integer_variables.s_source_many = s_source_many 

313 integer_variables.s_source_index = s_source_pos_index 

314 

315 

316def grab_pressure_source_props(integer_variables, source, p_source_pos_index, u_source_pos_index): 

317 # pressure source 

318 if integer_variables.p_source_flag: 

319 integer_variables.p_source_mode = { 

320 "dirichlet": 0, 

321 "additive-no-correction": 1, 

322 "additive": 2, 

323 }[source.p_mode] 

324 integer_variables.p_source_many = num_dim2(source.p) > 1 

325 integer_variables.p_source_index = p_source_pos_index 

326 

327 # transducer source 

328 if integer_variables.transducer_source_flag: 

329 integer_variables.u_source_index = u_source_pos_index 

330 

331 

332def grab_time_varying_source_props(integer_variables, float_variables, source, transducer_input_signal, delay_mask): 

333 # time varying source variables 

334 # - these are only defined if the source flgs are > 0 

335 # - these are the actual source values 

336 # - these are indexed as (position_index, time_index) 

337 if integer_variables.ux_source_flag: 

338 float_variables.ux_source_input = source.ux 

339 

340 if integer_variables.uy_source_flag: 

341 float_variables.uy_source_input = source.uy 

342 

343 if integer_variables.uz_source_flag: 

344 float_variables.uz_source_input = source.uz 

345 

346 if integer_variables.sxx_source_flag: 

347 float_variables.sxx_source_input = source.sxx 

348 

349 if integer_variables.syy_source_flag: 

350 float_variables.syy_source_input = source.syy 

351 

352 if integer_variables.szz_source_flag: 

353 float_variables.szz_source_input = source.szz 

354 

355 if integer_variables.sxy_source_flag: 

356 float_variables.sxy_source_input = source.sxy 

357 

358 if integer_variables.sxz_source_flag: 

359 float_variables.sxz_source_input = source.sxz 

360 

361 if integer_variables.syz_source_flag: 

362 float_variables.syz_source_input = source.syz 

363 

364 if integer_variables.p_source_flag: 

365 float_variables.p_source_input = source.p 

366 

367 if integer_variables.transducer_source_flag: 

368 float_variables.transducer_source_input = transducer_input_signal 

369 integer_variables.delay_mask = delay_mask 

370 

371 # initial pressure source variable 

372 # - this is only defined if the p0 source flag is 1 

373 # - this defines the initial pressure everywhere (there is no indices) 

374 if integer_variables.p0_source_flag: 

375 float_variables.p0_source_input = source.p0 

376 

377 

378def grab_sensor_props(integer_variables, kgrid_dim, sensor_mask_index, cuboid_corners_list): 

379 # ========================================================================= 

380 # SENSOR VARIABLES 

381 # ========================================================================= 

382 

383 if integer_variables.sensor_mask_type == 0: 

384 # mask is defined as a list of grid indices 

385 integer_variables.sensor_mask_index = sensor_mask_index 

386 

387 elif integer_variables.sensor_mask_type == 1: 

388 cuboid_corners_list = cuboid_corners_list 

389 # mask is defined as a list of cuboid corners 

390 if kgrid_dim == 2: 

391 sensor_mask_corners = np.ones((6, cuboid_corners_list.shape[1])) 

392 sensor_mask_corners[0, :] = cuboid_corners_list[0, :] 

393 sensor_mask_corners[1, :] = cuboid_corners_list[1, :] 

394 sensor_mask_corners[3, :] = cuboid_corners_list[2, :] 

395 sensor_mask_corners[4, :] = cuboid_corners_list[3, :] 

396 else: 

397 sensor_mask_corners = cuboid_corners_list 

398 integer_variables.sensor_mask_corners = sensor_mask_corners 

399 

400 else: 

401 raise NotImplementedError("unknown option for sensor_mask_type") 

402 

403 

404def grab_nonuniform_grid_props(float_variables, kgrid, is_nonuniform_grid): 

405 # ========================================================================= 

406 # VARIABLES USED FOR NONUNIFORM GRIDS 

407 # ========================================================================= 

408 

409 # set nonuniform flag and variables 

410 # - these are only defined if nonuniform_grid_flag is 1 

411 # - these are applied using the bsxfun formulation 

412 if not is_nonuniform_grid: 

413 return 

414 

415 dxudxn = kgrid.dudn.x 

416 if np.array(dxudxn).size == 1: 

417 dxudxn = np.ones((kgrid.Nx, 1)) 

418 float_variables["dxudxn"] = dxudxn 

419 

420 dyudyn = kgrid.dudn.y 

421 if np.array(dyudyn).size == 1: 

422 dyudyn = np.ones((1, kgrid.Ny)) 

423 float_variables["dyudyn"] = dyudyn 

424 

425 dzudzn = kgrid.dudn.z 

426 if np.array(dzudzn).size == 1: 

427 dzudzn = np.ones((1, 1, kgrid.Nz)) 

428 float_variables["dzudzn"] = dzudzn 

429 

430 dxudxn_sgx = kgrid.dudn_sg.x 

431 if np.array(dxudxn).size == 1: 

432 dxudxn_sgx = np.ones((kgrid.Nx, 1)) 

433 float_variables["dxudxn_sgx"] = dxudxn_sgx 

434 

435 dyudyn_sgy = kgrid.dudn_sg.y 

436 if np.array(dyudyn).size == 1: 

437 dyudyn_sgy = np.ones((1, kgrid.Ny)) 

438 float_variables["dyudyn_sgy"] = dyudyn_sgy 

439 

440 dzudzn_sgz = kgrid.dudn_sg.z 

441 if np.array(dzudzn).size == 1: 

442 dzudzn_sgz = np.ones((1, 1, kgrid.Nz)) 

443 float_variables["dzudzn_sgz"] = dzudzn_sgz 

444 

445 

446def remove_z_dimension(float_variables, kgrid_dim): 

447 # remove z-dimension variables for saving 2D files 

448 if kgrid_dim == 2: 

449 for k in list(float_variables.keys()): 

450 if "z" in k: 

451 del float_variables[k] 

452 

453 

454def enforce_filename_standards(filepath): 

455 # check for HDF5 filename extension 

456 filename_ext = os.path.splitext(filepath)[1] 

457 

458 # use .h5 as default if no extension is given 

459 if len(filename_ext) == 0: 

460 filename_ext = ".h5" 

461 filepath = filepath + ".h5" 

462 return filepath, filename_ext 

463 

464 

465def save_file(filepath, integer_variables, float_variables, hdf_compression_level, auto_chunk): 

466 filepath, filename_ext = enforce_filename_standards(filepath) 

467 

468 # save file 

469 if filename_ext == ".h5": 

470 save_h5_file(filepath, integer_variables, float_variables, hdf_compression_level, auto_chunk) 

471 

472 elif filename_ext == ".mat": 

473 save_mat_file(filepath, integer_variables, float_variables) 

474 else: 

475 # throw error for unknown filetype 

476 raise NotImplementedError("unknown file extension for " "save_to_disk" " filename") 

477 

478 

479def save_h5_file(filepath, integer_variables, float_variables, hdf_compression_level, auto_chunk): 

480 # ---------------- 

481 # SAVE HDF5 FILE 

482 # ---------------- 

483 

484 # check if file exists, and delete if it does (the hdf5 library will 

485 # give an error if the file already exists) 

486 if os.path.exists(filepath): 

487 os.remove(filepath) 

488 

489 # change all the variables to be in single precision (float in C++), 

490 # then add to HDF5 File 

491 for key, value in float_variables.items(): 

492 # cast matrix to single precision 

493 value = np.array(value, dtype=np.float32) 

494 write_matrix(filepath, value, key, hdf_compression_level, auto_chunk) 

495 del value 

496 

497 # change all the index variables to be in 64-bit unsigned integers 

498 # (long in C++), then add to HDF5 file 

499 for key, value in integer_variables.items(): 

500 # cast matrix to 64-bit unsigned integer 

501 value = np.array(value, dtype=np.uint64) 

502 write_matrix(filepath, value, key, hdf_compression_level, auto_chunk) 

503 del value 

504 

505 # set additional file attributes 

506 write_attributes(filepath) 

507 

508 

509def save_mat_file(filepath, integer_variables, float_variables): 

510 # ---------------- 

511 # SAVE .MAT FILE 

512 # ---------------- 

513 

514 # change all the variables to be in single precision (float in C++) 

515 for key, value in float_variables.items(): 

516 float_variables[key] = np.array(value, dtype=np.float32) 

517 

518 for key, value in integer_variables.items(): 

519 integer_variables[key] = np.array(value, dtype=np.uint64) 

520 

521 # save the input variables to disk as a MATLAB binary file 

522 float_variables = dict(**float_variables, **integer_variables) 

523 savemat(filepath, float_variables)