tflite_runtime,但更简单。
项目描述
tflit 🔥
为什么是 tflite_runtime
?
interpreter.invoke()
? interpreter.set_tensor(input_details[0]['index'], X)
?
必须从这里选择特定平台的URL?
不,当然不是 🔥。
这个功能的作用
- 检测您的平台和Python版本,因此您无需选择正确的URL,您可以添加
tflite_runtime
作为依赖项,而无需选择单个平台进行支持。 - 为模型创建一个类似
keras
的界面,因此您可以使用tflit.Model(path).predict(X)
,而无需考虑张量索引或三步预测,或批处理。
有用的链接
安装
pip install tflit
用法
我试图提供一个尽可能接近Keras的接口。
import tflit
model = tflit.Model('path/to/model.tflite')
model.summary() # prints input and output details
print(model.input_shape) # (10, 30) - a single input
print(model.output_shape) # [(5, 2), (1, 2)] - two outputs
print(model.dtype) # 'float32'
# *see notes below
print(model.input_names) # may not preserve names (based on how you export)
print(model.output_names) # doesn't preserve names atm
# predict over batches of outputs.
y_pred = model.predict(np.random.randn(32, 10, 30))
# predict single output at a time
y_pred = model.predict_batch(np.random.randn(1, 10, 30))
不在Keras中的额外tflite功能
# remember, you can access the tflite_runtime interpreter directly
# so if something is being weird, please submit an issue, but also
# there's not that much code in here so just look here to figure out
# the right way:
# https://tensorflowcn.cn/lite/api_docs/python/tf/lite/Interpreter
interpreter = model.interpreter
# change the model's batch size
model.set_batch_size(64)
# reset the model variables
model.reset()
# get tensor by index
model.input(1) # 2nd input
model.output(0) # 1st output
# get tensor value copy by index
model.input_value(1) # 2nd input
model.output_value(0) # 1st output
黑暗时代
仅作参考,这是我过去是如何做的
def load_tflite_model_function(model_path, **kw):
import tflite_runtime.interpreter as tflite
compute = prepare_model_function(tflite.Interpreter(model_path), **kw)
compute.model_path = model_path
return compute
def prepare_model_function(model, verbose=False):
# assumes a single input and output
in_dets = model.get_input_details()[0]
out_dets = model.get_output_details()[0]
model.allocate_tensors()
def compute(x):
# set inputs
model.set_tensor(in_dets['index'], X.astype(in_dets['dtype']))
# compute outputs
model.invoke()
# get outputs
return model.get_tensor(out_dets['index'])
if verbose:
print('-- Input details --')
print(in_dets, '\n')
print('-- Output details --')
print(out_dets, '\n')
# set input and output shapes so they're easily accessible
compute.input_shape = in_dets['shape'][1:]
compute.output_shape = out_dets['shape'][1:]
return compute
这比我从代码中提取出来的代码更简洁,但它仍然过于复杂,我在将代码复制到第三个项目后感到厌倦。这也不处理像多个输入/输出或批处理之类的事情。
注意
-
更新日期 7/8/21:Tensorflow 在安装 tflite_runtime 方面有所改进 - 然而,我不明白为什么他们不直接发布到 PyPI,我确信这只是一个或两个谷歌的开发者的问题,但如果已经在推送谷歌 coral,为什么不能也安装 twine 呢?所以现在你可以使用以下命令安装:
pip install --index-url https://google-coral.github.io/py-repo/ tflite_runtime
,但是,你必须明确指定索引 URL(哦!) -
我在将 tflite_runtime 作为
setup.py
的依赖项进行安装时遇到了麻烦,所以现在它只有在第一次运行时未安装的情况下才会安装。我可能会在某个时候修复它...但我还有其他事情要做,现在这个方法可行。希望 tensorflow 也能开始部署到 pypi,这样所有问题都会得到解决。不知道那是什么问题... -
可能
tflite_runtime
没有为您的系统提供构建版本。请查看这个链接以验证。 -
当前的
tflite
转换器存在一个bug,它不会复制输入和输出名称。然而,在导出时这样做,输入名称将会被保存
converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.experimental_new_converter = True # <<< this tflite_model = converter.convert()
但是输出名称仍然没有成功 :/。为了清楚起见,这是一个 tensorflow 问题,我对此无能为力。
-
我打算有一个model.set_batch_size
方法来在运行时更改批量大小,但它目前不起作用,因为 tflite 对于增加的张量大小感到恐慌(它不知道如何广播)。这也是一个 tensorflow 问题。目前,我们一次只计算一个批次,并在最后将它们连接起来。如果模型的固定批量大小不能均匀分割,它将引发错误。默认情况下,tflite 将None
批量大小转换为1
,所以大多数情况下不会有问题。要计算单个帧,直接使用model.predict_batch(X)
更有效。
我非常希望解决这个问题,但它们不在我的控制范围内,而且我实际上没有足够的时间和紧迫性去解决这些问题。
项目详情
tflit-0.1.2.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 76a945d864558a5322e5aa892b1c091d9f0d038876f41904dff29532577ee610 |
|
MD5 | 85b327b2ed2bf3249f8ca1ff454b9db4 |
|
BLAKE2b-256 | 34e2cda67aa652074f760c9aa3e17929224a9b3c42187256c8cf366908ce3c34 |