Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] Cleanup legacy flags and policies #71279

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
if TYPE_CHECKING:
from .function_graph import CompileGraphResult, FunctionGraph

SUPPORT_COMPARE_OP = {
COMPARE_OP_NAME_TO_FN = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
Expand Down Expand Up @@ -1347,7 +1347,7 @@ def COMPARE_OP(self, instr: Instruction):
right, left = self.stack.pop(), self.stack.pop()
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand All @@ -1366,7 +1366,7 @@ def IS_OP(self, instr: Instruction):
op = "is" if instr.arg == 0 else "is not"
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand Down Expand Up @@ -1583,7 +1583,7 @@ def CONTAINS_OP(self, instr: Instruction):
op = "in" if instr.arg == 0 else "not in"
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
FallbackError,
InnerError,
ResumeFnNameFactory,
is_clean_code,
list_contain_by_id,
list_find_index_by_id,
no_eval_frame,
Expand Down Expand Up @@ -516,8 +515,6 @@ def gen_disable_eval_frame(self):
"""
Generates instructions to disable the evaluation frame.
"""
if is_clean_code():
return
self.gen_load_object(
paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand All @@ -529,8 +526,6 @@ def gen_enable_eval_frame(self):
"""
Generates instructions to enable the evaluation frame.
"""
if is_clean_code():
return
self.gen_load_object(
paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand Down
52 changes: 16 additions & 36 deletions python/paddle/jit/sot/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
GraphLogger,
InfoCollector,
StepInfoManager,
StepState,
log_do,
)

Expand Down Expand Up @@ -96,42 +95,23 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]:
def callback(frame):
return eval_frame_callback(frame, **kwargs)

def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R:
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
StepInfoManager().sot_step()
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R:
outs = fn(*args, **kwargs)
return outs

def impl(*args: P.args, **kwargs: P.kwargs) -> R:
with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard():
state = StepInfoManager().current_state

if state == StepState.RUN_SOT:
return impl_sot(*args, **kwargs)
elif state == StepState.RUN_DYN:
return impl_dynamic(*args, **kwargs)
elif state == StepState.COLLECT_INFO:
return StepInfoManager().collect_info(
impl_dynamic, impl_sot, *args, **kwargs
)
else:
raise RuntimeError("Unknown state.")
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

return impl
5 changes: 0 additions & 5 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from .call_ast_utils import get_static_function, try_ast_func # noqa: F401
from .envs import ( # noqa: F401
ENV_CLEAN_CODE,
ENV_COST_MODEL,
ENV_MIN_GRAPH_SIZE,
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_ENABLE_FASTER_GUARD,
Expand All @@ -26,7 +24,6 @@
ENV_SOT_WITH_CONTROL_FLOW,
ENV_STRICT_MODE,
allow_dynamic_shape_guard,
cost_model_guard,
export_guard,
faster_guard_guard,
guard_tree_guard,
Expand Down Expand Up @@ -65,7 +62,6 @@
Singleton,
SotUndefinedVar,
StepInfoManager,
StepState,
count_if,
current_symbol_registry,
execute_time,
Expand All @@ -78,7 +74,6 @@
in_paddle_module,
is_break_graph_api,
is_builtin_fn,
is_clean_code,
is_comprehensive_name,
is_paddle_api,
is_strict_mode,
Expand Down
8 changes: 0 additions & 8 deletions python/paddle/jit/sot/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
StringListEnvironmentVariable,
)

ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False)
ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10)
ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0)
ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False)
ENV_CLEAN_CODE = BooleanEnvironmentVariable("CLEAN_CODE", False)
ENV_SOT_WITH_CONTROL_FLOW = BooleanEnvironmentVariable(
"SOT_WITH_CONTROL_FLOW", True
)
Expand Down Expand Up @@ -60,12 +58,6 @@
)


@contextmanager
def cost_model_guard(value: bool):
with EnvironmentVariableGuard(ENV_COST_MODEL, value):
yield


@contextmanager
def strict_mode_guard(value: bool):
with EnvironmentVariableGuard(ENV_STRICT_MODE, value):
Expand Down
88 changes: 0 additions & 88 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import weakref
from collections import OrderedDict
from contextlib import contextmanager
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from weakref import WeakValueDictionary

Expand All @@ -32,8 +31,6 @@
from paddle.utils import flatten, map_structure

from .envs import (
ENV_CLEAN_CODE,
ENV_COST_MODEL,
ENV_SOT_LOG_LEVEL,
ENV_STRICT_MODE,
)
Expand Down Expand Up @@ -319,10 +316,6 @@ def is_strict_mode():
return ENV_STRICT_MODE.get()


def is_clean_code() -> bool:
return ENV_CLEAN_CODE.get()


def list_find_index_by_id(li: list[Any], item: Any) -> int:
return [id(it) for it in li].index(id(item))

Expand Down Expand Up @@ -409,69 +402,15 @@ def printable(obj):
return False


class StepState(Enum):
COLLECT_INFO = 1
RUN_SOT = 2
RUN_DYN = 3


class StepInfo:
REQUIRED_DYN_INFOS = 10
REQUIRED_SOT_INFOS = 10

USED_DYN_INFOS = 5

COLLECT_INFO_MAX_STEP = 50
CV_BOUNDARY = 0.1

BACK_TRACE_STEPS = 20

def __init__(self):
self.step_count = -1
self.state = (
StepState.COLLECT_INFO
if ENV_COST_MODEL.get()
else StepState.RUN_SOT
)
self.dyn_time_costs = []
self.avg_dyn_time = 0
self.sot_time_costs = []
self.sot_step = -1

def add_dynamic_time_info(self, time_cost):
self.dyn_time_costs.append(time_cost)
if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS:
self.avg_dyn_time = np.mean(
self.dyn_time_costs[-self.USED_DYN_INFOS :]
)

def add_sot_time_info(self, time_cost, current_code):
self.sot_time_costs.append(time_cost)
if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS:
avg_sot_time = np.mean(self.sot_time_costs)
log(
1,
f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n",
)
if avg_sot_time < self.avg_dyn_time:
log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n")
self.state = StepState.RUN_SOT
elif (
self.step_count > self.COLLECT_INFO_MAX_STEP
or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY
):
log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n")
self.state = StepState.RUN_DYN
else:
log(1, f"[Cost Model] Decision delayed: {current_code}\n")
self.sot_time_costs.clear()

def need_back_trace(self):
return self.step_count < self.BACK_TRACE_STEPS

def need_dynamic_info(self):
return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS


class StepInfoManager(metaclass=Singleton):
def __init__(self):
Expand All @@ -491,34 +430,11 @@ def step_guard(self, code):
self.current_step_info = self.step_record[code]

self.current_step_info.step_count += 1

log(
2,
f"[Cost Model] New step start, current state is {self.current_state}\n",
)
yield
finally:
self.current_code = old_code
self.current_step_info = old_info

def sot_step(self):
self.current_step_info.sot_step += 1

def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs):
if self.current_step_info.need_dynamic_info():
start_time = time.perf_counter()
outs = impl_dynamic(*args, **kwargs)
time_cost = time.perf_counter() - start_time
self.current_step_info.add_dynamic_time_info(time_cost)
else:
start_time = time.perf_counter()
outs = impl_sot(*args, **kwargs)
time_cost = time.perf_counter() - start_time
self.current_step_info.add_sot_time_info(
time_cost, self.current_code
)
return outs

@property
def need_back_trace(self):
return self.current_step_info.need_back_trace()
Expand All @@ -527,10 +443,6 @@ def need_back_trace(self):
def current_step(self):
return self.current_step_info.step_count

@property
def current_state(self):
return self.current_step_info.state

def clear(self):
self.step_record.clear()
self.current_code = None
Expand Down
Loading