Skip to content

Commit

Permalink
Only read from disk if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
erick-xanadu committed Nov 19, 2024
1 parent 8cfc766 commit 8cb9c83
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 36 deletions.
11 changes: 2 additions & 9 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
if self.options.verbose:
print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)

with tempfile.NamedTemporaryFile(
mode="w", suffix=".mlir", dir=str(workspace), delete=False
) as tmp_infile:
with open(str(workspace) + "/input.test", "w") as tmp_infile:

Check notice on line 456 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L456

Using open without explicitly specifying an encoding (unspecified-encoding)
tmp_infile_name = tmp_infile.name
tmp_infile.write(ir)

Expand All @@ -477,20 +475,15 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
f"catalyst-cli failed with error code {e.returncode}: {e.stderr}"
) from e

with open(output_ir_name, "r", encoding="utf-8") as f:
out_IR = f.read()

if lower_to_llvm:
output = LinkerDriver.run(output_object_name, options=self.options)
output_object_name = str(pathlib.Path(output).absolute())

# Clean up temporary files
if os.path.exists(tmp_infile_name):
os.remove(tmp_infile_name)
if os.path.exists(output_ir_name):
os.remove(output_ir_name)

return output_object_name, out_IR
return output_object_name, output_ir_name

@debug_logger
def run(self, mlir_module, *args, **kwargs):
Expand Down
61 changes: 40 additions & 21 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,10 @@ def __init__(self, fn, compile_options):
self.jaxed_function = None
# IRs are only available for the most recently traced function.
self.jaxpr = None
self.mlir = None # string form (historic presence)
self._mlir = None # string form (historic presence)
self.mlir_module = None
self.qir = None
self._qir = None
self.qir_file = None
self.out_type = None
self.overwrite_ir = None

Expand Down Expand Up @@ -514,6 +515,21 @@ def __init__(self, fn, compile_options):

super().__init__("user_function")

@property
def mlir(self):

Check notice on line 519 in frontend/catalyst/jit.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jit.py#L519

Missing function or method docstring (missing-function-docstring)
if not self._mlir and self.mlir_module:
_, mlir_file = self.canonicalize(self.mlir_module)
with open(mlir_file, "r", encoding="utf-8") as f:
self._mlir = f.read()
return self._mlir

@property
def qir(self):

Check notice on line 527 in frontend/catalyst/jit.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jit.py#L527

Missing function or method docstring (missing-function-docstring)
if not self._qir and self.qir_file:
with open(self.qir_file, "r", encoding="utf-8") as f:
self._qir = f.read()
return self._qir

@debug_logger
def __call__(self, *args, **kwargs):
# Transparantly call Python function in case of nested QJIT calls.
Expand Down Expand Up @@ -555,10 +571,10 @@ def aot_compile(self):
)

if self.compile_options.target in ("mlir", "binary"):
self.mlir_module, self.mlir = self.generate_ir()
self.mlir_module = self.generate_ir()

if self.compile_options.target in ("binary",):
self.compiled_function, self.qir = self.compile()
self.compiled_function, self.qir_file = self.compile()
self.fn_cache.insert(
self.compiled_function, self.user_sig, self.out_treedef, self.workspace
)
Expand Down Expand Up @@ -599,8 +615,8 @@ def jit_compile(self, args, **kwargs):
args, **kwargs
)

self.mlir_module, self.mlir = self.generate_ir()
self.compiled_function, self.qir = self.compile()
self.mlir_module = self.generate_ir()
self.compiled_function, self.qir_file = self.compile()

self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace)

Expand Down Expand Up @@ -696,7 +712,20 @@ def closure(qnode, *args, **kwargs):
PipelineNameUniquer.reset()
return jaxpr, out_type, treedef, dynamic_sig

@instrument(size_from=0, has_finegrained=True)
@debug_logger
def canonicalize(self, mlir_module):
"""Canonicalize the mlir_module"""

# Canonicalize the MLIR since there can be a lot of redundancy coming from JAX.
options = copy.deepcopy(self.compile_options)
options.pipelines = [("0_canonicalize", ["canonicalize"])]
options.lower_to_llvm = False
canonicalizer = Compiler(options)

# TODO: the in-memory and textual form are different after this, consider unification
return canonicalizer.run(mlir_module, self.workspace)

@instrument(has_finegrained=True)
@debug_logger
def generate_ir(self):
"""Generate Catalyst's intermediate representation (IR) as an MLIR module.
Expand All @@ -710,18 +739,8 @@ def generate_ir(self):
# Inject Runtime Library-specific functions (e.g. setup/teardown).
inject_functions(mlir_module, ctx, self.compile_options.seed)

# Canonicalize the MLIR since there can be a lot of redundancy coming from JAX.
options = copy.deepcopy(self.compile_options)
options.pipelines = [("0_canonicalize", ["canonicalize"])]
options.lower_to_llvm = False
canonicalizer = Compiler(options)

# TODO: the in-memory and textual form are different after this, consider unification
_, mlir_string = canonicalizer.run(mlir_module, self.workspace)

return mlir_module, mlir_string
return mlir_module

@instrument(size_from=1, has_finegrained=True)
@debug_logger
def compile(self):
"""Compile an MLIR module to LLVMIR and shared library code.
Expand All @@ -746,19 +765,19 @@ def compile(self):
# `replace` method, so we need to get a regular Python string out of it.
func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
if self.overwrite_ir:
shared_object, llvm_ir = self.compiler.run_from_ir(
shared_object, llvm_ir_file = self.compiler.run_from_ir(
self.overwrite_ir,
str(self.mlir_module.operation.attributes["sym_name"]).replace('"', ""),
self.workspace,
)
else:
shared_object, llvm_ir = self.compiler.run(self.mlir_module, self.workspace)
shared_object, llvm_ir_file = self.compiler.run(self.mlir_module, self.workspace)

compiled_fn = CompiledFunction(
shared_object, func_name, restype, self.out_type, self.compile_options
)

return compiled_fn, llvm_ir
return compiled_fn, llvm_ir_file

@instrument(has_finegrained=True)
@debug_logger
Expand Down
1 change: 1 addition & 0 deletions frontend/test/pytest/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def workflow():
qml.PauliX(wires=0)
return qml.state()

workflow.mlir

Check notice on line 273 in frontend/test/pytest/test_compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_compiler.py#L273

Statement seems to have no effect (pointless-statement)
directory = os.path.join(os.getcwd(), workflow.__name__)
files = os.listdir(directory)
# The directory is non-empty. Should at least contain the original .mlir file
Expand Down
7 changes: 1 addition & 6 deletions frontend/test/pytest/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,7 @@ def f(x: float):

f(2.0)

with pytest.raises(
CompileError,
match="Attempting to get output for pipeline: mlir, "
"but no file was found.\nAre you sure the file exists?",
):
get_compilation_stage(f, "mlir")
get_compilation_stage(f, "mlir")

@pytest.mark.parametrize(
"arg",
Expand Down

0 comments on commit 8cb9c83

Please sign in to comment.