Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
# temp directory instead of a hard-coded path in order to avoid data races
# by default.
input_dir = os.path.dirname(os.path.abspath(args.input_file))
temp_dir = (
Path(input_dir if args.temp_dir is None else args.temp_dir)
/ "onnx-importer-temp"
)
temp_dir = Path(args.temp_dir or input_dir) / "onnx-importer-temp"
shutil.rmtree(temp_dir, ignore_errors=True)
temp_dir.mkdir(exist_ok=True)

Expand Down Expand Up @@ -121,10 +118,13 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
# onnx.shape_inference.infer_shapes_path(temp_raw_file, temp_inferred_file)
# inferred_model = onnx.load(temp_inferred_file)

data_dir = Path(args.data_dir or input_dir)

# Model is too big for in-memory inference: do file-based shape inference
# to a temp file.
# First need to save as model when it has been changed (e.g. version conversion).
if raw_model_modified:
data_dir = temp_dir
temp_raw_file = temp_dir / "raw.onnx"
onnx.save(raw_model, temp_raw_file, save_as_external_data=True)
temp_inferred_file = temp_dir / "inferred.onnx"
Expand All @@ -146,7 +146,6 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:

# Load the temp file and the external data.
inferred_model = onnx.load(temp_inferred_file, load_external_data=False)
data_dir = Path(input_dir if args.temp_dir is None else args.data_dir)
onnx.load_external_data_for_model(inferred_model, str(data_dir))

# Remove the inferred shape file unless asked to keep it
Expand Down
82 changes: 68 additions & 14 deletions test/python/onnx_importer/command_line_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import shutil
import sys
import subprocess
import tempfile
import unittest
import unittest.mock

Expand All @@ -39,6 +40,8 @@

OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

MOCK_MAXIMUM_PROTOBUF = 1 << 20


def const_model() -> onnx.ModelProto:
# Note: data_path must be relative to model_file
Expand Down Expand Up @@ -87,7 +90,26 @@ def linear_model() -> onnx.ModelProto:
return onnx_model


ALL_MODELS = [const_model, linear_model]
def path_based_shape_inference_model() -> onnx.ModelProto:
# Create a model with a serialized form that's large enough to require
# path-based shape inference.
dtype = numpy.float32
byte_size = numpy.dtype(dtype).itemsize
tensor_size = MOCK_MAXIMUM_PROTOBUF // byte_size + 1
large_tensor = numpy.random.rand(tensor_size).astype(dtype)
assert large_tensor.nbytes > MOCK_MAXIMUM_PROTOBUF
node1 = make_node(
"Constant",
[],
["large_const"],
value=numpy_helper.from_array(large_tensor, name="large_const"),
)
X = make_tensor_value_info("large_const", TensorProto.FLOAT, [tensor_size])
graph = make_graph([node1], "large_const_graph", [], [X])
return make_model(graph)


ALL_MODELS = [const_model, linear_model, path_based_shape_inference_model]


class CommandLineTest(unittest.TestCase):
Expand All @@ -110,7 +132,12 @@ def run_model_intern(self, onnx_model: onnx.ModelProto, model_name: str):
args = __main__.parse_arguments([str(model_file), "-o", str(mlir_file)])
__main__.main(args)

def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
def run_model_extern(
self,
onnx_model: onnx.ModelProto,
model_name: str,
extra_args: list[str] | None = None,
):
run_path = self.get_run_path(model_name)
model_file = run_path / f"{model_name}-e.onnx"
mlir_file = run_path / f"{model_name}-e.torch.mlir"
Expand All @@ -127,20 +154,41 @@ def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
onnx.save(onnx_model, model_file)
temp_dir = run_path / "temp"
temp_dir.mkdir(exist_ok=True)
args = __main__.parse_arguments(
[
str(model_file),
"-o",
str(mlir_file),
"--keep-temps",
"--temp-dir",
str(temp_dir),
"--data-dir",
str(run_path),
]
)
raw_args = [
str(model_file),
"-o",
str(mlir_file),
"--keep-temps",
"--temp-dir",
str(temp_dir),
"--data-dir",
str(run_path),
]
if extra_args:
raw_args.extend(extra_args)
args = __main__.parse_arguments(raw_args)
__main__.main(args)

@unittest.mock.patch("onnx.checker.MAXIMUM_PROTOBUF", MOCK_MAXIMUM_PROTOBUF)
def run_model_explicit_temp_implicit_data(
self, onnx_model: onnx.ModelProto, model_name: str
):
run_path = self.get_run_path(model_name)
model_file = run_path / f"{model_name}-explicit_temp_implicit_data.onnx"
mlir_file = run_path / f"{model_name}-explicit_temp_implicit_data.torch.mlir"
onnx.save(onnx_model, model_file)
with tempfile.TemporaryDirectory(dir=run_path) as temp_dir:
args = __main__.parse_arguments(
[
str(model_file),
"-o",
str(mlir_file),
"--temp-dir",
str(temp_dir),
]
)
__main__.main(args)

def test_all(self):
for model_func in ALL_MODELS:
model_name = model_func.__name__
Expand All @@ -150,6 +198,12 @@ def test_all(self):
self.run_model_intern(model, model_name)
with self.subTest("External data"):
self.run_model_extern(model, model_name)
with self.subTest("External data, raw model modified"):
self.run_model_extern(
model, model_name, extra_args=["--clear-domain"]
)
with self.subTest("Explicit temp dir, implicit data dir"):
self.run_model_explicit_temp_implicit_data(model, model_name)


if __name__ == "__main__":
Expand Down
Loading