-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FasterTokenizer model in experiment #1220
Changes from 5 commits
d8eacc9
18587f9
9b74b1b
f6ef389
6e32df5
771af1b
3775622
8eec9ca
d5ff0ee
84959f6
f96c478
d6c77a3
bcb5d76
734b93e
046d9ce
de328db
fda2ec2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
cmake_minimum_required(VERSION 3.0) | ||
project(cpp_inference_demo CXX C) | ||
option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) | ||
option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) | ||
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) | ||
option(USE_TENSORRT "Compile demo with TensorRT." OFF) | ||
option(WITH_ROCM "Compile demo with rocm." OFF) | ||
|
||
if(NOT WITH_STATIC_LIB) | ||
add_definitions("-DPADDLE_WITH_SHARED_LIB") | ||
else() | ||
# PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode. | ||
# Set it to empty in static library mode to avoid compilation issues. | ||
add_definitions("/DPD_INFER_DECL=") | ||
endif() | ||
|
||
macro(safe_set_static_flag) | ||
foreach(flag_var | ||
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE | ||
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) | ||
if(${flag_var} MATCHES "/MD") | ||
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") | ||
endif(${flag_var} MATCHES "/MD") | ||
endforeach(flag_var) | ||
endmacro() | ||
|
||
if(NOT DEFINED PADDLE_LIB) | ||
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") | ||
endif() | ||
if(NOT DEFINED DEMO_NAME) | ||
message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name") | ||
endif() | ||
|
||
include_directories("${PADDLE_LIB}/") | ||
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") | ||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") | ||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include") | ||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include") | ||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") | ||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include") | ||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}utf8proc/include") | ||
|
||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") | ||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib") | ||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib") | ||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") | ||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") | ||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}utf8proc/lib") | ||
link_directories("${PADDLE_LIB}/paddle/lib") | ||
|
||
if (WIN32) | ||
add_definitions("/DGOOGLE_GLOG_DLL_DECL=") | ||
option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) | ||
if (MSVC_STATIC_CRT) | ||
if (WITH_MKL) | ||
set(FLAG_OPENMP "/openmp") | ||
endif() | ||
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") | ||
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") | ||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") | ||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") | ||
safe_set_static_flag() | ||
if (WITH_STATIC_LIB) | ||
add_definitions(-DSTATIC_LIB) | ||
endif() | ||
endif() | ||
else() | ||
if(WITH_MKL) | ||
set(FLAG_OPENMP "-fopenmp") | ||
endif() | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${FLAG_OPENMP}") | ||
endif() | ||
|
||
if(WITH_GPU) | ||
if(NOT WIN32) | ||
set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library") | ||
else() | ||
if(CUDA_LIB STREQUAL "") | ||
set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") | ||
endif() | ||
endif(NOT WIN32) | ||
endif() | ||
|
||
if (USE_TENSORRT AND WITH_GPU) | ||
set(TENSORRT_ROOT "" CACHE STRING "The root directory of TensorRT library") | ||
if("${TENSORRT_ROOT}" STREQUAL "") | ||
message(FATAL_ERROR "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH ") | ||
endif() | ||
set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT}/include) | ||
set(TENSORRT_LIB_DIR ${TENSORRT_ROOT}/lib) | ||
file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS) | ||
string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION | ||
"${TENSORRT_VERSION_FILE_CONTENTS}") | ||
if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") | ||
file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS) | ||
string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION | ||
"${TENSORRT_VERSION_FILE_CONTENTS}") | ||
endif() | ||
if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") | ||
message(SEND_ERROR "Failed to detect TensorRT version.") | ||
endif() | ||
string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1" | ||
TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}") | ||
message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " | ||
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") | ||
include_directories("${TENSORRT_INCLUDE_DIR}") | ||
link_directories("${TENSORRT_LIB_DIR}") | ||
endif() | ||
|
||
if(WITH_MKL) | ||
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") | ||
include_directories("${MATH_LIB_PATH}/include") | ||
if(WIN32) | ||
set(MATH_LIB ${MATH_LIB_PATH}/lib/mklml${CMAKE_STATIC_LIBRARY_SUFFIX} | ||
${MATH_LIB_PATH}/lib/libiomp5md${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
else() | ||
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} | ||
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
endif() | ||
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") | ||
if(EXISTS ${MKLDNN_PATH}) | ||
include_directories("${MKLDNN_PATH}/include") | ||
if(WIN32) | ||
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) | ||
else(WIN32) | ||
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) | ||
endif(WIN32) | ||
endif() | ||
else() | ||
set(OPENBLAS_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}openblas") | ||
include_directories("${OPENBLAS_LIB_PATH}/include/openblas") | ||
if(WIN32) | ||
set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
else() | ||
set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
endif() | ||
endif() | ||
|
||
if(WITH_STATIC_LIB) | ||
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
else() | ||
if(WIN32) | ||
set(DEPS ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
else() | ||
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
endif() | ||
endif() | ||
|
||
if (NOT WIN32) | ||
set(EXTERNAL_LIB "-lrt -ldl -lpthread") | ||
set(DEPS ${DEPS} | ||
${MATH_LIB} ${MKLDNN_LIB} | ||
glog gflags protobuf xxhash cryptopp utf8proc | ||
${EXTERNAL_LIB}) | ||
else() | ||
set(DEPS ${DEPS} | ||
${MATH_LIB} ${MKLDNN_LIB} | ||
glog gflags_static libprotobuf xxhash cryptopp-static utf8proc_static ${EXTERNAL_LIB}) | ||
set(DEPS ${DEPS} shlwapi.lib) | ||
endif(NOT WIN32) | ||
|
||
if(WITH_GPU) | ||
if(NOT WIN32) | ||
if (USE_TENSORRT) | ||
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
endif() | ||
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
else() | ||
if(USE_TENSORRT) | ||
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) | ||
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
endif() | ||
endif() | ||
set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) | ||
set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) | ||
set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) | ||
endif() | ||
endif() | ||
|
||
if(WITH_ROCM) | ||
if(NOT WIN32) | ||
set(DEPS ${DEPS} ${ROCM_LIB}/libamdhip64${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
endif() | ||
endif() | ||
|
||
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) | ||
target_link_libraries(${DEMO_NAME} ${DEPS}) | ||
if(WIN32) | ||
if(USE_TENSORRT) | ||
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX} | ||
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} | ||
COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX} | ||
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} | ||
) | ||
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) | ||
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_SHARED_LIBRARY_SUFFIX} | ||
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) | ||
endif() | ||
endif() | ||
if(WITH_MKL) | ||
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll ${CMAKE_BINARY_DIR}/Release | ||
COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/Release | ||
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/mkldnn.dll ${CMAKE_BINARY_DIR}/Release | ||
) | ||
else() | ||
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll ${CMAKE_BINARY_DIR}/Release | ||
) | ||
endif() | ||
if(NOT WITH_STATIC_LIB) | ||
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy "${PADDLE_LIB}/paddle/lib/paddle_inference.dll" ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} | ||
) | ||
endif() | ||
endif() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/bin/bash | ||
set +x | ||
set -e | ||
|
||
work_path=$(dirname $(readlink -f $0)) | ||
|
||
# 1. check paddle_inference exists | ||
if [ ! -d "${work_path}/lib/paddle_inference" ]; then | ||
echo "Please download paddle_inference lib and move it in cpp_deploy/lib" | ||
exit 1 | ||
fi | ||
|
||
# 2. check CMakeLists exists | ||
if [ ! -f "${work_path}/CMakeLists.txt" ]; then | ||
cp -a "${work_path}/lib/CMakeLists.txt" "${work_path}/" | ||
fi | ||
|
||
# 3. compile | ||
mkdir -p build | ||
cd build | ||
rm -rf * | ||
|
||
# same with the demo.cc | ||
DEMO_NAME=demo | ||
|
||
WITH_MKL=ON | ||
WITH_GPU=ON | ||
|
||
LIB_DIR=${work_path}/lib/paddle_inference | ||
CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ | ||
CUDA_LIB=/usr/local/cuda/lib64 | ||
|
||
cmake .. -DPADDLE_LIB=${LIB_DIR} \ | ||
-DWITH_MKL=${WITH_MKL} \ | ||
-DDEMO_NAME=${DEMO_NAME} \ | ||
-DWITH_GPU=${WITH_GPU} \ | ||
-DWITH_STATIC_LIB=OFF \ | ||
-DCUDNN_LIB=${CUDNN_LIB} \ | ||
-DCUDA_LIB=${CUDA_LIB} | ||
|
||
make -j |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要这些莫名其妙的空行 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 文件名不要定义为demo,改为infer.cc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 或者是ernie_infer。同时后面应该还得区分下句子分类还是序列标注任务 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, 修改为text_cls_infer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seq_cls_infer/token_cls_infer可能可以跟类名保持更好一致 |
||
|
||
#include <gflags/gflags.h> | ||
#include <glog/logging.h> | ||
|
||
#include <algorithm> | ||
#include <numeric> | ||
|
||
#include "paddle/include/paddle_inference_api.h" | ||
|
||
using paddle_infer::Config; | ||
using paddle_infer::Predictor; | ||
using paddle_infer::CreatePredictor; | ||
|
||
DEFINE_string(model_file, "", "Directory of the inference model."); | ||
DEFINE_string(params_file, "", "Directory of the inference model."); | ||
DEFINE_int32(batch_size, 1, "Directory of the inference model."); | ||
DEFINE_bool(use_gpu, true, "enable gpu"); | ||
|
||
std::shared_ptr<Predictor> InitPredictor() { | ||
Config config; | ||
config.SetModel(FLAGS_model_file, FLAGS_params_file); | ||
if (FLAGS_use_gpu) { | ||
config.EnableUseGpu(100, 0); | ||
} | ||
return CreatePredictor(config); | ||
} | ||
|
||
void Run(Predictor* predictor, | ||
std::vector<std::string>* input_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 输入应该是用const引用,输出才是指针 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改为
|
||
std::vector<float>* probs) { | ||
auto input_names = predictor->GetInputNames(); | ||
|
||
auto text = predictor->GetInputHandle(input_names[0]); | ||
text->ReshapeStrings(input_data->size()); | ||
text->CopyStringsFromCpu(input_data); | ||
|
||
predictor->Run(); | ||
|
||
auto output_names = predictor->GetOutputNames(); | ||
auto logits = predictor->GetOutputHandle(output_names[0]); | ||
std::vector<int> output_shape = logits->shape(); | ||
int logits_num = std::accumulate( | ||
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>()); | ||
probs->resize(logits_num); | ||
logits->CopyToCpu(probs->data()); | ||
} | ||
|
||
int main(int argc, char* argv[]) { | ||
google::ParseCommandLineFlags(&argc, &argv, true); | ||
auto predictor = InitPredictor(); | ||
|
||
std::vector<std::string> data{ | ||
"这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般", | ||
"怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的" | ||
"动画片", | ||
"作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上" | ||
"办理入住手续,节省时间。"}; | ||
|
||
std::vector<float> probs; | ||
Run(predictor.get(), &data, &probs); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要要给出print的结果 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 输出应该是const引用,保持data输入 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果这个demo就是分类,那就写清楚分类的,和序列标注的分开 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seq_cls_infer/token_cls_infer可能可以跟类名保持更好一致 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done,已修改为seq_cls_infer |
||
|
||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
set +x | ||
set -e | ||
|
||
work_path=$(dirname $(readlink -f $0)) | ||
|
||
# 1. compile | ||
bash ${work_path}/compile.sh | ||
|
||
# 2. run | ||
./build/demo -model_file ../export/inference.pdmodel --params_file ../export/inference.pdiparams |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
import paddle | ||
import paddlenlp | ||
from paddlenlp.experimental import FastSequenceClassificationModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Faster,我们整个技术代号统一使用Faster There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
# yapf: disable | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--params_path", type=str, required=True, default='./checkpoint/model_900.pdparams', help="The path to model parameters to be loaded.") | ||
parser.add_argument("--output_path", type=str, default='./export', help="The path of model parameter in static graph to be saved.") | ||
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. " | ||
"Sequences longer than this will be truncated, sequences shorter will be padded.") | ||
args = parser.parse_args() | ||
# yapf: enable | ||
|
||
if __name__ == "__main__": | ||
# The number of labels should be in accordance with the training dataset. | ||
label_map = {0: 'negative', 1: 'positive'} | ||
model = FastSequenceClassificationModel.from_pretrained( | ||
'ernie-1.0', | ||
num_classes=len(label_map), | ||
max_seq_len=args.max_seq_length) | ||
|
||
if args.params_path and os.path.isfile(args.params_path): | ||
state_dict = paddle.load(args.params_path) | ||
model.set_dict(state_dict) | ||
print("Loaded parameters from %s" % args.params_path) | ||
model.eval() | ||
|
||
# Convert to static graph with specific input description | ||
model = paddle.jit.to_static( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这行to_static封装到 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我感觉动转静应该自动导出到softmax之后,而不是事后还需要自己再softmax There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
model, | ||
input_spec=[ | ||
paddle.static.InputSpec( | ||
shape=[None, None], dtype=paddlenlp.ops.Strings), # texts | ||
]) | ||
# Save in static graph model. | ||
save_path = os.path.join(args.output_path, "inference") | ||
paddle.jit.save(model, save_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要叫DEMO,这不是DEMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改为text_cls_infer