Coverage for kwave/utils/colormap.py: 94%

31 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-24 12:06 -0700

1from typing import Optional 

2 

3import numpy as np 

4from beartype import beartype as typechecker 

5from jaxtyping import Float 

6from matplotlib.colors import ListedColormap 

7 

8 

9@typechecker 

10def get_color_map(num_colors: Optional[int] = None) -> ListedColormap: 

11 """ 

12 Returns the default color map used for display and visualisation across 

13 the k-Wave Toolbox. Zero values are displayed as white, positive values 

14 are displayed as yellow through red to black, and negative values are 

15 displayed as light to dark blue-greys. If no value for `num_colors` is 

16 provided, `cm` will have 256 colors. 

17 

18 Args: 

19 num_colors: The number of colors in the color map (default is 256). 

20 

21 Returns: 

22 A three-column color map matrix which can be applied using colormap. 

23 

24 """ 

25 

26 if num_colors is None: 26 ↛ 30line 26 didn't jump to line 30 because the condition on line 26 was always true

27 neg_pad = 48 

28 num_colors = 256 

29 else: 

30 neg_pad = int(round(48 * num_colors / 256)) 

31 

32 # define colour spectrums 

33 neg = bone(num_colors // 2 + neg_pad) 

34 neg = neg[neg_pad:, :] 

35 pos = np.flipud(hot(num_colors // 2)) 

36 

37 colors = np.vstack([neg, pos]) 

38 return ListedColormap(colors) 

39 

40 

41@typechecker 

42def hot(m: int) -> Float[np.ndarray, "N 3"]: 

43 """ 

44 Generate a hot colormap of length m. 

45 The colormap consists of a progression from black to red, yellow, and white. 

46 

47 Args: 

48 m: The length of the colormap. 

49 

50 Returns: 

51 An m-by-3 array containing the hot colormap. 

52 

53 """ 

54 

55 n = int(np.fix(3 / 8 * m)) 

56 

57 r = np.concatenate([np.arange(1, n + 1) / n, np.ones(m - n)]) 

58 g = np.concatenate([np.zeros(n), np.arange(1, n + 1) / n, np.ones(m - 2 * n)]) 

59 b = np.concatenate([np.zeros(2 * n), np.arange(1, m - 2 * n + 1) / (m - 2 * n)]) 

60 

61 return np.hstack([r[:, None], g[:, None], b[:, None]]) 

62 

63 

64@typechecker 

65def bone(m: int) -> Float[np.ndarray, "N 3"]: 

66 """ 

67 Returns an m-by-3 matrix containing a "bone" colormap. 

68 

69 Args: 

70 m: The number of rows in the colormap. 

71 

72 Returns: 

73 An m-by-3 matrix containing the colormap. 

74 """ 

75 return (7 * gray(m) + np.fliplr(hot(m))) / 8 

76 

77 

78@typechecker 

79def gray(m: int) -> Float[np.ndarray, "N 3"]: 

80 """ 

81 Returns an M-by-3 matrix containing a grayscale colormap. 

82 

83 Args: 

84 m: The length of the colormap. 

85 

86 Returns: 

87 An M-by-3 matrix containing the grayscale colormap. 

88 

89 """ 

90 

91 g = np.arange(m) / max(m - 1, 1) 

92 g = g[:, None] 

93 return np.hstack([g, g, g])