Skip to content

Commit

Permalink
cleanup cost model
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Feb 26, 2025
1 parent 74200f8 commit 4af8662
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 426 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ set(PYBIND_SRCS
xpu_streams_py.cc
jit.cc
auto_parallel_py.cc
sot/eval_frame_tools.cc
sot/skip_files.cc
sot/cpython_internals.c
sot/frame_proxy.c
sot/eval_frame.c
Expand Down
20 changes: 1 addition & 19 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ limitations under the License. */
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/fluid/pybind/sot/eval_frame.h"
#include "paddle/fluid/pybind/sot/eval_frame_tools.h"
#include "paddle/fluid/pybind/sot/frame_proxy.h"
#include "paddle/fluid/pybind/sot/guards.h"
#include "paddle/fluid/pybind/sot/macros.h"
#include "paddle/fluid/pybind/sot/skip_files.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/utils/pybind.h"
Expand Down Expand Up @@ -150,24 +150,6 @@ void BindSot(pybind11::module *m) {
return type->tp_getattro != PyObject_GenericGetAttr;
});

m->def(
"sot_setup_codes_with_graph",
[](const py::object &py_codes) {
auto ret = setup_codes_with_graph(py_codes.ptr());
auto obj = py::reinterpret_borrow<py::object>(ret);
return obj;
},
py::arg("py_codes"));

m->def(
"sot_set_with_graph",
[](const py::object &py_codes) {
auto ret = set_with_graph(py_codes.ptr());
auto obj = py::reinterpret_borrow<py::object>(ret);
return obj;
},
py::arg("py_codes"));

m->def(
"eval_frame_no_skip_codes",
[](const py::object &py_codes) {
Expand Down
13 changes: 1 addition & 12 deletions paddle/fluid/pybind/sot/eval_frame.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ limitations under the License. */
#if SOT_IS_SUPPORTED

#include "paddle/fluid/pybind/sot/cpython_internals.h"
#include "paddle/fluid/pybind/sot/eval_frame_tools.h"
#include "paddle/fluid/pybind/sot/frame_proxy.h"
#include "paddle/fluid/pybind/sot/skip_files.h"

#include <Python.h>

Expand Down Expand Up @@ -329,17 +329,6 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
Py_DECREF(f_locals);
#endif

// code status
if (is_code_without_graph(code == Py_None ? PyFrame_GET_CODE(frame)
: (PyCodeObject *)code) &&
disable_eval_frame == Py_False) {
out = eval_frame_default(tstate, frame, throw_flag);
eval_frame_callback_set(callback);
Py_DECREF(code);
Py_DECREF(disable_eval_frame);
return out;
}

// run code
if (disable_eval_frame != Py_True) {
// Re-enable custom behavior
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/sot/eval_frame_tools.h"
#include "paddle/fluid/pybind/sot/skip_files.h"

#include <Python.h>

#include <unordered_set>

#include "paddle/common/errors.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/platform/profiler/event_tracing.h"

#if SOT_IS_SUPPORTED
#define END_OF_STRING '\0'
Expand Down Expand Up @@ -142,77 +141,6 @@ int SkipCodeInfo::in_skip_path(PyObject* filename) {
return root->check_filename(name);
}

/*========================== code status ==============================*/
enum CodeState { UNKNOW, WITH_GRAPH, WITHOUT_GRAPH };

class CodeInfo {
public:
CodeState state;
int counter;
};

class CodeStatus {
public:
static CodeStatus& Instance();
int is_code_without_graph(PyCodeObject* code);
void set_with_graph(PyCodeObject* code);
void add_with_graph_code(PyCodeObject* code);
void clear();

private:
CodeStatus() { code_map = std::unordered_map<PyCodeObject*, CodeInfo*>(); }
~CodeStatus() { clear(); }
std::unordered_map<PyCodeObject*, CodeInfo*> code_map;
};

CodeStatus& CodeStatus::Instance() {
static CodeStatus _instance;
return _instance;
}

int CodeStatus::is_code_without_graph(PyCodeObject* code) {
CodeInfo* code_info;
if (code_map.find(code) != code_map.end()) {
code_info = code_map[code];
} else {
code_info = new CodeInfo();
code_map.emplace(code, code_info);
}
if (code_info->state == WITHOUT_GRAPH) return 1;
if (code_info->state == UNKNOW) {
code_info->counter += 1;
if (code_info->counter >= 10) code_info->state = WITHOUT_GRAPH;
}
return 0;
}

void CodeStatus::set_with_graph(PyCodeObject* code) {
CodeInfo* code_info;
if (code_map.find(code) != code_map.end()) {
code_info = code_map[code];
code_info->state = WITH_GRAPH;
}
}

void CodeStatus::add_with_graph_code(PyCodeObject* code) {
CodeInfo* code_info;
if (code_map.find(code) != code_map.end()) {
code_info = code_map[code];
code_info->state = WITH_GRAPH;
} else {
code_info = new CodeInfo();
code_info->state = WITH_GRAPH;
code_map.emplace(code, code_info);
}
}

void CodeStatus::clear() {
for (auto& iter : code_map) {
delete iter.second;
}
code_map.clear();
}

/*========================== interfaces ===============================*/

int need_skip(FrameObject* frame) {
Expand Down Expand Up @@ -245,29 +173,7 @@ int need_skip(FrameObject* frame) {
return result;
}

int is_code_without_graph(PyCodeObject* code) {
auto& code_status = CodeStatus::Instance();
return code_status.is_code_without_graph(code);
}

/*========================== pybind ===============================*/
PyObject* set_with_graph(PyObject* code) {
auto& code_status = CodeStatus::Instance();
code_status.set_with_graph((PyCodeObject*)code); // NOLINT
return Py_None;
}

PyObject* setup_codes_with_graph(PyObject* code_tuple) {
auto& code_status = CodeStatus::Instance();
Py_ssize_t size = PyTuple_GET_SIZE(code_tuple);
for (Py_ssize_t i = 0; i < size; i++) {
PyCodeObject* code =
(PyCodeObject*)PyTuple_GetItem(code_tuple, i); // NOLINT
code_status.add_with_graph_code(code);
}
return Py_None;
}

PyObject* no_skip_codes(PyObject* code_tuple) {
auto& skip_info = SkipCodeInfo::Instance();
Py_ssize_t size = PyTuple_GET_SIZE(code_tuple);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ extern "C" {
#if SOT_IS_SUPPORTED

int need_skip(FrameObject* frame);
int is_code_without_graph(PyCodeObject* code);

PyObject* set_with_graph(PyObject* code);
PyObject* setup_codes_with_graph(PyObject* code_tuple);
PyObject* no_skip_codes(PyObject* code_tuple);
PyObject* skip_file_prefix(PyObject* filepath_tuple);

Expand Down
6 changes: 0 additions & 6 deletions python/paddle/jit/sot/opcode_translator/skip_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,7 @@ def _module_dir(m: types.ModuleType):

no_skip_code = {paddle.nn.Sequential.forward.__code__}

with_graph_codes = (
paddle.nn.Layer.__call__.__code__,
paddle.nn.Layer._dygraph_call_func.__code__,
)


def setup_skip_files():
paddle.framework.core.eval_frame_skip_file_prefix(tuple(skip_file_names))
paddle.framework.core.eval_frame_no_skip_codes(tuple(no_skip_code))
paddle.framework.core.sot_setup_codes_with_graph(with_graph_codes)
4 changes: 0 additions & 4 deletions python/paddle/jit/sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
InfoCollector,
NewSymbolHitRateInfo,
Singleton,
StepInfoManager,
SubGraphRelationInfo,
log,
log_do,
Expand Down Expand Up @@ -219,9 +218,6 @@ def collect_subgraph_relation(self, inputs, outputs, partial_program_layer):

def __call__(self, *args, **kwargs):
with EventGuard(f"FallbackWrapper: {self.SIR.name}"):
if StepInfoManager().need_back_trace:
trace_back_frames()

log_do(
2,
lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR),
Expand Down
55 changes: 17 additions & 38 deletions python/paddle/jit/sot/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from .utils import (
GraphLogger,
InfoCollector,
StepInfoManager,
StepState,
log_do,
)

Expand Down Expand Up @@ -96,42 +94,23 @@ def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]:
def callback(frame):
return eval_frame_callback(frame, **kwargs)

def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R:
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
StepInfoManager().sot_step()
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R:
outs = fn(*args, **kwargs)
return outs

def impl(*args: P.args, **kwargs: P.kwargs) -> R:
with StepInfoManager().step_guard(fn.__code__), SotStepProfilerGuard():
state = StepInfoManager().current_state

if state == StepState.RUN_SOT:
return impl_sot(*args, **kwargs)
elif state == StepState.RUN_DYN:
return impl_dynamic(*args, **kwargs)
elif state == StepState.COLLECT_INFO:
return StepInfoManager().collect_info(
impl_dynamic, impl_sot, *args, **kwargs
)
else:
raise RuntimeError("Unknown state.")
with SotStepProfilerGuard():
assert hasattr(
fn, "__code__"
), "Target function doesn't have code for simulating."
GraphLogger().clear()
InfoCollector().clear_step_info()
paddle.framework.core.set_eval_frame(callback)
try:
outs = fn(*args, **kwargs)
except Exception as e:
raise e
finally:
paddle.framework.core.set_eval_frame(None)

log_do(1, lambda: GraphLogger().print_info())
InfoCollector().print_step_report()
return outs

return impl
1 change: 0 additions & 1 deletion python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
ResumeFnNameFactory,
Singleton,
SotUndefinedVar,
StepInfoManager,
StepState,
count_if,
current_symbol_registry,
Expand Down
Loading

0 comments on commit 4af8662

Please sign in to comment.