@@ -24,7 +24,7 @@ criteria:
24
24
in the thread partition
25
25
- The order of the thread partition should
26
26
follow the fastest changing index in the
27
- datalayout (e.g., VIJ in VIJFH )
27
+ datalayout (e.g., VIJ in VIJHF )
28
28
"""
29
29
function partition end
30
30
@@ -46,25 +46,25 @@ bounds to ensure that the result of
46
46
"""
47
47
function is_valid_index end
48
48
49
- # #### VIJFH
50
- @inline function partition (data:: DataLayouts.VIJFH , n_max_threads:: Integer )
49
+ # #### VIJHF
50
+ @inline function partition (data:: DataLayouts.VIJHF , n_max_threads:: Integer )
51
51
(Nij, _, _, Nv, Nh) = DataLayouts. universal_size (data)
52
52
Nv_thread = min (Int (fld (n_max_threads, Nij * Nij)), Nv)
53
53
Nv_blocks = cld (Nv, Nv_thread)
54
54
@assert prod ((Nv_thread, Nij, Nij)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Nv_thread, Nij, Nij))) ,$n_max_threads )"
55
55
return (; threads = (Nv_thread, Nij, Nij), blocks = (Nv_blocks, Nh))
56
56
end
57
- @inline function universal_index (:: DataLayouts.VIJFH )
57
+ @inline function universal_index (:: DataLayouts.VIJHF )
58
58
(tv, i, j) = CUDA. threadIdx ()
59
59
(bv, h) = CUDA. blockIdx ()
60
60
v = tv + (bv - 1 ) * CUDA. blockDim (). x
61
61
return CartesianIndex ((i, j, 1 , v, h))
62
62
end
63
- @inline is_valid_index (:: DataLayouts.VIJFH , I:: CI5 , us:: UniversalSize ) =
63
+ @inline is_valid_index (:: DataLayouts.VIJHF , I:: CI5 , us:: UniversalSize ) =
64
64
1 ≤ I[4 ] ≤ DataLayouts. get_Nv (us)
65
65
66
- # #### IJFH
67
- @inline function partition (data:: DataLayouts.IJFH , n_max_threads:: Integer )
66
+ # #### IJHF
67
+ @inline function partition (data:: DataLayouts.IJHF , n_max_threads:: Integer )
68
68
(Nij, _, _, _, Nh) = DataLayouts. universal_size (data)
69
69
Nh_thread = min (
70
70
Int (fld (n_max_threads, Nij * Nij)),
75
75
@assert prod ((Nij, Nij)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Nij, Nij))) ,$n_max_threads )"
76
76
return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,))
77
77
end
78
- @inline function universal_index (:: DataLayouts.IJFH )
78
+ @inline function universal_index (:: DataLayouts.IJHF )
79
79
(i, j, th) = CUDA. threadIdx ()
80
80
(bh,) = CUDA. blockIdx ()
81
81
h = th + (bh - 1 ) * CUDA. blockDim (). z
82
82
return CartesianIndex ((i, j, 1 , 1 , h))
83
83
end
84
- @inline is_valid_index (:: DataLayouts.IJFH , I:: CI5 , us:: UniversalSize ) =
84
+ @inline is_valid_index (:: DataLayouts.IJHF , I:: CI5 , us:: UniversalSize ) =
85
85
1 ≤ I[5 ] ≤ DataLayouts. get_Nh (us)
86
86
87
- # #### IFH
88
- @inline function partition (data:: DataLayouts.IFH , n_max_threads:: Integer )
87
+ # #### IHF
88
+ @inline function partition (data:: DataLayouts.IHF , n_max_threads:: Integer )
89
89
(Ni, _, _, _, Nh) = DataLayouts. universal_size (data)
90
90
Nh_thread = min (Int (fld (n_max_threads, Ni)), Nh)
91
91
Nh_blocks = cld (Nh, Nh_thread)
92
92
@assert prod ((Ni, Nh_thread)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Ni, Nh_thread))) ,$n_max_threads )"
93
93
return (; threads = (Ni, Nh_thread), blocks = (Nh_blocks,))
94
94
end
95
- @inline function universal_index (:: DataLayouts.IFH )
95
+ @inline function universal_index (:: DataLayouts.IHF )
96
96
(i, th) = CUDA. threadIdx ()
97
97
(bh,) = CUDA. blockIdx ()
98
98
h = th + (bh - 1 ) * CUDA. blockDim (). y
99
99
return CartesianIndex ((i, 1 , 1 , 1 , h))
100
100
end
101
- @inline is_valid_index (:: DataLayouts.IFH , I:: CI5 , us:: UniversalSize ) =
101
+ @inline is_valid_index (:: DataLayouts.IHF , I:: CI5 , us:: UniversalSize ) =
102
102
1 ≤ I[5 ] ≤ DataLayouts. get_Nh (us)
103
103
104
104
# #### IJF
@@ -125,21 +125,21 @@ end
125
125
end
126
126
@inline is_valid_index (:: DataLayouts.IF , I:: CI5 , us:: UniversalSize ) = true
127
127
128
- # #### VIFH
129
- @inline function partition (data:: DataLayouts.VIFH , n_max_threads:: Integer )
128
+ # #### VIHF
129
+ @inline function partition (data:: DataLayouts.VIHF , n_max_threads:: Integer )
130
130
(Ni, _, _, Nv, Nh) = DataLayouts. universal_size (data)
131
131
Nv_thread = min (Int (fld (n_max_threads, Ni)), Nv)
132
132
Nv_blocks = cld (Nv, Nv_thread)
133
133
@assert prod ((Nv_thread, Ni)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Nv_thread, Ni))) ,$n_max_threads )"
134
134
return (; threads = (Nv_thread, Ni), blocks = (Nv_blocks, Nh))
135
135
end
136
- @inline function universal_index (:: DataLayouts.VIFH )
136
+ @inline function universal_index (:: DataLayouts.VIHF )
137
137
(tv, i) = CUDA. threadIdx ()
138
138
(bv, h) = CUDA. blockIdx ()
139
139
v = tv + (bv - 1 ) * CUDA. blockDim (). x
140
140
return CartesianIndex ((i, 1 , 1 , v, h))
141
141
end
142
- @inline is_valid_index (:: DataLayouts.VIFH , I:: CI5 , us:: UniversalSize ) =
142
+ @inline is_valid_index (:: DataLayouts.VIHF , I:: CI5 , us:: UniversalSize ) =
143
143
1 ≤ I[4 ] ≤ DataLayouts. get_Nv (us)
144
144
145
145
# #### VF
0 commit comments