Android——Tensorflow-Lite简单使用

个人博客:haichenyi.com。感谢关注

  项目里面用到了tflite,用于做简单的图片处理,不是判断图片是什么类型,就是传进去图片,生成新图片,类似于前面一篇讲的GPUImage的滤镜功能,但是比滤镜功能更加强大。

  我这里要做的就是集成,拿人家训练好的模型直接来用,我不用去训练模型。

第一步 依赖

//依赖库
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'


android {
    ···
    //set no compress models
    aaptOptions {
        noCompress "tflite"
    }
}

第二步 加载训练模型

  网上很多介绍资料都是把训练模型直接copy到项目main目录下的assets目录(不存在就创建)与java目录平级,自然,这样的加载方式就是

// load infer model
    private void loadModel(String model) {
        try {
            tflite = new Interpreter(loadModelFile(model));
            Log.d(TAG, model + " model load success");
            tflite.setNumThreads(4);
            load_result = true;
        } catch (IOException e) {
            Log.d(TAG, model + " model load fail");
            load_result = false;
            e.printStackTrace();
        }
    }
    
    
    /**
     * Memory-map the model file in Assets.
     */
    private MappedByteBuffer loadModelFile(String model) throws IOException {
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

  一个tflite文件就好几M,甚至十几M,全部copy到项目里面不显示,所以,我们一般项目里面用都是先下载,然后再使用,那,这样的方式,我们要怎么加载训练模型呢?

  我们先分析一下再assets目录下面怎么加载的?说白了就是新建一个Interpreter对象,就是加载模型。上面的方法都过时了,我们可以找到Interpreter类,里面你会看到如下的方法

//第一个参数传tflite文件,第二个参数传一个Interpreter静态内部类对象
public Interpreter(@NonNull File modelFile, Interpreter.Options options) {
        this.wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}
    
//所以,我们自己项目里面加载模型,用如下方式即可
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
tflite = new Interpreter(new File(""), options);

第三步 执行run方法

tflite.run(in, out);

  通过执行这个run方法,获取我们需要的东西,第一个参数,输入对象,第二个参数,输出参数。

重点,敲黑板

重点,敲黑板

重点,敲黑板

  重点就在这里,这里的输入和输出参数要怎么传?我这里训练模型是用Python做的,它需要传入一个四维数组,所以,输出我们自然也要用一个四维数组接收。

  这里的四维数组怎么传递呐?就要说到Android里面的bitmap知识了,它的每个像素点都是一个ARGB数组。即透明度,红色,绿色,蓝色。我们前面的灰色滤镜之类的东西,实际上就是改变RGB三原色的值,让颜色变成灰色,然后改变亮度之类的就是改变每个管道的透明度。网上有很多这样的知识。

  再来说说这个四维数组,我项目里面用到的这个四维数组:1 X 256 X 256 X 3,这几个值怎么理解呢?

1:表示一张图片

256X256:表示图片的宽高

3:表示RGB色值

  那我们怎么把bitmap对象,转换成我们需要的四维数组呐?

//定义了一个一维数组,里面就是我们需要的参数,便于修改
private int[] ddims = {1, 256, 256, 3};

    /**
     * 获取图片的四维数组
     * @param bitmap bitmap对象
     * @param ddims 参数数组
     * @return 图片四维数组
     */
public float[][][][] getScaledMatrix(Bitmap bitmap, int[] ddims) {
        //新建一个1*256*256*3的四维数组
        float[][][][] inFloat = new float[ddims[0]][ddims[1]][ddims[2]][ddims[3]];
        //新建一个一维数组,长度是图片像素点的数量
        int[] pixels = new int[ddims[1] * ddims[2]];
        //把原图缩放成我们需要的图片大小
        Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[1], ddims[2], false);
        //把图片的每个像素点的值放到我们前面新建的一维数组中
        bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[1], ddims[2]);
        int pixel = 0;
        //for循环,把每个像素点的值转换成RBG的值,存放到我们的目标数组中
        for (int i = 0; i < ddims[1]; ++i) {
            for (int j = 0; j < ddims[2]; ++j) {
                final int val = pixels[pixel++];
                float red = ((val >> 16) & 0xFF);
                float green = ((val >> 8) & 0xFF);
                float blue = (val & 0xFF);
                float[] arr = {red, green, blue};
                inFloat[0][i][j] = arr;
            }
        }
        if (bm.isRecycled()) {
            bm.recycle();
        }
        return inFloat;
    }

  上面代码注释写的很清楚了吧?每一行都有注释,for循环的作用也标的很清楚,通过这个方法,我们得到的就是我们想要的四维数组了,这里的四维数组的格式,图片的大小,都是tflite文件建模型的时候设置好的,看你们训练模型的工程师是怎么定义的,你就怎么传。

  然后,新建一个一模一样格式的数组去接收输出值,也是一个四维数组,那么,我们怎么把这个四维数组转换成我们需要的bitmap呢?

//创建bitmap的方法,
Bitmap.createBitmap(@NonNull @ColorInt int[] colors,
            int width, int height, Config config);

  就是这个方法,传一个一维颜色数组,图片的宽高,还有一个图片的格式,那我们这里就是要把这个四维数组转成一个一维的颜色数组了。

    /**
     * 四维数组转成bitmap对象
     * @param outArr 数组
     * @param ddims 格式
     * @return bitmap
     */
    public Bitmap getBitmap(float[][][][] outArr, int[] ddims) {
        //获取图片的三维数组
        float[][][] temp = outArr[0];
        int n = 0;
        //新建一个接收的颜色数组,长度就是图片的宽高之积,类似于上面的像素那个数组
        int[] colorArr = new int[ddims[1] * ddims[2]];
        //for循环遍历把图片的ARGB色值转成一个颜色值,放入颜色数组中
        for (int i = 0; i < ddims[1]; i++) {
            for (int j = 0; j < ddims[2]; j++) {
                float[] arr = temp[i][j];
                int alpha = 255;
                int red = (int) arr[0];
                int green = (int) arr[1];
                int blue = (int) arr[2];
                int tempARGB = (alpha << 24) | (red << 16) | (green << 8) | blue;
                colorArr[n++] = tempARGB;
            }
        }
        //创建bitmap对象
        return Bitmap.createBitmap(colorArr, ddims[1], ddims[2], Bitmap.Config.ARGB_8888);
    }

  至此,我们就拿到了,我们需要的bitmap对象了,然后再做后续的逻辑即可。

项目链接

    原文作者:海晨忆
    原文地址: https://www.jianshu.com/p/f72c87efb198
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞