Skip to content

Commit

Permalink
[SOT] Cleanup legacy flags and policies (PaddlePaddle#71279)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored and Enigmatisms committed Mar 5, 2025
1 parent f66b093 commit 37e121e
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 260 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
if TYPE_CHECKING:
from .function_graph import CompileGraphResult, FunctionGraph

SUPPORT_COMPARE_OP = {
COMPARE_OP_NAME_TO_FN = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
Expand Down Expand Up @@ -1352,7 +1352,7 @@ def COMPARE_OP(self, instr: Instruction):
right, left = self.stack.pop(), self.stack.pop()
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand All @@ -1371,7 +1371,7 @@ def IS_OP(self, instr: Instruction):
op = "is" if instr.arg == 0 else "is not"
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand Down Expand Up @@ -1588,7 +1588,7 @@ def CONTAINS_OP(self, instr: Instruction):
op = "in" if instr.arg == 0 else "not in"
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
FallbackError,
InnerError,
ResumeFnNameFactory,
is_clean_code,
list_contain_by_id,
list_find_index_by_id,
no_eval_frame,
Expand Down Expand Up @@ -516,8 +515,6 @@ def gen_disable_eval_frame(self):
"""
Generates instructions to disable the evaluation frame.
"""
if is_clean_code():
return
self.gen_load_object(
paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand All @@ -529,8 +526,6 @@ def gen_enable_eval_frame(self):
"""
Generates instructions to enable the evaluation frame.
"""
if is_clean_code():
return
self.gen_load_object(
paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand Down
52 changes: 16 additions & 36 deletions python/paddle/jit/sot/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
GraphLogger,
InfoCollector,
StepInfoManager,
StepState,
log_do,
)

Expand Down Expand Up @@ -96,42 +95,23 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]:
def callback(frame):
return eval_frame_callback(frame, **kwargs)

def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R:
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
StepInfoManager().sot_step()
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R:
outs = fn(*args, **kwargs)
return outs

def impl(*args: P.args, **kwargs: P.kwargs) -> R:
with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard():
state = StepInfoManager().current_state

if state == StepState.RUN_SOT:
return impl_sot(*args, **kwargs)
elif state == StepState.RUN_DYN:
return impl_dynamic(*args, **kwargs)
elif state == StepState.COLLECT_INFO:
return StepInfoManager().collect_info(
impl_dynamic, impl_sot, *args, **kwargs
)
else:
raise RuntimeError("Unknown state.")
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

return impl
5 changes: 0 additions & 5 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from .call_ast_utils import get_static_function, try_ast_func # noqa: F401
from .envs import ( # noqa: F401
ENV_CLEAN_CODE,
ENV_COST_MODEL,
ENV_MIN_GRAPH_SIZE,
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
ENV_SOT_ENABLE_FASTER_GUARD,
Expand All @@ -26,7 +24,6 @@
ENV_SOT_WITH_CONTROL_FLOW,
ENV_STRICT_MODE,
allow_dynamic_shape_guard,
cost_model_guard,
export_guard,
faster_guard_guard,
guard_tree_guard,
Expand Down Expand Up @@ -74,7 +71,6 @@
Singleton,
SotUndefinedVar,
StepInfoManager,
StepState,
count_if,
current_symbol_registry,
execute_time,
Expand All @@ -87,7 +83,6 @@
in_paddle_module,
is_break_graph_api,
is_builtin_fn,
is_clean_code,
is_comprehensive_name,
is_paddle_api,
is_strict_mode,
Expand Down
8 changes: 0 additions & 8 deletions python/paddle/jit/sot/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
StringListEnvironmentVariable,
)

ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False)
ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10)
ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0)
ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False)
ENV_CLEAN_CODE = BooleanEnvironmentVariable("CLEAN_CODE", False)
ENV_SOT_WITH_CONTROL_FLOW = BooleanEnvironmentVariable(
"SOT_WITH_CONTROL_FLOW", True
)
Expand Down Expand Up @@ -60,12 +58,6 @@
)


@contextmanager
def cost_model_guard(value: bool):
with EnvironmentVariableGuard(ENV_COST_MODEL, value):
yield


@contextmanager
def strict_mode_guard(value: bool):
with EnvironmentVariableGuard(ENV_STRICT_MODE, value):
Expand Down
88 changes: 0 additions & 88 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import weakref
from collections import OrderedDict
from contextlib import contextmanager
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from weakref import WeakValueDictionary

Expand All @@ -32,8 +31,6 @@
from paddle.utils import flatten, map_structure

from .envs import (
ENV_CLEAN_CODE,
ENV_COST_MODEL,
ENV_SOT_LOG_LEVEL,
ENV_STRICT_MODE,
)
Expand Down Expand Up @@ -319,10 +316,6 @@ def is_strict_mode():
return ENV_STRICT_MODE.get()


def is_clean_code() -> bool:
return ENV_CLEAN_CODE.get()


def list_find_index_by_id(li: list[Any], item: Any) -> int:
return [id(it) for it in li].index(id(item))

Expand Down Expand Up @@ -409,69 +402,15 @@ def printable(obj):
return False


class StepState(Enum):
COLLECT_INFO = 1
RUN_SOT = 2
RUN_DYN = 3


class StepInfo:
REQUIRED_DYN_INFOS = 10
REQUIRED_SOT_INFOS = 10

USED_DYN_INFOS = 5

COLLECT_INFO_MAX_STEP = 50
CV_BOUNDARY = 0.1

BACK_TRACE_STEPS = 20

def __init__(self):
self.step_count = -1
self.state = (
StepState.COLLECT_INFO
if ENV_COST_MODEL.get()
else StepState.RUN_SOT
)
self.dyn_time_costs = []
self.avg_dyn_time = 0
self.sot_time_costs = []
self.sot_step = -1

def add_dynamic_time_info(self, time_cost):
self.dyn_time_costs.append(time_cost)
if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS:
self.avg_dyn_time = np.mean(
self.dyn_time_costs[-self.USED_DYN_INFOS :]
)

def add_sot_time_info(self, time_cost, current_code):
self.sot_time_costs.append(time_cost)
if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS:
avg_sot_time = np.mean(self.sot_time_costs)
log(
1,
f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n",
)
if avg_sot_time < self.avg_dyn_time:
log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n")
self.state = StepState.RUN_SOT
elif (
self.step_count > self.COLLECT_INFO_MAX_STEP
or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY
):
log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n")
self.state = StepState.RUN_DYN
else:
log(1, f"[Cost Model] Decision delayed: {current_code}\n")
self.sot_time_costs.clear()

def need_back_trace(self):
return self.step_count < self.BACK_TRACE_STEPS

def need_dynamic_info(self):
return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS


class StepInfoManager(metaclass=Singleton):
def __init__(self):
Expand All @@ -491,34 +430,11 @@ def step_guard(self, code):
self.current_step_info = self.step_record[code]

self.current_step_info.step_count += 1

log(
2,
f"[Cost Model] New step start, current state is {self.current_state}\n",
)
yield
finally:
self.current_code = old_code
self.current_step_info = old_info

def sot_step(self):
self.current_step_info.sot_step += 1

def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs):
if self.current_step_info.need_dynamic_info():
start_time = time.perf_counter()
outs = impl_dynamic(*args, **kwargs)
time_cost = time.perf_counter() - start_time
self.current_step_info.add_dynamic_time_info(time_cost)
else:
start_time = time.perf_counter()
outs = impl_sot(*args, **kwargs)
time_cost = time.perf_counter() - start_time
self.current_step_info.add_sot_time_info(
time_cost, self.current_code
)
return outs

@property
def need_back_trace(self):
return self.current_step_info.need_back_trace()
Expand All @@ -527,10 +443,6 @@ def need_back_trace(self):
def current_step(self):
return self.current_step_info.step_count

@property
def current_state(self):
return self.current_step_info.state

def clear(self):
self.step_record.clear()
self.current_code = None
Expand Down
Loading

0 comments on commit 37e121e

Please sign in to comment.