Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] Cleanup legacy flags and policies #71279

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ set(PYBIND_SRCS
xpu_streams_py.cc
jit.cc
auto_parallel_py.cc
sot/eval_frame_tools.cc
sot/skip_files.cc
sot/cpython_internals.c
sot/frame_proxy.c
sot/eval_frame.c
Expand Down
20 changes: 1 addition & 19 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ limitations under the License. */
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/fluid/pybind/sot/eval_frame.h"
#include "paddle/fluid/pybind/sot/eval_frame_tools.h"
#include "paddle/fluid/pybind/sot/frame_proxy.h"
#include "paddle/fluid/pybind/sot/guards.h"
#include "paddle/fluid/pybind/sot/macros.h"
#include "paddle/fluid/pybind/sot/skip_files.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/utils/pybind.h"
Expand Down Expand Up @@ -150,24 +150,6 @@ void BindSot(pybind11::module *m) {
return type->tp_getattro != PyObject_GenericGetAttr;
});

m->def(
"sot_setup_codes_with_graph",
[](const py::object &py_codes) {
auto ret = setup_codes_with_graph(py_codes.ptr());
auto obj = py::reinterpret_borrow<py::object>(ret);
return obj;
},
py::arg("py_codes"));

m->def(
"sot_set_with_graph",
[](const py::object &py_codes) {
auto ret = set_with_graph(py_codes.ptr());
auto obj = py::reinterpret_borrow<py::object>(ret);
return obj;
},
py::arg("py_codes"));

m->def(
"eval_frame_no_skip_codes",
[](const py::object &py_codes) {
Expand Down
13 changes: 1 addition & 12 deletions paddle/fluid/pybind/sot/eval_frame.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ limitations under the License. */
#if SOT_IS_SUPPORTED

#include "paddle/fluid/pybind/sot/cpython_internals.h"
#include "paddle/fluid/pybind/sot/eval_frame_tools.h"
#include "paddle/fluid/pybind/sot/frame_proxy.h"
#include "paddle/fluid/pybind/sot/skip_files.h"

#include <Python.h>

Expand Down Expand Up @@ -329,17 +329,6 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
Py_DECREF(f_locals);
#endif

// code status
if (is_code_without_graph(code == Py_None ? PyFrame_GET_CODE(frame)
: (PyCodeObject *)code) &&
disable_eval_frame == Py_False) {
out = eval_frame_default(tstate, frame, throw_flag);
eval_frame_callback_set(callback);
Py_DECREF(code);
Py_DECREF(disable_eval_frame);
return out;
}

// run code
if (disable_eval_frame != Py_True) {
// Re-enable custom behavior
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/sot/eval_frame_tools.h"
#include "paddle/fluid/pybind/sot/skip_files.h"

#include <Python.h>

#include <unordered_set>

#include "paddle/common/errors.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/platform/profiler/event_tracing.h"

#if SOT_IS_SUPPORTED
#define END_OF_STRING '\0'
Expand Down Expand Up @@ -142,77 +141,6 @@ int SkipCodeInfo::in_skip_path(PyObject* filename) {
return root->check_filename(name);
}

/*========================== code status ==============================*/
enum CodeState { UNKNOW, WITH_GRAPH, WITHOUT_GRAPH };

class CodeInfo {
public:
CodeState state;
int counter;
};

class CodeStatus {
public:
static CodeStatus& Instance();
int is_code_without_graph(PyCodeObject* code);
void set_with_graph(PyCodeObject* code);
void add_with_graph_code(PyCodeObject* code);
void clear();

private:
CodeStatus() { code_map = std::unordered_map<PyCodeObject*, CodeInfo*>(); }
~CodeStatus() { clear(); }
std::unordered_map<PyCodeObject*, CodeInfo*> code_map;
};

CodeStatus& CodeStatus::Instance() {
static CodeStatus _instance;
return _instance;
}

int CodeStatus::is_code_without_graph(PyCodeObject* code) {
CodeInfo* code_info;
if (code_map.find(code) != code_map.end()) {
code_info = code_map[code];
} else {
code_info = new CodeInfo();
code_map.emplace(code, code_info);
}
if (code_info->state == WITHOUT_GRAPH) return 1;
if (code_info->state == UNKNOW) {
code_info->counter += 1;
if (code_info->counter >= 10) code_info->state = WITHOUT_GRAPH;
}
return 0;
}

void CodeStatus::set_with_graph(PyCodeObject* code) {
CodeInfo* code_info;
if (code_map.find(code) != code_map.end()) {
code_info = code_map[code];
code_info->state = WITH_GRAPH;
}
}

void CodeStatus::add_with_graph_code(PyCodeObject* code) {
CodeInfo* code_info;
if (code_map.find(code) != code_map.end()) {
code_info = code_map[code];
code_info->state = WITH_GRAPH;
} else {
code_info = new CodeInfo();
code_info->state = WITH_GRAPH;
code_map.emplace(code, code_info);
}
}

void CodeStatus::clear() {
for (auto& iter : code_map) {
delete iter.second;
}
code_map.clear();
}

/*========================== interfaces ===============================*/

int need_skip(FrameObject* frame) {
Expand Down Expand Up @@ -245,29 +173,7 @@ int need_skip(FrameObject* frame) {
return result;
}

int is_code_without_graph(PyCodeObject* code) {
auto& code_status = CodeStatus::Instance();
return code_status.is_code_without_graph(code);
}

/*========================== pybind ===============================*/
PyObject* set_with_graph(PyObject* code) {
auto& code_status = CodeStatus::Instance();
code_status.set_with_graph((PyCodeObject*)code); // NOLINT
return Py_None;
}

PyObject* setup_codes_with_graph(PyObject* code_tuple) {
auto& code_status = CodeStatus::Instance();
Py_ssize_t size = PyTuple_GET_SIZE(code_tuple);
for (Py_ssize_t i = 0; i < size; i++) {
PyCodeObject* code =
(PyCodeObject*)PyTuple_GetItem(code_tuple, i); // NOLINT
code_status.add_with_graph_code(code);
}
return Py_None;
}

PyObject* no_skip_codes(PyObject* code_tuple) {
auto& skip_info = SkipCodeInfo::Instance();
Py_ssize_t size = PyTuple_GET_SIZE(code_tuple);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ extern "C" {
#if SOT_IS_SUPPORTED

int need_skip(FrameObject* frame);
int is_code_without_graph(PyCodeObject* code);

PyObject* set_with_graph(PyObject* code);
PyObject* setup_codes_with_graph(PyObject* code_tuple);
PyObject* no_skip_codes(PyObject* code_tuple);
PyObject* skip_file_prefix(PyObject* filepath_tuple);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
if TYPE_CHECKING:
from .function_graph import CompileGraphResult, FunctionGraph

SUPPORT_COMPARE_OP = {
COMPARE_OP_NAME_TO_FN = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
Expand Down Expand Up @@ -1347,7 +1347,7 @@ def COMPARE_OP(self, instr: Instruction):
right, left = self.stack.pop(), self.stack.pop()
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand All @@ -1366,7 +1366,7 @@ def IS_OP(self, instr: Instruction):
op = "is" if instr.arg == 0 else "is not"
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand Down Expand Up @@ -1583,7 +1583,7 @@ def CONTAINS_OP(self, instr: Instruction):
op = "in" if instr.arg == 0 else "not in"
self.stack.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
COMPARE_OP_NAME_TO_FN[op], self._graph, DanglingTracker()
)(left, right)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
FallbackError,
InnerError,
ResumeFnNameFactory,
is_clean_code,
list_contain_by_id,
list_find_index_by_id,
no_eval_frame,
Expand Down Expand Up @@ -516,8 +515,6 @@ def gen_disable_eval_frame(self):
"""
Generates instructions to disable the evaluation frame.
"""
if is_clean_code():
return
self.gen_load_object(
paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand All @@ -529,8 +526,6 @@ def gen_enable_eval_frame(self):
"""
Generates instructions to enable the evaluation frame.
"""
if is_clean_code():
return
self.gen_load_object(
paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand Down
6 changes: 0 additions & 6 deletions python/paddle/jit/sot/opcode_translator/skip_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,7 @@ def _module_dir(m: types.ModuleType):

no_skip_code = {paddle.nn.Sequential.forward.__code__}

with_graph_codes = (
paddle.nn.Layer.__call__.__code__,
paddle.nn.Layer._dygraph_call_func.__code__,
)


def setup_skip_files():
paddle.framework.core.eval_frame_skip_file_prefix(tuple(skip_file_names))
paddle.framework.core.eval_frame_no_skip_codes(tuple(no_skip_code))
paddle.framework.core.sot_setup_codes_with_graph(with_graph_codes)
4 changes: 0 additions & 4 deletions python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
InfoCollector,
NewSymbolHitRateInfo,
Singleton,
StepInfoManager,
SubGraphRelationInfo,
log,
log_do,
Expand Down Expand Up @@ -219,9 +218,6 @@ def collect_subgraph_relation(self, inputs, outputs, partial_program_layer):

def __call__(self, *args, **kwargs):
with EventGuard(f"FallbackWrapper: {self.SIR.name}"):
if StepInfoManager().need_back_trace:
trace_back_frames()

log_do(
2,
lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR),
Expand Down
55 changes: 17 additions & 38 deletions python/paddle/jit/sot/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from .utils import (
GraphLogger,
InfoCollector,
StepInfoManager,
StepState,
log_do,
)

Expand Down Expand Up @@ -96,42 +94,23 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]:
def callback(frame):
return eval_frame_callback(frame, **kwargs)

def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R:
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
StepInfoManager().sot_step()
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R:
outs = fn(*args, **kwargs)
return outs

def impl(*args: P.args, **kwargs: P.kwargs) -> R:
with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard():
state = StepInfoManager().current_state

if state == StepState.RUN_SOT:
return impl_sot(*args, **kwargs)
elif state == StepState.RUN_DYN:
return impl_dynamic(*args, **kwargs)
elif state == StepState.COLLECT_INFO:
return StepInfoManager().collect_info(
impl_dynamic, impl_sot, *args, **kwargs
)
else:
raise RuntimeError("Unknown state.")
with SotStepProfilerGuard():
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

return impl
Loading