我有以下code条:

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids

sequence_ids = model.generate(input_ids)
sequences = tokenizer.batch_decode(sequence_ids)
sequences

目前,它生产的产品是:

['<pad><extra_id_0> park offers<extra_id_1> the<extra_id_2> park.</s>']

有没有办法防止生成器产生某些单词(例如stopwords = ["park", "offer"])?

推荐答案

在查看文档后发现,您可以在generate()中传递一个bad_words_ids参数

给定一个坏单词列表,您可以使用以下命令创建id列表

tokenizer(bad_words, add_special_tokens=False).input_ids
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
bad_words = ["park", "offers"]
bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids 
#[[2447], [704]]

sequence_ids = model.generate(input_ids, bad_words_ids=bad_words_ids)
#tensor([[    0, 32099,  1061,    19,     3,     9,   710,  1482,   550,    45, 32098,     8, 32097,  1061,     5,     1]])

sequences = tokenizer.batch_decode(sequence_ids)
print(sequences) 
#['<pad><extra_id_0> Park is a short walk away from<extra_id_1> the<extra_id_2> Park.</s>']

Colab demo

Python相关问答推荐

Pandas—合并数据帧,在公共列上保留非空值,在另一列上保留平均值

所有列的滚动标准差,忽略NaN

当递归函数的返回值未绑定到变量时,非局部变量不更新:

计算分布的标准差

调用decorator返回原始函数的输出

与命令行相比,相同的Python代码在Companyter Notebook中运行速度慢20倍

如何在TensorFlow中分类多个类

Geopandas未返回正确的缓冲区(单位:米)

pandas:对多级列框架的列进行排序/重新排序

numpy.unique如何消除重复列?

ModuleNotFoundError:没有模块名为x时try 运行我的代码''

判断Python操作:如何从字面上得到所有decorator ?

我什么时候应该使用帆布和标签?

删除Dataframe中的第一个空白行并重新索引列

无法在盐流道中获得柱子

将相应的值从第2列合并到第1列(Pandas )

如何在不不断遇到ChromeDriver版本错误的情况下使用Selify?

如何将ManyToManyfield用于Self类

401使用有效的OAuth令牌向Google Apps脚本Web App发出POST请求时出现未经授权的错误(";

Python-迭代PANAS中的数据框并替换列表中不包含字符串的值