个人博客:haichenyi.com。感谢关注
项目里面用到了tflite,用于做简单的图片处理,不是判断图片是什么类型,就是传进去图片,生成新图片,类似于前面一篇讲的GPUImage的滤镜功能,但是比滤镜功能更加强大。
我这里要做的就是集成,拿人家训练好的模型直接来用,我不用去训练模型。
第一步 依赖
//依赖库
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
android {
···
//set no compress models
aaptOptions {
noCompress "tflite"
}
}
第二步 加载训练模型
网上很多介绍资料都是把训练模型直接copy到项目main目录下的assets目录(不存在就创建)与java目录平级,自然,这样的加载方式就是
// load infer model
private void loadModel(String model) {
try {
tflite = new Interpreter(loadModelFile(model));
Log.d(TAG, model + " model load success");
tflite.setNumThreads(4);
load_result = true;
} catch (IOException e) {
Log.d(TAG, model + " model load fail");
load_result = false;
e.printStackTrace();
}
}
/**
* Memory-map the model file in Assets.
*/
private MappedByteBuffer loadModelFile(String model) throws IOException {
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
一个tflite文件就好几M,甚至十几M,全部copy到项目里面不显示,所以,我们一般项目里面用都是先下载,然后再使用,那,这样的方式,我们要怎么加载训练模型呢?
我们先分析一下再assets目录下面怎么加载的?说白了就是新建一个Interpreter对象,就是加载模型。上面的方法都过时了,我们可以找到Interpreter类,里面你会看到如下的方法
//第一个参数传tflite文件,第二个参数传一个Interpreter静态内部类对象
public Interpreter(@NonNull File modelFile, Interpreter.Options options) {
this.wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}
//所以,我们自己项目里面加载模型,用如下方式即可
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
tflite = new Interpreter(new File(""), options);
第三步 执行run方法
tflite.run(in, out);
通过执行这个run方法,获取我们需要的东西,第一个参数,输入对象,第二个参数,输出参数。
重点,敲黑板
重点,敲黑板
重点,敲黑板
重点就在这里,这里的输入和输出参数要怎么传?我这里训练模型是用Python做的,它需要传入一个四维数组,所以,输出我们自然也要用一个四维数组接收。
这里的四维数组怎么传递呐?就要说到Android里面的bitmap知识了,它的每个像素点都是一个ARGB数组。即透明度,红色,绿色,蓝色。我们前面的灰色滤镜之类的东西,实际上就是改变RGB三原色的值,让颜色变成灰色,然后改变亮度之类的就是改变每个管道的透明度。网上有很多这样的知识。
再来说说这个四维数组,我项目里面用到的这个四维数组:1 X 256 X 256 X 3,这几个值怎么理解呢?
1:表示一张图片
256X256:表示图片的宽高
3:表示RGB色值
那我们怎么把bitmap对象,转换成我们需要的四维数组呐?
//定义了一个一维数组,里面就是我们需要的参数,便于修改
private int[] ddims = {1, 256, 256, 3};
/**
* 获取图片的四维数组
* @param bitmap bitmap对象
* @param ddims 参数数组
* @return 图片四维数组
*/
public float[][][][] getScaledMatrix(Bitmap bitmap, int[] ddims) {
//新建一个1*256*256*3的四维数组
float[][][][] inFloat = new float[ddims[0]][ddims[1]][ddims[2]][ddims[3]];
//新建一个一维数组,长度是图片像素点的数量
int[] pixels = new int[ddims[1] * ddims[2]];
//把原图缩放成我们需要的图片大小
Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[1], ddims[2], false);
//把图片的每个像素点的值放到我们前面新建的一维数组中
bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[1], ddims[2]);
int pixel = 0;
//for循环,把每个像素点的值转换成RBG的值,存放到我们的目标数组中
for (int i = 0; i < ddims[1]; ++i) {
for (int j = 0; j < ddims[2]; ++j) {
final int val = pixels[pixel++];
float red = ((val >> 16) & 0xFF);
float green = ((val >> 8) & 0xFF);
float blue = (val & 0xFF);
float[] arr = {red, green, blue};
inFloat[0][i][j] = arr;
}
}
if (bm.isRecycled()) {
bm.recycle();
}
return inFloat;
}
上面代码注释写的很清楚了吧?每一行都有注释,for循环的作用也标的很清楚,通过这个方法,我们得到的就是我们想要的四维数组了,这里的四维数组的格式,图片的大小,都是tflite文件建模型的时候设置好的,看你们训练模型的工程师是怎么定义的,你就怎么传。
然后,新建一个一模一样格式的数组去接收输出值,也是一个四维数组,那么,我们怎么把这个四维数组转换成我们需要的bitmap呢?
//创建bitmap的方法,
Bitmap.createBitmap(@NonNull @ColorInt int[] colors,
int width, int height, Config config);
就是这个方法,传一个一维颜色数组,图片的宽高,还有一个图片的格式,那我们这里就是要把这个四维数组转成一个一维的颜色数组了。
/**
* 四维数组转成bitmap对象
* @param outArr 数组
* @param ddims 格式
* @return bitmap
*/
public Bitmap getBitmap(float[][][][] outArr, int[] ddims) {
//获取图片的三维数组
float[][][] temp = outArr[0];
int n = 0;
//新建一个接收的颜色数组,长度就是图片的宽高之积,类似于上面的像素那个数组
int[] colorArr = new int[ddims[1] * ddims[2]];
//for循环遍历把图片的ARGB色值转成一个颜色值,放入颜色数组中
for (int i = 0; i < ddims[1]; i++) {
for (int j = 0; j < ddims[2]; j++) {
float[] arr = temp[i][j];
int alpha = 255;
int red = (int) arr[0];
int green = (int) arr[1];
int blue = (int) arr[2];
int tempARGB = (alpha << 24) | (red << 16) | (green << 8) | blue;
colorArr[n++] = tempARGB;
}
}
//创建bitmap对象
return Bitmap.createBitmap(colorArr, ddims[1], ddims[2], Bitmap.Config.ARGB_8888);
}
至此,我们就拿到了,我们需要的bitmap对象了,然后再做后续的逻辑即可。