Coverage for kwave/utils/io.py: 9%

155 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import os 

2import platform 

3import socket 

4from datetime import datetime 

5from typing import Optional 

6 

7import cv2 

8import h5py 

9import numpy as np 

10 

11import kwave 

12 

13from .conversion import cast_to_type 

14from .data import get_date_string 

15from .dotdictionary import dotdict 

16 

17 

18def get_h5_literals(): 

19 literals = dotdict( 

20 { 

21 # data type 

22 "DATA_TYPE_ATT_NAME": "data_type", 

23 "MATRIX_DATA_TYPE_MATLAB": "single", 

24 "MATRIX_DATA_TYPE_C": "float", 

25 "INTEGER_DATA_TYPE_MATLAB": "uint64", 

26 "INTEGER_DATA_TYPE_C": "long", 

27 # real / complex 

28 "DOMAIN_TYPE_ATT_NAME": "domain_type", 

29 "DOMAIN_TYPE_REAL": "real", 

30 "DOMAIN_TYPE_COMPLEX": "complex", 

31 # file descriptors 

32 "FILE_MAJOR_VER_ATT_NAME": "major_version", 

33 "FILE_MINOR_VER_ATT_NAME": "minor_version", 

34 "FILE_DESCR_ATT_NAME": "file_description", 

35 "FILE_CREATION_DATE_ATT_NAME": "creation_date", 

36 "CREATED_BY_ATT_NAME": "created_by", 

37 # file type 

38 "FILE_TYPE_ATT_NAME": "file_type", 

39 "HDF_INPUT_FILE": "input", 

40 "HDF_OUTPUT_FILE": "output", 

41 "HDF_CHECKPOINT_FILE": "checkpoint", 

42 # file version information 

43 "HDF_FILE_MAJOR_VERSION": "1", 

44 "HDF_FILE_MINOR_VERSION": "2", 

45 # compression level 

46 "HDF_COMPRESSION_LEVEL": 0, 

47 } 

48 ) 

49 return literals 

50 

51 

52def write_matrix(filename, matrix: np.ndarray, matrix_name: str, compression_level: int = None, auto_chunk: bool = True): 

53 # get literals 

54 h5_literals = get_h5_literals() 

55 

56 assert isinstance(auto_chunk, bool), "auto_chunk must be a boolean." 

57 

58 if compression_level is None: 

59 compression_level = h5_literals.HDF_COMPRESSION_LEVEL 

60 

61 # dims = num_dim(matrix) 

62 dims = len(matrix.shape) 

63 

64 if dims == 3: 

65 matrix = np.transpose(matrix, [2, 1, 0]) # C <=> Fortran ordering 

66 if dims == 2: 

67 matrix = np.transpose(matrix) # C <=> Fortran ordering 

68 

69 # get the size of the input matrix 

70 if dims == 3: 

71 Nx, Ny, Nz = matrix.shape 

72 elif dims == 2: 

73 Ny, Nz = matrix.shape 

74 Nx = 1 

75 else: 

76 Nx, Ny, Nz = 1, 1, 1 

77 

78 # check size of matrix and set chunk size and compression level 

79 if dims == 3: 

80 # set chunk size to Nx * Ny 

81 chunk_size = [Nx, Ny, 1] 

82 elif dims == 2: 

83 # set chunk size to Nx 

84 chunk_size = [Nx, 1, 1] 

85 elif dims <= 1: 

86 # check that the matrix size is greater than 1 MB 

87 one_mb = (1024**2) / 8 

88 if matrix.size > one_mb: 

89 # set chunk size to 1 MB 

90 if Nx > Ny: 

91 chunk_size = [one_mb, 1, 1] 

92 elif Ny > Nz: 

93 chunk_size = [1, one_mb, 1] 

94 else: 

95 chunk_size = [1, 1, one_mb] 

96 else: 

97 # set no compression 

98 compression_level = 0 

99 

100 # set chunk size to grid size 

101 if matrix.size == 1: 

102 chunk_size = (1, 1, 1) 

103 elif Nx > Ny: 

104 chunk_size = (Nx, 1, 1) 

105 elif Ny > Nz: 

106 chunk_size = (1, Ny, 1) 

107 else: 

108 chunk_size = (1, 1, Nz) 

109 else: 

110 # throw error for unknown matrix size 

111 raise ValueError("Input matrix must have 1, 2 or 3 dimensions.") 

112 

113 # check the format of the matrix is either single precision (float in C++) 

114 # or uint64 (unsigned long in C++) 

115 if matrix.dtype == np.float32: 

116 # set data type flags 

117 data_type_matlab = h5_literals.MATRIX_DATA_TYPE_MATLAB 

118 data_type_c = h5_literals.MATRIX_DATA_TYPE_C 

119 elif matrix.dtype == np.uint64: 

120 # set data type flags 

121 data_type_matlab = h5_literals.INTEGER_DATA_TYPE_MATLAB 

122 data_type_c = h5_literals.INTEGER_DATA_TYPE_C 

123 

124 else: 

125 # throw error for unknown data type 

126 raise ValueError("Input matrix must be of type " "single" " or " "uint64" ".") 

127 

128 # check if the input matrix is real or complex, if complex, rearrange the 

129 # data in the C++ format 

130 if np.isreal(matrix).all(): 

131 # set file tag 

132 domain_type = "real" # DOMAIN_TYPE_REAL 

133 

134 elif dims == 3: 

135 # set file tag 

136 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX 

137 

138 # rearrange the data so the real and imaginary parts are stored in the 

139 # same matrix 

140 matrix = np.concatenate(matrix.real, matrix.imag, axis=0) 

141 matrix = matrix.reshape((Nx, 2, Ny, Nz)) 

142 matrix = np.transpose(matrix, (1, 0, 2, 3)) 

143 matrix = matrix.reshape((2 * Nx, Ny, Nz)) 

144 

145 # update the size of Nx 

146 Nx = 2 * Nx 

147 

148 elif dims <= 1: 

149 # set file tag 

150 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX 

151 

152 # rearrange the data so the real and imaginary parts are stored in the 

153 # same matrix 

154 nelems = matrix.size 

155 matrix = matrix.reshape((nelems, 1)) 

156 matrix = np.concatenate(matrix.real, matrix.imag, axis=0) 

157 matrix = matrix.reshape((nelems, 2, 1, 1)) 

158 matrix = np.transpose(matrix, (1, 0, 2, 3)) 

159 

160 # update the matrix size 

161 Nx = Nx * (2 - np.array(Nx == 1).astype(float)) 

162 Ny = Ny * (2 - np.array(Ny == 1).astype(float)) 

163 Nz = Nz * (2 - np.array(Nz == 1).astype(float)) 

164 

165 # double store in x-direction if a complex scalar 

166 if Nx == 1 and Ny == 1 and Nz == 1: 

167 Nx = 2 * Nx 

168 

169 # put in correct dimension 

170 matrix = matrix.reshape((Nx, Ny, Nz)) 

171 

172 else: 

173 raise NotImplementedError("Currently there is no support for saving 2D complex matrices.") 

174 

175 # allocate a holder for the new matrix within the file 

176 opts = {"dtype": data_type_matlab, "chunks": auto_chunk if auto_chunk is True else tuple(chunk_size)} 

177 

178 if compression_level != 0: 

179 # use compression 

180 opts["compression"] = compression_level 

181 

182 # write the matrix into the file 

183 with h5py.File(filename, "a") as f: 

184 f.create_dataset(f"/{matrix_name}", [Nx, Ny, Nz], data=matrix, **opts) 

185 

186 # set attributes for the matrix (used by k-Wave++) 

187 assign_str_attr(f[f"/{matrix_name}"].attrs, h5_literals.DOMAIN_TYPE_ATT_NAME, domain_type) 

188 assign_str_attr(f[f"/{matrix_name}"].attrs, h5_literals.DATA_TYPE_ATT_NAME, data_type_c) 

189 

190 

191def write_attributes(filename: str, file_description: Optional[str] = None) -> None: 

192 """ 

193 Write attributes to a HDF5 file. 

194 

195 This function writes attributes to a HDF5 file using a deprecated legacy method if legacy is set to True, or a new 

196 typed method if legacy is set to False. The function warns if legacy is set to True and deprecates it. If 

197 file_description is not provided, a default file description will be used. 

198 

199 Args: 

200 filename: The name of the HDF5 file. 

201 file_description: The description of the file. If not provided, a default file description 

202 will be used. 

203 

204 """ 

205 

206 # get literals 

207 h5_literals = get_h5_literals() 

208 

209 # get computer info 

210 comp_info = dotdict( 

211 { 

212 "date": datetime.now().strftime("%d-%b-%Y"), 

213 "computer_name": socket.gethostname(), 

214 "operating_system_type": platform.system(), 

215 "operating_system": platform.system() + " " + platform.release() + " " + platform.version(), 

216 "user_name": os.environ.get("USERNAME"), 

217 "matlab_version": "N/A", 

218 "kwave_version": "1.3", 

219 "kwave_path": "N/A", 

220 } 

221 ) 

222 

223 # set file description if not provided by user 

224 if file_description is None: 

225 file_description = ( 

226 f"Input data created by {comp_info.user_name} running MATLAB " 

227 f"{comp_info.matlab_version} on {comp_info.operating_system_type}" 

228 ) 

229 

230 # set additional file attributes 

231 with h5py.File(filename, "a") as f: 

232 # create a dictionary of attributes 

233 attributes = { 

234 h5_literals.FILE_MAJOR_VER_ATT_NAME: h5_literals.HDF_FILE_MAJOR_VERSION, 

235 h5_literals.FILE_MINOR_VER_ATT_NAME: h5_literals.HDF_FILE_MINOR_VERSION, 

236 h5_literals.CREATED_BY_ATT_NAME: f"k-Wave {kwave.__version__}", 

237 h5_literals.FILE_DESCR_ATT_NAME: file_description, 

238 h5_literals.FILE_TYPE_ATT_NAME: h5_literals.HDF_INPUT_FILE, 

239 h5_literals.FILE_CREATION_DATE_ATT_NAME: get_date_string(), 

240 } 

241 # loop through the attributes dictionary and assign each attribute to the file 

242 for key, value in attributes.items(): 

243 assign_str_attr(f.attrs, key, value) 

244 

245 

246def write_flags(filename): 

247 """ 

248 writeFlags reads the input HDF5 file and derives and writes the 

249 required source and medium flags based on the datasets present in the 

250 file. For example, if the file contains a data set named 'BonA', the 

251 nonlinear_flag will be written as true. Conditional flags are also 

252 written. The source mode flags are written when appropriate if they 

253 are not already present in the file. The default source mode is 

254 'additive'. 

255 

256 List of flags that are always written 

257 ux_source_flag 

258 uy_source_flag 

259 uz_source_flag 

260 sxx_source_flag 

261 sxy_source_flag 

262 sxz_source_flag 

263 syy_source_flag 

264 syz_source_flag 

265 szz_source_flag 

266 p_source_flag 

267 p0_source_flag 

268 transducer_source_flag 

269 nonuniform_grid_flag 

270 nonlinear_flag 

271 absorbing_flag 

272 axisymmetric_flag 

273 elastic_flag 

274 sensor_mask_type 

275 

276 List of conditional flags 

277 u_source_mode 

278 u_source_many 

279 p_source_mode 

280 p_source_many 

281 s_source_mode 

282 s_source_many 

283 

284 Args: 

285 filename: 

286 

287 """ 

288 

289 # h5_literals = get_h5_literals() 

290 

291 with h5py.File(filename, "r") as hf: 

292 names = hf.keys() 

293 

294 v_list = [ 

295 ("ux_source", "u_source_many"), 

296 ("uy_source", "u_source_many"), 

297 ("uz_source", "u_source_many"), 

298 ("sxx_source", "s_source_many"), 

299 ("syy_source", "s_source_many"), 

300 ("szz_source", "s_source_many"), 

301 ("sxy_source", "s_source_many"), 

302 ("sxz_source", "s_source_many"), 

303 ("syz_source", "s_source_many"), 

304 ("p_source", "p_source_many"), 

305 ] 

306 variable_list = {} 

307 for prefix, many_flag_key in v_list: 

308 inp_name = f"{prefix}_input" 

309 flag_name = f"{prefix}_flag" 

310 if inp_name in names: 

311 variable_list[flag_name] = hf[inp_name].shape[1] 

312 

313 variable_list[many_flag_key] = hf[inp_name].shape[0] != 1 

314 else: 

315 variable_list[flag_name] = 0 

316 

317 # -------------------- 

318 # u source 

319 # -------------------- 

320 

321 # write u_source mode if not already in file (1 is Additive, 0 is Dirichlet) 

322 if any(variable_list[flag] for flag in ["ux_source_flag", "uy_source_flag", "uz_source_flag"]) and "u_source_mode" not in names: 

323 variable_list["u_source_mode"] = 1 

324 

325 # -------------------- 

326 # s source 

327 # -------------------- 

328 

329 # write s_source mode if not already in file (1 is Additive, 0 is Dirichlet) 

330 if ( 

331 any( 

332 variable_list[flag] 

333 for flag in [ 

334 "sxx_source_flag", 

335 "syy_source_flag", 

336 "szz_source_flag", 

337 "sxy_source_flag", 

338 "sxz_source_flag", 

339 "syz_source_flag", 

340 ] 

341 ) 

342 and "s_source_mode" not in names 

343 ): 

344 variable_list["s_source_mode"] = 1 

345 

346 # -------------------- 

347 # p source 

348 # -------------------- 

349 

350 # write p_source mode if not already in file (1 is Additive, 0 is Dirichlet) 

351 if any(variable_list[flag] for flag in ["p_source_flag"]) and "p_source_mode" not in names: 

352 variable_list["p_source_mode"] = 1 

353 

354 # check for p0_source_input and set p0_source_flag 

355 variable_list["p0_source_flag"] = "p0_source_input" in names 

356 

357 # -------------------- 

358 # additional flags 

359 # -------------------- 

360 # check for transducer_source_input and set transducer_source_flag 

361 variable_list["transducer_source_flag"] = "transducer_source_input" in names 

362 

363 # check for BonA and set nonlinear flag 

364 variable_list["nonlinear_flag"] = "BonA" in names 

365 

366 # check for alpha_coeff and set absorbing flag 

367 variable_list["absorbing_flag"] = "alpha_coeff" in names 

368 

369 # check for lambda and set elastic flag 

370 variable_list["elastic_flag"] = "lambda" in names 

371 

372 # set axisymmetric grid flag to false 

373 variable_list["axisymmetric_flag"] = 0 

374 

375 # set nonuniform grid flag to false 

376 variable_list["nonuniform_grid_flag"] = 0 

377 

378 # check for sensor_mask_index and sensor_mask_corners 

379 if "sensor_mask_index" in names: 

380 variable_list["sensor_mask_type"] = 0 

381 elif "sensor_mask_corners" in names: 

382 variable_list["sensor_mask_type"] = 1 

383 else: 

384 raise ValueError("Either sensor_mask_index or sensor_mask_corners must be defined in the input file") 

385 

386 # -------------------- 

387 # write flags to file 

388 # -------------------- 

389 

390 # change all the index variables to be in 64-bit unsigned integers (long in C++) and write to file 

391 for key, value in variable_list.items(): 

392 # cast matrix to 64-bit unsigned integer 

393 value = np.array(value, dtype=np.uint64) 

394 write_matrix(filename, value, key) 

395 del value 

396 

397 

398def write_grid(filename, grid_size, grid_spacing, pml_size, pml_alpha, Nt, dt, c_ref): 

399 """ 

400 Creates and writes the wavenumber grids and PML variables 

401 required by the k-Wave C++ code to the HDF5 file specified by the 

402 user. 

403 

404 List of parameters that are written: 

405 Nx 

406 Ny 

407 Nz 

408 Nt 

409 dt 

410 dx 

411 dy 

412 dz 

413 c_ref 

414 pml_x_alpha 

415 pml_y_alpha 

416 pml_z_alpha 

417 pml_x_size 

418 pml_y_size 

419 pml_z_size 

420 

421 """ 

422 

423 h5_literals = get_h5_literals() 

424 

425 # ========================================================================= 

426 # STORE FLOATS 

427 # ========================================================================= 

428 variable_list = { 

429 "dt": dt, 

430 "dx": grid_spacing[0], 

431 "dy": grid_spacing[1], 

432 "dz": grid_spacing[2], 

433 "pml_x_alpha": pml_alpha[0], 

434 "pml_y_alpha": pml_alpha[1], 

435 "pml_z_alpha": pml_alpha[2], 

436 "c_ref": c_ref, 

437 } 

438 

439 # change float variables to be in single precision (float in C++), then add to HDF5 file 

440 for key, value in variable_list.items(): 

441 # cast matrix to single precision 

442 value = cast_to_type(value, h5_literals.MATRIX_DATA_TYPE_MATLAB) 

443 write_matrix(filename, value, key) 

444 del value 

445 

446 # ========================================================================= 

447 # STORE INTEGERS 

448 # ========================================================================= 

449 

450 # integer variables 

451 variable_list = { 

452 "Nx": grid_size[0], 

453 "Ny": grid_size[1], 

454 "Nz": grid_size[2], 

455 "Nt": Nt, 

456 "pml_x_size": pml_size[0], 

457 "pml_y_size": pml_size[1], 

458 "pml_z_size": pml_size[2], 

459 } 

460 

461 # change all the index variables to be in 64-bit unsigned integers (long in C++) 

462 for key, value in variable_list.items(): 

463 # cast matrix to 64-bit unsigned integer 

464 value = cast_to_type(value, h5_literals.INTEGER_DATA_TYPE_MATLAB) 

465 write_matrix(filename, value, key) 

466 del value 

467 

468 

469def assign_str_attr(attrs, attr_name, attr_val): 

470 """ 

471 Assigns HDF5 attribute with value as a fixed-length string 

472 

473 Args: 

474 attrs: HDF5 attribute object 

475 attr_name: name of attribute 

476 attr_val: value of attribute 

477 

478 """ 

479 attrs.create(attr_name, attr_val, None, dtype=f"<S{len(attr_val)}") 

480 

481 

482def load_image(path, is_gray): 

483 if is_gray: 

484 img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 

485 else: 

486 img = cv2.imread(path, cv2.IMREAD_COLOR) 

487 raise NotImplementedError 

488 # im = squeeze(double(im(:, :, 1)) + double(im(:, :, 2)) + double(im(:, :, 3))); 

489 img = img.astype(float) 

490 

491 # scale pixel values from 0 -> 1 

492 img = img.max() - img 

493 img = img * (1 / img.max()) 

494 return img