Coverage for kwave/utils/io.py: 9%
155 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-24 12:06 -0700
1import os
2import platform
3import socket
4from datetime import datetime
5from typing import Optional
7import cv2
8import h5py
9import numpy as np
11import kwave
13from .conversion import cast_to_type
14from .data import get_date_string
15from .dotdictionary import dotdict
18def get_h5_literals():
19 literals = dotdict(
20 {
21 # data type
22 "DATA_TYPE_ATT_NAME": "data_type",
23 "MATRIX_DATA_TYPE_MATLAB": "single",
24 "MATRIX_DATA_TYPE_C": "float",
25 "INTEGER_DATA_TYPE_MATLAB": "uint64",
26 "INTEGER_DATA_TYPE_C": "long",
27 # real / complex
28 "DOMAIN_TYPE_ATT_NAME": "domain_type",
29 "DOMAIN_TYPE_REAL": "real",
30 "DOMAIN_TYPE_COMPLEX": "complex",
31 # file descriptors
32 "FILE_MAJOR_VER_ATT_NAME": "major_version",
33 "FILE_MINOR_VER_ATT_NAME": "minor_version",
34 "FILE_DESCR_ATT_NAME": "file_description",
35 "FILE_CREATION_DATE_ATT_NAME": "creation_date",
36 "CREATED_BY_ATT_NAME": "created_by",
37 # file type
38 "FILE_TYPE_ATT_NAME": "file_type",
39 "HDF_INPUT_FILE": "input",
40 "HDF_OUTPUT_FILE": "output",
41 "HDF_CHECKPOINT_FILE": "checkpoint",
42 # file version information
43 "HDF_FILE_MAJOR_VERSION": "1",
44 "HDF_FILE_MINOR_VERSION": "2",
45 # compression level
46 "HDF_COMPRESSION_LEVEL": 0,
47 }
48 )
49 return literals
52def write_matrix(filename, matrix: np.ndarray, matrix_name: str, compression_level: int = None, auto_chunk: bool = True):
53 # get literals
54 h5_literals = get_h5_literals()
56 assert isinstance(auto_chunk, bool), "auto_chunk must be a boolean."
58 if compression_level is None:
59 compression_level = h5_literals.HDF_COMPRESSION_LEVEL
61 # dims = num_dim(matrix)
62 dims = len(matrix.shape)
64 if dims == 3:
65 matrix = np.transpose(matrix, [2, 1, 0]) # C <=> Fortran ordering
66 if dims == 2:
67 matrix = np.transpose(matrix) # C <=> Fortran ordering
69 # get the size of the input matrix
70 if dims == 3:
71 Nx, Ny, Nz = matrix.shape
72 elif dims == 2:
73 Ny, Nz = matrix.shape
74 Nx = 1
75 else:
76 Nx, Ny, Nz = 1, 1, 1
78 # check size of matrix and set chunk size and compression level
79 if dims == 3:
80 # set chunk size to Nx * Ny
81 chunk_size = [Nx, Ny, 1]
82 elif dims == 2:
83 # set chunk size to Nx
84 chunk_size = [Nx, 1, 1]
85 elif dims <= 1:
86 # check that the matrix size is greater than 1 MB
87 one_mb = (1024**2) / 8
88 if matrix.size > one_mb:
89 # set chunk size to 1 MB
90 if Nx > Ny:
91 chunk_size = [one_mb, 1, 1]
92 elif Ny > Nz:
93 chunk_size = [1, one_mb, 1]
94 else:
95 chunk_size = [1, 1, one_mb]
96 else:
97 # set no compression
98 compression_level = 0
100 # set chunk size to grid size
101 if matrix.size == 1:
102 chunk_size = (1, 1, 1)
103 elif Nx > Ny:
104 chunk_size = (Nx, 1, 1)
105 elif Ny > Nz:
106 chunk_size = (1, Ny, 1)
107 else:
108 chunk_size = (1, 1, Nz)
109 else:
110 # throw error for unknown matrix size
111 raise ValueError("Input matrix must have 1, 2 or 3 dimensions.")
113 # check the format of the matrix is either single precision (float in C++)
114 # or uint64 (unsigned long in C++)
115 if matrix.dtype == np.float32:
116 # set data type flags
117 data_type_matlab = h5_literals.MATRIX_DATA_TYPE_MATLAB
118 data_type_c = h5_literals.MATRIX_DATA_TYPE_C
119 elif matrix.dtype == np.uint64:
120 # set data type flags
121 data_type_matlab = h5_literals.INTEGER_DATA_TYPE_MATLAB
122 data_type_c = h5_literals.INTEGER_DATA_TYPE_C
124 else:
125 # throw error for unknown data type
126 raise ValueError("Input matrix must be of type " "single" " or " "uint64" ".")
128 # check if the input matrix is real or complex, if complex, rearrange the
129 # data in the C++ format
130 if np.isreal(matrix).all():
131 # set file tag
132 domain_type = "real" # DOMAIN_TYPE_REAL
134 elif dims == 3:
135 # set file tag
136 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX
138 # rearrange the data so the real and imaginary parts are stored in the
139 # same matrix
140 matrix = np.concatenate(matrix.real, matrix.imag, axis=0)
141 matrix = matrix.reshape((Nx, 2, Ny, Nz))
142 matrix = np.transpose(matrix, (1, 0, 2, 3))
143 matrix = matrix.reshape((2 * Nx, Ny, Nz))
145 # update the size of Nx
146 Nx = 2 * Nx
148 elif dims <= 1:
149 # set file tag
150 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX
152 # rearrange the data so the real and imaginary parts are stored in the
153 # same matrix
154 nelems = matrix.size
155 matrix = matrix.reshape((nelems, 1))
156 matrix = np.concatenate(matrix.real, matrix.imag, axis=0)
157 matrix = matrix.reshape((nelems, 2, 1, 1))
158 matrix = np.transpose(matrix, (1, 0, 2, 3))
160 # update the matrix size
161 Nx = Nx * (2 - np.array(Nx == 1).astype(float))
162 Ny = Ny * (2 - np.array(Ny == 1).astype(float))
163 Nz = Nz * (2 - np.array(Nz == 1).astype(float))
165 # double store in x-direction if a complex scalar
166 if Nx == 1 and Ny == 1 and Nz == 1:
167 Nx = 2 * Nx
169 # put in correct dimension
170 matrix = matrix.reshape((Nx, Ny, Nz))
172 else:
173 raise NotImplementedError("Currently there is no support for saving 2D complex matrices.")
175 # allocate a holder for the new matrix within the file
176 opts = {"dtype": data_type_matlab, "chunks": auto_chunk if auto_chunk is True else tuple(chunk_size)}
178 if compression_level != 0:
179 # use compression
180 opts["compression"] = compression_level
182 # write the matrix into the file
183 with h5py.File(filename, "a") as f:
184 f.create_dataset(f"/{matrix_name}", [Nx, Ny, Nz], data=matrix, **opts)
186 # set attributes for the matrix (used by k-Wave++)
187 assign_str_attr(f[f"/{matrix_name}"].attrs, h5_literals.DOMAIN_TYPE_ATT_NAME, domain_type)
188 assign_str_attr(f[f"/{matrix_name}"].attrs, h5_literals.DATA_TYPE_ATT_NAME, data_type_c)
191def write_attributes(filename: str, file_description: Optional[str] = None) -> None:
192 """
193 Write attributes to a HDF5 file.
195 This function writes attributes to a HDF5 file using a deprecated legacy method if legacy is set to True, or a new
196 typed method if legacy is set to False. The function warns if legacy is set to True and deprecates it. If
197 file_description is not provided, a default file description will be used.
199 Args:
200 filename: The name of the HDF5 file.
201 file_description: The description of the file. If not provided, a default file description
202 will be used.
204 """
206 # get literals
207 h5_literals = get_h5_literals()
209 # get computer info
210 comp_info = dotdict(
211 {
212 "date": datetime.now().strftime("%d-%b-%Y"),
213 "computer_name": socket.gethostname(),
214 "operating_system_type": platform.system(),
215 "operating_system": platform.system() + " " + platform.release() + " " + platform.version(),
216 "user_name": os.environ.get("USERNAME"),
217 "matlab_version": "N/A",
218 "kwave_version": "1.3",
219 "kwave_path": "N/A",
220 }
221 )
223 # set file description if not provided by user
224 if file_description is None:
225 file_description = (
226 f"Input data created by {comp_info.user_name} running MATLAB "
227 f"{comp_info.matlab_version} on {comp_info.operating_system_type}"
228 )
230 # set additional file attributes
231 with h5py.File(filename, "a") as f:
232 # create a dictionary of attributes
233 attributes = {
234 h5_literals.FILE_MAJOR_VER_ATT_NAME: h5_literals.HDF_FILE_MAJOR_VERSION,
235 h5_literals.FILE_MINOR_VER_ATT_NAME: h5_literals.HDF_FILE_MINOR_VERSION,
236 h5_literals.CREATED_BY_ATT_NAME: f"k-Wave {kwave.__version__}",
237 h5_literals.FILE_DESCR_ATT_NAME: file_description,
238 h5_literals.FILE_TYPE_ATT_NAME: h5_literals.HDF_INPUT_FILE,
239 h5_literals.FILE_CREATION_DATE_ATT_NAME: get_date_string(),
240 }
241 # loop through the attributes dictionary and assign each attribute to the file
242 for key, value in attributes.items():
243 assign_str_attr(f.attrs, key, value)
246def write_flags(filename):
247 """
248 writeFlags reads the input HDF5 file and derives and writes the
249 required source and medium flags based on the datasets present in the
250 file. For example, if the file contains a data set named 'BonA', the
251 nonlinear_flag will be written as true. Conditional flags are also
252 written. The source mode flags are written when appropriate if they
253 are not already present in the file. The default source mode is
254 'additive'.
256 List of flags that are always written
257 ux_source_flag
258 uy_source_flag
259 uz_source_flag
260 sxx_source_flag
261 sxy_source_flag
262 sxz_source_flag
263 syy_source_flag
264 syz_source_flag
265 szz_source_flag
266 p_source_flag
267 p0_source_flag
268 transducer_source_flag
269 nonuniform_grid_flag
270 nonlinear_flag
271 absorbing_flag
272 axisymmetric_flag
273 elastic_flag
274 sensor_mask_type
276 List of conditional flags
277 u_source_mode
278 u_source_many
279 p_source_mode
280 p_source_many
281 s_source_mode
282 s_source_many
284 Args:
285 filename:
287 """
289 # h5_literals = get_h5_literals()
291 with h5py.File(filename, "r") as hf:
292 names = hf.keys()
294 v_list = [
295 ("ux_source", "u_source_many"),
296 ("uy_source", "u_source_many"),
297 ("uz_source", "u_source_many"),
298 ("sxx_source", "s_source_many"),
299 ("syy_source", "s_source_many"),
300 ("szz_source", "s_source_many"),
301 ("sxy_source", "s_source_many"),
302 ("sxz_source", "s_source_many"),
303 ("syz_source", "s_source_many"),
304 ("p_source", "p_source_many"),
305 ]
306 variable_list = {}
307 for prefix, many_flag_key in v_list:
308 inp_name = f"{prefix}_input"
309 flag_name = f"{prefix}_flag"
310 if inp_name in names:
311 variable_list[flag_name] = hf[inp_name].shape[1]
313 variable_list[many_flag_key] = hf[inp_name].shape[0] != 1
314 else:
315 variable_list[flag_name] = 0
317 # --------------------
318 # u source
319 # --------------------
321 # write u_source mode if not already in file (1 is Additive, 0 is Dirichlet)
322 if any(variable_list[flag] for flag in ["ux_source_flag", "uy_source_flag", "uz_source_flag"]) and "u_source_mode" not in names:
323 variable_list["u_source_mode"] = 1
325 # --------------------
326 # s source
327 # --------------------
329 # write s_source mode if not already in file (1 is Additive, 0 is Dirichlet)
330 if (
331 any(
332 variable_list[flag]
333 for flag in [
334 "sxx_source_flag",
335 "syy_source_flag",
336 "szz_source_flag",
337 "sxy_source_flag",
338 "sxz_source_flag",
339 "syz_source_flag",
340 ]
341 )
342 and "s_source_mode" not in names
343 ):
344 variable_list["s_source_mode"] = 1
346 # --------------------
347 # p source
348 # --------------------
350 # write p_source mode if not already in file (1 is Additive, 0 is Dirichlet)
351 if any(variable_list[flag] for flag in ["p_source_flag"]) and "p_source_mode" not in names:
352 variable_list["p_source_mode"] = 1
354 # check for p0_source_input and set p0_source_flag
355 variable_list["p0_source_flag"] = "p0_source_input" in names
357 # --------------------
358 # additional flags
359 # --------------------
360 # check for transducer_source_input and set transducer_source_flag
361 variable_list["transducer_source_flag"] = "transducer_source_input" in names
363 # check for BonA and set nonlinear flag
364 variable_list["nonlinear_flag"] = "BonA" in names
366 # check for alpha_coeff and set absorbing flag
367 variable_list["absorbing_flag"] = "alpha_coeff" in names
369 # check for lambda and set elastic flag
370 variable_list["elastic_flag"] = "lambda" in names
372 # set axisymmetric grid flag to false
373 variable_list["axisymmetric_flag"] = 0
375 # set nonuniform grid flag to false
376 variable_list["nonuniform_grid_flag"] = 0
378 # check for sensor_mask_index and sensor_mask_corners
379 if "sensor_mask_index" in names:
380 variable_list["sensor_mask_type"] = 0
381 elif "sensor_mask_corners" in names:
382 variable_list["sensor_mask_type"] = 1
383 else:
384 raise ValueError("Either sensor_mask_index or sensor_mask_corners must be defined in the input file")
386 # --------------------
387 # write flags to file
388 # --------------------
390 # change all the index variables to be in 64-bit unsigned integers (long in C++) and write to file
391 for key, value in variable_list.items():
392 # cast matrix to 64-bit unsigned integer
393 value = np.array(value, dtype=np.uint64)
394 write_matrix(filename, value, key)
395 del value
398def write_grid(filename, grid_size, grid_spacing, pml_size, pml_alpha, Nt, dt, c_ref):
399 """
400 Creates and writes the wavenumber grids and PML variables
401 required by the k-Wave C++ code to the HDF5 file specified by the
402 user.
404 List of parameters that are written:
405 Nx
406 Ny
407 Nz
408 Nt
409 dt
410 dx
411 dy
412 dz
413 c_ref
414 pml_x_alpha
415 pml_y_alpha
416 pml_z_alpha
417 pml_x_size
418 pml_y_size
419 pml_z_size
421 """
423 h5_literals = get_h5_literals()
425 # =========================================================================
426 # STORE FLOATS
427 # =========================================================================
428 variable_list = {
429 "dt": dt,
430 "dx": grid_spacing[0],
431 "dy": grid_spacing[1],
432 "dz": grid_spacing[2],
433 "pml_x_alpha": pml_alpha[0],
434 "pml_y_alpha": pml_alpha[1],
435 "pml_z_alpha": pml_alpha[2],
436 "c_ref": c_ref,
437 }
439 # change float variables to be in single precision (float in C++), then add to HDF5 file
440 for key, value in variable_list.items():
441 # cast matrix to single precision
442 value = cast_to_type(value, h5_literals.MATRIX_DATA_TYPE_MATLAB)
443 write_matrix(filename, value, key)
444 del value
446 # =========================================================================
447 # STORE INTEGERS
448 # =========================================================================
450 # integer variables
451 variable_list = {
452 "Nx": grid_size[0],
453 "Ny": grid_size[1],
454 "Nz": grid_size[2],
455 "Nt": Nt,
456 "pml_x_size": pml_size[0],
457 "pml_y_size": pml_size[1],
458 "pml_z_size": pml_size[2],
459 }
461 # change all the index variables to be in 64-bit unsigned integers (long in C++)
462 for key, value in variable_list.items():
463 # cast matrix to 64-bit unsigned integer
464 value = cast_to_type(value, h5_literals.INTEGER_DATA_TYPE_MATLAB)
465 write_matrix(filename, value, key)
466 del value
469def assign_str_attr(attrs, attr_name, attr_val):
470 """
471 Assigns HDF5 attribute with value as a fixed-length string
473 Args:
474 attrs: HDF5 attribute object
475 attr_name: name of attribute
476 attr_val: value of attribute
478 """
479 attrs.create(attr_name, attr_val, None, dtype=f"<S{len(attr_val)}")
482def load_image(path, is_gray):
483 if is_gray:
484 img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
485 else:
486 img = cv2.imread(path, cv2.IMREAD_COLOR)
487 raise NotImplementedError
488 # im = squeeze(double(im(:, :, 1)) + double(im(:, :, 2)) + double(im(:, :, 3)));
489 img = img.astype(float)
491 # scale pixel values from 0 -> 1
492 img = img.max() - img
493 img = img * (1 / img.max())
494 return img