No Description
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionPointerInvokeTransform.cs 26KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using Burst.Compiler.IL.Syntax;
  6. using Mono.Cecil;
  7. using Mono.Cecil.Cil;
  8. using Mono.Cecil.Rocks;
  9. namespace zzzUnity.Burst.CodeGen
  10. {
  11. /// <summary>
  12. /// Transforms a direct invoke on a burst function pointer into an calli, avoiding the need to marshal the delegate back.
  13. /// </summary>
  14. internal class FunctionPointerInvokeTransform
  15. {
  16. private struct CaptureInformation
  17. {
  18. public MethodReference Operand;
  19. public List<Instruction> Captured;
  20. }
  21. private Dictionary<TypeReference, (MethodDefinition method, Instruction instruction)> _needsNativeFunctionPointer;
  22. private Dictionary<MethodDefinition, TypeReference> _needsIl2cppInvoke;
  23. private Dictionary<MethodDefinition, List<CaptureInformation>> _capturedSets;
  24. private MethodDefinition _monoPInvokeAttributeCtorDef;
  25. private MethodDefinition _nativePInvokeAttributeCtorDef;
  26. private MethodDefinition _unmanagedFunctionPointerAttributeCtorDef;
  27. private TypeReference _burstFunctionPointerType;
  28. private TypeReference _burstCompilerType;
  29. private TypeReference _systemType;
  30. private TypeReference _callingConventionType;
  31. private LogDelegate _debugLog;
  32. private int _logLevel;
  33. private AssemblyResolver _loader;
  34. private ErrorDiagnosticDelegate _errorReport;
  35. public readonly static bool enableInvokeAttribute = true;
  36. public readonly static bool enableCalliOptimisation = false; // For now only run the pass on dots player/tiny
  37. public readonly static bool enableUnmangedFunctionPointerInject = true;
  38. public FunctionPointerInvokeTransform(AssemblyResolver loader,ErrorDiagnosticDelegate error, LogDelegate log = null, int logLevel = 0)
  39. {
  40. _loader = loader;
  41. _needsNativeFunctionPointer = new Dictionary<TypeReference, (MethodDefinition, Instruction)>();
  42. _needsIl2cppInvoke = new Dictionary<MethodDefinition, TypeReference>();
  43. _capturedSets = new Dictionary<MethodDefinition, List<CaptureInformation>>();
  44. _monoPInvokeAttributeCtorDef = null;
  45. _unmanagedFunctionPointerAttributeCtorDef = null;
  46. _nativePInvokeAttributeCtorDef = null; // Only present on DOTS_PLAYER
  47. _burstFunctionPointerType = null;
  48. _burstCompilerType = null;
  49. _systemType = null;
  50. _callingConventionType = null;
  51. _debugLog = log;
  52. _logLevel = logLevel;
  53. _errorReport = error;
  54. }
  55. private AssemblyDefinition GetAsmDefinitionFromFile(AssemblyResolver loader, string assemblyName)
  56. {
  57. if (loader.TryResolve(AssemblyNameReference.Parse(assemblyName), out var result))
  58. {
  59. return result;
  60. }
  61. return null;
  62. }
  63. public void Initialize(AssemblyResolver loader, AssemblyDefinition assemblyDefinition, TypeSystem typeSystem)
  64. {
  65. if (_monoPInvokeAttributeCtorDef == null)
  66. {
  67. var burstAssembly = GetAsmDefinitionFromFile(loader, "Unity.Burst");
  68. _burstFunctionPointerType = burstAssembly.MainModule.GetType("Unity.Burst.FunctionPointer`1");
  69. _burstCompilerType = burstAssembly.MainModule.GetType("Unity.Burst.BurstCompiler");
  70. var corLibrary = loader.Resolve(typeSystem.CoreLibrary as AssemblyNameReference);
  71. // If the corLibrary is a redirecting assembly, then the type isn't present in Types
  72. // and GetType() will therefore not find it, so instead we'll have to look it up in ExportedTypes
  73. Func<string, TypeDefinition> getCorLibTy = (name) =>
  74. {
  75. return corLibrary.MainModule.GetType(name) ??
  76. corLibrary.MainModule.ExportedTypes.FirstOrDefault(x => x.FullName == name)?.Resolve();
  77. };
  78. _systemType = getCorLibTy("System.Type"); // Only needed for MonoPInvokeCallback constructor in Unity
  79. if (enableUnmangedFunctionPointerInject)
  80. {
  81. var unmanagedFunctionPointerAttribute = getCorLibTy("System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute");
  82. _callingConventionType = getCorLibTy("System.Runtime.InteropServices.CallingConvention");
  83. _unmanagedFunctionPointerAttributeCtorDef = unmanagedFunctionPointerAttribute.GetConstructors().Single(c => c.Parameters.Count == 1 && c.Parameters[0].ParameterType.MetadataType == _callingConventionType.MetadataType);
  84. }
  85. var asmDef = GetAsmDefinitionFromFile(loader, "UnityEngine.CoreModule");
  86. // bail if we can't find a reference, handled gracefully later
  87. if (asmDef == null)
  88. return;
  89. var monoPInvokeAttribute = asmDef.MainModule.GetType("AOT.MonoPInvokeCallbackAttribute");
  90. _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First();
  91. }
  92. }
  93. public bool Run(AssemblyDefinition assemblyDefinition)
  94. {
  95. Initialize(_loader, assemblyDefinition, assemblyDefinition.MainModule.TypeSystem);
  96. var types = assemblyDefinition.MainModule.GetTypes().ToArray();
  97. foreach (var type in types)
  98. {
  99. CollectDelegateInvokesFromType(type);
  100. }
  101. return Finish();
  102. }
  103. public void CollectDelegateInvokesFromType(TypeDefinition type)
  104. {
  105. foreach (var m in type.Methods)
  106. {
  107. if (m.HasBody)
  108. {
  109. CollectDelegateInvokes(m);
  110. }
  111. }
  112. }
  113. private bool ProcessUnmanagedAttributeFixups()
  114. {
  115. if (_unmanagedFunctionPointerAttributeCtorDef == null)
  116. return false;
  117. bool modified = false;
  118. foreach (var kp in _needsNativeFunctionPointer)
  119. {
  120. var delegateType = kp.Key;
  121. var instruction = kp.Value.instruction;
  122. var method = kp.Value.method;
  123. var delegateDef = delegateType.Resolve();
  124. var hasAttributeAlready = delegateDef.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == _unmanagedFunctionPointerAttributeCtorDef.DeclaringType.FullName);
  125. // If there is already an an attribute present
  126. if (hasAttributeAlready!=null)
  127. {
  128. if (hasAttributeAlready.ConstructorArguments.Count==1)
  129. {
  130. var cc = (System.Runtime.InteropServices.CallingConvention)hasAttributeAlready.ConstructorArguments[0].Value;
  131. if (cc == System.Runtime.InteropServices.CallingConvention.Cdecl)
  132. {
  133. if (_logLevel > 2) _debugLog?.Invoke($"UnmanagedAttributeFixups Skipping appending unmanagedFunctionPointerAttribute as already present aand calling convention matches");
  134. }
  135. else
  136. {
  137. // constructor with non cdecl calling convention
  138. _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer(CallingConvention.{ Enum.GetName(typeof(System.Runtime.InteropServices.CallingConvention), cc) })]` please remove the attribute if you wish to use this function with Burst.");
  139. }
  140. }
  141. else
  142. {
  143. // Empty constructor which defaults to Winapi which is incompatable
  144. _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer]` please remove the attribute if you wish to use this function with Burst.");
  145. }
  146. continue;
  147. }
  148. var attribute = new CustomAttribute(delegateType.Module.ImportReference(_unmanagedFunctionPointerAttributeCtorDef));
  149. attribute.ConstructorArguments.Add(new CustomAttributeArgument(delegateType.Module.ImportReference(_callingConventionType), System.Runtime.InteropServices.CallingConvention.Cdecl));
  150. delegateDef.CustomAttributes.Add(attribute);
  151. modified = true;
  152. }
  153. return modified;
  154. }
  155. private bool ProcessIl2cppInvokeFixups()
  156. {
  157. if (_monoPInvokeAttributeCtorDef == null)
  158. return false;
  159. bool modified = false;
  160. foreach (var invokeNeeded in _needsIl2cppInvoke)
  161. {
  162. var declaringType = invokeNeeded.Value;
  163. var implementationMethod = invokeNeeded.Key;
  164. // Unity requires a type parameter for the attributecallback
  165. if (declaringType == null)
  166. {
  167. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing declaringType for {implementationMethod}");
  168. continue;
  169. }
  170. var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef));
  171. attribute.ConstructorArguments.Add(new CustomAttributeArgument(implementationMethod.Module.ImportReference(_systemType), implementationMethod.Module.ImportReference(declaringType)));
  172. implementationMethod.CustomAttributes.Add(attribute);
  173. modified = true;
  174. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Added InvokeCallbackAttribute to {implementationMethod}");
  175. }
  176. return modified;
  177. }
  178. private bool ProcessFunctionPointerInvokes()
  179. {
  180. var madeChange = false;
  181. foreach (var capturedData in _capturedSets)
  182. {
  183. var latePatchMethod = capturedData.Key;
  184. var capturedList = capturedData.Value;
  185. latePatchMethod.Body.SimplifyMacros(); // De-optimise short branches, since we will end up inserting instructions
  186. foreach(var capturedInfo in capturedList)
  187. {
  188. var captured = capturedInfo.Captured;
  189. var operand = capturedInfo.Operand;
  190. if (captured.Count!=2)
  191. {
  192. _debugLog?.Invoke($"FunctionPtrInvoke.Finish: expected 2 instructions - Unable to optimise this reference");
  193. continue;
  194. }
  195. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish:{Environment.NewLine} latePatchMethod:{latePatchMethod}{Environment.NewLine} captureList:{capturedList}{Environment.NewLine} capture0:{captured[0]}{Environment.NewLine} operand:{operand}");
  196. var processor = latePatchMethod.Body.GetILProcessor();
  197. var genericContext = GenericContext.From(operand, operand.DeclaringType);
  198. CallSite callsite;
  199. try
  200. {
  201. callsite = new CallSite(genericContext.Resolve(operand.ReturnType))
  202. {
  203. CallingConvention = MethodCallingConvention.C
  204. };
  205. for (int oo = 0; oo < operand.Parameters.Count; oo++)
  206. {
  207. var param = operand.Parameters[oo];
  208. var ty = genericContext.Resolve(param.ParameterType);
  209. callsite.Parameters.Add(new ParameterDefinition(param.Name, param.Attributes, ty));
  210. }
  211. }
  212. catch (NullReferenceException)
  213. {
  214. _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Failed to resolve the generic context of `{operand}`");
  215. continue;
  216. }
  217. // Make sure everything is in order before we make a change
  218. var originalGetInvoke = captured[0];
  219. if (originalGetInvoke.Operand is MethodReference mmr)
  220. {
  221. var genericMethodDef = mmr.Resolve();
  222. var genericInstanceType = mmr.DeclaringType as GenericInstanceType;
  223. var genericInstanceDef = genericInstanceType.Resolve();
  224. // Locate the correct instance method - we know already at this point we have an instance of Function
  225. MethodReference mr = default;
  226. bool failed = true;
  227. foreach (var m in genericInstanceDef.Methods)
  228. {
  229. if (m.FullName.Contains("get_Value"))
  230. {
  231. mr = m;
  232. failed = false;
  233. break;
  234. }
  235. }
  236. if (failed)
  237. {
  238. _debugLog?.Invoke($"FunctionPtrInvoke.Finish: failed to locate get_Value method on {genericInstanceDef} - Unable to optimise this reference");
  239. continue;
  240. }
  241. var newGenericRef = new MethodReference(mr.Name, mr.ReturnType, genericInstanceType)
  242. {
  243. HasThis = mr.HasThis,
  244. ExplicitThis = mr.ExplicitThis,
  245. CallingConvention = mr.CallingConvention
  246. };
  247. foreach (var param in mr.Parameters)
  248. newGenericRef.Parameters.Add(new ParameterDefinition(param.ParameterType));
  249. foreach (var gparam in mr.GenericParameters)
  250. newGenericRef.GenericParameters.Add(new GenericParameter(gparam.Name, newGenericRef));
  251. var importRef = latePatchMethod.Module.ImportReference(newGenericRef);
  252. var newMethodCall = processor.Create(OpCodes.Call, importRef);
  253. // Replace get_invoke with get_Value - Don't use replace though as if the original call is target of a branch
  254. //the branch doesn't get updated.
  255. originalGetInvoke.OpCode = newMethodCall.OpCode;
  256. originalGetInvoke.Operand = newMethodCall.Operand;
  257. // Add local to capture result
  258. var newLocal = new VariableDefinition(mr.ReturnType);
  259. latePatchMethod.Body.Variables.Add(newLocal);
  260. // Store result of get_Value
  261. var storeInst = processor.Create(OpCodes.Stloc, newLocal);
  262. processor.InsertAfter(originalGetInvoke, storeInst);
  263. // Swap invoke with calli
  264. var calli = processor.Create(OpCodes.Calli, callsite);
  265. // We can use replace here, since we already checked this is in the same Basic Block, and thus can't be target of a branch
  266. processor.Replace(captured[1], calli);
  267. // Insert load local prior to calli
  268. var loadValue = processor.Create(OpCodes.Ldloc, newLocal);
  269. processor.InsertBefore(calli, loadValue);
  270. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Optimised {originalGetInvoke} with {newMethodCall}");
  271. madeChange = true;
  272. }
  273. }
  274. latePatchMethod.Body.OptimizeMacros(); // Re-optimise branches
  275. }
  276. return madeChange;
  277. }
  278. public bool Finish()
  279. {
  280. bool madeChange = false;
  281. if (enableInvokeAttribute)
  282. {
  283. madeChange |= ProcessIl2cppInvokeFixups();
  284. }
  285. if (enableUnmangedFunctionPointerInject)
  286. {
  287. madeChange |= ProcessUnmanagedAttributeFixups();
  288. }
  289. if (enableCalliOptimisation)
  290. {
  291. madeChange |= ProcessFunctionPointerInvokes();
  292. }
  293. return madeChange;
  294. }
  295. private bool IsBurstFunctionPointerMethod(MethodReference methodRef, string method, out GenericInstanceType methodInstance)
  296. {
  297. methodInstance = methodRef?.DeclaringType as GenericInstanceType;
  298. return (methodInstance != null && methodInstance.ElementType.FullName == _burstFunctionPointerType.FullName && methodRef.Name == method);
  299. }
  300. private bool IsBurstCompilerMethod(GenericInstanceMethod methodRef, string method)
  301. {
  302. var methodInstance = methodRef?.DeclaringType as TypeReference;
  303. return (methodInstance != null && methodInstance.FullName == _burstCompilerType.FullName && methodRef.Name == method);
  304. }
  305. private void LocateFunctionPointerTCreation(MethodDefinition m, Instruction i)
  306. {
  307. if (i.OpCode == OpCodes.Call)
  308. {
  309. var genInstMethod = i.Operand as GenericInstanceMethod;
  310. var isBurstCompilerCompileFunctionPointer = IsBurstCompilerMethod(genInstMethod, "CompileFunctionPointer");
  311. var isBurstFunctionPointerGetInvoke = IsBurstFunctionPointerMethod(i.Operand as MethodReference, "get_Invoke", out var instanceType);
  312. if (!(isBurstCompilerCompileFunctionPointer || isBurstFunctionPointerGetInvoke)) return;
  313. if (enableUnmangedFunctionPointerInject)
  314. {
  315. var delegateType = isBurstCompilerCompileFunctionPointer ? genInstMethod.GenericArguments[0].Resolve() : instanceType.GenericArguments[0].Resolve();
  316. // We check for null, since unfortunately it is possible that the call is wrapped inside
  317. //another open delegate and we cannot determine the delegate type
  318. if (delegateType != null && !_needsNativeFunctionPointer.ContainsKey(delegateType))
  319. {
  320. _needsNativeFunctionPointer.Add(delegateType, (m, i));
  321. }
  322. }
  323. // No need to process further if its not a CompileFunctionPointer method
  324. if (!isBurstCompilerCompileFunctionPointer) return;
  325. if (enableInvokeAttribute)
  326. {
  327. // Currently only handles the following pre-pattern (which should cover most common uses)
  328. // ldftn ...
  329. // newobj ...
  330. if (i.Previous?.OpCode != OpCodes.Newobj)
  331. {
  332. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding NewObj {i.Previous}");
  333. return;
  334. }
  335. var newObj = i.Previous;
  336. if (newObj.Previous?.OpCode != OpCodes.Ldftn)
  337. {
  338. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding LdFtn {newObj.Previous}");
  339. return;
  340. }
  341. var ldFtn = newObj.Previous;
  342. // Determine the delegate type
  343. var methodDefinition = newObj.Operand as MethodDefinition;
  344. var declaringType = methodDefinition?.DeclaringType;
  345. // Fetch the implementation method
  346. var implementationMethod = ldFtn.Operand as MethodDefinition;
  347. var hasInvokeAlready = implementationMethod?.CustomAttributes.FirstOrDefault(x =>
  348. (x.AttributeType.FullName == _monoPInvokeAttributeCtorDef.DeclaringType.FullName)
  349. || (_nativePInvokeAttributeCtorDef != null && x.AttributeType.FullName == _nativePInvokeAttributeCtorDef.DeclaringType.FullName));
  350. if (hasInvokeAlready != null)
  351. {
  352. if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Skipping appending Callback Attribute as already present {hasInvokeAlready}");
  353. return;
  354. }
  355. if (implementationMethod == null)
  356. {
  357. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing method from {ldFtn} {ldFtn.Operand}");
  358. return;
  359. }
  360. if (implementationMethod.CustomAttributes.FirstOrDefault(x => x.Constructor.DeclaringType.Name == "BurstCompileAttribute") == null)
  361. {
  362. _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing burst attribute from {implementationMethod}");
  363. return;
  364. }
  365. // Need to add the custom attribute
  366. if (!_needsIl2cppInvoke.ContainsKey(implementationMethod))
  367. {
  368. _needsIl2cppInvoke.Add(implementationMethod, declaringType);
  369. }
  370. }
  371. }
  372. }
  373. [Obsolete("Will be removed in a future Burst verison")]
  374. public bool IsInstructionForFunctionPointerInvoke(MethodDefinition m, Instruction i)
  375. {
  376. throw new NotImplementedException();
  377. }
  378. private void CollectDelegateInvokes(MethodDefinition m)
  379. {
  380. if (!(enableCalliOptimisation || enableInvokeAttribute || enableUnmangedFunctionPointerInject))
  381. return;
  382. bool hitGetInvoke = false;
  383. TypeDefinition delegateType = null;
  384. List<Instruction> captured = null;
  385. foreach (var inst in m.Body.Instructions)
  386. {
  387. if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: CurrentInstruction {inst} {inst.Operand}");
  388. // Check for a FunctionPointerT creation
  389. if (enableUnmangedFunctionPointerInject || enableInvokeAttribute)
  390. {
  391. LocateFunctionPointerTCreation(m, inst);
  392. }
  393. if (enableCalliOptimisation)
  394. {
  395. if (!hitGetInvoke)
  396. {
  397. if (inst.OpCode != OpCodes.Call) continue;
  398. if (!IsBurstFunctionPointerMethod(inst.Operand as MethodReference, "get_Invoke", out var methodInstance)) continue;
  399. // At this point we have a call to a FunctionPointer.Invoke
  400. hitGetInvoke = true;
  401. delegateType = methodInstance.GenericArguments[0].Resolve();
  402. captured = new List<Instruction>();
  403. captured.Add(inst); // Capture the get_invoke, we will swap this for get_value and a store to local
  404. }
  405. else
  406. {
  407. if (!(inst.OpCode.FlowControl == FlowControl.Next || inst.OpCode.FlowControl == FlowControl.Call))
  408. {
  409. // Don't perform transform across blocks
  410. hitGetInvoke = false;
  411. }
  412. else
  413. {
  414. if (inst.OpCode == OpCodes.Callvirt)
  415. {
  416. if (inst.Operand is MethodReference mref)
  417. {
  418. var method = mref.Resolve();
  419. if (method.DeclaringType == delegateType)
  420. {
  421. hitGetInvoke = false;
  422. List<CaptureInformation> storage = null;
  423. if (!_capturedSets.TryGetValue(m, out storage))
  424. {
  425. storage = new List<CaptureInformation>();
  426. _capturedSets.Add(m, storage);
  427. }
  428. // Capture the invoke - which we will swap for a load local (stored from the get_value) and a calli
  429. captured.Add(inst);
  430. var captureInfo = new CaptureInformation { Captured = captured, Operand = mref };
  431. if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: captureInfo:{captureInfo}{Environment.NewLine}capture0{captured[0]}");
  432. storage.Add(captureInfo);
  433. }
  434. }
  435. else
  436. {
  437. hitGetInvoke = false;
  438. }
  439. }
  440. }
  441. }
  442. }
  443. }
  444. }
  445. }
  446. }