我正在用TensorFlow的TextLineDataset读一个大文本文件.我想对数据集进行标记化,创建一个滑动窗口,并将标记化文本分为两部分——输入和标签.如果文本文件包含以下文本:

Lorem ipsum dolor sit amet...

然后我想创建一个指定长度的序列,用0预先填充.我想迭代文本,并使用除最后一个之外的所有内容作为输入,最后一个作为标签.因此,我的目标是首先将文本标记为如下内容:

Lorem: 1,
ipsum: 2,
dolor: 3,
sit: 4,
amet: 5,
...

然后创建一个序列,比如说5个长度,这样训练一个模型:

X_train = [[0, 0, 0, 0, 1], [0, 0, 0, 1, 2], [0, 0, 1, 2, 3], ...]
y_train = [2, 3, 4, ...] # next word of the sequence in X_train

我使用TextVectorization来标记,但无法找到为大型数据集创建输入和标签的有效方法.

vectorize_layer = tf.keras.layers.TextVectorization(output_mode='int',
                                                    max_tokens=MAX_WORDS,
                                                    output_sequence_length=MAX_SEQUENCE_LENGTH)
vectorize_layer.adapt(train_data)
train_data = train_data.map(vectorize_layer)

在数据集上使用for循环会使设备在试图分配大量内存时耗尽内存.最好的方法是什么?

推荐答案

您可以使用tensorflow-text中的滑动窗口function;然而,TextVectorization层似乎只适用于后期填充:

import tensorflow as tf
import tensorflow_text as tft

with open('data.txt', 'w') as f:
  f.write('Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam efficitur viverra lacus?\n')

train_data = tf.data.TextLineDataset(['/content/data.txt'])

vectorize_layer = tf.keras.layers.TextVectorization(output_mode='int', max_tokens=50, pad_to_max_tokens=True)
vectorize_layer.adapt(train_data)

window_size = 5

def sliding_window(x):
  encoded = vectorize_layer(x)
  x = tft.sliding_window(encoded, width=window_size, axis=0)
  y = tft.sliding_window(encoded, width=window_size + 1, axis=0)[:, -1]
  return x[:tf.shape(y)[0],:], y

train_data = train_data.map(sliding_window)


vocab = tf.constant(vectorize_layer.get_vocabulary())
keys = tf.cast(tf.range(vocab.shape[0]), tf.int64)
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys, vocab),
    default_value="")

train_data = tf.data.Dataset.zip((train_data.map(lambda x, y: x).flat_map(tf.data.Dataset.from_tensor_slices),
                                 train_data.map(lambda x, y: y).flat_map(tf.data.Dataset.from_tensor_slices)))

for x, y in train_data:
  print('x -->', x, 'y -->', y)
  print('x -->', table.lookup(x), 'y -->', table.lookup(y), '\n')

x --> tf.Tensor([ 4  6  9  3 11], shape=(5,), dtype=int64) y --> tf.Tensor(10, shape=(), dtype=int64)
x --> tf.Tensor([b'lorem' b'ipsum' b'dolor' b'sit' b'amet'], shape=(5,), dtype=string) y --> tf.Tensor(b'consectetur', shape=(), dtype=string) 

x --> tf.Tensor([ 6  9  3 11 10], shape=(5,), dtype=int64) y --> tf.Tensor(13, shape=(), dtype=int64)
x --> tf.Tensor([b'ipsum' b'dolor' b'sit' b'amet' b'consectetur'], shape=(5,), dtype=string) y --> tf.Tensor(b'adipiscing', shape=(), dtype=string) 

x --> tf.Tensor([ 9  3 11 10 13], shape=(5,), dtype=int64) y --> tf.Tensor(7, shape=(), dtype=int64)
x --> tf.Tensor([b'dolor' b'sit' b'amet' b'consectetur' b'adipiscing'], shape=(5,), dtype=string) y --> tf.Tensor(b'elit', shape=(), dtype=string) 

x --> tf.Tensor([ 3 11 10 13  7], shape=(5,), dtype=int64) y --> tf.Tensor(12, shape=(), dtype=int64)
x --> tf.Tensor([b'sit' b'amet' b'consectetur' b'adipiscing' b'elit'], shape=(5,), dtype=string) y --> tf.Tensor(b'aliquam', shape=(), dtype=string) 

x --> tf.Tensor([11 10 13  7 12], shape=(5,), dtype=int64) y --> tf.Tensor(8, shape=(), dtype=int64)
x --> tf.Tensor([b'amet' b'consectetur' b'adipiscing' b'elit' b'aliquam'], shape=(5,), dtype=string) y --> tf.Tensor(b'efficitur', shape=(), dtype=string) 

x --> tf.Tensor([10 13  7 12  8], shape=(5,), dtype=int64) y --> tf.Tensor(2, shape=(), dtype=int64)
x --> tf.Tensor([b'consectetur' b'adipiscing' b'elit' b'aliquam' b'efficitur'], shape=(5,), dtype=string) y --> tf.Tensor(b'viverra', shape=(), dtype=string) 

x --> tf.Tensor([13  7 12  8  2], shape=(5,), dtype=int64) y --> tf.Tensor(5, shape=(), dtype=int64)
x --> tf.Tensor([b'adipiscing' b'elit' b'aliquam' b'efficitur' b'viverra'], shape=(5,), dtype=string) y --> tf.Tensor(b'lacus', shape=(), dtype=string) 

注意,没有相应标签的序列将与第x[:tf.shape(y)[0],:]行一起丢弃.此外,查找表仅用于演示目的,不需要实现您想要的功能.如果你想应用预填充,你可以看tft.pad_along_dimension.

Python相关问答推荐

如何避免Chained when/then分配中的Mypy不兼容类型警告?

如何在虚拟Python环境中运行Python程序?

如何从在虚拟Python环境中运行的脚本中运行需要宿主Python环境的Shell脚本?

Python虚拟环境的轻量级使用

如何将多进程池声明为变量并将其导入到另一个Python文件

Python—从np.array中 Select 复杂的列子集

海上重叠直方图

Pandas—在数据透视表中占总数的百分比

将JSON对象转换为Dataframe

不允许访问非IPM文件夹

如何在turtle中不使用write()来绘制填充字母(例如OEG)

无法连接到Keycloat服务器

如何指定列数据类型

判断solve_ivp中的事件

为什么if2/if3会提供两种不同的输出?

Pandas:计算中间时间条目的总时间增量

如何使用OpenGL使球体遵循Python中的八样路径?

如何检测鼠标/键盘的空闲时间,而不是其他输入设备?

什么是一种快速而优雅的方式来转换一个包含一串重复的列,而不对同一个值多次运行转换,

Pandas:使列中的列表大小与另一列中的列表大小相同