From 74200f89ce7b72f81b8be18a3e042c1703a9b4cd Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 25 Feb 2025 22:00:48 +0800 Subject: [PATCH 01/15] cleanup CLEAN_CODE --- .../jit/sot/opcode_translator/executor/pycode_generator.py | 5 ----- python/paddle/jit/sot/utils/__init__.py | 2 -- python/paddle/jit/sot/utils/envs.py | 1 - python/paddle/jit/sot/utils/utils.py | 5 ----- 4 files changed, 13 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index 8d81a478cfb2ea..02c313c3d45eba 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -33,7 +33,6 @@ FallbackError, InnerError, ResumeFnNameFactory, - is_clean_code, list_contain_by_id, list_find_index_by_id, no_eval_frame, @@ -514,8 +513,6 @@ def gen_disable_eval_frame(self): """ Generates instructions to disable the evaluation frame. """ - if is_clean_code(): - return self.gen_load_object( paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" ) @@ -527,8 +524,6 @@ def gen_enable_eval_frame(self): """ Generates instructions to enable the evaluation frame. """ - if is_clean_code(): - return self.gen_load_object( paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" ) diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 601da1af1c144b..4d75c8e2574e4d 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -14,7 +14,6 @@ from .call_ast_utils import get_static_function, try_ast_func # noqa: F401 from .envs import ( # noqa: F401 - ENV_CLEAN_CODE, ENV_COST_MODEL, ENV_MIN_GRAPH_SIZE, ENV_SOT_ALLOW_DYNAMIC_SHAPE, @@ -77,7 +76,6 @@ in_paddle_module, is_break_graph_api, is_builtin_fn, - is_clean_code, is_comprehensive_name, is_paddle_api, is_strict_mode, diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index e84753866d44f4..94c9fd27586eb0 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -29,7 +29,6 @@ ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10) ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0) ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) -ENV_CLEAN_CODE = BooleanEnvironmentVariable("CLEAN_CODE", False) ENV_SOT_WITH_CONTROL_FLOW = BooleanEnvironmentVariable( "SOT_WITH_CONTROL_FLOW", True ) diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 1246e97fb38ebf..6865427d61c692 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -32,7 +32,6 @@ from paddle.utils import flatten, map_structure from .envs import ( - ENV_CLEAN_CODE, ENV_COST_MODEL, ENV_SOT_LOG_LEVEL, ENV_STRICT_MODE, @@ -319,10 +318,6 @@ def is_strict_mode(): return ENV_STRICT_MODE.get() -def is_clean_code() -> bool: - return ENV_CLEAN_CODE.get() - - def list_find_index_by_id(li: list[Any], item: Any) -> int: return [id(it) for it in li].index(id(item)) From 4af86623b13667bf06bbb2b248eda4bca324d5a6 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 26 Feb 2025 10:21:22 +0800 Subject: [PATCH 02/15] cleanup cost model --- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/jit.cc | 20 +-- paddle/fluid/pybind/sot/eval_frame.c | 13 +- .../{eval_frame_tools.cc => skip_files.cc} | 96 +------------ .../sot/{eval_frame_tools.h => skip_files.h} | 4 - .../jit/sot/opcode_translator/skip_files.py | 6 - .../paddle/jit/sot/symbolic/compile_cache.py | 4 - python/paddle/jit/sot/translate.py | 55 +++----- python/paddle/jit/sot/utils/__init__.py | 1 - python/paddle/jit/sot/utils/utils.py | 132 ------------------ test/sot/test_sot_cost_model.py | 114 --------------- 11 files changed, 21 insertions(+), 426 deletions(-) rename paddle/fluid/pybind/sot/{eval_frame_tools.cc => skip_files.cc} (66%) rename paddle/fluid/pybind/sot/{eval_frame_tools.h => skip_files.h} (87%) delete mode 100644 test/sot/test_sot_cost_model.py diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index efd74ede2943e8..cacc145eba54c6 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -140,7 +140,7 @@ set(PYBIND_SRCS xpu_streams_py.cc jit.cc auto_parallel_py.cc - sot/eval_frame_tools.cc + sot/skip_files.cc sot/cpython_internals.c sot/frame_proxy.c sot/eval_frame.c diff --git a/paddle/fluid/pybind/jit.cc b/paddle/fluid/pybind/jit.cc index 6db6d45b136ed9..2721931db0f4d4 100644 --- a/paddle/fluid/pybind/jit.cc +++ b/paddle/fluid/pybind/jit.cc @@ -21,10 +21,10 @@ limitations under the License. */ #include "paddle/fluid/jit/layer.h" #include "paddle/fluid/jit/serializer.h" #include "paddle/fluid/pybind/sot/eval_frame.h" -#include "paddle/fluid/pybind/sot/eval_frame_tools.h" #include "paddle/fluid/pybind/sot/frame_proxy.h" #include "paddle/fluid/pybind/sot/guards.h" #include "paddle/fluid/pybind/sot/macros.h" +#include "paddle/fluid/pybind/sot/skip_files.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/utils/pybind.h" @@ -150,24 +150,6 @@ void BindSot(pybind11::module *m) { return type->tp_getattro != PyObject_GenericGetAttr; }); - m->def( - "sot_setup_codes_with_graph", - [](const py::object &py_codes) { - auto ret = setup_codes_with_graph(py_codes.ptr()); - auto obj = py::reinterpret_borrow(ret); - return obj; - }, - py::arg("py_codes")); - - m->def( - "sot_set_with_graph", - [](const py::object &py_codes) { - auto ret = set_with_graph(py_codes.ptr()); - auto obj = py::reinterpret_borrow(ret); - return obj; - }, - py::arg("py_codes")); - m->def( "eval_frame_no_skip_codes", [](const py::object &py_codes) { diff --git a/paddle/fluid/pybind/sot/eval_frame.c b/paddle/fluid/pybind/sot/eval_frame.c index 073d0d3780d429..cc2d2534459ba8 100644 --- a/paddle/fluid/pybind/sot/eval_frame.c +++ b/paddle/fluid/pybind/sot/eval_frame.c @@ -17,8 +17,8 @@ limitations under the License. */ #if SOT_IS_SUPPORTED #include "paddle/fluid/pybind/sot/cpython_internals.h" -#include "paddle/fluid/pybind/sot/eval_frame_tools.h" #include "paddle/fluid/pybind/sot/frame_proxy.h" +#include "paddle/fluid/pybind/sot/skip_files.h" #include @@ -329,17 +329,6 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, Py_DECREF(f_locals); #endif - // code status - if (is_code_without_graph(code == Py_None ? PyFrame_GET_CODE(frame) - : (PyCodeObject *)code) && - disable_eval_frame == Py_False) { - out = eval_frame_default(tstate, frame, throw_flag); - eval_frame_callback_set(callback); - Py_DECREF(code); - Py_DECREF(disable_eval_frame); - return out; - } - // run code if (disable_eval_frame != Py_True) { // Re-enable custom behavior diff --git a/paddle/fluid/pybind/sot/eval_frame_tools.cc b/paddle/fluid/pybind/sot/skip_files.cc similarity index 66% rename from paddle/fluid/pybind/sot/eval_frame_tools.cc rename to paddle/fluid/pybind/sot/skip_files.cc index 35b1a507e3a9a9..5b8007ce98eb4c 100644 --- a/paddle/fluid/pybind/sot/eval_frame_tools.cc +++ b/paddle/fluid/pybind/sot/skip_files.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pybind/sot/eval_frame_tools.h" +#include "paddle/fluid/pybind/sot/skip_files.h" #include @@ -20,7 +20,6 @@ #include "paddle/common/errors.h" #include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/platform/profiler/event_tracing.h" #if SOT_IS_SUPPORTED #define END_OF_STRING '\0' @@ -142,77 +141,6 @@ int SkipCodeInfo::in_skip_path(PyObject* filename) { return root->check_filename(name); } -/*========================== code status ==============================*/ -enum CodeState { UNKNOW, WITH_GRAPH, WITHOUT_GRAPH }; - -class CodeInfo { - public: - CodeState state; - int counter; -}; - -class CodeStatus { - public: - static CodeStatus& Instance(); - int is_code_without_graph(PyCodeObject* code); - void set_with_graph(PyCodeObject* code); - void add_with_graph_code(PyCodeObject* code); - void clear(); - - private: - CodeStatus() { code_map = std::unordered_map(); } - ~CodeStatus() { clear(); } - std::unordered_map code_map; -}; - -CodeStatus& CodeStatus::Instance() { - static CodeStatus _instance; - return _instance; -} - -int CodeStatus::is_code_without_graph(PyCodeObject* code) { - CodeInfo* code_info; - if (code_map.find(code) != code_map.end()) { - code_info = code_map[code]; - } else { - code_info = new CodeInfo(); - code_map.emplace(code, code_info); - } - if (code_info->state == WITHOUT_GRAPH) return 1; - if (code_info->state == UNKNOW) { - code_info->counter += 1; - if (code_info->counter >= 10) code_info->state = WITHOUT_GRAPH; - } - return 0; -} - -void CodeStatus::set_with_graph(PyCodeObject* code) { - CodeInfo* code_info; - if (code_map.find(code) != code_map.end()) { - code_info = code_map[code]; - code_info->state = WITH_GRAPH; - } -} - -void CodeStatus::add_with_graph_code(PyCodeObject* code) { - CodeInfo* code_info; - if (code_map.find(code) != code_map.end()) { - code_info = code_map[code]; - code_info->state = WITH_GRAPH; - } else { - code_info = new CodeInfo(); - code_info->state = WITH_GRAPH; - code_map.emplace(code, code_info); - } -} - -void CodeStatus::clear() { - for (auto& iter : code_map) { - delete iter.second; - } - code_map.clear(); -} - /*========================== interfaces ===============================*/ int need_skip(FrameObject* frame) { @@ -245,29 +173,7 @@ int need_skip(FrameObject* frame) { return result; } -int is_code_without_graph(PyCodeObject* code) { - auto& code_status = CodeStatus::Instance(); - return code_status.is_code_without_graph(code); -} - /*========================== pybind ===============================*/ -PyObject* set_with_graph(PyObject* code) { - auto& code_status = CodeStatus::Instance(); - code_status.set_with_graph((PyCodeObject*)code); // NOLINT - return Py_None; -} - -PyObject* setup_codes_with_graph(PyObject* code_tuple) { - auto& code_status = CodeStatus::Instance(); - Py_ssize_t size = PyTuple_GET_SIZE(code_tuple); - for (Py_ssize_t i = 0; i < size; i++) { - PyCodeObject* code = - (PyCodeObject*)PyTuple_GetItem(code_tuple, i); // NOLINT - code_status.add_with_graph_code(code); - } - return Py_None; -} - PyObject* no_skip_codes(PyObject* code_tuple) { auto& skip_info = SkipCodeInfo::Instance(); Py_ssize_t size = PyTuple_GET_SIZE(code_tuple); diff --git a/paddle/fluid/pybind/sot/eval_frame_tools.h b/paddle/fluid/pybind/sot/skip_files.h similarity index 87% rename from paddle/fluid/pybind/sot/eval_frame_tools.h rename to paddle/fluid/pybind/sot/skip_files.h index 417a4a5ed89777..8a8e0001e8d5b9 100644 --- a/paddle/fluid/pybind/sot/eval_frame_tools.h +++ b/paddle/fluid/pybind/sot/skip_files.h @@ -25,10 +25,6 @@ extern "C" { #if SOT_IS_SUPPORTED int need_skip(FrameObject* frame); -int is_code_without_graph(PyCodeObject* code); - -PyObject* set_with_graph(PyObject* code); -PyObject* setup_codes_with_graph(PyObject* code_tuple); PyObject* no_skip_codes(PyObject* code_tuple); PyObject* skip_file_prefix(PyObject* filepath_tuple); diff --git a/python/paddle/jit/sot/opcode_translator/skip_files.py b/python/paddle/jit/sot/opcode_translator/skip_files.py index bf84a4c32c5acf..75bb6000e10ca4 100644 --- a/python/paddle/jit/sot/opcode_translator/skip_files.py +++ b/python/paddle/jit/sot/opcode_translator/skip_files.py @@ -134,13 +134,7 @@ def _module_dir(m: types.ModuleType): no_skip_code = {paddle.nn.Sequential.forward.__code__} -with_graph_codes = ( - paddle.nn.Layer.__call__.__code__, - paddle.nn.Layer._dygraph_call_func.__code__, -) - def setup_skip_files(): paddle.framework.core.eval_frame_skip_file_prefix(tuple(skip_file_names)) paddle.framework.core.eval_frame_no_skip_codes(tuple(no_skip_code)) - paddle.framework.core.sot_setup_codes_with_graph(with_graph_codes) diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index d8b5c3d00816fb..09d86e54b81ad7 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -31,7 +31,6 @@ InfoCollector, NewSymbolHitRateInfo, Singleton, - StepInfoManager, SubGraphRelationInfo, log, log_do, @@ -219,9 +218,6 @@ def collect_subgraph_relation(self, inputs, outputs, partial_program_layer): def __call__(self, *args, **kwargs): with EventGuard(f"FallbackWrapper: {self.SIR.name}"): - if StepInfoManager().need_back_trace: - trace_back_frames() - log_do( 2, lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR), diff --git a/python/paddle/jit/sot/translate.py b/python/paddle/jit/sot/translate.py index 958526346a3b0f..b0f91874f4979e 100644 --- a/python/paddle/jit/sot/translate.py +++ b/python/paddle/jit/sot/translate.py @@ -25,8 +25,6 @@ from .utils import ( GraphLogger, InfoCollector, - StepInfoManager, - StepState, log_do, ) @@ -96,42 +94,23 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]: def callback(frame): return eval_frame_callback(frame, **kwargs) - def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: - assert hasattr( - fn, "__code__" - ), "Target function doesn't have code for simulating." - StepInfoManager().sot_step() - GraphLogger().clear() - InfoCollector().clear_step_info() - paddle.framework.core.set_eval_frame(callback) - try: - outs = fn(*args, **kwargs) - except Exception as e: - raise e - finally: - paddle.framework.core.set_eval_frame(None) - - log_do(1, lambda: GraphLogger().print_info()) - InfoCollector().print_step_report() - return outs - - def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R: - outs = fn(*args, **kwargs) - return outs - def impl(*args: P.args, **kwargs: P.kwargs) -> R: - with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard(): - state = StepInfoManager().current_state - - if state == StepState.RUN_SOT: - return impl_sot(*args, **kwargs) - elif state == StepState.RUN_DYN: - return impl_dynamic(*args, **kwargs) - elif state == StepState.COLLECT_INFO: - return StepInfoManager().collect_info( - impl_dynamic, impl_sot, *args, **kwargs - ) - else: - raise RuntimeError("Unknown state.") + with SotStepProfilerGuard(): + assert hasattr( + fn, "__code__" + ), "Target function doesn't have code for simulating." + GraphLogger().clear() + InfoCollector().clear_step_info() + paddle.framework.core.set_eval_frame(callback) + try: + outs = fn(*args, **kwargs) + except Exception as e: + raise e + finally: + paddle.framework.core.set_eval_frame(None) + + log_do(1, lambda: GraphLogger().print_info()) + InfoCollector().print_step_report() + return outs return impl diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 4d75c8e2574e4d..10127032ab3370 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -63,7 +63,6 @@ ResumeFnNameFactory, Singleton, SotUndefinedVar, - StepInfoManager, StepState, count_if, current_symbol_registry, diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 6865427d61c692..4ccf598dc773f8 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -22,17 +22,13 @@ import weakref from collections import OrderedDict from contextlib import contextmanager -from enum import Enum from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import WeakValueDictionary -import numpy as np - import paddle from paddle.utils import flatten, map_structure from .envs import ( - ENV_COST_MODEL, ENV_SOT_LOG_LEVEL, ENV_STRICT_MODE, ) @@ -404,134 +400,6 @@ def printable(obj): return False -class StepState(Enum): - COLLECT_INFO = 1 - RUN_SOT = 2 - RUN_DYN = 3 - - -class StepInfo: - REQUIRED_DYN_INFOS = 10 - REQUIRED_SOT_INFOS = 10 - - USED_DYN_INFOS = 5 - - COLLECT_INFO_MAX_STEP = 50 - CV_BOUNDARY = 0.1 - - BACK_TRACE_STEPS = 20 - - def __init__(self): - self.step_count = -1 - self.state = ( - StepState.COLLECT_INFO - if ENV_COST_MODEL.get() - else StepState.RUN_SOT - ) - self.dyn_time_costs = [] - self.avg_dyn_time = 0 - self.sot_time_costs = [] - self.sot_step = -1 - - def add_dynamic_time_info(self, time_cost): - self.dyn_time_costs.append(time_cost) - if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS: - self.avg_dyn_time = np.mean( - self.dyn_time_costs[-self.USED_DYN_INFOS :] - ) - - def add_sot_time_info(self, time_cost, current_code): - self.sot_time_costs.append(time_cost) - if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS: - avg_sot_time = np.mean(self.sot_time_costs) - log( - 1, - f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n", - ) - if avg_sot_time < self.avg_dyn_time: - log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n") - self.state = StepState.RUN_SOT - elif ( - self.step_count > self.COLLECT_INFO_MAX_STEP - or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY - ): - log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n") - self.state = StepState.RUN_DYN - else: - log(1, f"[Cost Model] Decision delayed: {current_code}\n") - self.sot_time_costs.clear() - - def need_back_trace(self): - return self.step_count < self.BACK_TRACE_STEPS - - def need_dynamic_info(self): - return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS - - -class StepInfoManager(metaclass=Singleton): - def __init__(self): - self.step_record = {} - self.current_code = None - self.current_step_info = None - - @contextmanager - def step_guard(self, code): - try: - old_code = self.current_code - old_info = self.current_step_info - - self.current_code = code - if code not in self.step_record: - self.step_record[code] = StepInfo() - self.current_step_info = self.step_record[code] - - self.current_step_info.step_count += 1 - - log( - 2, - f"[Cost Model] New step start, current state is {self.current_state}\n", - ) - yield - finally: - self.current_code = old_code - self.current_step_info = old_info - - def sot_step(self): - self.current_step_info.sot_step += 1 - - def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs): - if self.current_step_info.need_dynamic_info(): - start_time = time.perf_counter() - outs = impl_dynamic(*args, **kwargs) - time_cost = time.perf_counter() - start_time - self.current_step_info.add_dynamic_time_info(time_cost) - else: - start_time = time.perf_counter() - outs = impl_sot(*args, **kwargs) - time_cost = time.perf_counter() - start_time - self.current_step_info.add_sot_time_info( - time_cost, self.current_code - ) - return outs - - @property - def need_back_trace(self): - return self.current_step_info.need_back_trace() - - @property - def current_step(self): - return self.current_step_info.step_count - - @property - def current_state(self): - return self.current_step_info.state - - def clear(self): - self.step_record.clear() - self.current_code = None - self.current_step = -1 - - def get_api_fullname(api): api_name = api.__name__ module_str = api.__module__ diff --git a/test/sot/test_sot_cost_model.py b/test/sot/test_sot_cost_model.py deleted file mode 100644 index eed690a1e77815..00000000000000 --- a/test/sot/test_sot_cost_model.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import unittest - -from test_case_base import TestCaseBase - -import paddle -from paddle.jit.sot import psdb, symbolic_translate -from paddle.jit.sot.utils import StepInfoManager, StepState, cost_model_guard - - -def dyn_fast(x, net, iter_): - for i in iter_: - x = net(x) - return x - - -def sot_fast_with_single_graph(x, net): - if not psdb.in_sot(): - time.sleep(0.1) - return x + 1 - - -def sot_fast_with_multi_graph(x, net): - if not psdb.in_sot(): - time.sleep(0.1) - x = x + 1 - psdb.breakgraph() - x = x + 2 - return x - - -class Net(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.linear = paddle.nn.Linear(10, 10) - - def forward(self, x): - if not psdb.in_sot(): - time.sleep(0.1) - x = x / 3 - x = x + 5 - x = self.linear(x) - return x - - -class TestCostModel(TestCaseBase): - @cost_model_guard(True) - def test_dyn_fast(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - sot_fn = symbolic_translate(dyn_fast) - for i in range(60): - sot_fn(x, net, iter(range(10))) - - state = StepInfoManager().step_record[dyn_fast.__code__].state - assert state == StepState.RUN_DYN - - @cost_model_guard(True) - def test_sot_fast_with_multi_graph(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - sot_fn = symbolic_translate(sot_fast_with_multi_graph) - for i in range(30): - sot_fn(x, net) - - state = ( - StepInfoManager() - .step_record[sot_fast_with_multi_graph.__code__] - .state - ) - assert state == StepState.RUN_SOT - - @cost_model_guard(True) - def test_sot_fast_with_single_graph(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - for i in range(30): - symbolic_translate(sot_fast_with_single_graph)(x, net) - - state = ( - StepInfoManager() - .step_record[sot_fast_with_single_graph.__code__] - .state - ) - assert state == StepState.RUN_SOT - - @cost_model_guard(True) - def test_net(self): - x = paddle.rand([10]) - net = Net() - net = paddle.jit.to_static(net, full_graph=False) - for i in range(30): - x = net(x) - - state = StepInfoManager().step_record[Net.forward.__code__].state - assert state == StepState.RUN_SOT - - -if __name__ == "__main__": - unittest.main() From 22d06e3070f539d7817391c5a22ee44bdf034a96 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 26 Feb 2025 10:24:20 +0800 Subject: [PATCH 03/15] cleanup env cost model --- python/paddle/jit/sot/utils/__init__.py | 2 -- python/paddle/jit/sot/utils/envs.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 10127032ab3370..bcb02b0c8b7be1 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -14,7 +14,6 @@ from .call_ast_utils import get_static_function, try_ast_func # noqa: F401 from .envs import ( # noqa: F401 - ENV_COST_MODEL, ENV_MIN_GRAPH_SIZE, ENV_SOT_ALLOW_DYNAMIC_SHAPE, ENV_SOT_ENABLE_FASTER_GUARD, @@ -25,7 +24,6 @@ ENV_SOT_WITH_CONTROL_FLOW, ENV_STRICT_MODE, allow_dynamic_shape_guard, - cost_model_guard, export_guard, faster_guard_guard, guard_tree_guard, diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index 94c9fd27586eb0..0af7e46fd86c9d 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -25,7 +25,6 @@ StringListEnvironmentVariable, ) -ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False) ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10) ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0) ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) @@ -59,12 +58,6 @@ ) -@contextmanager -def cost_model_guard(value: bool): - with EnvironmentVariableGuard(ENV_COST_MODEL, value): - yield - - @contextmanager def strict_mode_guard(value: bool): with EnvironmentVariableGuard(ENV_STRICT_MODE, value): From 12eaf3e75e8677171be3bae713a89d51d0ae9258 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 26 Feb 2025 11:06:27 +0800 Subject: [PATCH 04/15] cleanup StepState --- python/paddle/jit/sot/utils/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index bcb02b0c8b7be1..b1dacc769bb0ca 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -61,7 +61,6 @@ ResumeFnNameFactory, Singleton, SotUndefinedVar, - StepState, count_if, current_symbol_registry, execute_time, From 31f87326011b3b2ce9f229c777f2e41ae0178c22 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 26 Feb 2025 14:03:38 +0800 Subject: [PATCH 05/15] cleanup no_eval_frame --- python/paddle/jit/sot/utils/utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 4ccf598dc773f8..1a2cb3092dcb00 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -145,20 +145,6 @@ def log_enabled(level): return level <= ENV_SOT_LOG_LEVEL.get_with_cache() -def no_eval_frame(func): - def no_eval_frame_func(*args, **kwargs): - old_cb = paddle.framework.core.set_eval_frame(None) - try: - retval = func(*args, **kwargs) - except: - raise - finally: - paddle.framework.core.set_eval_frame(old_cb) - return retval - - return no_eval_frame_func - - def is_comprehensive_name(name): return name in ["", "", "", ""] From 5fd05eb7c0a353b35688629f1abcaa572d92fc52 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 26 Feb 2025 14:10:40 +0800 Subject: [PATCH 06/15] refine name COMPARE_OP_NAME_TO_FN --- .../jit/sot/opcode_translator/executor/opcode_executor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index b9987250261e55..7db892f243ac5c 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -113,7 +113,7 @@ if TYPE_CHECKING: from .function_graph import CompileGraphResult -SUPPORT_COMPARE_OP = { +COMPARE_OP_NAME_TO_FN = { ">": operator.gt, "<": operator.lt, ">=": operator.ge, @@ -1358,7 +1358,7 @@ def COMPARE_OP(self, instr: Instruction): right, left = self.stack.pop(), self.stack.pop() self.stack.push( BuiltinVariable( - SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker() )(left, right) ) @@ -1377,7 +1377,7 @@ def IS_OP(self, instr: Instruction): op = "is" if instr.arg == 0 else "is not" self.stack.push( BuiltinVariable( - SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker() )(left, right) ) @@ -1594,7 +1594,7 @@ def CONTAINS_OP(self, instr: Instruction): op = "in" if instr.arg == 0 else "not in" self.stack.push( BuiltinVariable( - SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker() )(left, right) ) From 3d0d9aa3da6cc7ec44ebd84ac0ef3305a99708f2 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 26 Feb 2025 15:23:16 +0800 Subject: [PATCH 07/15] Revert "cleanup no_eval_frame" This reverts commit 31f87326011b3b2ce9f229c777f2e41ae0178c22. --- python/paddle/jit/sot/utils/utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 1a2cb3092dcb00..4ccf598dc773f8 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -145,6 +145,20 @@ def log_enabled(level): return level <= ENV_SOT_LOG_LEVEL.get_with_cache() +def no_eval_frame(func): + def no_eval_frame_func(*args, **kwargs): + old_cb = paddle.framework.core.set_eval_frame(None) + try: + retval = func(*args, **kwargs) + except: + raise + finally: + paddle.framework.core.set_eval_frame(old_cb) + return retval + + return no_eval_frame_func + + def is_comprehensive_name(name): return name in ["", "", "", ""] From 60f0da0ae9c9b6b6ef7785446bb4bc0304f6bd1f Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 14:18:09 +0800 Subject: [PATCH 08/15] Revert "cleanup StepState" This reverts commit 12eaf3e75e8677171be3bae713a89d51d0ae9258. --- python/paddle/jit/sot/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index b1dacc769bb0ca..bcb02b0c8b7be1 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -61,6 +61,7 @@ ResumeFnNameFactory, Singleton, SotUndefinedVar, + StepState, count_if, current_symbol_registry, execute_time, From fd8516eaf9b49b69fe99ba5076514359a693dd53 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 14:18:13 +0800 Subject: [PATCH 09/15] Revert "cleanup env cost model" This reverts commit 22d06e3070f539d7817391c5a22ee44bdf034a96. --- python/paddle/jit/sot/utils/__init__.py | 2 ++ python/paddle/jit/sot/utils/envs.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index bcb02b0c8b7be1..10127032ab3370 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -14,6 +14,7 @@ from .call_ast_utils import get_static_function, try_ast_func # noqa: F401 from .envs import ( # noqa: F401 + ENV_COST_MODEL, ENV_MIN_GRAPH_SIZE, ENV_SOT_ALLOW_DYNAMIC_SHAPE, ENV_SOT_ENABLE_FASTER_GUARD, @@ -24,6 +25,7 @@ ENV_SOT_WITH_CONTROL_FLOW, ENV_STRICT_MODE, allow_dynamic_shape_guard, + cost_model_guard, export_guard, faster_guard_guard, guard_tree_guard, diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index 0af7e46fd86c9d..94c9fd27586eb0 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -25,6 +25,7 @@ StringListEnvironmentVariable, ) +ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False) ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10) ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0) ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) @@ -58,6 +59,12 @@ ) +@contextmanager +def cost_model_guard(value: bool): + with EnvironmentVariableGuard(ENV_COST_MODEL, value): + yield + + @contextmanager def strict_mode_guard(value: bool): with EnvironmentVariableGuard(ENV_STRICT_MODE, value): From 61eed67466be9c7b8087109db83f78e5f75f77c9 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 14:18:17 +0800 Subject: [PATCH 10/15] Revert "cleanup cost model" This reverts commit 4af86623b13667bf06bbb2b248eda4bca324d5a6. --- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/jit.cc | 20 ++- paddle/fluid/pybind/sot/eval_frame.c | 13 +- .../{skip_files.cc => eval_frame_tools.cc} | 96 ++++++++++++- .../sot/{skip_files.h => eval_frame_tools.h} | 4 + .../jit/sot/opcode_translator/skip_files.py | 6 + .../paddle/jit/sot/symbolic/compile_cache.py | 4 + python/paddle/jit/sot/translate.py | 55 +++++--- python/paddle/jit/sot/utils/__init__.py | 1 + python/paddle/jit/sot/utils/utils.py | 132 ++++++++++++++++++ test/sot/test_sot_cost_model.py | 114 +++++++++++++++ 11 files changed, 426 insertions(+), 21 deletions(-) rename paddle/fluid/pybind/sot/{skip_files.cc => eval_frame_tools.cc} (66%) rename paddle/fluid/pybind/sot/{skip_files.h => eval_frame_tools.h} (87%) create mode 100644 test/sot/test_sot_cost_model.py diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index cacc145eba54c6..efd74ede2943e8 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -140,7 +140,7 @@ set(PYBIND_SRCS xpu_streams_py.cc jit.cc auto_parallel_py.cc - sot/skip_files.cc + sot/eval_frame_tools.cc sot/cpython_internals.c sot/frame_proxy.c sot/eval_frame.c diff --git a/paddle/fluid/pybind/jit.cc b/paddle/fluid/pybind/jit.cc index 2721931db0f4d4..6db6d45b136ed9 100644 --- a/paddle/fluid/pybind/jit.cc +++ b/paddle/fluid/pybind/jit.cc @@ -21,10 +21,10 @@ limitations under the License. */ #include "paddle/fluid/jit/layer.h" #include "paddle/fluid/jit/serializer.h" #include "paddle/fluid/pybind/sot/eval_frame.h" +#include "paddle/fluid/pybind/sot/eval_frame_tools.h" #include "paddle/fluid/pybind/sot/frame_proxy.h" #include "paddle/fluid/pybind/sot/guards.h" #include "paddle/fluid/pybind/sot/macros.h" -#include "paddle/fluid/pybind/sot/skip_files.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/utils/pybind.h" @@ -150,6 +150,24 @@ void BindSot(pybind11::module *m) { return type->tp_getattro != PyObject_GenericGetAttr; }); + m->def( + "sot_setup_codes_with_graph", + [](const py::object &py_codes) { + auto ret = setup_codes_with_graph(py_codes.ptr()); + auto obj = py::reinterpret_borrow(ret); + return obj; + }, + py::arg("py_codes")); + + m->def( + "sot_set_with_graph", + [](const py::object &py_codes) { + auto ret = set_with_graph(py_codes.ptr()); + auto obj = py::reinterpret_borrow(ret); + return obj; + }, + py::arg("py_codes")); + m->def( "eval_frame_no_skip_codes", [](const py::object &py_codes) { diff --git a/paddle/fluid/pybind/sot/eval_frame.c b/paddle/fluid/pybind/sot/eval_frame.c index cc2d2534459ba8..073d0d3780d429 100644 --- a/paddle/fluid/pybind/sot/eval_frame.c +++ b/paddle/fluid/pybind/sot/eval_frame.c @@ -17,8 +17,8 @@ limitations under the License. */ #if SOT_IS_SUPPORTED #include "paddle/fluid/pybind/sot/cpython_internals.h" +#include "paddle/fluid/pybind/sot/eval_frame_tools.h" #include "paddle/fluid/pybind/sot/frame_proxy.h" -#include "paddle/fluid/pybind/sot/skip_files.h" #include @@ -329,6 +329,17 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, Py_DECREF(f_locals); #endif + // code status + if (is_code_without_graph(code == Py_None ? PyFrame_GET_CODE(frame) + : (PyCodeObject *)code) && + disable_eval_frame == Py_False) { + out = eval_frame_default(tstate, frame, throw_flag); + eval_frame_callback_set(callback); + Py_DECREF(code); + Py_DECREF(disable_eval_frame); + return out; + } + // run code if (disable_eval_frame != Py_True) { // Re-enable custom behavior diff --git a/paddle/fluid/pybind/sot/skip_files.cc b/paddle/fluid/pybind/sot/eval_frame_tools.cc similarity index 66% rename from paddle/fluid/pybind/sot/skip_files.cc rename to paddle/fluid/pybind/sot/eval_frame_tools.cc index 5b8007ce98eb4c..35b1a507e3a9a9 100644 --- a/paddle/fluid/pybind/sot/skip_files.cc +++ b/paddle/fluid/pybind/sot/eval_frame_tools.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pybind/sot/skip_files.h" +#include "paddle/fluid/pybind/sot/eval_frame_tools.h" #include @@ -20,6 +20,7 @@ #include "paddle/common/errors.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/platform/profiler/event_tracing.h" #if SOT_IS_SUPPORTED #define END_OF_STRING '\0' @@ -141,6 +142,77 @@ int SkipCodeInfo::in_skip_path(PyObject* filename) { return root->check_filename(name); } +/*========================== code status ==============================*/ +enum CodeState { UNKNOW, WITH_GRAPH, WITHOUT_GRAPH }; + +class CodeInfo { + public: + CodeState state; + int counter; +}; + +class CodeStatus { + public: + static CodeStatus& Instance(); + int is_code_without_graph(PyCodeObject* code); + void set_with_graph(PyCodeObject* code); + void add_with_graph_code(PyCodeObject* code); + void clear(); + + private: + CodeStatus() { code_map = std::unordered_map(); } + ~CodeStatus() { clear(); } + std::unordered_map code_map; +}; + +CodeStatus& CodeStatus::Instance() { + static CodeStatus _instance; + return _instance; +} + +int CodeStatus::is_code_without_graph(PyCodeObject* code) { + CodeInfo* code_info; + if (code_map.find(code) != code_map.end()) { + code_info = code_map[code]; + } else { + code_info = new CodeInfo(); + code_map.emplace(code, code_info); + } + if (code_info->state == WITHOUT_GRAPH) return 1; + if (code_info->state == UNKNOW) { + code_info->counter += 1; + if (code_info->counter >= 10) code_info->state = WITHOUT_GRAPH; + } + return 0; +} + +void CodeStatus::set_with_graph(PyCodeObject* code) { + CodeInfo* code_info; + if (code_map.find(code) != code_map.end()) { + code_info = code_map[code]; + code_info->state = WITH_GRAPH; + } +} + +void CodeStatus::add_with_graph_code(PyCodeObject* code) { + CodeInfo* code_info; + if (code_map.find(code) != code_map.end()) { + code_info = code_map[code]; + code_info->state = WITH_GRAPH; + } else { + code_info = new CodeInfo(); + code_info->state = WITH_GRAPH; + code_map.emplace(code, code_info); + } +} + +void CodeStatus::clear() { + for (auto& iter : code_map) { + delete iter.second; + } + code_map.clear(); +} + /*========================== interfaces ===============================*/ int need_skip(FrameObject* frame) { @@ -173,7 +245,29 @@ int need_skip(FrameObject* frame) { return result; } +int is_code_without_graph(PyCodeObject* code) { + auto& code_status = CodeStatus::Instance(); + return code_status.is_code_without_graph(code); +} + /*========================== pybind ===============================*/ +PyObject* set_with_graph(PyObject* code) { + auto& code_status = CodeStatus::Instance(); + code_status.set_with_graph((PyCodeObject*)code); // NOLINT + return Py_None; +} + +PyObject* setup_codes_with_graph(PyObject* code_tuple) { + auto& code_status = CodeStatus::Instance(); + Py_ssize_t size = PyTuple_GET_SIZE(code_tuple); + for (Py_ssize_t i = 0; i < size; i++) { + PyCodeObject* code = + (PyCodeObject*)PyTuple_GetItem(code_tuple, i); // NOLINT + code_status.add_with_graph_code(code); + } + return Py_None; +} + PyObject* no_skip_codes(PyObject* code_tuple) { auto& skip_info = SkipCodeInfo::Instance(); Py_ssize_t size = PyTuple_GET_SIZE(code_tuple); diff --git a/paddle/fluid/pybind/sot/skip_files.h b/paddle/fluid/pybind/sot/eval_frame_tools.h similarity index 87% rename from paddle/fluid/pybind/sot/skip_files.h rename to paddle/fluid/pybind/sot/eval_frame_tools.h index 8a8e0001e8d5b9..417a4a5ed89777 100644 --- a/paddle/fluid/pybind/sot/skip_files.h +++ b/paddle/fluid/pybind/sot/eval_frame_tools.h @@ -25,6 +25,10 @@ extern "C" { #if SOT_IS_SUPPORTED int need_skip(FrameObject* frame); +int is_code_without_graph(PyCodeObject* code); + +PyObject* set_with_graph(PyObject* code); +PyObject* setup_codes_with_graph(PyObject* code_tuple); PyObject* no_skip_codes(PyObject* code_tuple); PyObject* skip_file_prefix(PyObject* filepath_tuple); diff --git a/python/paddle/jit/sot/opcode_translator/skip_files.py b/python/paddle/jit/sot/opcode_translator/skip_files.py index 75bb6000e10ca4..bf84a4c32c5acf 100644 --- a/python/paddle/jit/sot/opcode_translator/skip_files.py +++ b/python/paddle/jit/sot/opcode_translator/skip_files.py @@ -134,7 +134,13 @@ def _module_dir(m: types.ModuleType): no_skip_code = {paddle.nn.Sequential.forward.__code__} +with_graph_codes = ( + paddle.nn.Layer.__call__.__code__, + paddle.nn.Layer._dygraph_call_func.__code__, +) + def setup_skip_files(): paddle.framework.core.eval_frame_skip_file_prefix(tuple(skip_file_names)) paddle.framework.core.eval_frame_no_skip_codes(tuple(no_skip_code)) + paddle.framework.core.sot_setup_codes_with_graph(with_graph_codes) diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index 09d86e54b81ad7..d8b5c3d00816fb 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -31,6 +31,7 @@ InfoCollector, NewSymbolHitRateInfo, Singleton, + StepInfoManager, SubGraphRelationInfo, log, log_do, @@ -218,6 +219,9 @@ def collect_subgraph_relation(self, inputs, outputs, partial_program_layer): def __call__(self, *args, **kwargs): with EventGuard(f"FallbackWrapper: {self.SIR.name}"): + if StepInfoManager().need_back_trace: + trace_back_frames() + log_do( 2, lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR), diff --git a/python/paddle/jit/sot/translate.py b/python/paddle/jit/sot/translate.py index b0f91874f4979e..958526346a3b0f 100644 --- a/python/paddle/jit/sot/translate.py +++ b/python/paddle/jit/sot/translate.py @@ -25,6 +25,8 @@ from .utils import ( GraphLogger, InfoCollector, + StepInfoManager, + StepState, log_do, ) @@ -94,23 +96,42 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]: def callback(frame): return eval_frame_callback(frame, **kwargs) + def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: + assert hasattr( + fn, "__code__" + ), "Target function doesn't have code for simulating." + StepInfoManager().sot_step() + GraphLogger().clear() + InfoCollector().clear_step_info() + paddle.framework.core.set_eval_frame(callback) + try: + outs = fn(*args, **kwargs) + except Exception as e: + raise e + finally: + paddle.framework.core.set_eval_frame(None) + + log_do(1, lambda: GraphLogger().print_info()) + InfoCollector().print_step_report() + return outs + + def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R: + outs = fn(*args, **kwargs) + return outs + def impl(*args: P.args, **kwargs: P.kwargs) -> R: - with SotStepProfilerGuard(): - assert hasattr( - fn, "__code__" - ), "Target function doesn't have code for simulating." - GraphLogger().clear() - InfoCollector().clear_step_info() - paddle.framework.core.set_eval_frame(callback) - try: - outs = fn(*args, **kwargs) - except Exception as e: - raise e - finally: - paddle.framework.core.set_eval_frame(None) - - log_do(1, lambda: GraphLogger().print_info()) - InfoCollector().print_step_report() - return outs + with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard(): + state = StepInfoManager().current_state + + if state == StepState.RUN_SOT: + return impl_sot(*args, **kwargs) + elif state == StepState.RUN_DYN: + return impl_dynamic(*args, **kwargs) + elif state == StepState.COLLECT_INFO: + return StepInfoManager().collect_info( + impl_dynamic, impl_sot, *args, **kwargs + ) + else: + raise RuntimeError("Unknown state.") return impl diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 10127032ab3370..4d75c8e2574e4d 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -63,6 +63,7 @@ ResumeFnNameFactory, Singleton, SotUndefinedVar, + StepInfoManager, StepState, count_if, current_symbol_registry, diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 4ccf598dc773f8..6865427d61c692 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -22,13 +22,17 @@ import weakref from collections import OrderedDict from contextlib import contextmanager +from enum import Enum from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import WeakValueDictionary +import numpy as np + import paddle from paddle.utils import flatten, map_structure from .envs import ( + ENV_COST_MODEL, ENV_SOT_LOG_LEVEL, ENV_STRICT_MODE, ) @@ -400,6 +404,134 @@ def printable(obj): return False +class StepState(Enum): + COLLECT_INFO = 1 + RUN_SOT = 2 + RUN_DYN = 3 + + +class StepInfo: + REQUIRED_DYN_INFOS = 10 + REQUIRED_SOT_INFOS = 10 + + USED_DYN_INFOS = 5 + + COLLECT_INFO_MAX_STEP = 50 + CV_BOUNDARY = 0.1 + + BACK_TRACE_STEPS = 20 + + def __init__(self): + self.step_count = -1 + self.state = ( + StepState.COLLECT_INFO + if ENV_COST_MODEL.get() + else StepState.RUN_SOT + ) + self.dyn_time_costs = [] + self.avg_dyn_time = 0 + self.sot_time_costs = [] + self.sot_step = -1 + + def add_dynamic_time_info(self, time_cost): + self.dyn_time_costs.append(time_cost) + if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS: + self.avg_dyn_time = np.mean( + self.dyn_time_costs[-self.USED_DYN_INFOS :] + ) + + def add_sot_time_info(self, time_cost, current_code): + self.sot_time_costs.append(time_cost) + if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS: + avg_sot_time = np.mean(self.sot_time_costs) + log( + 1, + f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n", + ) + if avg_sot_time < self.avg_dyn_time: + log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n") + self.state = StepState.RUN_SOT + elif ( + self.step_count > self.COLLECT_INFO_MAX_STEP + or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY + ): + log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n") + self.state = StepState.RUN_DYN + else: + log(1, f"[Cost Model] Decision delayed: {current_code}\n") + self.sot_time_costs.clear() + + def need_back_trace(self): + return self.step_count < self.BACK_TRACE_STEPS + + def need_dynamic_info(self): + return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS + + +class StepInfoManager(metaclass=Singleton): + def __init__(self): + self.step_record = {} + self.current_code = None + self.current_step_info = None + + @contextmanager + def step_guard(self, code): + try: + old_code = self.current_code + old_info = self.current_step_info + + self.current_code = code + if code not in self.step_record: + self.step_record[code] = StepInfo() + self.current_step_info = self.step_record[code] + + self.current_step_info.step_count += 1 + + log( + 2, + f"[Cost Model] New step start, current state is {self.current_state}\n", + ) + yield + finally: + self.current_code = old_code + self.current_step_info = old_info + + def sot_step(self): + self.current_step_info.sot_step += 1 + + def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs): + if self.current_step_info.need_dynamic_info(): + start_time = time.perf_counter() + outs = impl_dynamic(*args, **kwargs) + time_cost = time.perf_counter() - start_time + self.current_step_info.add_dynamic_time_info(time_cost) + else: + start_time = time.perf_counter() + outs = impl_sot(*args, **kwargs) + time_cost = time.perf_counter() - start_time + self.current_step_info.add_sot_time_info( + time_cost, self.current_code + ) + return outs + + @property + def need_back_trace(self): + return self.current_step_info.need_back_trace() + + @property + def current_step(self): + return self.current_step_info.step_count + + @property + def current_state(self): + return self.current_step_info.state + + def clear(self): + self.step_record.clear() + self.current_code = None + self.current_step = -1 + + def get_api_fullname(api): api_name = api.__name__ module_str = api.__module__ diff --git a/test/sot/test_sot_cost_model.py b/test/sot/test_sot_cost_model.py new file mode 100644 index 00000000000000..eed690a1e77815 --- /dev/null +++ b/test/sot/test_sot_cost_model.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot import psdb, symbolic_translate +from paddle.jit.sot.utils import StepInfoManager, StepState, cost_model_guard + + +def dyn_fast(x, net, iter_): + for i in iter_: + x = net(x) + return x + + +def sot_fast_with_single_graph(x, net): + if not psdb.in_sot(): + time.sleep(0.1) + return x + 1 + + +def sot_fast_with_multi_graph(x, net): + if not psdb.in_sot(): + time.sleep(0.1) + x = x + 1 + psdb.breakgraph() + x = x + 2 + return x + + +class Net(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(10, 10) + + def forward(self, x): + if not psdb.in_sot(): + time.sleep(0.1) + x = x / 3 + x = x + 5 + x = self.linear(x) + return x + + +class TestCostModel(TestCaseBase): + @cost_model_guard(True) + def test_dyn_fast(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + sot_fn = symbolic_translate(dyn_fast) + for i in range(60): + sot_fn(x, net, iter(range(10))) + + state = StepInfoManager().step_record[dyn_fast.__code__].state + assert state == StepState.RUN_DYN + + @cost_model_guard(True) + def test_sot_fast_with_multi_graph(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + sot_fn = symbolic_translate(sot_fast_with_multi_graph) + for i in range(30): + sot_fn(x, net) + + state = ( + StepInfoManager() + .step_record[sot_fast_with_multi_graph.__code__] + .state + ) + assert state == StepState.RUN_SOT + + @cost_model_guard(True) + def test_sot_fast_with_single_graph(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + for i in range(30): + symbolic_translate(sot_fast_with_single_graph)(x, net) + + state = ( + StepInfoManager() + .step_record[sot_fast_with_single_graph.__code__] + .state + ) + assert state == StepState.RUN_SOT + + @cost_model_guard(True) + def test_net(self): + x = paddle.rand([10]) + net = Net() + net = paddle.jit.to_static(net, full_graph=False) + for i in range(30): + x = net(x) + + state = StepInfoManager().step_record[Net.forward.__code__].state + assert state == StepState.RUN_SOT + + +if __name__ == "__main__": + unittest.main() From 2e34028d63b6159dba464ef71caf268ade178f76 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 14:33:26 +0800 Subject: [PATCH 11/15] cleanup cost model --- python/paddle/jit/sot/translate.py | 52 +++++---------- python/paddle/jit/sot/utils/__init__.py | 1 - python/paddle/jit/sot/utils/envs.py | 7 -- python/paddle/jit/sot/utils/utils.py | 85 ------------------------- 4 files changed, 16 insertions(+), 129 deletions(-) diff --git a/python/paddle/jit/sot/translate.py b/python/paddle/jit/sot/translate.py index 958526346a3b0f..2fdb61b6f3f211 100644 --- a/python/paddle/jit/sot/translate.py +++ b/python/paddle/jit/sot/translate.py @@ -26,7 +26,6 @@ GraphLogger, InfoCollector, StepInfoManager, - StepState, log_do, ) @@ -96,42 +95,23 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]: def callback(frame): return eval_frame_callback(frame, **kwargs) - def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: - assert hasattr( - fn, "__code__" - ), "Target function doesn't have code for simulating." - StepInfoManager().sot_step() - GraphLogger().clear() - InfoCollector().clear_step_info() - paddle.framework.core.set_eval_frame(callback) - try: - outs = fn(*args, **kwargs) - except Exception as e: - raise e - finally: - paddle.framework.core.set_eval_frame(None) - - log_do(1, lambda: GraphLogger().print_info()) - InfoCollector().print_step_report() - return outs - - def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R: - outs = fn(*args, **kwargs) - return outs - def impl(*args: P.args, **kwargs: P.kwargs) -> R: with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard(): - state = StepInfoManager().current_state - - if state == StepState.RUN_SOT: - return impl_sot(*args, **kwargs) - elif state == StepState.RUN_DYN: - return impl_dynamic(*args, **kwargs) - elif state == StepState.COLLECT_INFO: - return StepInfoManager().collect_info( - impl_dynamic, impl_sot, *args, **kwargs - ) - else: - raise RuntimeError("Unknown state.") + assert hasattr( + fn, "__code__" + ), "Target function doesn't have code for simulating." + GraphLogger().clear() + InfoCollector().clear_step_info() + paddle.framework.core.set_eval_frame(callback) + try: + outs = fn(*args, **kwargs) + except Exception as e: + raise e + finally: + paddle.framework.core.set_eval_frame(None) + + log_do(1, lambda: GraphLogger().print_info()) + InfoCollector().print_step_report() + return outs return impl diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 4d75c8e2574e4d..d8c23f1b499395 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -14,7 +14,6 @@ from .call_ast_utils import get_static_function, try_ast_func # noqa: F401 from .envs import ( # noqa: F401 - ENV_COST_MODEL, ENV_MIN_GRAPH_SIZE, ENV_SOT_ALLOW_DYNAMIC_SHAPE, ENV_SOT_ENABLE_FASTER_GUARD, diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index 94c9fd27586eb0..0af7e46fd86c9d 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -25,7 +25,6 @@ StringListEnvironmentVariable, ) -ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False) ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10) ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0) ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) @@ -59,12 +58,6 @@ ) -@contextmanager -def cost_model_guard(value: bool): - with EnvironmentVariableGuard(ENV_COST_MODEL, value): - yield - - @contextmanager def strict_mode_guard(value: bool): with EnvironmentVariableGuard(ENV_STRICT_MODE, value): diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 6865427d61c692..2d12081fc269e1 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -22,17 +22,13 @@ import weakref from collections import OrderedDict from contextlib import contextmanager -from enum import Enum from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import WeakValueDictionary -import numpy as np - import paddle from paddle.utils import flatten, map_structure from .envs import ( - ENV_COST_MODEL, ENV_SOT_LOG_LEVEL, ENV_STRICT_MODE, ) @@ -404,69 +400,15 @@ def printable(obj): return False -class StepState(Enum): - COLLECT_INFO = 1 - RUN_SOT = 2 - RUN_DYN = 3 - - class StepInfo: - REQUIRED_DYN_INFOS = 10 - REQUIRED_SOT_INFOS = 10 - - USED_DYN_INFOS = 5 - - COLLECT_INFO_MAX_STEP = 50 - CV_BOUNDARY = 0.1 - BACK_TRACE_STEPS = 20 def __init__(self): self.step_count = -1 - self.state = ( - StepState.COLLECT_INFO - if ENV_COST_MODEL.get() - else StepState.RUN_SOT - ) - self.dyn_time_costs = [] - self.avg_dyn_time = 0 - self.sot_time_costs = [] - self.sot_step = -1 - - def add_dynamic_time_info(self, time_cost): - self.dyn_time_costs.append(time_cost) - if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS: - self.avg_dyn_time = np.mean( - self.dyn_time_costs[-self.USED_DYN_INFOS :] - ) - - def add_sot_time_info(self, time_cost, current_code): - self.sot_time_costs.append(time_cost) - if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS: - avg_sot_time = np.mean(self.sot_time_costs) - log( - 1, - f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n", - ) - if avg_sot_time < self.avg_dyn_time: - log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n") - self.state = StepState.RUN_SOT - elif ( - self.step_count > self.COLLECT_INFO_MAX_STEP - or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY - ): - log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n") - self.state = StepState.RUN_DYN - else: - log(1, f"[Cost Model] Decision delayed: {current_code}\n") - self.sot_time_costs.clear() def need_back_trace(self): return self.step_count < self.BACK_TRACE_STEPS - def need_dynamic_info(self): - return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS - class StepInfoManager(metaclass=Singleton): def __init__(self): @@ -486,34 +428,11 @@ def step_guard(self, code): self.current_step_info = self.step_record[code] self.current_step_info.step_count += 1 - - log( - 2, - f"[Cost Model] New step start, current state is {self.current_state}\n", - ) yield finally: self.current_code = old_code self.current_step_info = old_info - def sot_step(self): - self.current_step_info.sot_step += 1 - - def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs): - if self.current_step_info.need_dynamic_info(): - start_time = time.perf_counter() - outs = impl_dynamic(*args, **kwargs) - time_cost = time.perf_counter() - start_time - self.current_step_info.add_dynamic_time_info(time_cost) - else: - start_time = time.perf_counter() - outs = impl_sot(*args, **kwargs) - time_cost = time.perf_counter() - start_time - self.current_step_info.add_sot_time_info( - time_cost, self.current_code - ) - return outs - @property def need_back_trace(self): return self.current_step_info.need_back_trace() @@ -522,10 +441,6 @@ def need_back_trace(self): def current_step(self): return self.current_step_info.step_count - @property - def current_state(self): - return self.current_step_info.state - def clear(self): self.step_record.clear() self.current_code = None From a8020d6694d10f141b4c33fc5ed45ff0bb4eaa28 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 16:16:45 +0800 Subject: [PATCH 12/15] remove ut --- test/sot/test_sot_cost_model.py | 114 -------------------------------- 1 file changed, 114 deletions(-) delete mode 100644 test/sot/test_sot_cost_model.py diff --git a/test/sot/test_sot_cost_model.py b/test/sot/test_sot_cost_model.py deleted file mode 100644 index eed690a1e77815..00000000000000 --- a/test/sot/test_sot_cost_model.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import unittest - -from test_case_base import TestCaseBase - -import paddle -from paddle.jit.sot import psdb, symbolic_translate -from paddle.jit.sot.utils import StepInfoManager, StepState, cost_model_guard - - -def dyn_fast(x, net, iter_): - for i in iter_: - x = net(x) - return x - - -def sot_fast_with_single_graph(x, net): - if not psdb.in_sot(): - time.sleep(0.1) - return x + 1 - - -def sot_fast_with_multi_graph(x, net): - if not psdb.in_sot(): - time.sleep(0.1) - x = x + 1 - psdb.breakgraph() - x = x + 2 - return x - - -class Net(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.linear = paddle.nn.Linear(10, 10) - - def forward(self, x): - if not psdb.in_sot(): - time.sleep(0.1) - x = x / 3 - x = x + 5 - x = self.linear(x) - return x - - -class TestCostModel(TestCaseBase): - @cost_model_guard(True) - def test_dyn_fast(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - sot_fn = symbolic_translate(dyn_fast) - for i in range(60): - sot_fn(x, net, iter(range(10))) - - state = StepInfoManager().step_record[dyn_fast.__code__].state - assert state == StepState.RUN_DYN - - @cost_model_guard(True) - def test_sot_fast_with_multi_graph(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - sot_fn = symbolic_translate(sot_fast_with_multi_graph) - for i in range(30): - sot_fn(x, net) - - state = ( - StepInfoManager() - .step_record[sot_fast_with_multi_graph.__code__] - .state - ) - assert state == StepState.RUN_SOT - - @cost_model_guard(True) - def test_sot_fast_with_single_graph(self): - x = paddle.rand([10]) - net = paddle.nn.Linear(10, 10) - for i in range(30): - symbolic_translate(sot_fast_with_single_graph)(x, net) - - state = ( - StepInfoManager() - .step_record[sot_fast_with_single_graph.__code__] - .state - ) - assert state == StepState.RUN_SOT - - @cost_model_guard(True) - def test_net(self): - x = paddle.rand([10]) - net = Net() - net = paddle.jit.to_static(net, full_graph=False) - for i in range(30): - x = net(x) - - state = StepInfoManager().step_record[Net.forward.__code__].state - assert state == StepState.RUN_SOT - - -if __name__ == "__main__": - unittest.main() From 5fd03d4a5200767281c176182a4bc74bced80be9 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 16:32:19 +0800 Subject: [PATCH 13/15] fix implicit conflict --- python/paddle/jit/sot/utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index e061a98ced3f2d..6cbc05582afcc8 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Callable, TypeVar from weakref import WeakValueDictionary +import numpy as np + import paddle from paddle.utils import flatten, map_structure From 6dd27c155281e67f629753a604ed1e291a6a15fc Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 18:30:17 +0800 Subject: [PATCH 14/15] cleanup cost model guard --- python/paddle/jit/sot/utils/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index ba7322255f6030..6ff0c9651e275f 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -24,7 +24,6 @@ ENV_SOT_WITH_CONTROL_FLOW, ENV_STRICT_MODE, allow_dynamic_shape_guard, - cost_model_guard, export_guard, faster_guard_guard, guard_tree_guard, From ecc92fe6975cc1a70f0c6d9e8a36a97103ceb21b Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 27 Feb 2025 19:08:13 +0800 Subject: [PATCH 15/15] cleanup StepState --- python/paddle/jit/sot/utils/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 6ff0c9651e275f..3ceccc605ec09c 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -62,7 +62,6 @@ Singleton, SotUndefinedVar, StepInfoManager, - StepState, count_if, current_symbol_registry, execute_time,