使用TensorFlow Lite的三个入门问题
Preface
最近开始上手TensorFlow做项目,之前总是看书觉得停留在表面,这次实战果然遇到了书本上遇不到的问题。
我把问题总结下来,事后来看其实都是小事,不过在刚开始时也是破费时间,也许其他人也会遇到,于是就发表出来。
Problem 1: 加载模型文件失败
在Andriod上使用TensorFlow Lite,处置模型文件的最常见做法,就是放在assets
目录下,于是新手就会常常遇到下面的问题:
java.lang.RuntimeException: java.lang.reflect.InvocationTargetException at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:502) at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930) Caused by: java.lang.reflect.InvocationTargetException at java.lang.reflect.Method.invoke(Native Method) at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492) at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930) Caused by: java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed at android.content.res.AssetManager.nativeOpenAssetFd(Native Method) at android.content.res.AssetManager.openFd(AssetManager.java:848) at org.tensorflow.lite.support.common.FileUtil.loadMappedFile(FileUtil.java:74)
这个错误描述非常清晰,它甚至都给出了猜测:“it is probably compressed”。
事实上也确实如此,AAPT对于asset
目录下的资源也是默认压缩的(raw
子目录除外)。
public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath) throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
//.......
}
在上述代码中,openFd
函数需要明确待打开的资源文件,必须得是未压缩:
Open an uncompressed asset by mmapping it and returning an AssetFileDescriptor.
要避免资源被压缩,解决方法也有现成的,那就是在gradle编译文件里,增加aaptOption
的noCompress
属性:
Extensions of files that will not be stored compressed in the APK. Adding an empty extension, i.e., setting noCompress ” will trivially disable compression for all files.
简单的修改如下:
android {
aaptOptions {
noCompress "tflite"
}
}
本节参考:
- User Guide – Command Line Tool – AAPT2
- AaptOptions – noCompress
- java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed
Problem 2: 缓冲区错误
在处理图片时,加载完模型,然后读取图片进行,此时容易把图片尺寸搞错,从而引发缓冲区错误,如下:
java.lang.IllegalArgumentException: Cannot convert between a TensorFlowLite buffer with 602112 bytes and a ByteBuffer with 4915200 bytes. at org.tensorflow.lite.Tensor.throwIfShapeIsIncompatible(Tensor.java:272) at org.tensorflow.lite.Tensor.throwIfDataIsIncompatible(Tensor.java:249) at org.tensorflow.lite.Tensor.setTo(Tensor.java:110) at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:145) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:275) at org.tensorflow.lite.Interpreter.run(Interpreter.java:249)
问题的可能原因之一在于,ImaggeProcessor
的操作符ResizeOp
设置了错误的宽度和高度。
val imageProcessor: ImageProcessor = ImageProcessor.Builder()
//......
.add(ResizeOp(imageSizeY, imageSizeX, ResizeMethod.NEAREST_NEIGHBOR))
//......
.build()
对ResizeOp
的理解,可以参考源码或者文档:
/**
* Creates a ResizeOp which can resize images to specified size in specified method.
*
* @param targetHeight: The expected height of resized image.
* @param targetWidth: The expected width of resized image.
* @param resizeMethod: The algorithm to use for resizing. Options: {@link ResizeMethod}
*/
public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod)
此处targetHeight
和targetWidth
,并不是随意设置的,是需要从模型文件中解析出来:
val imageShape = tflite.getInputTensor(imageTensorIndex).shape() // {1, height, width, 3}
imageSizeY = imageShape[1]
imageSizeX = imageShape[2]
Quesion 3: 类型不匹配
在执行Interpreter
的方法run
时,光看文档是非常容易犯错的,比如下面的错误:
java.lang.IllegalArgumentException: DataType error: cannot resolve DataType of org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat at org.tensorflow.lite.Tensor.dataTypeOf(Tensor.java:199) at org.tensorflow.lite.Tensor.throwIfTypeIsIncompatible(Tensor.java:257) at org.tensorflow.lite.Tensor.throwIfDataIsIncompatible(Tensor.java:248) at org.tensorflow.lite.Tensor.copyTo(Tensor.java:141) at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:161) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:275) at org.tensorflow.lite.Interpreter.run(Interpreter.java:249)
这个错误源于run
方法的实参类型不对:
var outputBuffer: TensorBuffer
//......
tflite.run(inputImage.buffer, outputBuffer)
虽然形参类型是Object
,但是实际上却是有着严格要求的,如下:
/**
* Runs model inference if the model takes only one input, and provides only one output.
*
* <p>Warning: The API is more efficient if a {@link Buffer} (preferably direct, but not required)
* is used as the input/output data type. Please consider using {@link Buffer} to feed and fetch
* primitive data for better performance. The following concrete {@link Buffer} types are
* supported:
*
* <ul>
* <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
* <li>{@link FloatBuffer} - compatible with float Tensors.
* <li>{@link IntBuffer} - compatible with int32 Tensors.
* <li>{@link LongBuffer} - compatible with int64 Tensors.
* </ul>
*
* ......
* @param output a multidimensional array of output data, or a {@link Buffer} of primitive types
* including int, float, long, and byte. When a {@link Buffer} is used, the caller must ensure
* that it is set the appropriate write position. A null value is allowed only if the caller
* is using a {@link Delegate} that allows buffer handle interop, and such a buffer has been
* bound to the output {@link Tensor}. See {@link Options#setAllowBufferHandleOutput()}.
* ......
*/
public void run(Object input, Object output) {
以我的理解,我觉得run
方法最好是要重写,使用不同的参数类型来定义不同的方法签名,只有这样才算是对开发者友好。
Summary
上述三个问题,是我首次使用TF Lite时遇到的,把它写出来或许对其他人有帮助。 后期我会陆续把深入使用遇到的典型问题,继续分享出来。
非常感謝您的分享!不然真的還不知道問題在哪!
找了很久问题的原因,一头雾水,直到看到您的文章,豁然开朗!感谢!