Skip to content

Commit

Permalink
[SOT] Collect BreakGraph Reason (#71268)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Copilot <[email protected]>
  • Loading branch information
DrRyanHuang and Copilot authored Feb 27, 2025
1 parent c750413 commit 6800335
Show file tree
Hide file tree
Showing 12 changed files with 344 additions and 44 deletions.
5 changes: 2 additions & 3 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,8 @@ def __init__(self):
self.var_name_generator = UniqueNameGenerator(SOT_INFER_META_INNER_VAR)

def gen_name(self, meta):
name = f"{meta.dtype}_{meta.stop_gradient}"
for l in meta.shape:
name += f"_{l}"
name = f"{meta.dtype}_{meta.stop_gradient}_"
name += "_".join(map(str, meta.shape))
return name

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

# This file stores the customized function that will be called by the dispatch mechanism.

from ...utils import BreakGraphError, FallbackError
from ...utils import BreakGraphError, BreakGraphReasonBase, FallbackError


def raise_break_graph_fn(*args, **kwarg):
raise BreakGraphError("raise by raise_break_graph_fn.")
def create_raise_break_graph_handler(reason: BreakGraphReasonBase):
def raise_break_graph_fn(*args, **kwarg):
raise BreakGraphError(reason)

return raise_break_graph_fn


def raise_not_implement_fn(*args, **kwarg):
Expand Down
19 changes: 15 additions & 4 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@
map_if,
switch_symbol_registry,
)
from ...utils.exceptions import BreakGraphError, SotExtraInfo
from ...utils.exceptions import (
BreakGraphError,
DygraphInconsistentWithStaticBreak,
InferMetaBreak,
SotExtraInfo,
)
from ..instruction_utils import get_instructions
from .guard import Guard, StringifiedExpression, make_guard
from .mutable_data import MutationDel, MutationNew, MutationSet
Expand Down Expand Up @@ -671,7 +676,9 @@ def try_infer_meta_fn(args, kwargs) -> Any:
):
# TODO(zrr1999): maybe we can continue to fallback to all args are constant.
raise BreakGraphError(
f"InferMeta encount {type(e)}, but all args are not symbolic."
InferMetaBreak(
f"InferMeta encount {type(e)}, but all args are not symbolic."
)
)

args, kwargs = map_if(
Expand All @@ -698,7 +705,9 @@ def try_infer_meta_fn(args, kwargs) -> Any:
for arg in flatten_vars
):
raise BreakGraphError(
f"InferMeta encount {type(e)}, but all args are not symbolic."
InferMetaBreak(
f"InferMeta encount {type(e)}, but all args are not symbolic."
)
)

args, kwargs = map_structure(
Expand All @@ -712,7 +721,9 @@ def try_infer_meta_fn(args, kwargs) -> Any:
except Exception as e:
if SotExtraInfo.from_exception(e).need_breakgraph:
raise BreakGraphError(
f"API {func} encountered a need break graph error {e}"
DygraphInconsistentWithStaticBreak(
f"API {func} encountered a need break graph error {e}"
)
)
raise e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@
ENV_MIN_GRAPH_SIZE,
ENV_SOT_FORCE_FALLBACK_SIR_IDS,
BreakGraphError,
BuiltinFunctionBreak,
FallbackError,
InnerError,
SotUndefinedVar,
UnsupportedIteratorBreak,
UnsupportedOperationBreak,
get_static_function,
is_comprehensive_name,
log,
Expand Down Expand Up @@ -806,7 +809,9 @@ def LOAD_ATTR(self, instr: Instruction):
def LOAD_SUPER_ATTR(self, instr: Instruction):
# This bytecode is for Python 3.12+, and it will break graph in Python 3.11-.
# We align it's behavior with Python 3.11-.
raise BreakGraphError("call super is not supported")
raise BreakGraphError(
BuiltinFunctionBreak(reason_str="call super is not supported")
)

def LOAD_CONST(self, instr: Instruction):
var = self._co_consts[instr.arg]
Expand Down Expand Up @@ -1082,7 +1087,11 @@ def build_seq_unpack(self, instr: Instruction):
if not isinstance(
item, (TupleVariable, ListVariable, RangeVariable)
):
raise BreakGraphError(f"{type(item)} not support unpack")
raise BreakGraphError(
UnsupportedOperationBreak(
reason_str=f"{type(item)} not support unpack"
)
)
retval.extend(item.get_iter().to_list())

if instr.opname in {
Expand Down Expand Up @@ -1971,7 +1980,9 @@ def FOR_ITER(self, instr):
try:
if not isinstance(iterator, SequenceIterVariable):
raise BreakGraphError(
f"Can not simulate iterator of {type(iterator)}."
UnsupportedIteratorBreak(
f"Can not simulate iterator of {type(iterator)}."
)
)

backup_iter_idx = iterator.idx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from typing import TYPE_CHECKING

from ...profiler import event_register
from ...utils import BreakGraphError, log
from ...utils import (
BreakGraphError,
DataDependencyControlFlowBreak,
UnsupportedIteratorBreak,
log,
)
from ..instruction_utils import Instruction
from .guard import StringifiedExpression, union_free_vars
from .opcode_executor import OpcodeExecutorBase, Stop
Expand Down Expand Up @@ -298,9 +303,8 @@ def _break_graph_when_if(self, result, instr: Instruction):
result: The result of the operation.
instr (Instruction): The jump instruction.
"""
raise BreakGraphError(
"OpcodeInlineExecutor want break graph when simulate `if`."
)

raise BreakGraphError(DataDependencyControlFlowBreak())

def FOR_ITER(self, instr: Instruction):
iterator = self.stack.top
Expand All @@ -327,5 +331,7 @@ def FOR_ITER(self, instr: Instruction):
else:
self._graph.remove_global_guarded_variable(iterator)
raise BreakGraphError(
f"Found {iterator.__class__.__name__} as iterator."
UnsupportedIteratorBreak(
reason_str=f"Found {iterator.__class__.__name__} as iterator."
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,26 @@

import paddle

from ...utils import BreakGraphError, FallbackError, get_numpy_ufuncs
from ...utils import (
BreakGraphError,
BuiltinFunctionBreak,
FallbackError,
UnsupportedIteratorBreak,
UnsupportedOperationBreak,
get_numpy_ufuncs,
)
from ...utils.magic_methods import (
BINARY_OPS,
UNARY_OPS,
magic_method_builtin_dispatch,
)
from ...utils.paddle_api_config import get_tensor_methods
from .dispatch_functions import (
create_raise_break_graph_handler,
operator_in,
operator_is_none,
operator_is_not_none,
operator_not_in,
raise_break_graph_fn,
tensor_numel,
)
from .dispatcher import Dispatcher, optional
Expand Down Expand Up @@ -120,7 +127,9 @@ def inner(*args, **kwargs):
Dispatcher.register(
operator_in,
("VariableBase", "IterVariable"),
raise_err_handle(BreakGraphError("Codes like: `variable in iterator`.")),
create_raise_break_graph_handler(
UnsupportedIteratorBreak("Codes like: `variable in iterator`.")
),
)

Dispatcher.register(
Expand Down Expand Up @@ -152,8 +161,8 @@ def inner(*args, **kwargs):
Dispatcher.register(
operator_not_in,
("VariableBase", "IterVariable"),
raise_err_handle(
BreakGraphError("Codes like: `variable not in iterator`.")
create_raise_break_graph_handler(
UnsupportedIteratorBreak("Codes like: `variable not in iterator`.")
),
)

Expand Down Expand Up @@ -1025,7 +1034,11 @@ def is_not_func(var: VariableBase, other: VariableBase):
Dispatcher.register(
unary_fn,
("TensorVariable",),
raise_break_graph_fn,
create_raise_break_graph_handler(
BuiltinFunctionBreak(
fn_name=unary_fn, arg_types="TensorVariable"
)
),
)
continue

Expand Down Expand Up @@ -1090,7 +1103,11 @@ def tensor_mod_dispatcher(
):
if var.get_py_type() is str:
raise BreakGraphError(
"(ConstantVariable % TensorVariable) raise a callback. "
UnsupportedOperationBreak(
left_type="ConstantVariable",
right_type="TensorVariable",
operator="__rmod__",
)
)
raise FallbackError("Tensor doesn't support __rmod__")

Expand Down
50 changes: 40 additions & 10 deletions python/paddle/jit/sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@
from ....utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
BreakGraphError,
BuiltinFunctionBreak,
ConstTypes,
DataDependencyDynamicShapeBreak,
DataDependencyOperationBreak,
FallbackError,
NameGenerator,
UnsupportedOperationBreak,
get_tensor_methods,
log,
printable,
Expand Down Expand Up @@ -416,7 +420,9 @@ def analyse_dynamic_axes(self, tracker: Tracker):
def __len__(self):
if isinstance(self.meta.shape[0], SymbolicInt):
raise BreakGraphError(
"length of tensor variable with first dimension is dynamic shape causes graph break."
DataDependencyDynamicShapeBreak(
"length of tensor variable with first dimension is dynamic shape causes graph break."
)
)
return self.meta.shape[0]

Expand All @@ -439,7 +445,9 @@ def __hash__(self):
return SotTensor(self.id)

raise BreakGraphError(
"Called TensorVariable.get_py_value. Should not use Tensor's value in simulating."
DataDependencyOperationBreak(
"Called TensorVariable.get_py_value. Should not use Tensor's value in simulating."
)
)

def get_py_type(self):
Expand Down Expand Up @@ -589,7 +597,9 @@ def size(self):
# TODO: maybe break graph.
if self.meta.is_dynamic_shape():
raise BreakGraphError(
f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
DataDependencyDynamicShapeBreak(
f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
)
elements = reduce(operator.mul, self.meta.shape, 1)
return ConstantVariable(elements, self.graph, DummyTracker([self]))
Expand All @@ -601,7 +611,9 @@ def shape(self):
and self.meta.is_dynamic_shape()
):
raise BreakGraphError(
f"Getting shape for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
DataDependencyDynamicShapeBreak(
f"Getting shape for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
)
from .container import ListVariable

Expand All @@ -617,7 +629,9 @@ def len(self):
first_dim = self.meta.shape[0]
if isinstance(first_dim, SymbolicInt):
raise BreakGraphError(
"Getting len() for a dynamic shape tensor causes graph break."
DataDependencyDynamicShapeBreak(
"Getting len() for a dynamic shape tensor causes graph break."
)
)

return ConstantVariable(first_dim, self.graph, DummyTracker([self]))
Expand Down Expand Up @@ -662,7 +676,9 @@ def getattr(self, name: str, default=None):
}
if name in ["name", "place", "type"] and self.meta.is_inner_var():
raise BreakGraphError(
f"{self.meta.name} is a middle tensor. get {name} property."
DataDependencyOperationBreak(
f"{self.meta.name} is a middle tensor. Not support to get {name} property."
)
)
if name in [
"dtype",
Expand Down Expand Up @@ -709,7 +725,9 @@ def setattr(self, key, val):
)

def delattr(self, key):
raise BreakGraphError("Don't support TensorVariable delattr")
raise BreakGraphError(
BuiltinFunctionBreak("Don't support TensorVariable delattr")
)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
Expand Down Expand Up @@ -814,7 +832,11 @@ def disable_symbolic(var: VariableBase):

def get_py_value(self, allow_tensor: bool = False) -> bool | int | float:
if ENV_SOT_BREAK_GRAPH_ON_GET_SYMBOLIC_VALUE.get():
raise BreakGraphError("get_py_value from SymbolicVariable")
raise BreakGraphError(
DataDependencyOperationBreak(
"get_py_value from SymbolicVariable"
)
)
self.need_guard_value = True
log(
3,
Expand Down Expand Up @@ -1116,10 +1138,18 @@ def _reconstruct(self, codegen: PyCodeGen):
super()._reconstruct(codegen)

def setattr(self, key, val):
raise BreakGraphError("Don't support SliceVariable setattr")
raise BreakGraphError(
UnsupportedOperationBreak(
reason_str="Don't support SliceVariable setattr"
)
)

def delattr(self, key):
raise BreakGraphError("Don't support SliceVariable delattr")
raise BreakGraphError(
UnsupportedOperationBreak(
reason_str="Don't support SliceVariable delattr"
)
)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
Expand Down
Loading

0 comments on commit 6800335

Please sign in to comment.