123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- #ifndef THREADING
- #define THREADING
-
- ///
- /// Compute Shader Threading Utilities
- ///
- /// This file is intended to provide a portable implementation of the wave-level operations in DirectX Shader Model 6.0.
- ///
- /// The functions in this file will automatically resolve to native intrinsics when possible.
- /// A fallback groupshared memory implementation is used when native support is not available.
- ///
- /// Usage:
- ///
- /// To use this file, define all required preprocessor symbols and then include this file in your compute shader.
- ///
- /// Required Preprocessor Symbols:
- ///
- /// THREADING_BLOCK_SIZE
- /// - The size of the compute shader's flattened thread group size
- ///
- /// Optional Preprocessor Symbols:
- ///
- /// THREADING_WAVE_SIZE
- /// - The size of a wave within the compute shader
- /// - This symbol MUST be defined when authoring shader code that requires a specific wave size for correctness!
- ///
- /// THREADING_FORCE_WAVE_EMULATION
- /// - If defined, forces usage of the fallback groupshared memory implementation
- ///
-
- #ifndef THREADING_BLOCK_SIZE
- #error THREADING_BLOCK_SIZE must be defined as the flattened thread group size.
- #endif
-
- // The emulation path is automatically enabled when we're running on hardware that doesn't meet minimum requirements.
- //
- // In order to use the non-emulated path, the current device must have native support for wave-level operations.
- // If THREADING_WAVE_SIZE is provided, then the device's wave size must also match the size specified by THREADING_WAVE_SIZE.
- //
- // The emulation path can also be forced on via the THREADING_FORCE_WAVE_EMULATION preprocessor symbol for debug/testing purposes.
- #define _THREADING_IS_HW_SUPPORTED (defined(UNITY_HW_SUPPORTS_WAVE) && (!defined(THREADING_WAVE_SIZE) || (defined(UNITY_HW_WAVE_SIZE) && (UNITY_HW_WAVE_SIZE == THREADING_WAVE_SIZE))))
- #define _THREADING_ENABLE_WAVE_EMULATION (!_THREADING_IS_HW_SUPPORTED || defined(THREADING_FORCE_WAVE_EMULATION))
- #define _THREADING_GROUP_BALLOT_DWORDS ((THREADING_BLOCK_SIZE + 31u) / 32u)
-
- namespace Threading
- {
- struct Wave
- {
- // Unfortunately 'private' is a reserved keyword in HLSL.
- uint indexG;
- uint indexW;
- #if _THREADING_ENABLE_WAVE_EMULATION
- uint indexL;
- uint offset; // Per-wave offset into LDS scratch space.
- #endif
-
- uint GetIndex();
-
- void Init(uint groupIndex);
-
- #define DECLARE_API_FOR_TYPE(TYPE) \
- bool AllEqual(TYPE v); \
- TYPE Product(TYPE v); \
- TYPE Sum(TYPE v); \
- TYPE Max(TYPE v); \
- TYPE Min(TYPE v); \
- TYPE InclusivePrefixSum(TYPE v); \
- TYPE InclusivePrefixProduct(TYPE v); \
- TYPE PrefixSum(TYPE v); \
- TYPE PrefixProduct(TYPE v); \
- TYPE ReadLaneAt(TYPE v, uint i); \
- TYPE ReadLaneFirst(TYPE v); \
-
- // Currently just support scalars.
- DECLARE_API_FOR_TYPE(uint)
- DECLARE_API_FOR_TYPE(int)
- DECLARE_API_FOR_TYPE(float)
-
- // The following intrinsics need only be declared once.
- uint GetLaneCount();
- uint GetLaneIndex();
- bool IsFirstLane();
- bool AllTrue(bool v);
- bool AnyTrue(bool v);
- uint4 Ballot(bool v);
- uint CountBits(bool v);
- uint PrefixCountBits(bool v);
- uint And(uint v);
- uint Or(uint v);
- uint Xor(uint v);
- };
-
- struct GroupBallot
- {
- uint dwords[_THREADING_GROUP_BALLOT_DWORDS];
-
- uint CountBits()
- {
- uint result = 0;
-
- [unroll]
- for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex)
- {
- result += countbits(dwords[dwordIndex]);
- }
-
- return result;
- }
- };
-
- struct Group
- {
- uint groupIndex : SV_GroupIndex;
- uint3 groupID : SV_GroupID;
- uint3 dispatchID : SV_DispatchThreadID;
-
- Wave GetWave()
- {
- Wave wave;
- {
- wave = (Wave)0;
- wave.Init(groupIndex);
- }
- return wave;
- }
-
- // Lane remap which is safe for both portability (different wave sizes up to 128) and for 2D wave reductions.
- // 6543210
- // =======
- // ..xx..x
- // yy..yy.
- // Details,
- // LANE TO 8x16 MAPPING
- // ====================
- // 00 01 08 09 10 11 18 19
- // 02 03 0a 0b 12 13 1a 1b
- // 04 05 0c 0d 14 15 1c 1d
- // 06 07 0e 0f 16 17 1e 1f
- // 20 21 28 29 30 31 38 39
- // 22 23 2a 2b 32 33 3a 3b
- // 24 25 2c 2d 34 35 3c 3d
- // 26 27 2e 2f 36 37 3e 3f
- // .......................
- // ... repeat the 8x8 ....
- // .... pattern, but .....
- // .... for 40 to 7f .....
- // .......................
- // NOTE: This function is only intended to be used with one dimensional thread groups
- uint2 RemapLaneTo8x16()
- {
- // Note the BFIs used for MSBs have "strange offsets" due to leaving space for the LSB bits replaced in the BFI.
- return uint2(BitFieldInsert(1u, groupIndex, BitFieldExtract(groupIndex, 2u, 3u)),
- BitFieldInsert(3u, BitFieldExtract(groupIndex, 1u, 2u), BitFieldExtract(groupIndex, 3u, 4u)));
- }
-
- uint GetWaveCount();
-
- #define DECLARE_API_FOR_TYPE_GROUP(TYPE) \
- bool AllEqual(TYPE v); \
- TYPE Product(TYPE v); \
- TYPE Sum(TYPE v); \
- TYPE Max(TYPE v); \
- TYPE Min(TYPE v); \
- TYPE InclusivePrefixSum(TYPE v); \
- TYPE InclusivePrefixProduct(TYPE v); \
- TYPE PrefixSum(TYPE v); \
- TYPE PrefixProduct(TYPE v); \
- TYPE ReadThreadAt(TYPE v, uint i); \
- TYPE ReadThreadFirst(TYPE v); \
- TYPE ReadThreadShuffle(TYPE v, uint i); \
-
- // Currently just support scalars.
- DECLARE_API_FOR_TYPE_GROUP(uint)
- DECLARE_API_FOR_TYPE_GROUP(int)
- DECLARE_API_FOR_TYPE_GROUP(float)
-
- // The following intrinsics need only be declared once.
- uint GetThreadCount();
- uint GetThreadIndex();
- bool IsFirstThread();
- bool AllTrue(bool v);
- bool AnyTrue(bool v);
- GroupBallot Ballot(bool v);
- uint CountBits(bool v);
- uint PrefixCountBits(bool v);
- uint And(uint v);
- uint Or(uint v);
- uint Xor(uint v);
- };
- }
-
- #if _THREADING_ENABLE_WAVE_EMULATION
- #include "ThreadingEmuImpl.hlsl"
- #else
- #include "ThreadingSM6Impl.hlsl"
- #endif
-
- #endif
|