这个问题是tensorflow 2 TextVectorization process tensor and dataset error人的后续问题

我想用tnesorflow 2.8在Jupyter上为处理过的文本做一个单词嵌入.

def standardize(input_data):

    input_data = tf.strings.lower(input_data)
    input_data = tf.strings.regex_replace(input_data, f"[{re.escape(string.punctuation)}]", " ")
    return input_data

# the input data loaded from text files by TfRecordDataset(file_paths, "GZIP")
# each file can be 200+MB, totally about 300 files
# each file hold the data with multiple columns
# some columns are text
# after loading, the dataset will be accessed by column name 
# e.g. one column is "sports", so the input_dataset["sports"] 
# return a tensor, which is like the following example

input_data = tf.constant([["SWIM 2008-07 Baseball"], ["Football"]], shape=(2, 1), dtype=tf.string)

text_layer = tf.keras.layers.TextVectorization( standardize = standardize, max_tokens = 10, output_mode = 'int', output_sequence_length=10 )

dataset = tf.data.Dataset.from_tensors( input_data )

dataset = dataset.batch(2)

text_layer.adapt(dataset)

process_text = dataset.map(text_layer)

emb_layer = layers.Embedding(10, 10)

emb_layer(process_text) # error 

错误:

 AttributeError: Exception encountered when calling layer "embedding_7" (type Embedding).

'MapDataset' object has no attribute 'dtype'

Call arguments received:

 • inputs=<MapDataset element_spec=TensorSpec(shape=(None, 2, 10), dtype=tf.int64, name=None)>

如何转换tf.数据集二tf.张量?

TensorFlow: convert tf.Dataset to tf.Tensor美元帮不了我.

上述各层将在机器学习神经网络模型中实现.

loading data --> processing features (multiple text columns) --> tokens --> embedding --> average pooling --> some dense layers --> output layer

谢谢

推荐答案

您不能将tf.data.Dataset直接馈送到Embedding层,您可以使用.map(...):

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import re
import string 
def standardize(input_data):

    input_data = tf.strings.lower(input_data)
    input_data = tf.strings.regex_replace(input_data, f"[{re.escape(string.punctuation)}]", " ")
    return input_data

input_data = tf.constant([["SWIM 2008-07 Baseball"], ["Football"]], shape=(2, 1), dtype=tf.string)

text_layer = tf.keras.layers.TextVectorization( standardize = standardize, max_tokens = 10, output_mode = 'int', output_sequence_length=10 )

dataset = tf.data.Dataset.from_tensors( input_data )

dataset = dataset.batch(2).map(lambda x: tf.squeeze(x, axis=0))

text_layer.adapt(dataset)

process_text = dataset.map(text_layer)

emb_layer = layers.Embedding(10, 10)
process_text = process_text.map(emb_layer)

或者定义模型并将数据集输入model.fit(...):

import tensorflow as tf
import re
import string 
def standardize(input_data):

    input_data = tf.strings.lower(input_data)
    input_data = tf.strings.regex_replace(input_data, f"[{re.escape(string.punctuation)}]", " ")
    return input_data

input_data = tf.constant([["SWIM 2008-07 Baseball"], ["Football"]], shape=(2, 1), dtype=tf.string)

text_layer = tf.keras.layers.TextVectorization( standardize = standardize, max_tokens = 10, output_mode = 'int', output_sequence_length=10 )

dataset = tf.data.Dataset.from_tensors( input_data )

dataset = dataset.batch(2)

text_layer.adapt(dataset)

process_text = dataset.map(lambda x: (text_layer(tf.squeeze(x, axis=0)), tf.random.uniform((2, ), maxval=2, dtype=tf.int32))) # add random label to each entry

inputs = tf.keras.layers.Input((10, ))
emb_layer = tf.keras.layers.Embedding(10, 10)
x = emb_layer(inputs)
x = tf.keras.layers.GlobalAveragePooling1D()(x)
outputs = tf.keras.layers.Dense(1, 'sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(process_text)

Python相关问答推荐

Pandas基于另一列的价值的新列

将大小为n*512的数组绘制到另一个大小为n*256的数组的PC组件

Pandas使用过滤器映射多列

在Python中根据id填写年份系列

PyQt5如何将pyuic 5生成的Python类添加到QStackedWidget中?

如何使用bs 4从元素中提取文本

多处理代码在while循环中不工作

如何从具有多个嵌入选项卡的网页中Web抓取td类元素

try 与gemini-pro进行多轮聊天时出错

在Wayland上使用setCellWidget时,try 编辑QTable Widget中的单元格时,PyQt 6崩溃

如何在Python数据框架中加速序列的符号化

如何将一个动态分配的C数组转换为Numpy数组,并在C扩展模块中返回给Python

我对我应该做什么以及我如何做感到困惑'

利用Selenium和Beautiful Soup实现Web抓取JavaScript表

无法在Docker内部运行Python的Matlab SDK模块,但本地没有问题

使用groupby方法移除公共子字符串

什么是合并两个embrame的最佳方法,其中一个有日期范围,另一个有日期没有任何共享列?

如何在Python中使用另一个数据框更改列值(列表)

如何检测鼠标/键盘的空闲时间,而不是其他输入设备?

Python避免mypy在相互引用中从另一个类重定义类时失败