Keine Beschreibung
Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

FunctionPointerInvokeTransform.cs 25KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using Mono.Cecil;
  6. using Mono.Cecil.Cil;
  7. using Mono.Cecil.Rocks;
  8. namespace zzzUnity.Burst.CodeGen
  9. {
  10. /// <summary>
  11. /// Transforms a direct invoke on a burst function pointer into an calli, avoiding the need to marshal the delegate back.
  12. /// </summary>
  13. internal class FunctionPointerInvokeTransform
  14. {
  15. private struct CaptureInformation
  16. {
  17. public MethodReference Operand;
  18. public List<Instruction> Captured;
  19. }
  20. private Dictionary<TypeReference, (MethodDefinition method, Instruction instruction)> _needsNativeFunctionPointer;
  21. private Dictionary<MethodDefinition, TypeReference> _needsIl2cppInvoke;
  22. private Dictionary<MethodDefinition, List<CaptureInformation>> _capturedSets;
  23. private MethodDefinition _monoPInvokeAttributeCtorDef;
  24. private MethodDefinition _nativePInvokeAttributeCtorDef;
  25. private MethodDefinition _unmanagedFunctionPointerAttributeCtorDef;
  26. private TypeReference _burstFunctionPointerType;
  27. private TypeReference _burstCompilerType;
  28. private TypeReference _systemType;
  29. private TypeReference _callingConventionType;
  30. private LogDelegate _debugLog;
  31. private int _logLevel;
  32. private AssemblyLoader _loader;
  33. private ErrorDiagnosticDelegate _errorReport;
  34. public readonly static bool enableInvokeAttribute = true;
  35. #if UNITY_DOTSPLAYER
  36. public readonly static bool enableCalliOptimisation = true;
  37. public readonly static bool enableUnmangedFunctionPointerInject = false;
  38. #else
  39. public readonly static bool enableCalliOptimisation = false; // For now only run the pass on dots player/tiny
  40. public readonly static bool enableUnmangedFunctionPointerInject = true;
  41. #endif
  42. public FunctionPointerInvokeTransform(AssemblyLoader loader,ErrorDiagnosticDelegate error, LogDelegate log = null, int logLevel = 0)
  43. {
  44. _loader = loader;
  45. _needsNativeFunctionPointer = new Dictionary<TypeReference, (MethodDefinition, Instruction)>();
  46. _needsIl2cppInvoke = new Dictionary<MethodDefinition, TypeReference>();
  47. _capturedSets = new Dictionary<MethodDefinition, List<CaptureInformation>>();
  48. _monoPInvokeAttributeCtorDef = null;
  49. _unmanagedFunctionPointerAttributeCtorDef = null;
  50. _nativePInvokeAttributeCtorDef = null; // Only present on DOTS_PLAYER
  51. _burstFunctionPointerType = null;
  52. _burstCompilerType = null;
  53. _systemType = null;
  54. _callingConventionType = null;
  55. _debugLog = log;
  56. _logLevel = logLevel;
  57. _errorReport = error;
  58. }
  59. private AssemblyDefinition GetAsmDefinitionFromFile(AssemblyLoader loader, string filename)
  60. {
  61. foreach (var folder in loader.GetSearchDirectories())
  62. {
  63. var path = Path.Combine(folder, filename);
  64. if (File.Exists(path))
  65. return loader.LoadFromFile(path);
  66. }
  67. return null;
  68. }
  69. public void Initialize(AssemblyLoader loader, AssemblyDefinition assemblyDefinition, TypeSystem typeSystem)
  70. {
  71. if (_monoPInvokeAttributeCtorDef == null)
  72. {
  73. var burstAssembly = GetAsmDefinitionFromFile(loader, "Unity.Burst.dll");
  74. _burstFunctionPointerType = burstAssembly.MainModule.GetType("Unity.Burst.FunctionPointer`1");
  75. _burstCompilerType = burstAssembly.MainModule.GetType("Unity.Burst.BurstCompiler");
  76. var corLibrary = loader.Resolve(typeSystem.CoreLibrary as AssemblyNameReference);
  77. _systemType = corLibrary.MainModule.Types.FirstOrDefault(x => x.FullName == "System.Type"); // Only needed for MonoPInvokeCallback constructor in Unity
  78. if (enableUnmangedFunctionPointerInject)
  79. {
  80. var unmanagedFunctionPointerAttribute = corLibrary.MainModule.GetType("System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute");
  81. _callingConventionType = corLibrary.MainModule.GetType("System.Runtime.InteropServices.CallingConvention");
  82. _unmanagedFunctionPointerAttributeCtorDef = unmanagedFunctionPointerAttribute.GetConstructors().Single(c => c.Parameters.Count == 1 && c.Parameters[0].ParameterType.MetadataType == _callingConventionType.MetadataType);
  83. }
  84. #if UNITY_DOTSPLAYER
  85. var asmDef = assemblyDefinition;
  86. if (asmDef.Name.Name != "Unity.Runtime")
  87. asmDef = GetAsmDefinitionFromFile(loader, "Unity.Runtime.dll");
  88. if (asmDef == null)
  89. return;
  90. var monoPInvokeAttribute = asmDef.MainModule.GetType("Unity.Jobs.MonoPInvokeCallbackAttribute");
  91. _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First();
  92. var nativePInvokeAttribute = asmDef.MainModule.GetType("NativePInvokeCallbackAttribute");
  93. _nativePInvokeAttributeCtorDef = nativePInvokeAttribute.GetConstructors().First();
  94. #else
  95. var asmDef = GetAsmDefinitionFromFile(loader, "UnityEngine.CoreModule.dll");
  96. // bail if we can't find a reference, handled gracefully later
  97. if (asmDef == null)
  98. return;
  99. var monoPInvokeAttribute = asmDef.MainModule.GetType("AOT.MonoPInvokeCallbackAttribute");
  100. _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First();
  101. #endif
  102. }
  103. }
  104. public bool Run(AssemblyDefinition assemblyDefinition)
  105. {
  106. Initialize(_loader, assemblyDefinition, assemblyDefinition.MainModule.TypeSystem);
  107. var types = assemblyDefinition.MainModule.GetTypes().ToArray();
  108. foreach (var type in types)
  109. {
  110. CollectDelegateInvokesFromType(type);
  111. }
  112. return Finish();
  113. }
  114. public void CollectDelegateInvokesFromType(TypeDefinition type)
  115. {
  116. foreach (var m in type.Methods)
  117. {
  118. if (m.HasBody)
  119. {
  120. CollectDelegateInvokes(m);
  121. }
  122. }
  123. }
  124. private bool ProcessUnmanagedAttributeFixups()
  125. {
  126. if (_unmanagedFunctionPointerAttributeCtorDef == null)
  127. return false;
  128. bool modified = false;
  129. foreach (var kp in _needsNativeFunctionPointer)
  130. {
  131. var delegateType = kp.Key;
  132. var instruction = kp.Value.instruction;
  133. var method = kp.Value.method;
  134. var delegateDef = delegateType.Resolve();
  135. var hasAttributeAlready = delegateDef.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == _unmanagedFunctionPointerAttributeCtorDef.DeclaringType.FullName);
  136. // If there is already an an attribute present
  137. if (hasAttributeAlready!=null)
  138. {
  139. if (hasAttributeAlready.ConstructorArguments.Count==1)
  140. {
  141. var cc = (System.Runtime.InteropServices.CallingConvention)hasAttributeAlready.ConstructorArguments[0].Value;
  142. if (cc == System.Runtime.InteropServices.CallingConvention.Cdecl)
  143. {
  144. if (_logLevel > 2) _debugLog?.Invoke($"UnmanagedAttributeFixups Skipping appending unmanagedFunctionPointerAttribute as already present aand calling convention matches");
  145. }
  146. else
  147. {
  148. // constructor with non cdecl calling convention
  149. _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer(CallingConvention.{ Enum.GetName(typeof(System.Runtime.InteropServices.CallingConvention), cc) })]` please remove the attribute if you wish to use this function with Burst.");
  150. }
  151. }
  152. else
  153. {
  154. // Empty constructor which defaults to Winapi which is incompatable
  155. _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer]` please remove the attribute if you wish to use this function with Burst.");
  156. }
  157. continue;
  158. }
  159. var attribute = new CustomAttribute(delegateType.Module.ImportReference(_unmanagedFunctionPointerAttributeCtorDef));
  160. attribute.ConstructorArguments.Add(new CustomAttributeArgument(delegateType.Module.ImportReference(_callingConventionType), System.Runtime.InteropServices.CallingConvention.Cdecl));
  161. delegateDef.CustomAttributes.Add(attribute);
  162. modified = true;
  163. }
  164. return modified;
  165. }
  166. private bool ProcessIl2cppInvokeFixups()
  167. {
  168. if (_monoPInvokeAttributeCtorDef == null)
  169. return false;
  170. bool modified = false;
  171. foreach (var invokeNeeded in _needsIl2cppInvoke)
  172. {
  173. var declaringType = invokeNeeded.Value;
  174. var implementationMethod = invokeNeeded.Key;
  175. #if UNITY_DOTSPLAYER
  176. // At present always uses monoPInvokeAttribute because we don't currently know if the burst is enabled
  177. var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef));
  178. implementationMethod.CustomAttributes.Add(attribute);
  179. modified = true;
  180. #else
  181. // Unity requires a type parameter for the attributecallback
  182. if (declaringType == null)
  183. {
  184. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing declaringType for {implementationMethod}");
  185. continue;
  186. }
  187. var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef));
  188. attribute.ConstructorArguments.Add(new CustomAttributeArgument(implementationMethod.Module.ImportReference(_systemType), implementationMethod.Module.ImportReference(declaringType)));
  189. implementationMethod.CustomAttributes.Add(attribute);
  190. modified = true;
  191. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Added InvokeCallbackAttribute to {implementationMethod}");
  192. #endif
  193. }
  194. return modified;
  195. }
  196. private bool ProcessFunctionPointerInvokes()
  197. {
  198. var madeChange = false;
  199. foreach (var capturedData in _capturedSets)
  200. {
  201. var latePatchMethod = capturedData.Key;
  202. var capturedList = capturedData.Value;
  203. latePatchMethod.Body.SimplifyMacros(); // De-optimise short branches, since we will end up inserting instructions
  204. foreach(var capturedInfo in capturedList)
  205. {
  206. var captured = capturedInfo.Captured;
  207. var operand = capturedInfo.Operand;
  208. if (captured.Count!=2)
  209. {
  210. _debugLog?.Invoke($"FunctionPtrInvoke.Finish: expected 2 instructions - Unable to optimise this reference");
  211. continue;
  212. }
  213. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish:{Environment.NewLine} latePatchMethod:{latePatchMethod}{Environment.NewLine} captureList:{capturedList}{Environment.NewLine} capture0:{captured[0]}{Environment.NewLine} operand:{operand}");
  214. var processor = latePatchMethod.Body.GetILProcessor();
  215. var callsite = new CallSite(operand.ReturnType) { CallingConvention = MethodCallingConvention.C };
  216. for (int oo = 0; oo < operand.Parameters.Count; oo++)
  217. {
  218. callsite.Parameters.Add(operand.Parameters[oo]);
  219. }
  220. // Make sure everything is in order before we make a change
  221. var originalGetInvoke = captured[0];
  222. if (originalGetInvoke.Operand is MethodReference mmr)
  223. {
  224. var genericMethodDef = mmr.Resolve();
  225. var genericInstanceType = mmr.DeclaringType as GenericInstanceType;
  226. var genericInstanceDef = genericInstanceType.Resolve();
  227. // Locate the correct instance method - we know already at this point we have an instance of Function
  228. MethodReference mr = default;
  229. bool failed = true;
  230. foreach (var m in genericInstanceDef.Methods)
  231. {
  232. if (m.FullName.Contains("get_Value"))
  233. {
  234. mr = m;
  235. failed = false;
  236. break;
  237. }
  238. }
  239. if (failed)
  240. {
  241. _debugLog?.Invoke($"FunctionPtrInvoke.Finish: failed to locate get_Value method on {genericInstanceDef} - Unable to optimise this reference");
  242. continue;
  243. }
  244. var newGenericRef = new MethodReference(mr.Name, mr.ReturnType, genericInstanceType)
  245. {
  246. HasThis = mr.HasThis,
  247. ExplicitThis = mr.ExplicitThis,
  248. CallingConvention = mr.CallingConvention
  249. };
  250. foreach (var param in mr.Parameters)
  251. newGenericRef.Parameters.Add(new ParameterDefinition(param.ParameterType));
  252. foreach (var gparam in mr.GenericParameters)
  253. newGenericRef.GenericParameters.Add(new GenericParameter(gparam.Name, newGenericRef));
  254. var importRef = latePatchMethod.Module.ImportReference(newGenericRef);
  255. var newMethodCall = processor.Create(OpCodes.Call, importRef);
  256. // Replace get_invoke with get_Value - Don't use replace though as if the original call is target of a branch
  257. //the branch doesn't get updated.
  258. originalGetInvoke.OpCode = newMethodCall.OpCode;
  259. originalGetInvoke.Operand = newMethodCall.Operand;
  260. // Add local to capture result
  261. var newLocal = new VariableDefinition(mr.ReturnType);
  262. latePatchMethod.Body.Variables.Add(newLocal);
  263. // Store result of get_Value
  264. var storeInst = processor.Create(OpCodes.Stloc, newLocal);
  265. processor.InsertAfter(originalGetInvoke, storeInst);
  266. // Swap invoke with calli
  267. var calli = processor.Create(OpCodes.Calli, callsite);
  268. // We can use replace here, since we already checked this is in the same Basic Block, and thus can't be target of a branch
  269. processor.Replace(captured[1], calli);
  270. // Insert load local prior to calli
  271. var loadValue = processor.Create(OpCodes.Ldloc, newLocal);
  272. processor.InsertBefore(calli, loadValue);
  273. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Optimised {originalGetInvoke} with {newMethodCall}");
  274. madeChange = true;
  275. }
  276. }
  277. latePatchMethod.Body.OptimizeMacros(); // Re-optimise branches
  278. }
  279. return madeChange;
  280. }
  281. public bool Finish()
  282. {
  283. bool madeChange = false;
  284. if (enableInvokeAttribute)
  285. {
  286. madeChange |= ProcessIl2cppInvokeFixups();
  287. }
  288. if (enableUnmangedFunctionPointerInject)
  289. {
  290. madeChange |= ProcessUnmanagedAttributeFixups();
  291. }
  292. if (enableCalliOptimisation)
  293. {
  294. madeChange |= ProcessFunctionPointerInvokes();
  295. }
  296. return madeChange;
  297. }
  298. private bool IsBurstFunctionPointerMethod(MethodReference methodRef, string method, out GenericInstanceType methodInstance)
  299. {
  300. methodInstance = methodRef?.DeclaringType as GenericInstanceType;
  301. return (methodInstance != null && methodInstance.ElementType.FullName == _burstFunctionPointerType.FullName && methodRef.Name == method);
  302. }
  303. private bool IsBurstCompilerMethod(GenericInstanceMethod methodRef, string method)
  304. {
  305. var methodInstance = methodRef?.DeclaringType as TypeReference;
  306. return (methodInstance != null && methodInstance.FullName == _burstCompilerType.FullName && methodRef.Name == method);
  307. }
  308. private void LocateFunctionPointerTCreation(MethodDefinition m, Instruction i)
  309. {
  310. if (i.OpCode == OpCodes.Call)
  311. {
  312. var genInstMethod = i.Operand as GenericInstanceMethod;
  313. if (!IsBurstCompilerMethod(genInstMethod, "CompileFunctionPointer")) return;
  314. if (enableUnmangedFunctionPointerInject)
  315. {
  316. var delegateType = genInstMethod.GenericArguments[0].Resolve();
  317. // We check for null, since unfortunately it is possible that the call is wrapped inside
  318. //another open delegate and we cannot determine the delegate type
  319. if (delegateType != null && !_needsNativeFunctionPointer.ContainsKey(delegateType))
  320. {
  321. _needsNativeFunctionPointer.Add(delegateType, (m,i));
  322. }
  323. }
  324. if (enableInvokeAttribute)
  325. {
  326. // Currently only handles the following pre-pattern (which should cover most common uses)
  327. // ldftn ...
  328. // newobj ...
  329. if (i.Previous?.OpCode != OpCodes.Newobj)
  330. {
  331. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding NewObj {i.Previous}");
  332. return;
  333. }
  334. var newObj = i.Previous;
  335. if (newObj.Previous?.OpCode != OpCodes.Ldftn)
  336. {
  337. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding LdFtn {newObj.Previous}");
  338. return;
  339. }
  340. var ldFtn = newObj.Previous;
  341. // Determine the delegate type
  342. var methodDefinition = newObj.Operand as MethodDefinition;
  343. var declaringType = methodDefinition?.DeclaringType;
  344. // Fetch the implementation method
  345. var implementationMethod = ldFtn.Operand as MethodDefinition;
  346. var hasInvokeAlready = implementationMethod?.CustomAttributes.FirstOrDefault(x =>
  347. (x.AttributeType.FullName == _monoPInvokeAttributeCtorDef.DeclaringType.FullName)
  348. || (_nativePInvokeAttributeCtorDef != null && x.AttributeType.FullName == _nativePInvokeAttributeCtorDef.DeclaringType.FullName));
  349. if (hasInvokeAlready != null)
  350. {
  351. if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Skipping appending Callback Attribute as already present {hasInvokeAlready}");
  352. return;
  353. }
  354. if (implementationMethod == null)
  355. {
  356. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing method from {ldFtn} {ldFtn.Operand}");
  357. return;
  358. }
  359. if (implementationMethod.CustomAttributes.FirstOrDefault(x => x.Constructor.DeclaringType.Name == "BurstCompileAttribute") == null)
  360. {
  361. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing burst attribute from {implementationMethod}");
  362. return;
  363. }
  364. // Need to add the custom attribute
  365. if (!_needsIl2cppInvoke.ContainsKey(implementationMethod))
  366. {
  367. _needsIl2cppInvoke.Add(implementationMethod, declaringType);
  368. }
  369. }
  370. }
  371. }
  372. [Obsolete("Will be removed in a future Burst verison")]
  373. public bool IsInstructionForFunctionPointerInvoke(MethodDefinition m, Instruction i)
  374. {
  375. throw new NotImplementedException();
  376. }
  377. private void CollectDelegateInvokes(MethodDefinition m)
  378. {
  379. if (!(enableCalliOptimisation || enableInvokeAttribute || enableUnmangedFunctionPointerInject))
  380. return;
  381. bool hitGetInvoke = false;
  382. TypeDefinition delegateType = null;
  383. List<Instruction> captured = null;
  384. foreach (var inst in m.Body.Instructions)
  385. {
  386. if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: CurrentInstruction {inst} {inst.Operand}");
  387. // Check for a FunctionPointerT creation
  388. if (enableUnmangedFunctionPointerInject || enableInvokeAttribute)
  389. {
  390. LocateFunctionPointerTCreation(m, inst);
  391. }
  392. if (enableCalliOptimisation)
  393. {
  394. if (!hitGetInvoke)
  395. {
  396. if (inst.OpCode != OpCodes.Call) continue;
  397. if (!IsBurstFunctionPointerMethod(inst.Operand as MethodReference, "get_Invoke", out var methodInstance)) continue;
  398. // At this point we have a call to a FunctionPointer.Invoke
  399. hitGetInvoke = true;
  400. delegateType = methodInstance.GenericArguments[0].Resolve();
  401. captured = new List<Instruction>();
  402. captured.Add(inst); // Capture the get_invoke, we will swap this for get_value and a store to local
  403. }
  404. else
  405. {
  406. if (!(inst.OpCode.FlowControl == FlowControl.Next || inst.OpCode.FlowControl == FlowControl.Call))
  407. {
  408. // Don't perform transform across blocks
  409. hitGetInvoke = false;
  410. }
  411. else
  412. {
  413. if (inst.OpCode == OpCodes.Callvirt)
  414. {
  415. if (inst.Operand is MethodReference mref)
  416. {
  417. var method = mref.Resolve();
  418. if (method.DeclaringType == delegateType)
  419. {
  420. hitGetInvoke = false;
  421. List<CaptureInformation> storage = null;
  422. if (!_capturedSets.TryGetValue(m, out storage))
  423. {
  424. storage = new List<CaptureInformation>();
  425. _capturedSets.Add(m, storage);
  426. }
  427. // Capture the invoke - which we will swap for a load local (stored from the get_value) and a calli
  428. captured.Add(inst);
  429. var captureInfo = new CaptureInformation { Captured = captured, Operand = mref };
  430. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: captureInfo:{captureInfo}{Environment.NewLine}capture0{captured[0]}");
  431. storage.Add(captureInfo);
  432. }
  433. }
  434. else
  435. {
  436. hitGetInvoke = false;
  437. }
  438. }
  439. }
  440. }
  441. }
  442. }
  443. }
  444. }
  445. }