Ei kuvausta
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ThreadingEmuImpl.hlsl 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. #ifndef THREADING_EMU_IMPL
  2. #define THREADING_EMU_IMPL
  3. // If the user didn't specify a wave size, we assume that their code is "wave size independent" and that they don't
  4. // care which size is actually used. In this case, we automatically select an arbitrary size for them since the
  5. // emulation logic depends on having *some* known size.
  6. #ifndef THREADING_WAVE_SIZE
  7. #define THREADING_WAVE_SIZE 32
  8. #endif
  9. namespace Threading
  10. {
  11. // Currently we only cover scalar types as at the time of writing this utility library we only needed emulation for those.
  12. // Support for vector types is currently not there but can be added as needed (and this comment removed).
  13. groupshared uint g_Scratch[THREADING_BLOCK_SIZE];
  14. #define EMULATED_WAVE_REDUCE(TYPE, OP) \
  15. GroupMemoryBarrierWithGroupSync(); \
  16. g_Scratch[indexG] = asuint(v); \
  17. GroupMemoryBarrierWithGroupSync(); \
  18. [unroll] \
  19. for (uint s = THREADING_WAVE_SIZE / 2u; s > 0u; s >>= 1u) \
  20. { \
  21. if (indexL < s) \
  22. g_Scratch[indexG] = asuint(as##TYPE(g_Scratch[indexG]) OP as##TYPE(g_Scratch[indexG + s])); \
  23. GroupMemoryBarrierWithGroupSync(); \
  24. } \
  25. return as##TYPE(g_Scratch[offset]); \
  26. #define EMULATED_WAVE_REDUCE_CMP(TYPE, OP) \
  27. GroupMemoryBarrierWithGroupSync(); \
  28. g_Scratch[indexG] = asuint(v); \
  29. GroupMemoryBarrierWithGroupSync(); \
  30. [unroll] \
  31. for (uint s = THREADING_WAVE_SIZE / 2u; s > 0u; s >>= 1u) \
  32. { \
  33. if (indexL < s) \
  34. g_Scratch[indexG] = asuint(OP(as##TYPE(g_Scratch[indexG]), as##TYPE(g_Scratch[indexG + s]))); \
  35. GroupMemoryBarrierWithGroupSync(); \
  36. } \
  37. return as##TYPE(g_Scratch[offset]); \
  38. #define EMULATED_WAVE_PREFIX(TYPE, OP, FILL_VALUE) \
  39. GroupMemoryBarrierWithGroupSync(); \
  40. g_Scratch[indexG] = asuint(v); \
  41. GroupMemoryBarrierWithGroupSync(); \
  42. [unroll] \
  43. for (uint s = 1u; s < THREADING_WAVE_SIZE; s <<= 1u) \
  44. { \
  45. TYPE nv = FILL_VALUE; \
  46. if (indexL >= s) \
  47. { \
  48. nv = as##TYPE(g_Scratch[indexG - s]); \
  49. } \
  50. nv = as##TYPE(g_Scratch[indexG]) OP nv; \
  51. GroupMemoryBarrierWithGroupSync(); \
  52. g_Scratch[indexG] = asuint(nv); \
  53. GroupMemoryBarrierWithGroupSync(); \
  54. } \
  55. TYPE result = FILL_VALUE; \
  56. if (indexL > 0u) \
  57. result = as##TYPE(g_Scratch[indexG - 1]); \
  58. return result; \
  59. uint Wave::GetIndex() { return indexW; }
  60. void Wave::Init(uint groupIndex)
  61. {
  62. indexG = groupIndex;
  63. indexW = indexG / THREADING_WAVE_SIZE;
  64. indexL = indexG & (THREADING_WAVE_SIZE - 1);
  65. offset = indexW * THREADING_WAVE_SIZE;
  66. }
  67. // WARNING:
  68. // These emulated functions do not emulate the execution mask.
  69. // So they WILL produce incorrect results if you have divergent lanes.
  70. #define DEFINE_API_FOR_TYPE(TYPE) \
  71. bool Wave::AllEqual(TYPE v) { return AllTrue(ReadLaneFirst(v) == v); } \
  72. TYPE Wave::Product(TYPE v) { EMULATED_WAVE_REDUCE(TYPE, *) } \
  73. TYPE Wave::Sum(TYPE v) { EMULATED_WAVE_REDUCE(TYPE, +) } \
  74. TYPE Wave::Max(TYPE v) { EMULATED_WAVE_REDUCE_CMP(TYPE, max) } \
  75. TYPE Wave::Min(TYPE v) { EMULATED_WAVE_REDUCE_CMP(TYPE, min) } \
  76. TYPE Wave::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \
  77. TYPE Wave::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \
  78. TYPE Wave::PrefixSum (TYPE v) { EMULATED_WAVE_PREFIX(TYPE, +, (TYPE)0) } \
  79. TYPE Wave::PrefixProduct (TYPE v) { EMULATED_WAVE_PREFIX(TYPE, *, (TYPE)1) } \
  80. TYPE Wave::ReadLaneAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[indexG] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[offset + i]); } \
  81. TYPE Wave::ReadLaneFirst(TYPE v) { return ReadLaneAt(v, 0u); } \
  82. // Currently just support scalars.
  83. DEFINE_API_FOR_TYPE(uint)
  84. DEFINE_API_FOR_TYPE(int)
  85. DEFINE_API_FOR_TYPE(float)
  86. // The following emulated functions need only be declared once.
  87. uint Wave::GetLaneCount() { return THREADING_WAVE_SIZE; }
  88. uint Wave::GetLaneIndex() { return indexL; }
  89. bool Wave::IsFirstLane() { return indexL == 0u; }
  90. bool Wave::AllTrue(bool v) { return And(v) != 0u; }
  91. bool Wave::AnyTrue(bool v) { return Or (v) != 0u; }
  92. uint Wave::PrefixCountBits(bool v) { return PrefixSum((uint)v); }
  93. uint Wave::And(uint v) { EMULATED_WAVE_REDUCE(uint, &) }
  94. uint Wave::Or (uint v) { EMULATED_WAVE_REDUCE(uint, |) }
  95. uint Wave::Xor(uint v) { EMULATED_WAVE_REDUCE(uint, ^) }
  96. uint4 Wave::Ballot(bool v)
  97. {
  98. uint indexDw = indexL % 32u;
  99. uint offsetDw = (indexL / 32u) * 32u;
  100. uint indexScratch = offset + offsetDw + indexDw;
  101. GroupMemoryBarrierWithGroupSync();
  102. g_Scratch[indexG] = v << indexDw;
  103. GroupMemoryBarrierWithGroupSync();
  104. [unroll]
  105. for (uint s = min(THREADING_WAVE_SIZE / 2u, 16u); s > 0u; s >>= 1u)
  106. {
  107. if (indexDw < s)
  108. g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s];
  109. GroupMemoryBarrierWithGroupSync();
  110. }
  111. uint4 result = uint4(g_Scratch[offset], 0, 0, 0);
  112. #if THREADING_WAVE_SIZE > 32
  113. result.y = g_Scratch[offset + 32];
  114. #endif
  115. #if THREADING_WAVE_SIZE > 64
  116. result.z = g_Scratch[offset + 64];
  117. #endif
  118. #if THREADING_WAVE_SIZE > 96
  119. result.w = g_Scratch[offset + 96];
  120. #endif
  121. return result;
  122. }
  123. uint Wave::CountBits(bool v)
  124. {
  125. uint4 ballot = Ballot(v);
  126. uint result = countbits(ballot.x);
  127. #if THREADING_WAVE_SIZE > 32
  128. result += countbits(ballot.y);
  129. #endif
  130. #if THREADING_WAVE_SIZE > 64
  131. result += countbits(ballot.z);
  132. #endif
  133. #if THREADING_WAVE_SIZE > 96
  134. result += countbits(ballot.w);
  135. #endif
  136. return result;
  137. }
  138. #define EMULATED_GROUP_REDUCE(TYPE, OP) \
  139. GroupMemoryBarrierWithGroupSync(); \
  140. g_Scratch[groupIndex] = asuint(v); \
  141. GroupMemoryBarrierWithGroupSync(); \
  142. [unroll] \
  143. for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
  144. { \
  145. if (groupIndex < s) \
  146. g_Scratch[groupIndex] = asuint(as##TYPE(g_Scratch[groupIndex]) OP as##TYPE(g_Scratch[groupIndex + s])); \
  147. GroupMemoryBarrierWithGroupSync(); \
  148. } \
  149. return as##TYPE(g_Scratch[0]); \
  150. #define EMULATED_GROUP_REDUCE_CMP(TYPE, OP) \
  151. GroupMemoryBarrierWithGroupSync(); \
  152. g_Scratch[groupIndex] = asuint(v); \
  153. GroupMemoryBarrierWithGroupSync(); \
  154. [unroll] \
  155. for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
  156. { \
  157. if (groupIndex < s) \
  158. g_Scratch[groupIndex] = asuint(OP(as##TYPE(g_Scratch[groupIndex]), as##TYPE(g_Scratch[groupIndex + s]))); \
  159. GroupMemoryBarrierWithGroupSync(); \
  160. } \
  161. return as##TYPE(g_Scratch[0]); \
  162. #define EMULATED_GROUP_PREFIX(TYPE, OP, FILL_VALUE) \
  163. GroupMemoryBarrierWithGroupSync(); \
  164. g_Scratch[groupIndex] = asuint(v); \
  165. GroupMemoryBarrierWithGroupSync(); \
  166. [unroll] \
  167. for (uint s = 1u; s < THREADING_BLOCK_SIZE; s <<= 1u) \
  168. { \
  169. TYPE nv = FILL_VALUE; \
  170. if (groupIndex >= s) \
  171. { \
  172. nv = as##TYPE(g_Scratch[groupIndex - s]); \
  173. } \
  174. nv = as##TYPE(g_Scratch[groupIndex]) OP nv; \
  175. GroupMemoryBarrierWithGroupSync(); \
  176. g_Scratch[groupIndex] = asuint(nv); \
  177. GroupMemoryBarrierWithGroupSync(); \
  178. } \
  179. TYPE result = FILL_VALUE; \
  180. if (groupIndex > 0u) \
  181. result = as##TYPE(g_Scratch[groupIndex - 1]); \
  182. return result; \
  183. uint Group::GetWaveCount()
  184. {
  185. return THREADING_BLOCK_SIZE / THREADING_WAVE_SIZE;
  186. }
  187. #define DEFINE_API_FOR_TYPE_GROUP(TYPE) \
  188. bool Group::AllEqual(TYPE v) { return AllTrue(ReadThreadFirst(v) == v); } \
  189. TYPE Group::Product(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, *) } \
  190. TYPE Group::Sum(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, +) } \
  191. TYPE Group::Max(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, max) } \
  192. TYPE Group::Min(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, min) } \
  193. TYPE Group::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \
  194. TYPE Group::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \
  195. TYPE Group::PrefixSum (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, +, (TYPE)0) } \
  196. TYPE Group::PrefixProduct (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, *, (TYPE)1) } \
  197. TYPE Group::ReadThreadAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[groupIndex] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[i]); } \
  198. TYPE Group::ReadThreadFirst(TYPE v) { return ReadThreadAt(v, 0u); } \
  199. TYPE Group::ReadThreadShuffle(TYPE v, uint i) { return ReadThreadAt(v, i); } \
  200. // Currently just support scalars.
  201. DEFINE_API_FOR_TYPE_GROUP(uint)
  202. DEFINE_API_FOR_TYPE_GROUP(int)
  203. DEFINE_API_FOR_TYPE_GROUP(float)
  204. // The following emulated functions need only be declared once.
  205. uint Group::GetThreadCount() { return THREADING_BLOCK_SIZE; }
  206. uint Group::GetThreadIndex() { return groupIndex; }
  207. bool Group::IsFirstThread() { return groupIndex == 0u; }
  208. bool Group::AllTrue(bool v) { return And(v) != 0u; }
  209. bool Group::AnyTrue(bool v) { return Or (v) != 0u; }
  210. uint Group::PrefixCountBits(bool v) { return PrefixSum((uint)v); }
  211. uint Group::And(uint v) { EMULATED_GROUP_REDUCE(uint, &) }
  212. uint Group::Or (uint v) { EMULATED_GROUP_REDUCE(uint, |) }
  213. uint Group::Xor(uint v) { EMULATED_GROUP_REDUCE(uint, ^) }
  214. GroupBallot Group::Ballot(bool v)
  215. {
  216. uint indexDw = groupIndex % 32u;
  217. uint offsetDw = (groupIndex / 32u) * 32u;
  218. uint indexScratch = offsetDw + indexDw;
  219. GroupMemoryBarrierWithGroupSync();
  220. g_Scratch[groupIndex] = v << indexDw;
  221. GroupMemoryBarrierWithGroupSync();
  222. [unroll]
  223. for (uint s = min(THREADING_BLOCK_SIZE / 2u, 16u); s > 0u; s >>= 1u)
  224. {
  225. if (indexDw < s)
  226. g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s];
  227. GroupMemoryBarrierWithGroupSync();
  228. }
  229. GroupBallot ballot = (GroupBallot)0;
  230. // Explicitly mark this loop as "unroll" to avoid warnings about assigning to an array reference
  231. [unroll]
  232. for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex)
  233. {
  234. ballot.dwords[dwordIndex] = g_Scratch[dwordIndex * 32];
  235. }
  236. return ballot;
  237. }
  238. uint Group::CountBits(bool v)
  239. {
  240. return Ballot(v).CountBits();
  241. }
  242. }
  243. #endif