Skip to content

Commit

Permalink
Add fetch API for java, refine android log (#1558)
Browse files Browse the repository at this point in the history
* Clear no persistable tensor array before predicting, fix crash when predicting with gpu debugging mode

* Fix code style

* Add fetch API for java, refine android log
  • Loading branch information
hjchen2 authored Apr 15, 2019
1 parent 563f0cc commit 8505700
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 25 deletions.
12 changes: 8 additions & 4 deletions src/common/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,20 @@ static const char *ANDROID_LOG_TAG =

#define ANDROIDLOGI(...) \
__android_log_print(ANDROID_LOG_INFO, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__);
fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#define ANDROIDLOGW(...) \
__android_log_print(ANDROID_LOG_WARNING, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__);
fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#define ANDROIDLOGD(...) \
__android_log_print(ANDROID_LOG_DEBUG, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__)
fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#define ANDROIDLOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, ANDROID_LOG_TAG, __VA_ARGS__); \
printf("%s\n", __VA_ARGS__)
fprintf(stderr, "%s\n", __VA_ARGS__); \
fflush(stderr)
#else
#define ANDROIDLOGI(...)
#define ANDROIDLOGW(...)
Expand Down
2 changes: 2 additions & 0 deletions src/io/jni/PML.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public class PML {
*/
public static native float[] predictImage(float[] buf, int[] ddims);

public static native float[] fetch(String varName);

public static native float[] predictYuv(byte[] buf, int imgWidth, int imgHeight, int[] ddims, float[] meanValues);

// predict with variable length input
Expand Down
60 changes: 40 additions & 20 deletions src/io/jni/paddle_mobile_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License. */

#ifdef ANDROID

#include "paddle_mobile_jni.h"
#include "io/jni/paddle_mobile_jni.h"
#include <cmath>
#include <string>
#include <vector>
Expand Down Expand Up @@ -193,11 +193,9 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
env->DeleteLocalRef(ddims);
env->ReleaseFloatArrayElements(buf, dataPointer, 0);
env->DeleteLocalRef(buf);

} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}

#else
jsize ddim_size = env->GetArrayLength(ddims);
if (ddim_size != 4) {
Expand Down Expand Up @@ -231,18 +229,43 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
#endif

ANDROIDLOGI("predictImage finished");
return result;
}

JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_fetch(JNIEnv *env,
jclass thiz,
jstring varName) {
jfloatArray result = NULL;

#ifdef ENABLE_EXCEPTION
try {
auto output =
getPaddleMobileInstance()->Fetch(jstring2cppstring(env, varName));
int count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
auto output =
getPaddleMobileInstance()->Fetch(jstring2cppstring(env, varName));
int count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
#endif

return result;
}

inline int yuv_to_rgb(int y, int u, int v, float *r, float *g, float *b) {
int r1 = (int)(y + 1.370705 * (v - 128));
int g1 = (int)(y - 0.698001 * (u - 128) - 0.703125 * (v - 128));
int b1 = (int)(y + 1.732446 * (u - 128));
int r1 = (int)(y + 1.370705 * (v - 128)); // NOLINT
int g1 = (int)(y - 0.698001 * (u - 128) - 0.703125 * (v - 128)); // NOLINT
int b1 = (int)(y + 1.732446 * (u - 128)); // NOLINT

r1 = (int)fminf(255, fmaxf(0, r1));
g1 = (int)fminf(255, fmaxf(0, g1));
b1 = (int)fminf(255, fmaxf(0, b1));
r1 = (int)fminf(255, fmaxf(0, r1)); // NOLINT
g1 = (int)fminf(255, fmaxf(0, g1)); // NOLINT
b1 = (int)fminf(255, fmaxf(0, b1)); // NOLINT
*r = r1;
*g = g1;
*b = b1;
Expand Down Expand Up @@ -299,14 +322,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
float matrix[length];
float matrix[length]; // NOLINT
jbyte *yuv = env->GetByteArrayElements(yuv_, NULL);
float *meansPointer = nullptr;
if (nullptr != meanValues) {
meansPointer = env->GetFloatArrayElements(meanValues, NULL);
}
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3],
ddim[2], meansPointer);
convert_nv21_to_matrix(reinterpret_cast<uint8_t *>(yuv), matrix, imgwidth,
imgHeight, ddim[3], ddim[2], meansPointer);
int count = 0;
framework::Tensor input;
input.Resize(ddim);
Expand Down Expand Up @@ -335,14 +358,14 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
framework::DDim ddim = framework::make_ddim(
{ddim_ptr[0], ddim_ptr[1], ddim_ptr[2], ddim_ptr[3]});
int length = framework::product(ddim);
float matrix[length];
float matrix[length]; // NOLINT
jbyte *yuv = env->GetByteArrayElements(yuv_, NULL);
float *meansPointer = nullptr;
if (nullptr != meanValues) {
meansPointer = env->GetFloatArrayElements(meanValues, NULL);
}
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, imgHeight, ddim[3],
ddim[2], meansPointer);
convert_nv21_to_matrix((uint8_t *)yuv, matrix, imgwidth, // NOLINT
imgHeight, ddim[3], ddim[2], meansPointer);
int count = 0;
framework::Tensor input;
input.Resize(ddim);
Expand Down Expand Up @@ -408,13 +431,12 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env,
ANDROIDLOGI("setThreadCount %d", threadCount);
#ifdef ENABLE_EXCEPTION
try {
getPaddleMobileInstance()->SetThreadNum((int)threadCount);
getPaddleMobileInstance()->SetThreadNum(static_cast<int>(threadCount));
} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
getPaddleMobileInstance()->SetThreadNum((int)threadCount);

getPaddleMobileInstance()->SetThreadNum(static_cast<int>(threadCount));
#endif
}

Expand All @@ -425,13 +447,11 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_clear(JNIEnv *env,
#ifdef ENABLE_EXCEPTION
try {
getPaddleMobileInstance()->Clear();

} catch (paddle_mobile::PaddleMobileException &e) {
ANDROIDLOGE("jni got an PaddleMobileException! ", e.what());
}
#else
getPaddleMobileInstance()->Clear();

#endif
}

Expand Down
4 changes: 4 additions & 0 deletions src/io/jni/paddle_mobile_jni.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_loadCombinedQualified(
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
JNIEnv *env, jclass thiz, jfloatArray buf, jintArray ddims);

JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_fetch(JNIEnv *env,
jclass thiz,
jstring varName);

/**
* object detection for anroid
*/
Expand Down
1 change: 1 addition & 0 deletions src/operators/kernel/arm/convolution/conv_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {

bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];

if (param->Filter()->type() == type_id<int8_t>().hash_code()) {
#ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 &&
Expand Down
2 changes: 1 addition & 1 deletion tools/android-cmake/android.toolchain.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ endif()

# Generic flags.
list(APPEND ANDROID_COMPILER_FLAGS
-g
# -g
-DANDROID
-ffunction-sections
-funwind-tables
Expand Down

0 comments on commit 8505700

Please sign in to comment.