15
15
import os
16
16
import uuid
17
17
import shutil
18
+ import llvmlite .ir
18
19
19
20
from typing import Optional , List
20
21
21
22
from typed_python .compiler .binary_shared_object import LoadedBinarySharedObject , BinarySharedObject
22
23
from typed_python .compiler .directed_graph import DirectedGraph
23
24
from typed_python .compiler .typed_call_target import TypedCallTarget
25
+ import typed_python .compiler .native_ast as native_ast
24
26
from typed_python .SerializationContext import SerializationContext
25
27
from typed_python import Dict , ListOf
26
28
@@ -67,6 +69,9 @@ def __init__(self, cacheDir):
67
69
self .targetsLoaded : Dict [str , TypedCallTarget ] = {}
68
70
# the set of link_names for functions with linked and validated globals (i.e. ready to be run).
69
71
self .targetsValidated = set ()
72
+ # the total number of instructions for each link_name
73
+ self .targetComplexity = Dict (str , int )()
74
+
70
75
# link_name -> link_name
71
76
self .function_dependency_graph = DirectedGraph ()
72
77
# dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
@@ -90,6 +95,20 @@ def getTarget(self, func_name: str) -> TypedCallTarget:
90
95
self .loadForSymbol (link_name )
91
96
return self .targetsLoaded [link_name ]
92
97
98
+ def getIR (self , func_name : str ) -> llvmlite .ir .Function :
99
+ if not self .hasSymbol (func_name ):
100
+ raise ValueError (f'symbol not found for func_name { func_name } ' )
101
+ link_name = self ._select_link_name (func_name )
102
+ module_hash = self .link_name_to_module_hash [link_name ]
103
+ return self .loadedBinarySharedObjects [module_hash ].binarySharedObject .functionIRs [func_name ]
104
+
105
+ def getDefinition (self , func_name : str ) -> native_ast .Function :
106
+ if not self .hasSymbol (func_name ):
107
+ raise ValueError (f'symbol not found for func_name { func_name } ' )
108
+ link_name = self ._select_link_name (func_name )
109
+ module_hash = self .link_name_to_module_hash [link_name ]
110
+ return self .loadedBinarySharedObjects [module_hash ].binarySharedObject .functionDefinitions [func_name ]
111
+
93
112
def _generate_link_name (self , func_name : str , module_hash : str ) -> str :
94
113
return func_name + "." + module_hash
95
114
@@ -126,6 +145,14 @@ def loadForSymbol(self, linkName: str) -> None:
126
145
if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
127
146
raise RuntimeError ('failed to validate globals when loading:' , linkName )
128
147
148
+ def complexityForSymbol (self , func_name : str ) -> int :
149
+ """Get the total number of LLVM instructions for a given symbol."""
150
+ try :
151
+ link_name = self ._select_link_name (func_name )
152
+ return self .targetComplexity [link_name ]
153
+ except KeyError as e :
154
+ raise ValueError (f'No complexity value cached for { func_name } ' ) from e
155
+
129
156
def loadModuleByHash (self , moduleHash : str ) -> None :
130
157
"""Load a module by name.
131
158
@@ -139,23 +166,23 @@ def loadModuleByHash(self, moduleHash: str) -> None:
139
166
140
167
# TODO (Will) - store these names as module consts, use one .dat only
141
168
with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
142
- # func_name -> typedcalltarget
143
169
callTargets = SerializationContext ().deserialize (f .read ())
144
-
145
170
with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
146
171
serializedGlobalVarDefs = SerializationContext ().deserialize (f .read ())
147
-
148
172
with open (os .path .join (targetDir , "native_type_manifest.dat" ), "rb" ) as f :
149
173
functionNameToNativeType = SerializationContext ().deserialize (f .read ())
150
-
151
174
with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
152
175
submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
153
-
154
176
with open (os .path .join (targetDir , "function_dependencies.dat" ), "rb" ) as f :
155
177
dependency_edgelist = SerializationContext ().deserialize (f .read ())
156
-
157
178
with open (os .path .join (targetDir , "global_dependencies.dat" ), "rb" ) as f :
158
179
globalDependencies = SerializationContext ().deserialize (f .read ())
180
+ with open (os .path .join (targetDir , "function_complexities.dat" ), "rb" ) as f :
181
+ functionComplexities = SerializationContext ().deserialize (f .read ())
182
+ with open (os .path .join (targetDir , "function_irs.dat" ), "rb" ) as f :
183
+ functionIRs = SerializationContext ().deserialize (f .read ())
184
+ with open (os .path .join (targetDir , "function_definitions.dat" ), "rb" ) as f :
185
+ functionDefinitions = SerializationContext ().deserialize (f .read ())
159
186
160
187
# load the submodules first
161
188
for submodule in submodules :
@@ -167,7 +194,10 @@ def loadModuleByHash(self, moduleHash: str) -> None:
167
194
modulePath ,
168
195
serializedGlobalVarDefs ,
169
196
functionNameToNativeType ,
170
- globalDependencies
197
+ globalDependencies ,
198
+ functionComplexities ,
199
+ functionIRs ,
200
+ functionDefinitions
171
201
).loadFromPath (modulePath )
172
202
173
203
self .loadedBinarySharedObjects [moduleHash ] = loaded
@@ -177,8 +207,9 @@ def loadModuleByHash(self, moduleHash: str) -> None:
177
207
assert link_name not in self .targetsLoaded
178
208
self .targetsLoaded [link_name ] = callTarget
179
209
180
- link_name_global_dependencies = { self ._generate_link_name ( x , moduleHash ): y for x , y in globalDependencies . items ()}
210
+ self .targetComplexity . update ( functionComplexities )
181
211
212
+ link_name_global_dependencies = {self ._generate_link_name (x , moduleHash ): y for x , y in globalDependencies .items ()}
182
213
assert not any (key in self .global_dependencies for key in link_name_global_dependencies )
183
214
184
215
self .global_dependencies .update (link_name_global_dependencies )
@@ -314,6 +345,15 @@ def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget
314
345
with open (os .path .join (tempTargetDir , "global_dependencies.dat" ), "wb" ) as f :
315
346
f .write (SerializationContext ().serialize (binarySharedObject .globalDependencies ))
316
347
348
+ with open (os .path .join (tempTargetDir , "function_complexities.dat" ), "wb" ) as f :
349
+ f .write (SerializationContext ().serialize (binarySharedObject .functionComplexities ))
350
+
351
+ with open (os .path .join (tempTargetDir , "function_irs.dat" ), "wb" ) as f :
352
+ f .write (SerializationContext ().serialize (binarySharedObject .functionIRs ))
353
+
354
+ with open (os .path .join (tempTargetDir , "function_definitions.dat" ), "wb" ) as f :
355
+ f .write (SerializationContext ().serialize (binarySharedObject .functionDefinitions ))
356
+
317
357
try :
318
358
os .rename (tempTargetDir , targetDir )
319
359
except IOError :
0 commit comments