Coverage for kwave/kWaveSimulation_helper/save_to_disk_func.py: 10%
210 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1import logging
2import os
4import numpy as np
5from scipy.io import savemat
7from kwave.kgrid import kWaveGrid
8from kwave.kmedium import kWaveMedium
9from kwave.options.simulation_options import SimulationOptions
10from kwave.utils.data import scale_time
11from kwave.utils.dotdictionary import dotdict
12from kwave.utils.io import write_attributes, write_matrix
13from kwave.utils.matrix import num_dim2
14from kwave.utils.tictoc import TicToc
17def save_to_disk_func(
18 kgrid: kWaveGrid, medium: kWaveMedium, source, opt: SimulationOptions, auto_chunk: bool, values: dotdict, flags: dotdict
19):
20 # update command line status
21 logging.log(logging.INFO, " precomputation completed in ", scale_time(TicToc.toc()))
22 TicToc.tic()
23 logging.log(logging.INFO, " saving input files to disk...")
25 # check for a binary sensor mask or cuboid corners
26 # modified by Farid | disabled temporarily!
27 # assert self.binary_sensor_mask or self.cuboid_corners, \
28 # "Optional input ''save_to_disk'' only supported for sensor masks defined as a binary matrix
29 # or the opposing corners of a rectangle (2D) or cuboid (3D)."
31 # =========================================================================
32 # VARIABLE LIST
33 # =========================================================================
34 integer_variables = dotdict()
35 float_variables = dotdict()
37 grab_integer_variables(integer_variables, kgrid, flags, medium)
38 grab_pml_size(integer_variables, opt)
39 grab_float_variables(float_variables, kgrid, opt, values, flags.elastic_code, flags.axisymmetric)
41 # overwrite z-values for 2D simulations
42 if kgrid.dim == 2:
43 integer_variables.Nz = 1
44 integer_variables.pml_z_size = 0
46 grab_medium_props(integer_variables, float_variables, medium, flags.elastic_code)
47 grab_source_props(
48 integer_variables,
49 float_variables,
50 source,
51 values.u_source_pos_index,
52 values.s_source_pos_index,
53 values.p_source_pos_index,
54 values.transducer_input_signal,
55 values.delay_mask,
56 )
58 grab_sensor_props(integer_variables, kgrid.dim, values.sensor_mask_index, values.record.cuboid_corners_list)
59 grab_nonuniform_grid_props(float_variables, kgrid, flags.nonuniform_grid)
61 # =========================================================================
62 # DATACAST AND SAVING
63 # =========================================================================
65 remove_z_dimension(float_variables, kgrid.dim)
66 save_file(opt.input_filename, integer_variables, float_variables, opt.hdf_compression_level, auto_chunk=auto_chunk)
68 # update command line status
69 logging.log(logging.INFO, " completed in ", scale_time(TicToc.toc()))
72def grab_integer_variables(integer_variables, kgrid, flags, medium):
73 # integer variables used within the time loop for all codes
75 variables = dotdict(
76 {
77 "Nx": kgrid.Nx,
78 "Ny": kgrid.Ny,
79 "Nz": kgrid.Nz,
80 "Nt": kgrid.Nt,
81 "p_source_flag": flags.source_p,
82 "p0_source_flag": flags.source_p0,
83 "ux_source_flag": flags.source_ux,
84 "uy_source_flag": flags.source_uy,
85 "uz_source_flag": flags.source_uz,
86 "sxx_source_flag": flags.source_sxx,
87 "syy_source_flag": flags.source_syy,
88 "szz_source_flag": flags.source_szz,
89 "sxy_source_flag": flags.source_sxy,
90 "sxz_source_flag": flags.source_sxz,
91 "syz_source_flag": flags.source_syz,
92 "transducer_source_flag": flags.transducer_source,
93 "nonuniform_grid_flag": flags.nonuniform_grid,
94 "nonlinear_flag": medium.is_nonlinear(),
95 "absorbing_flag": None,
96 "elastic_flag": flags.elastic_code,
97 "axisymmetric_flag": flags.axisymmetric,
98 # create pseudonyms for the sensor flgs
99 # 0: binary mask indices
100 # 1: cuboid corners
101 "sensor_mask_type": flags.cuboid_corners,
102 }
103 )
104 integer_variables.update(variables)
107def grab_pml_size(integer_variables, opt):
108 # additional integer variables not used within time loop but stored directly to output file
109 integer_variables["pml_x_size"] = opt.pml_x_size
110 integer_variables["pml_y_size"] = opt.pml_y_size
111 integer_variables["pml_z_size"] = opt.pml_z_size
114def grab_float_variables(float_variables: dotdict, kgrid, opt, values, is_elastic_code, is_axisymmetric):
115 # single precision variables not used within time loop but stored directly
116 # to the output file for all files
117 variables = dotdict(
118 {
119 "dx": kgrid.dx,
120 "dy": kgrid.dy,
121 "dz": kgrid.dz,
122 "pml_x_alpha": opt.pml_x_alpha,
123 "pml_y_alpha": opt.pml_y_alpha,
124 "pml_z_alpha": opt.pml_z_alpha,
125 }
126 )
127 float_variables.update(variables)
129 if is_elastic_code: # pragma: no cover
130 grab_elastic_code_variables(float_variables, kgrid, values)
131 elif is_axisymmetric:
132 grab_axisymmetric_variables(float_variables, values)
133 else:
134 # single precision variables used within the time loop
135 float_variables["dt"] = values.dt
136 float_variables["c0"] = values.c0
137 float_variables["c_ref"] = values.c_ref
138 float_variables["rho0"] = values.rho0
139 float_variables["rho0_sgx"] = values.rho0_sgx
140 float_variables["rho0_sgy"] = values.rho0_sgy
141 float_variables["rho0_sgz"] = values.rho0_sgz
144def grab_elastic_code_variables(float_variables, kgrid, values): # pragma: no cover
145 # single precision variables used within the time loop
146 float_variables["dt"] = None
147 float_variables["c_ref"] = None
148 float_variables["lambda"] = None
149 float_variables["mu"] = None
151 float_variables["rho0_sgx"] = None
152 float_variables["rho0_sgy"] = None
153 float_variables["rho0_sgz"] = None
155 float_variables["mu_sgxy"] = None
156 float_variables["mu_sgxz"] = None
157 float_variables["mu_sgyz"] = None
159 # create shift variables used for calculating u_non_staggered and I outputs
160 x_shift_neg = np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.x * kgrid.dx / 2))
161 y_shift_neg = np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.y * kgrid.dy / 2)).T
162 z_shift_neg = np.transpose(np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.z * kgrid.dz / 2)), (1, 2, 0))
164 # create reduced variables for use with real-to-complex FFT
165 Nz = kgrid.Nz if kgrid.dim != 2 else 1
166 Nx_r = kgrid.Nx // 2 + 1
167 Ny_r = kgrid.Ny // 2 + 1
168 Nz_r = Nz // 2 + 1
170 ddx_k_shift_pos = values.ddx_k_shift_pos
171 ddx_k_shift_neg = values.ddx_k_shift_neg
173 float_variables["ddx_k_shift_pos_r"] = ddx_k_shift_pos[:Nx_r]
174 float_variables["ddy_k_shift_pos"] = None
175 float_variables["ddz_k_shift_pos"] = None
177 float_variables["ddx_k_shift_neg_r"] = ddx_k_shift_neg[:Nx_r]
178 float_variables["ddy_k_shift_neg"] = None
179 float_variables["ddz_k_shift_neg"] = None
181 float_variables["x_shift_neg_r"] = x_shift_neg[:Nx_r]
182 float_variables["y_shift_neg_r"] = y_shift_neg[:Ny_r]
183 float_variables["z_shift_neg_r"] = z_shift_neg[:Nz_r]
185 del x_shift_neg
187 float_variables["pml_x"] = None
188 float_variables["pml_y"] = None
189 float_variables["pml_z"] = None
191 float_variables["pml_x_sgx"] = None
192 float_variables["pml_y_sgy"] = None
193 float_variables["pml_z_sgz"] = None
195 float_variables["mpml_x_sgx"] = None
196 float_variables["mpml_y_sgy"] = None
197 float_variables["mpml_z_sgz"] = None
199 float_variables["mpml_x"] = None
200 float_variables["mpml_y"] = None
201 float_variables["mpml_z"] = None
204def grab_axisymmetric_variables(float_variables, values):
205 # single precision variables used within the time loop
206 float_variables["dt"] = values.dt
207 float_variables["c0"] = values.c0
208 float_variables["c_ref"] = values.c_ref
209 float_variables["rho0"] = values.rho0
210 float_variables["rho0_sgx"] = values.rho0_sgx
211 float_variables["rho0_sgy"] = values.rho0_sgy
214def grab_medium_props(integer_variables, float_variables, medium, is_elastic_code):
215 # =========================================================================
216 # VARIABLES USED IN NONLINEAR SIMULATIONS
217 # =========================================================================
218 if medium.is_nonlinear():
219 float_variables["BonA"] = medium.BonA
221 # =========================================================================
222 # VARIABLES USED IN ABSORBING SIMULATIONS
223 # =========================================================================
225 # set absorbing flag
226 if medium.absorbing:
227 integer_variables.absorbing_flag = 2 if medium.stokes else 1
228 else:
229 integer_variables.absorbing_flag = 0
231 if medium.absorbing:
232 if is_elastic_code: # pragma: no cover
233 # add to the variable list
234 float_variables["chi"] = None
235 float_variables["eta"] = None
236 float_variables["eta_sgxy"] = None
237 float_variables["eta_sgxz"] = None
238 float_variables["eta_sgyz"] = None
239 else:
240 float_variables["alpha_coeff"] = medium.alpha_coeff
241 float_variables["alpha_power"] = medium.alpha_power
244def grab_source_props(
245 integer_variables,
246 float_variables,
247 source,
248 u_source_pos_index,
249 s_source_pos_index,
250 p_source_pos_index,
251 transducer_input_signal,
252 delay_mask,
253):
254 # =========================================================================
255 # SOURCE VARIABLES
256 # =========================================================================
257 # source modes and indices
258 # - these are only defined if the source flgs are > 0
259 # - the source mode describes whether the source will be added or replaced
260 # - the source indices describe which grid points act as the source
261 # - the u_source_index is reused for any of the u sources and the transducer source
263 grab_velocity_source_props(integer_variables, source, u_source_pos_index)
264 grab_stress_source_props(integer_variables, source, s_source_pos_index)
265 grab_pressure_source_props(integer_variables, source, p_source_pos_index, u_source_pos_index)
266 grab_time_varying_source_props(integer_variables, float_variables, source, transducer_input_signal, delay_mask)
269def grab_velocity_source_props(integer_variables, source, u_source_pos_index):
270 # velocity source
271 if any(integer_variables.get(k) for k in ["ux_source_flag", "uy_source_flag", "uz_source_flag"]):
272 integer_variables["u_source_mode"] = {
273 "dirichlet": 0,
274 "additive-no-correction": 1,
275 "additive": 2,
276 }[source.u_mode]
278 if integer_variables.ux_source_flag:
279 u_source_many = num_dim2(source.ux) > 1
280 elif integer_variables.uy_source_flag:
281 u_source_many = num_dim2(source.uy) > 1
282 elif integer_variables.uz_source_flag:
283 u_source_many = num_dim2(source.uz) > 1
284 integer_variables["u_source_many"] = u_source_many
286 integer_variables.u_source_index = u_source_pos_index
289def grab_stress_source_props(integer_variables, source, s_source_pos_index):
290 # stress source
291 if (
292 integer_variables.sxx_source_flag
293 or integer_variables.syy_source_flag
294 or integer_variables.szz_source_flag
295 or integer_variables.sxy_source_flag
296 or integer_variables.sxz_source_flag
297 or integer_variables.syz_source_flag
298 ):
299 integer_variables.s_source_mode = source.s_mode != "dirichlet"
300 if integer_variables.sxx_source_flag:
301 s_source_many = num_dim2(source.sxx) > 1
302 elif integer_variables.syy_source_flag:
303 s_source_many = num_dim2(source.syy) > 1
304 elif integer_variables.szz_source_flag:
305 s_source_many = num_dim2(source.szz) > 1
306 elif integer_variables.sxy_source_flag:
307 s_source_many = num_dim2(source.sxy) > 1
308 elif integer_variables.sxz_source_flag:
309 s_source_many = num_dim2(source.sxz) > 1
310 elif integer_variables.syz_source_flag:
311 s_source_many = num_dim2(source.syz) > 1
312 integer_variables.s_source_many = s_source_many
313 integer_variables.s_source_index = s_source_pos_index
316def grab_pressure_source_props(integer_variables, source, p_source_pos_index, u_source_pos_index):
317 # pressure source
318 if integer_variables.p_source_flag:
319 integer_variables.p_source_mode = {
320 "dirichlet": 0,
321 "additive-no-correction": 1,
322 "additive": 2,
323 }[source.p_mode]
324 integer_variables.p_source_many = num_dim2(source.p) > 1
325 integer_variables.p_source_index = p_source_pos_index
327 # transducer source
328 if integer_variables.transducer_source_flag:
329 integer_variables.u_source_index = u_source_pos_index
332def grab_time_varying_source_props(integer_variables, float_variables, source, transducer_input_signal, delay_mask):
333 # time varying source variables
334 # - these are only defined if the source flgs are > 0
335 # - these are the actual source values
336 # - these are indexed as (position_index, time_index)
337 if integer_variables.ux_source_flag:
338 float_variables.ux_source_input = source.ux
340 if integer_variables.uy_source_flag:
341 float_variables.uy_source_input = source.uy
343 if integer_variables.uz_source_flag:
344 float_variables.uz_source_input = source.uz
346 if integer_variables.sxx_source_flag:
347 float_variables.sxx_source_input = source.sxx
349 if integer_variables.syy_source_flag:
350 float_variables.syy_source_input = source.syy
352 if integer_variables.szz_source_flag:
353 float_variables.szz_source_input = source.szz
355 if integer_variables.sxy_source_flag:
356 float_variables.sxy_source_input = source.sxy
358 if integer_variables.sxz_source_flag:
359 float_variables.sxz_source_input = source.sxz
361 if integer_variables.syz_source_flag:
362 float_variables.syz_source_input = source.syz
364 if integer_variables.p_source_flag:
365 float_variables.p_source_input = source.p
367 if integer_variables.transducer_source_flag:
368 float_variables.transducer_source_input = transducer_input_signal
369 integer_variables.delay_mask = delay_mask
371 # initial pressure source variable
372 # - this is only defined if the p0 source flag is 1
373 # - this defines the initial pressure everywhere (there is no indices)
374 if integer_variables.p0_source_flag:
375 float_variables.p0_source_input = source.p0
378def grab_sensor_props(integer_variables, kgrid_dim, sensor_mask_index, cuboid_corners_list):
379 # =========================================================================
380 # SENSOR VARIABLES
381 # =========================================================================
383 if integer_variables.sensor_mask_type == 0:
384 # mask is defined as a list of grid indices
385 integer_variables.sensor_mask_index = sensor_mask_index
387 elif integer_variables.sensor_mask_type == 1:
388 cuboid_corners_list = cuboid_corners_list
389 # mask is defined as a list of cuboid corners
390 if kgrid_dim == 2:
391 sensor_mask_corners = np.ones((6, cuboid_corners_list.shape[1]))
392 sensor_mask_corners[0, :] = cuboid_corners_list[0, :]
393 sensor_mask_corners[1, :] = cuboid_corners_list[1, :]
394 sensor_mask_corners[3, :] = cuboid_corners_list[2, :]
395 sensor_mask_corners[4, :] = cuboid_corners_list[3, :]
396 else:
397 sensor_mask_corners = cuboid_corners_list
398 integer_variables.sensor_mask_corners = sensor_mask_corners
400 else:
401 raise NotImplementedError("unknown option for sensor_mask_type")
404def grab_nonuniform_grid_props(float_variables, kgrid, is_nonuniform_grid):
405 # =========================================================================
406 # VARIABLES USED FOR NONUNIFORM GRIDS
407 # =========================================================================
409 # set nonuniform flag and variables
410 # - these are only defined if nonuniform_grid_flag is 1
411 # - these are applied using the bsxfun formulation
412 if not is_nonuniform_grid:
413 return
415 dxudxn = kgrid.dudn.x
416 if np.array(dxudxn).size == 1:
417 dxudxn = np.ones((kgrid.Nx, 1))
418 float_variables["dxudxn"] = dxudxn
420 dyudyn = kgrid.dudn.y
421 if np.array(dyudyn).size == 1:
422 dyudyn = np.ones((1, kgrid.Ny))
423 float_variables["dyudyn"] = dyudyn
425 dzudzn = kgrid.dudn.z
426 if np.array(dzudzn).size == 1:
427 dzudzn = np.ones((1, 1, kgrid.Nz))
428 float_variables["dzudzn"] = dzudzn
430 dxudxn_sgx = kgrid.dudn_sg.x
431 if np.array(dxudxn).size == 1:
432 dxudxn_sgx = np.ones((kgrid.Nx, 1))
433 float_variables["dxudxn_sgx"] = dxudxn_sgx
435 dyudyn_sgy = kgrid.dudn_sg.y
436 if np.array(dyudyn).size == 1:
437 dyudyn_sgy = np.ones((1, kgrid.Ny))
438 float_variables["dyudyn_sgy"] = dyudyn_sgy
440 dzudzn_sgz = kgrid.dudn_sg.z
441 if np.array(dzudzn).size == 1:
442 dzudzn_sgz = np.ones((1, 1, kgrid.Nz))
443 float_variables["dzudzn_sgz"] = dzudzn_sgz
446def remove_z_dimension(float_variables, kgrid_dim):
447 # remove z-dimension variables for saving 2D files
448 if kgrid_dim == 2:
449 for k in list(float_variables.keys()):
450 if "z" in k:
451 del float_variables[k]
454def enforce_filename_standards(filepath):
455 # check for HDF5 filename extension
456 filename_ext = os.path.splitext(filepath)[1]
458 # use .h5 as default if no extension is given
459 if len(filename_ext) == 0:
460 filename_ext = ".h5"
461 filepath = filepath + ".h5"
462 return filepath, filename_ext
465def save_file(filepath, integer_variables, float_variables, hdf_compression_level, auto_chunk):
466 filepath, filename_ext = enforce_filename_standards(filepath)
468 # save file
469 if filename_ext == ".h5":
470 save_h5_file(filepath, integer_variables, float_variables, hdf_compression_level, auto_chunk)
472 elif filename_ext == ".mat":
473 save_mat_file(filepath, integer_variables, float_variables)
474 else:
475 # throw error for unknown filetype
476 raise NotImplementedError("unknown file extension for " "save_to_disk" " filename")
479def save_h5_file(filepath, integer_variables, float_variables, hdf_compression_level, auto_chunk):
480 # ----------------
481 # SAVE HDF5 FILE
482 # ----------------
484 # check if file exists, and delete if it does (the hdf5 library will
485 # give an error if the file already exists)
486 if os.path.exists(filepath):
487 os.remove(filepath)
489 # change all the variables to be in single precision (float in C++),
490 # then add to HDF5 File
491 for key, value in float_variables.items():
492 # cast matrix to single precision
493 value = np.array(value, dtype=np.float32)
494 write_matrix(filepath, value, key, hdf_compression_level, auto_chunk)
495 del value
497 # change all the index variables to be in 64-bit unsigned integers
498 # (long in C++), then add to HDF5 file
499 for key, value in integer_variables.items():
500 # cast matrix to 64-bit unsigned integer
501 value = np.array(value, dtype=np.uint64)
502 write_matrix(filepath, value, key, hdf_compression_level, auto_chunk)
503 del value
505 # set additional file attributes
506 write_attributes(filepath)
509def save_mat_file(filepath, integer_variables, float_variables):
510 # ----------------
511 # SAVE .MAT FILE
512 # ----------------
514 # change all the variables to be in single precision (float in C++)
515 for key, value in float_variables.items():
516 float_variables[key] = np.array(value, dtype=np.float32)
518 for key, value in integer_variables.items():
519 integer_variables[key] = np.array(value, dtype=np.uint64)
521 # save the input variables to disk as a MATLAB binary file
522 float_variables = dict(**float_variables, **integer_variables)
523 savemat(filepath, float_variables)