Coverage for kwave/kgrid.py: 40%

281 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1import math 

2import sys 

3from dataclasses import dataclass 

4 

5import numpy as np 

6 

7from kwave.data import FlexibleVector, Vector 

8from kwave.enums import DiscreteCosine, DiscreteSine 

9from kwave.utils import matlab 

10from kwave.utils.math import largest_prime_factor 

11 

12 

13@dataclass 

14class kWaveGrid(object): 

15 """ 

16 kWaveGrid is the grid class used across the k-Wave Toolbox. An object 

17 of the kWaveGrid class contains the grid coordinates and wavenumber 

18 matrices used within the simulation and reconstruction functions in 

19 k-Wave. The grid matrices are indexed as: (x, 1) in 1D; (x, y) in 

20 2D; and (x, y, z) in 3D. The grid is assumed to be a regularly spaced 

21 Cartesian grid, with grid spacing given by dx, dy, dz (typically the 

22 grid spacing in each direction is constant). 

23 """ 

24 

25 # default CFL number 

26 CFL_DEFAULT = 0.3 

27 

28 # machine precision 

29 MACHINE_PRECISION = 100 * sys.float_info.epsilon 

30 

31 def __init__(self, N, spacing): 

32 """ 

33 Args: 

34 N: grid size in each dimension [grid points] 

35 spacing: grid point spacing in each direction [m] 

36 """ 

37 N, spacing = np.atleast_1d(N), np.atleast_1d(spacing) # if inputs are lists 

38 assert N.ndim == 1 and spacing.ndim == 1 # ensure no multidimensional lists 

39 assert (1 <= N.size <= 3) and (1 <= spacing.size <= 3) # ensure valid dimensionality 

40 assert N.size == spacing.size, "Size list N and spacing list do not have the same size." 

41 

42 self.N = N.astype(int) #: grid size in each dimension [grid points] 

43 self.spacing = spacing #: grid point spacing in each direction [m] 

44 self.dim = self.N.size #: Number of dimensions (1, 2 or 3) 

45 

46 self.nonuniform = False #: flag that indicates grid non-uniformity 

47 self.dt = "auto" #: size of time step [s] 

48 self.Nt = "auto" #: number of time steps [s] 

49 

50 # originally there was [xn_vec, yn_vec, zn_vec] 

51 self.n_vec = FlexibleVector([0] * self.dim) #: position vectors for the grid points in [0, 1] 

52 # originally there was [xn_vec_sgx, yn_vec_sgy, zn_vec_sgz] 

53 self.n_vec_sg = FlexibleVector([0] * self.dim) #: position vectors for the staggered grid points in [0, 1] 

54 

55 # originally there was [dxudxn, dyudyn, dzudzn] 

56 self.dudn = FlexibleVector([0] * self.dim) #: transformation gradients between uniform and staggered grids 

57 # originally there was [dxudxn_sgx, dyudyn_sgy, dzudzn_sgz] 

58 self.dudn_sg = FlexibleVector([0] * self.dim) #: transformation gradients between uniform and staggered grids 

59 

60 # assign the grid parameters for the x spatial direction 

61 # originally kx_vec 

62 self.k_vec = FlexibleVector([self.makeDim(self.Nx, self.dx)]) #: Nx x 1 vector of wavenumber components in the x-direction [rad/m] 

63 

64 if self.dim == 1: 64 ↛ 66line 64 didn't jump to line 66 because the condition on line 64 was never true

65 # define the scalar wavenumber based on the wavenumber components 

66 self.k = abs(self.k_vec.x) #: scalar wavenumber 

67 

68 if self.dim >= 2: 68 ↛ 80line 68 didn't jump to line 80 because the condition on line 68 was always true

69 # assign the grid parameters for the x and y spatial directions 

70 # Ny x 1 vector of wavenumber components in the y-direction [rad/m] 

71 self.k_vec = self.k_vec.append(self.makeDim(self.Ny, self.dy)) 

72 

73 if self.dim == 2: 73 ↛ 80line 73 didn't jump to line 80 because the condition on line 73 was always true

74 # define the wavenumber based on the wavenumber components 

75 self.k = np.zeros((self.Nx, self.Ny)) 

76 self.k = np.reshape(self.k_vec.x, (-1, 1)) ** 2 + self.k 

77 self.k = np.reshape(self.k_vec.y, (1, -1)) ** 2 + self.k 

78 self.k = np.sqrt(self.k) #: scalar wavenumber 

79 

80 if self.dim == 3: 80 ↛ 83line 80 didn't jump to line 83 because the condition on line 80 was never true

81 # assign the grid parameters for the x, y, and z spatial directions 

82 # Nz x 1 vector of wavenumber components in the z-direction [rad/m] 

83 self.k_vec = self.k_vec.append(self.makeDim(self.Nz, self.dz)) 

84 

85 # define the wavenumber based on the wavenumber components 

86 self.k = np.zeros((self.Nx, self.Ny, self.Nz)) 

87 self.k = np.reshape(self.k_vec.x, (-1, 1, 1)) ** 2 + self.k 

88 self.k = np.reshape(self.k_vec.y, (1, -1, 1)) ** 2 + self.k 

89 self.k = np.reshape(self.k_vec.z, (1, 1, -1)) ** 2 + self.k 

90 self.k = np.sqrt(self.k) #: scalar wavenumber 

91 

92 @property 

93 def t_array(self): 

94 """ 

95 time array [s] 

96 """ 

97 # TODO (walter): I would change this functionality to return a time array even if Nt or dt are not yet set 

98 # (e.g. if they are still 'auto') 

99 

100 if self.Nt == "auto" or self.dt == "auto": 

101 return "auto" 

102 else: 

103 t_array = np.arange(0, self.Nt) * self.dt 

104 # TODO: adding this extra dimension seems unnecessary 

105 # This leads to an extra squeeze when plotting e.g. in example "array as sensor" on lines 110 and 111 

106 return np.expand_dims(t_array, axis=0) 

107 

108 @t_array.setter 

109 def t_array(self, t_array): 

110 # check for 'auto' input 

111 if t_array == "auto": 111 ↛ 118line 111 didn't jump to line 118 because the condition on line 111 was always true

112 # set values to auto 

113 self.Nt = "auto" 

114 self.dt = "auto" 

115 

116 else: 

117 # extract property values 

118 Nt_temp = t_array.size 

119 dt_temp = t_array[1] - t_array[0] 

120 

121 # check the time array begins at zero 

122 assert t_array[0] == 0, "t_array must begin at zero." 

123 

124 # check the time array is evenly spaced 

125 assert (t_array[1:] - t_array[0:-1] - dt_temp).max() < self.MACHINE_PRECISION, "t_array must be evenly spaced." 

126 

127 # check the time steps are increasing 

128 assert dt_temp > 0, "t_array must be monotonically increasing." 

129 

130 # assign values 

131 self.Nt = Nt_temp 

132 self.dt = dt_temp 

133 

134 def setTime(self, Nt, dt) -> None: 

135 """ 

136 Set Nt and dt based on user input 

137 

138 Args: 

139 Nt: 

140 dt: 

141 

142 Returns: None 

143 """ 

144 # check the value for Nt 

145 assert ( 

146 isinstance(Nt, int) or np.issubdtype(Nt, np.int64) or np.issubdtype(Nt, np.int32) 

147 ) and Nt > 0, "Nt must be a positive integer." 

148 

149 # check the value for dt 

150 assert dt > 0, "dt must be positive." 

151 

152 # assign values 

153 self.Nt = Nt 

154 self.dt = dt 

155 

156 @property 

157 def Nx(self): 

158 """ 

159 grid size in x-direction [grid points] 

160 """ 

161 return self.N[0] 

162 

163 @property 

164 def Ny(self): 

165 """ 

166 grid size in y-direction [grid points] 

167 """ 

168 return self.N[1] if self.N.size >= 2 else 0 

169 

170 @property 

171 def Nz(self): 

172 """ 

173 grid size in z-direction [grid points] 

174 """ 

175 return self.N[2] if self.N.size == 3 else 0 

176 

177 @property 

178 def dx(self): 

179 """ 

180 grid point spacing in x-direction [m] 

181 """ 

182 return self.spacing[0] 

183 

184 @property 

185 def dy(self): 

186 """ 

187 grid point spacing in y-direction [m] 

188 """ 

189 return self.spacing[1] if self.spacing.size >= 2 else 0 

190 

191 @property 

192 def dz(self): 

193 """ 

194 grid point spacing in z-direction [m] 

195 """ 

196 return self.spacing[2] if self.spacing.size == 3 else 0 

197 

198 @property 

199 def x_vec(self): 

200 """ 

201 Nx x 1 vector of the grid coordinates in the x-direction [m] 

202 """ 

203 # calculate x_vec based on kx_vec 

204 return self.size[0] * self.k_vec.x * self.dx / (2 * np.pi) 

205 

206 @property 

207 def y_vec(self): 

208 """ 

209 Ny x 1 vector of the grid coordinates in the y-direction [m] 

210 """ 

211 # calculate y_vec based on ky_vec 

212 if self.dim < 2: 

213 return np.nan 

214 return self.size[1] * self.k_vec.y * self.dy / (2 * np.pi) 

215 

216 @property 

217 def z_vec(self): 

218 """ 

219 Nz x 1 vector of the grid coordinates in the z-direction [m] 

220 """ 

221 # calculate z_vec based on kz_vec 

222 if self.dim < 3: 

223 return np.nan 

224 return self.size[2] * self.k_vec.z * self.dz / (2 * np.pi) 

225 

226 @property 

227 def x(self): 

228 """ 

229 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the x-direction [m] 

230 """ 

231 return self.size[0] * self.kx * self.dx / (2 * math.pi) 

232 

233 @property 

234 def y(self): 

235 """ 

236 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the y-direction [m] 

237 """ 

238 if self.dim < 2: 

239 return 0 

240 return self.size[1] * self.ky * self.dy / (2 * math.pi) 

241 

242 @property 

243 def z(self): 

244 """ 

245 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the z-direction [m] 

246 """ 

247 if self.dim < 3: 

248 return 0 

249 return self.size[2] * self.kz * self.dz / (2 * math.pi) 

250 

251 @property 

252 def xn(self): 

253 """ 

254 3D plaid non-uniform spatial grids 

255 

256 Returns: 

257 plaid xn matrix 

258 """ 

259 if self.dim == 1: 

260 return self.n_vec.x if self.nonuniform else 0 

261 elif self.dim == 2: 

262 return np.tile(self.n_vec.x, (1, self.Ny)) if self.nonuniform else 0 

263 else: 

264 return np.tile(self.n_vec.x, (1, self.Ny, self.Nz)) if self.nonuniform else 0 

265 

266 @property 

267 def yn(self): 

268 """ 

269 3D plaid non-uniform spatial grids 

270 

271 Returns: 

272 plaid yn matrix 

273 """ 

274 if self.dim < 2: 

275 return np.nan 

276 

277 n_vec_y = np.array(self.n_vec.y).T 

278 

279 if self.dim == 2: 

280 return np.tile(n_vec_y, (self.Nx, 1)) if self.nonuniform else 0 

281 else: 

282 return np.tile(n_vec_y, (self.Nx, 1, self.Nz)) if self.nonuniform else 0 

283 

284 @property 

285 def zn(self): 

286 """ 

287 3D plaid non-uniform spatial grids 

288 Returns: 

289 plaid zn matrix 

290 """ 

291 if self.dim < 3: 

292 return np.nan 

293 n_vec_z = np.atleast_1d(np.squeeze(self.n_vec.z))[None, None, :] 

294 return np.tile(n_vec_z, (self.Nx, self.Ny, 1)) if self.nonuniform else 0 

295 

296 @property 

297 def size(self): 

298 """ 

299 Size of grid in the all directions [m] 

300 """ 

301 return Vector(self.N * self.spacing) 

302 

303 @property 

304 def total_grid_points(self) -> np.ndarray: 

305 """ 

306 Total number of grid points (equal to Nx * Ny * Nz) 

307 """ 

308 return np.prod(self.N) 

309 

310 @property 

311 def kx(self): 

312 """ 

313 Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the x-direction [rad/m] 

314 

315 Returns: 

316 plaid xn matrix 

317 """ 

318 if self.dim == 1: 

319 return self.k_vec.x 

320 elif self.dim == 2: 

321 return np.tile(self.k_vec.x, (1, self.Ny)) 

322 else: 

323 return np.tile(self.k_vec.x[:, :, None], (1, self.Ny, self.Nz)) 

324 

325 @property 

326 def ky(self): 

327 """ 

328 Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the y-direction [rad/m] 

329 

330 Returns: 

331 plaid yn matrix 

332 """ 

333 if self.dim == 2: 

334 return np.tile(self.k_vec.y.T, (self.Nx, 1)) 

335 elif self.dim == 3: 

336 return np.tile(self.k_vec.y[None, :, :], (self.Nx, 1, self.Nz)) 

337 return np.nan 

338 

339 @property 

340 def kz(self): 

341 """ 

342 Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the z-direction [rad/m] 

343 

344 Returns: 

345 plaid zn matrix 

346 """ 

347 if self.dim == 3: 

348 return np.tile(self.k_vec.z.T[None, :, :], (self.Nx, self.Ny, 1)) 

349 else: 

350 return np.nan 

351 

352 @property 

353 def x_size(self): 

354 """ 

355 Size of grid in the x-direction [m] 

356 """ 

357 return self.Nx * self.dx 

358 

359 @property 

360 def y_size(self): 

361 """ 

362 Size of grid in the y-direction [m] 

363 """ 

364 return self.Ny * self.dy 

365 

366 @property 

367 def z_size(self): 

368 """ 

369 Size of grid in the z-direction [m] 

370 """ 

371 return self.Nz * self.dz 

372 

373 @property 

374 def k_max(self): # added by us, not the same as kWave k_max (see k_max_all for KwaveGrid.k_max) 

375 """ 

376 Maximum supported spatial frequency in the 3 directions [rad/m] 

377 

378 Returns: 

379 Vector of 3 elements each in [rad/m]. Value for higher dimensions set to NaN 

380 """ 

381 # 

382 kx_max = np.abs(self.k_vec.x).max() 

383 ky_max = np.abs(self.k_vec.y).max() if self.dim >= 2 else np.nan 

384 kz_max = np.abs(self.k_vec.z).max() if self.dim == 3 else np.nan 

385 return Vector([kx_max, ky_max, kz_max]) 

386 

387 @property 

388 def k_max_all(self): 

389 """ 

390 Maximum supported spatial frequency in all directions [rad/m] 

391 Originally k_max in kWave.kWaveGrid! 

392 

393 Returns: 

394 Scalar in [rad/m] 

395 """ 

396 # 

397 return np.nanmin(self.k_max) 

398 

399 ######################################## 

400 # functions that can only be accessed by class members 

401 ######################################## 

402 @staticmethod 

403 # TODO (walter): convert this name to snake case 

404 def makeDim(num_points, spacing): 

405 """ 

406 Create the grid parameters for a single spatial direction 

407 

408 Args: 

409 num_points: 

410 spacing: 

411 

412 Returns: 

413 

414 """ 

415 # define the discretisation of the spatial dimension such that there is always a DC component 

416 if num_points % 2 == 0: 416 ↛ 421line 416 didn't jump to line 421 because the condition on line 416 was always true

417 # grid dimension has an even number of points 

418 nx = np.arange(-num_points / 2, num_points / 2) / num_points 

419 else: 

420 # grid dimension has an odd number of points 

421 nx = np.arange(-(num_points - 1) / 2, (num_points - 1) / 2 + 1) / num_points 

422 nx = np.array(nx).T 

423 

424 # force middle value to be zero in case 1/Nx is a recurring 

425 # number and the series doesn't give exactly zero 

426 nx[int(num_points // 2)] = 0 

427 

428 # define the wavenumber vector components 

429 res = (2 * math.pi / spacing) * nx 

430 return res[:, None] 

431 

432 def highest_prime_factors(self, axisymmetric=None) -> np.ndarray: 

433 """ 

434 calculate the highest prime factors 

435 

436 Args: 

437 axisymmetric: Axisymmetric code or None 

438 

439 Returns: 

440 Vector of three elements 

441 """ 

442 # import statement place here in order to avoid circular dependencies 

443 if axisymmetric is not None: 

444 if axisymmetric == "WSWA": 

445 prime_facs = [largest_prime_factor(self.Nx), largest_prime_factor(self.Ny * 4), largest_prime_factor(self.Nz)] 

446 elif axisymmetric == "WSWS": 

447 prime_facs = [largest_prime_factor(self.Nx), largest_prime_factor(self.Ny * 2 - 2), largest_prime_factor(self.Nz)] 

448 else: 

449 raise ValueError("Unknown axisymmetric symmetry.") 

450 else: 

451 prime_facs = [largest_prime_factor(self.Nx), largest_prime_factor(self.Ny), largest_prime_factor(self.Nz)] 

452 return np.array(prime_facs) 

453 

454 # TODO (walter): convert this name to snake case 

455 def makeTime(self, c, cfl=CFL_DEFAULT, t_end=None): 

456 """ 

457 Compute Nt and dt based on the cfl number and grid size, where 

458 the number of time-steps is chosen based on the time it takes to 

459 travel from one corner of the grid to the geometrically opposite 

460 corner. Note, if c is given as a matrix, the calculation for dt 

461 is based on the maximum value, and the calculation for t_end 

462 based on the minimum value. 

463 

464 Args: 

465 c: sound speed 

466 cfl: convergence condition by Courant–Friedrichs–Lewy 

467 t_end: final time step 

468 

469 Returns: 

470 Nothing 

471 """ 

472 # if c is a matrix, find the minimum and maximum values 

473 c = np.array(c) 

474 c_min, c_max = np.min(c), np.max(c) 

475 

476 # check for user define t_end, otherwise set the simulation 

477 # length based on the size of the grid diagonal and the maximum 

478 # sound speed in the medium 

479 if t_end is None: 479 ↛ 483line 479 didn't jump to line 483 because the condition on line 479 was always true

480 t_end = np.linalg.norm(self.size, ord=2) / c_min 

481 

482 # extract the smallest grid spacing 

483 min_grid_dim = np.min(self.spacing) 

484 

485 # assign time step based on CFL stability criterion 

486 self.dt = cfl * min_grid_dim / c_max 

487 

488 # assign number of time steps based on t_end 

489 self.Nt = int(np.floor(t_end / self.dt) + 1) 

490 

491 # catch case where dt is a recurring number 

492 if (np.floor(t_end / self.dt) != np.ceil(t_end / self.dt)) and (matlab.rem(t_end, self.dt) == 0): 492 ↛ 493line 492 didn't jump to line 493 because the condition on line 492 was never true

493 self.Nt = self.Nt + 1 

494 

495 return self.t_array, self.dt 

496 

497 ################################################## 

498 #### 

499 #### FUNCTIONS BELOW WERE NOT TESTED FOR CORRECTNESS! 

500 #### 

501 ################################################## 

502 def kx_vec_dtt(self, dtt_type): 

503 """ 

504 Compute the DTT wavenumber vector in the x-direction 

505 

506 Args: 

507 dtt_type: 

508 

509 Returns: 

510 

511 """ 

512 kx_vec_dtt, M = self.makeDTTDim(self.Nx, self.dx, dtt_type) 

513 return kx_vec_dtt, M 

514 

515 def ky_vec_dtt(self, dtt_type): 

516 """ 

517 Compute the DTT wavenumber vector in the y-direction 

518 

519 Args: 

520 dtt_type: 

521 

522 Returns: 

523 

524 """ 

525 ky_vec_dtt, M = self.makeDTTDim(self.Ny, self.dy, dtt_type) 

526 return ky_vec_dtt, M 

527 

528 def kz_vec_dtt(self, dtt_type): 

529 """ 

530 Compute the DTT wavenumber vector in the z-direction 

531 

532 Args: 

533 dtt_type: 

534 

535 Returns: 

536 

537 """ 

538 kz_vec_dtt, M = self.makeDTTDim(self.Nz, self.dz, dtt_type) 

539 return kz_vec_dtt, M 

540 

541 @staticmethod 

542 # TODO (walter): convert this name to snake case 

543 def makeDTTDim(Nx, dx, dtt_type): 

544 """ 

545 Create the DTT grid parameters for a single spatial direction 

546 

547 Args: 

548 Nx: 

549 dx: 

550 dtt_type: 

551 

552 Returns: 

553 

554 """ 

555 

556 # compute the implied period of the input function 

557 if dtt_type == DiscreteCosine.TYPE_1: 

558 M = 2 * (Nx - 1) 

559 elif dtt_type == DiscreteSine.TYPE_1: 

560 M = 2 * (Nx + 1) 

561 else: 

562 M = 2 * Nx 

563 

564 # calculate the wavenumbers 

565 if dtt_type == DiscreteCosine.TYPE_1: 

566 # whole-wavenumber DTT 

567 # WSWS / DCT-I 

568 n = np.arange(0, M // 2 + 1).T 

569 kx_vec = 2 * math.pi * n / (M * dx) 

570 elif dtt_type == DiscreteCosine.TYPE_2: 

571 # whole-wavenumber DTT 

572 # HSHS / DCT-II 

573 n = np.arange(0, M // 2).T 

574 kx_vec = 2 * math.pi * n / (M * dx) 

575 elif dtt_type == DiscreteSine.TYPE_1: 

576 # whole-wavenumber DTT 

577 # WAWA / DST-I 

578 n = np.arange(1, M // 2).T 

579 kx_vec = 2 * math.pi * n / (M * dx) 

580 elif dtt_type == DiscreteSine.TYPE_2: 

581 # whole-wavenumber DTT 

582 # HAHA / DST-II 

583 n = np.arange(1, M // 2 + 1).T 

584 kx_vec = 2 * math.pi * n / (M * dx) 

585 elif dtt_type in [DiscreteCosine.TYPE_3, DiscreteCosine.TYPE_4, DiscreteSine.TYPE_3, DiscreteSine.TYPE_4]: 

586 # half-wavenumber DTTs 

587 # WSWA / DCT-III 

588 # HSHA / DCT-IV 

589 # WAWS / DST-III 

590 # HAHS / DST-IV 

591 n = np.arange(0, M // 2).T 

592 kx_vec = 2 * math.pi * (n + 0.5) / (M * dx) 

593 else: 

594 raise ValueError 

595 

596 return kx_vec, M 

597 

598 ######################################## 

599 # functions for non-uniform grids 

600 ######################################## 

601 # TODO (walter): convert this name to snake case 

602 def setNUGrid(self, dim, n_vec, dudn, n_vec_sg, dudn_sg): 

603 """ 

604 Function to set non-uniform grid parameters in specified dimension 

605 

606 Args: 

607 dim: 

608 n_vec: 

609 dudn: 

610 n_vec_sg: 

611 dudn_sg: 

612 

613 Returns: 

614 

615 """ 

616 

617 # check the dimension to set the nonuniform grid is appropriate 

618 assert dim <= self.dim, f"Cannot set nonuniform parameters for dimension {dim} of {self.dim}-dimensional grid." 

619 

620 # force non-uniform grid spacing to be column vectors, and the 

621 # gradients to be in the correct direction for use with bsxfun 

622 n_vec = np.reshape(n_vec, (-1, 1), order="F") 

623 n_vec_sg = np.reshape(n_vec_sg, (-1, 1), order="F") 

624 

625 if dim == 1: 

626 dudn = np.reshape(dudn, (-1, 1), order="F") 

627 dudn_sg = np.reshape(dudn_sg, (-1, 1), order="F") 

628 elif dim == 2: 

629 dudn = np.reshape(dudn, (1, -1), order="F") 

630 dudn_sg = np.reshape(dudn_sg, (1, -1), order="F") 

631 elif dim == 3: 

632 dudn = np.reshape(dudn, (1, 1, -1), order="F") 

633 dudn_sg = np.reshape(dudn_sg, (1, 1, -1), order="F") 

634 

635 self.n_vec.assign_dim(self.dim, n_vec) 

636 self.n_vec_sg.assign_dim(self.dim, n_vec_sg) 

637 

638 self.dudn.assign_dim(self.dim, dudn) 

639 self.dudn_sg.assign_dim(self.dim, dudn_sg) 

640 

641 # set non-uniform flag 

642 self.nonuniform = True 

643 

644 def k_dtt(self, dtt_type): # Not tested for correctness! 

645 """ 

646 compute the individual wavenumber vectors, where dtt_type is the 

647 type of discrete trigonometric transform, which corresponds to 

648 the assumed input symmetry of the input function, where: 

649 

650 1. DCT-I WSWS 

651 2. DCT-II HSHS 

652 3. DCT-III WSWA 

653 4. DCT-IV HSHA 

654 5. DST-I WAWA 

655 6. DST-II HAHA 

656 7. DST-III WAWS 

657 8. DST-IV HAHS 

658 

659 Args: 

660 dtt_type: 

661 

662 Returns: 

663 

664 """ 

665 # check dtt_type is a scalar or a vector the same size self.dim 

666 dtt_type = np.array(dtt_type) 

667 assert dtt_type.size in [1, self.dim], f"dtt_type must be a scalar, or {self.dim}D vector" 

668 if self.dim == 1: 

669 k, M = self.kx_vec_dtt(dtt_type[0]) 

670 return k, M 

671 elif self.dim == 2: 

672 # assign the grid parameters for the x and y spatial directions 

673 kx_vec_dtt, Mx = self.kx_vec_dtt(dtt_type[0]) 

674 ky_vec_dtt, My = self.ky_vec_dtt(dtt_type[-1]) 

675 

676 # define the wavenumber based on the wavenumber components 

677 k = np.zeros((self.Nx, self.Ny)) 

678 # assert len(kx_vec_dtt.shape) == 3 

679 k += np.reshape(kx_vec_dtt, (-1, 1)) ** 2 

680 k += np.reshape(ky_vec_dtt, (1, -1)) ** 2 

681 k = np.sqrt(k) 

682 

683 # define product of implied period 

684 M = Mx * My 

685 return k, M 

686 elif self.dim == 3: 

687 # assign the grid parameters for the x, y, and z spatial directions 

688 kx_vec_dtt, Mx = self.kx_vec_dtt(dtt_type[0]) 

689 ky_vec_dtt, My = self.ky_vec_dtt(dtt_type[len(dtt_type) // 2]) 

690 kz_vec_dtt, Mz = self.kz_vec_dtt(dtt_type[-1]) 

691 

692 # define the wavenumber based on the wavenumber components 

693 k = np.zeros((self.Nx, self.Ny, self.Nz)) 

694 k = np.reshape(kx_vec_dtt, (-1, 1, 1)) ** 2 + k 

695 k = np.reshape(ky_vec_dtt, (1, -1, 1)) ** 2 + k 

696 k = np.reshape(kz_vec_dtt, (1, 1, -1)) ** 2 + k 

697 k = np.sqrt(k) 

698 

699 # define product of implied period 

700 M = Mx * My * Mz 

701 return k, M