This patch fixes some scripts for generating source files. For gen_jit_decompositions.py, gen_mobile_upgraders.py and gen_jit_shape_functions.py, which depend on the compiled PyTorch library, the option to generate "dummy" source files is added for the initial build, which is later corrected. codegen_external.py is patched to avoid duplicate functions and add the static keyword as in the existing generated file. diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh index cc0263d..ac34e84 100644 --- a/tools/gen_flatbuffers.sh +++ b/tools/gen_flatbuffers.sh @@ -1,13 +1,13 @@ #!/bin/bash ROOT=$(pwd) -FF_LOCATION="$ROOT/third_party/flatbuffers" -cd "$FF_LOCATION" || exit -mkdir build -cd build || exit -cmake .. -cmake --build . --target flatc -mkdir -p "$ROOT/build/torch/csrc/jit/serialization" -./flatc --cpp --gen-mutable --scoped-enums \ +#FF_LOCATION="$ROOT/third_party/flatbuffers" +#cd "$FF_LOCATION" || exit +#mkdir build +#cd build || exit +#cmake .. +#cmake --build . --target flatc +#mkdir -p "$ROOT/build/torch/csrc/jit/serialization" +flatc --cpp --gen-mutable --scoped-enums \ -o "$ROOT/torch/csrc/jit/serialization" \ -c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs" echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h" diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py index 5dcf1b2..0e20b0c 100644 --- a/torch/csrc/jit/tensorexpr/codegen_external.py +++ b/torch/csrc/jit/tensorexpr/codegen_external.py @@ -21,9 +21,14 @@ def gen_external(native_functions_path, tags_path, external_path): native_functions = parse_native_yaml(native_functions_path, tags_path) func_decls = [] func_registrations = [] - for func in native_functions: + done_names = set() + for func in native_functions[0]: schema = func.func name = schema.name.name.base + if name in done_names: + continue + else: + done_names.add(name) args = schema.arguments # Only supports extern calls for functions with out variants if not schema.is_out_fn(): @@ -63,7 +68,7 @@ def gen_external(native_functions_path, tags_path, external_path): # print(tensor_decls, name, arg_names) func_decl = f"""\ -void nnc_aten_{name}( +static void nnc_aten_{name}( int64_t bufs_num, void** buf_data, int64_t* buf_ranks, diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py index 7a0024f..6b2445f 100644 --- a/torchgen/decompositions/gen_jit_decompositions.py +++ b/torchgen/decompositions/gen_jit_decompositions.py @@ -1,8 +1,12 @@ #!/usr/bin/env python3 import os from pathlib import Path +import sys -from torch.jit._decompositions import decomposition_table +if len(sys.argv) < 2 or sys.argv[1] != "dummy": + from torch.jit._decompositions import decomposition_table +else: + decomposition_table = {} # from torchgen.code_template import CodeTemplate @@ -85,7 +89,7 @@ def write_decomposition_util_file(path: str) -> None: def main() -> None: - pytorch_dir = Path(__file__).resolve().parents[3] + pytorch_dir = Path(__file__).resolve().parents[2] upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime" write_decomposition_util_file(str(upgrader_path)) diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 2907076..6866332 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -3,10 +3,12 @@ import os from enum import Enum from operator import itemgetter from pathlib import Path +import sys from typing import Any, Dict, List -import torch -from torch.jit.generate_bytecode import generate_upgraders_bytecode +if len(sys.argv) < 2 or sys.argv[1] != "dummy": + import torch + from torch.jit.generate_bytecode import generate_upgraders_bytecode from torchgen.code_template import CodeTemplate from torchgen.operator_versions.gen_mobile_upgraders_constant import ( @@ -263,7 +265,10 @@ def construct_register_size(register_size_from_yaml: int) -> str: def construct_version_maps( upgrader_bytecode_function_to_index_map: Dict[str, Any] ) -> str: - version_map = torch._C._get_operator_version_map() + if len(sys.argv) < 2 or sys.argv[1] != "dummy": + version_map = torch._C._get_operator_version_map() + else: + version_map = {} sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] sorted_version_map = dict(sorted_version_map_) @@ -379,7 +384,10 @@ def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def main() -> None: - upgrader_list = generate_upgraders_bytecode() + if len(sys.argv) < 2 or sys.argv[1] != "dummy": + upgrader_list = generate_upgraders_bytecode() + else: + upgrader_list = [] sorted_upgrader_list = sort_upgrader(upgrader_list) for up in sorted_upgrader_list: print("after sort upgrader : ", next(iter(up))) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py index bdfd5c7..72b237a 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -18,16 +18,20 @@ you are in the root directory of the Pytorch git repo""" if not file_path.exists(): raise Exception(err_msg) # noqa: TRY002 -spec = importlib.util.spec_from_file_location(module_name, file_path) -assert spec is not None -module = importlib.util.module_from_spec(spec) -sys.modules[module_name] = module -assert spec.loader is not None -assert module is not None -spec.loader.exec_module(module) - -bounded_compute_graph_mapping = module.bounded_compute_graph_mapping -shape_compute_graph_mapping = module.shape_compute_graph_mapping +if len(sys.argv) < 2 or sys.argv[1] != "dummy": + spec = importlib.util.spec_from_file_location(module_name, file_path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + assert module is not None + spec.loader.exec_module(module) + + bounded_compute_graph_mapping = module.bounded_compute_graph_mapping + shape_compute_graph_mapping = module.shape_compute_graph_mapping +else: + bounded_compute_graph_mapping = {} + shape_compute_graph_mapping = {} SHAPE_HEADER = r"""