Android studio开发——调用Tensorflow模型

人工智能166

需求

需要在Android studio上开发一个apk,使得可以调用Tensorflow生成的模型进行计算。

配置过程

1.添加依赖包

app/src文件夹下的 build.gradle中添加tensorflow的应用包(需联网下载)。

implementation 'org.tensorflow:tensorflow-android:1.13.1'

Android studio开发——调用Tensorflow模型

2.添加模型文件

把生成好的模型文件放入 assets文件夹中,以便后续的模型调用。
Android studio开发——调用Tensorflow模型

; 3.编写代码

3.1 PredictionTF代码编写

首先,编写模型初始化、调用的代码如下所示。其中 getPredict()用于调用模型进行计算,本例中简化为直接输入数组 inputdata [],实际使用的时候需要接受真实的数据。
注意1:一定要保证输入、输出变量的名称、维度大小正确(如果不正确会有报错提醒,及时调整)
注意2:pd文件在生成时的TensorFlow版本需要和训练时相一致

package com.example.test;

import android.content.res.AssetManager;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class PredictionTF {
    private static final String TAG = "PredictionTF";

    private static final int IN_COL = 6;
    private static final int IN_ROW = 8;
    private static final int OUT_COL = 1;
    private static final int OUT_ROW = 1;

    private static final String inputName = "actor/InputData/X";

    private static final String outputName = "actor/FullyConnected_1/Softmax";

    TensorFlowInferenceInterface inferenceInterface;
    static {

        System.loadLibrary("tensorflow_inference");
        Log.e(TAG,"libtensorflow_inference.so库加载成功");
    }

    PredictionTF(AssetManager assetManager, String modePath) {

        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e(TAG,"TensoFlow模型文件加载成功");
    }

    public float[] getPredict() {
        float[] inputdata = {0.361f ,-0.422f ,-0.992f ,-0.196f ,-0.564f  ,0.947f ,-0.339f  ,0.167f,
                0.434f  ,0.287f ,-0.704f  ,0.065f ,-0.083f  ,0.747f  ,0.874f ,-0.796f,
                -0.822f ,-0.366f  ,0.21f  ,-0.493f  ,0.97f  ,-0.779f  ,0.947f  ,0.118f,
                0.798f  ,0.911f  ,0.42f  ,-0.219f ,-0.572f  ,0.033f ,-0.515f ,-0.846f,
                -0.994f  ,0.254f  ,0.775f  ,0.782f  ,0.046f ,-0.403f  ,0.056f  ,0.731f,
                -0.714f  ,0.982f  ,0.117f ,-0.912f  ,0.467f ,-0.015f ,-0.998f  ,0.703f};

        inferenceInterface.feed(inputName, inputdata,1, IN_COL, IN_ROW);

        float[] outputs = new float[6];

        inferenceInterface.run(new String[] { outputName }, false);

        inferenceInterface.fetch(outputName, outputs);

        return outputs;
    }
}

3.2 MainActivity代码编写

接着,编写主体的 MainActivity代码,由于网络最后输出的是一个 1*6的概率数组,所以加入了 choose以输出最终的结果。

package com.example.test;
import androidx.appcompat.app.AppCompatActivity;
import android.content.res.AssetManager;
import android.os.Bundle;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.util.Random;

public class MainActivity extends AppCompatActivity {
    PredictionTF preTF;
    private static final String TAG = "MainActivity";
    private static final String MODEL_FILE = "file:///android_asset/frozen_model(1).pb";

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        preTF = new PredictionTF(getAssets(),MODEL_FILE);
        float[] result = preTF.getPredict();
        int true_result = choose(result);
        Log.i(TAG, "输出的结果为:");
        Log.i(TAG, String.valueOf(true_result));
    }

    public int choose(float result[]){
        int rand_times = 100;
        int A_DIM = 6;
        float[] action_cumsum = {0,0,0,0,0,0};
        action_cumsum[0] = result[0];
        for (int i = 1; i < A_DIM; i++) {
            action_cumsum[i] = action_cumsum[i-1] + result[i];
        }
        int[] hitcount = {0,0,0,0,0,0};
        int action = 0;
        int max = 0;
        for (int i = 0; i < rand_times; i++) {
            Random r = new Random();
            float randomnum = r.nextFloat();
            for (int j = 0; j < action_cumsum.length; j++) {
                if (action_cumsum[j] >= randomnum) {
                    hitcount[j] = hitcount[j] + 1;
                    break;
                }
            }
        }
        for(int i = 0; i < hitcount.length; i++){
            if(max < hitcount[i]){
                max = hitcount[i];
                action = i;
            }
        }
        return action;
    }
}

4.输出结果

最后的结果如下所示,成功加载so库、成功加载模型文件、成功输出实验结果。
Android studio开发——调用Tensorflow模型

Original: https://blog.csdn.net/qq_42775328/article/details/122401953
Author: 陈成不姓丞
Title: Android studio开发——调用Tensorflow模型