我想保存一个Tensorflow模型,然后在以后的部署中使用它.我不想使用model.save()来保存它,因为我的目的是以某种方式"pickle"它,并在未安装tensorflow的其他系统中使用它,例如:

model = pickle.load(open(path, 'rb'))
model.predict(prediction_array)

在sklearn的早期,当我酸洗KNN模型时,它是成功的,我能够在不安装sklearn的情况下运行推理.

但当我try pickle我的Tensorflow模型时,我得到了一个错误:

Traceback (most recent call last):
  File "e:/VA_nlu_addition_branch_lite/nlu_stable2/train.py", line 21, in <module>
pickle.dump(model, open('saved/model.p', 'wb'))
TypeError: can't pickle _thread.RLock objects

我的模型是这样的:

model = keras.Sequential([
            keras.Input(shape=(len(x[0]))),
            keras.layers.Dense(units=16, activation='elu'),
            keras.layers.Dense(units=8, activation='elu'),
            keras.layers.Dense(units=len(y[0]), activation='softmax'),
        ])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x, y, epochs=200, batch_size=8)
pickle.dump(model, open('saved/model.p', 'wb'))

Model summary

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense (Dense)                (None, 16)                1680
_________________________________________________________________
dense_1 (Dense)              (None, 8)                 136
_________________________________________________________________
dense_2 (Dense)              (None, 20)                180
=================================================================
Total params: 1,996
Trainable params: 1,996
Non-trainable params: 0

关于这个问题,这里有一个StackOverflow question,但答案中的链接已过期.

这里还有another similar question个,但我不太明白.

我有一个非常简单的模型,没有判断点,没有太复杂的东西,所以有没有办法将Tensorflow模型对象保存到一个二进制文件中?或者,即使它有多个二进制文件,我不介意,但它不需要使用tensoflow,如果numpy solution有帮助,我会使用它,但我不知道如何在这里实现它.任何帮助都将不胜感激,谢谢!

推荐答案

使用joblib似乎适用于TF 2.8,因为你有一个非常简单的模型,你可以在Google Colab上训练它,然后在你的另一个系统上使用pickle文件:

import joblib
import tensorflow as tf

model = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(5,)),
            tf.keras.layers.Dense(units=16, activation='elu'),
            tf.keras.layers.Dense(units=8, activation='elu'),
            tf.keras.layers.Dense(units=5, activation='softmax'),
        ])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
x = tf.random.normal((20, 5))
y = tf.keras.utils.to_categorical(tf.random.uniform((20, 1), maxval=5, dtype=tf.int32))
model.fit(x, y, epochs=200, batch_size=8)
joblib.dump(model, 'model.pkl')

不带tf的负荷模型:

import joblib
import numpy as np
print(joblib.__version__)

model = joblib.load("/content/model.pkl")
print(model(np.random.random((1,5))))
1.1.0
tf.Tensor([[0.38729233 0.04049021 0.06067584 0.07901421 0.43252742]], shape=(1, 5), dtype=float32)

但在不了解系统规格的情况下,很难判断这是否真的是"直接的".

Python相关问答推荐

在函数内部使用eval(),将函数的输入作为字符串的一部分

如何在BeautifulSoup中链接Find()方法并处理无?

替换字符串中的多个重叠子字符串

Matlab中是否有Python的f-字符串等效物

处理带有间隙(空)的duckDB上的重复副本并有效填充它们

使用groupby Pandas的一些操作

修复mypy错误-赋值中的类型不兼容(表达式具有类型xxx,变量具有类型yyy)

Python键入协议默认值

我想一列Panadas的Rashrame,这是一个URL,我保存为CSV,可以直接点击

迭代嵌套字典的值

使用NeuralProphet绘制置信区间时出错

给定高度约束的旋转角解析求解

多处理队列在与Forking http.server一起使用时随机跳过项目

从嵌套极轴列的列表中删除元素

Regex用于匹配Python中逗号分隔的AWS区域

文本溢出了Kivy的视区

Matplotlib中的曲线箭头样式

我可以同时更改多个图像吗?

如何在不不断遇到ChromeDriver版本错误的情况下使用Selify?

来自任务调度程序的作为系统的Python文件