暂无描述
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

ThreadingSM6Impl.hlsl 8.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #ifndef THREADING_SM6_IMPL
  2. #define THREADING_SM6_IMPL
  3. namespace Threading
  4. {
  5. // Currently we only cover scalar types as at the time of writing this utility library we only needed emulation for those.
  6. // Support for vector types is currently not there but can be added as needed (and this comment removed).
  7. groupshared uint g_Scratch[THREADING_BLOCK_SIZE];
  8. uint Wave::GetIndex() { return indexW; }
  9. void Wave::Init(uint groupIndex)
  10. {
  11. indexG = groupIndex;
  12. indexW = indexG / GetLaneCount();
  13. }
  14. // Note: The HLSL intrinsics should be correctly replaced by console-specific intrinsics by our API library.
  15. #define DEFINE_API_FOR_TYPE(TYPE) \
  16. bool Wave::AllEqual(TYPE v) { return WaveActiveAllEqual(v); } \
  17. TYPE Wave::Product(TYPE v) { return WaveActiveProduct(v); } \
  18. TYPE Wave::Sum(TYPE v) { return WaveActiveSum(v); } \
  19. TYPE Wave::Max(TYPE v) { return WaveActiveMax(v); } \
  20. TYPE Wave::Min(TYPE v) { return WaveActiveMin(v); } \
  21. TYPE Wave::InclusivePrefixSum (TYPE v) { return WavePrefixSum(v) + v; } \
  22. TYPE Wave::InclusivePrefixProduct (TYPE v) { return WavePrefixProduct(v) * v; } \
  23. TYPE Wave::PrefixSum(TYPE v) { return WavePrefixSum(v); } \
  24. TYPE Wave::PrefixProduct(TYPE v) { return WavePrefixProduct(v); } \
  25. TYPE Wave::ReadLaneAt(TYPE v, uint i) { return WaveReadLaneAt(v, i); } \
  26. TYPE Wave::ReadLaneFirst(TYPE v) { return WaveReadLaneFirst(v); } \
  27. // Currently just support scalars.
  28. DEFINE_API_FOR_TYPE(uint)
  29. DEFINE_API_FOR_TYPE(int)
  30. DEFINE_API_FOR_TYPE(float)
  31. // The following intrinsics need only be declared once.
  32. uint Wave::GetLaneCount() { return WaveGetLaneCount(); }
  33. uint Wave::GetLaneIndex() { return WaveGetLaneIndex(); }
  34. bool Wave::IsFirstLane() { return WaveIsFirstLane(); }
  35. bool Wave::AllTrue(bool v) { return WaveActiveAllTrue(v); }
  36. bool Wave::AnyTrue(bool v) { return WaveActiveAnyTrue(v); }
  37. uint4 Wave::Ballot(bool v) { return WaveActiveBallot(v); }
  38. uint Wave::CountBits(bool v) { return WaveActiveCountBits(v); }
  39. uint Wave::PrefixCountBits(bool v) { return WavePrefixCountBits(v); }
  40. uint Wave::And(uint v) { return WaveActiveBitAnd(v); }
  41. uint Wave::Or (uint v) { return WaveActiveBitOr(v); }
  42. uint Wave::Xor(uint v) { return WaveActiveBitXor(v); }
  43. #define EMULATED_GROUP_REDUCE(TYPE, OP) \
  44. GroupMemoryBarrierWithGroupSync(); \
  45. g_Scratch[groupIndex] = asuint(v); \
  46. GroupMemoryBarrierWithGroupSync(); \
  47. [unroll] \
  48. for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
  49. { \
  50. if (groupIndex < s) \
  51. g_Scratch[groupIndex] = asuint(as##TYPE(g_Scratch[groupIndex]) OP as##TYPE(g_Scratch[groupIndex + s])); \
  52. GroupMemoryBarrierWithGroupSync(); \
  53. } \
  54. return as##TYPE(g_Scratch[0]); \
  55. #define EMULATED_GROUP_REDUCE_CMP(TYPE, OP) \
  56. GroupMemoryBarrierWithGroupSync(); \
  57. g_Scratch[groupIndex] = asuint(v); \
  58. GroupMemoryBarrierWithGroupSync(); \
  59. [unroll] \
  60. for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
  61. { \
  62. if (groupIndex < s) \
  63. g_Scratch[groupIndex] = asuint(OP(as##TYPE(g_Scratch[groupIndex]), as##TYPE(g_Scratch[groupIndex + s]))); \
  64. GroupMemoryBarrierWithGroupSync(); \
  65. } \
  66. return as##TYPE(g_Scratch[0]); \
  67. #define EMULATED_GROUP_PREFIX(TYPE, OP, FILL_VALUE) \
  68. GroupMemoryBarrierWithGroupSync(); \
  69. g_Scratch[groupIndex] = asuint(v); \
  70. GroupMemoryBarrierWithGroupSync(); \
  71. [unroll] \
  72. for (uint s = 1u; s < THREADING_BLOCK_SIZE; s <<= 1u) \
  73. { \
  74. TYPE nv = FILL_VALUE; \
  75. if (groupIndex >= s) \
  76. { \
  77. nv = as##TYPE(g_Scratch[groupIndex - s]); \
  78. } \
  79. nv = as##TYPE(g_Scratch[groupIndex]) OP nv; \
  80. GroupMemoryBarrierWithGroupSync(); \
  81. g_Scratch[groupIndex] = asuint(nv); \
  82. GroupMemoryBarrierWithGroupSync(); \
  83. } \
  84. TYPE result = FILL_VALUE; \
  85. if (groupIndex > 0u) \
  86. result = as##TYPE(g_Scratch[groupIndex - 1]); \
  87. return result; \
  88. uint Group::GetWaveCount()
  89. {
  90. return THREADING_BLOCK_SIZE / WaveGetLaneCount();
  91. }
  92. #define DEFINE_API_FOR_TYPE_GROUP(TYPE) \
  93. bool Group::AllEqual(TYPE v) { return AllTrue(ReadThreadFirst(v) == v); } \
  94. TYPE Group::Product(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, *) } \
  95. TYPE Group::Sum(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, +) } \
  96. TYPE Group::Max(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, max) } \
  97. TYPE Group::Min(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, min) } \
  98. TYPE Group::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \
  99. TYPE Group::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \
  100. TYPE Group::PrefixSum (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, +, (TYPE)0) } \
  101. TYPE Group::PrefixProduct (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, *, (TYPE)1) } \
  102. TYPE Group::ReadThreadAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[groupIndex] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[i]); } \
  103. TYPE Group::ReadThreadFirst(TYPE v) { return ReadThreadAt(v, 0u); } \
  104. TYPE Group::ReadThreadShuffle(TYPE v, uint i) { return ReadThreadAt(v, i); } \
  105. // Currently just support scalars.
  106. DEFINE_API_FOR_TYPE_GROUP(uint)
  107. DEFINE_API_FOR_TYPE_GROUP(int)
  108. DEFINE_API_FOR_TYPE_GROUP(float)
  109. // The following emulated functions need only be declared once.
  110. uint Group::GetThreadCount() { return THREADING_BLOCK_SIZE; }
  111. uint Group::GetThreadIndex() { return groupIndex; }
  112. bool Group::IsFirstThread() { return groupIndex == 0u; }
  113. bool Group::AllTrue(bool v) { return And(v) != 0u; }
  114. bool Group::AnyTrue(bool v) { return Or (v) != 0u; }
  115. uint Group::PrefixCountBits(bool v) { return PrefixSum((uint)v); }
  116. uint Group::And(uint v) { EMULATED_GROUP_REDUCE(uint, &) }
  117. uint Group::Or (uint v) { EMULATED_GROUP_REDUCE(uint, |) }
  118. uint Group::Xor(uint v) { EMULATED_GROUP_REDUCE(uint, ^) }
  119. GroupBallot Group::Ballot(bool v)
  120. {
  121. uint indexDw = groupIndex % 32u;
  122. uint offsetDw = (groupIndex / 32u) * 32u;
  123. uint indexScratch = offsetDw + indexDw;
  124. GroupMemoryBarrierWithGroupSync();
  125. g_Scratch[groupIndex] = v << indexDw;
  126. GroupMemoryBarrierWithGroupSync();
  127. [unroll]
  128. for (uint s = min(THREADING_BLOCK_SIZE / 2u, 16u); s > 0u; s >>= 1u)
  129. {
  130. if (indexDw < s)
  131. g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s];
  132. GroupMemoryBarrierWithGroupSync();
  133. }
  134. GroupBallot ballot = (GroupBallot)0;
  135. // Explicitly mark this loop as "unroll" to avoid warnings about assigning to an array reference
  136. [unroll]
  137. for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex)
  138. {
  139. ballot.dwords[dwordIndex] = g_Scratch[dwordIndex * 32];
  140. }
  141. return ballot;
  142. }
  143. uint Group::CountBits(bool v)
  144. {
  145. return Ballot(v).CountBits();
  146. }
  147. }
  148. #endif