我正在使用两个不同的数据集进行文本分类,目的是使用一个数据集进行训练,另一个用于测试.请注意,我不希望为了防止泄漏而合并数据集(我认为这就是它的名称).与训练数据集(16k行)相比,测试数据集(~1000行)要小得多

我使用的是CountVectorizer,由于两个数据集具有不同的词汇表,因此会产生不同的列数--这会在预测步骤中导致错误.

ValueError: X has 55229 features, but DecisionTreeClassifier is expecting 387964 
features as input.

我在GPT和谷歌上搜索已经有一段时间了,我得到的指导褒贬不一.

  1. 将填充为零的列添加到较小的x_test
  2. 使用SCRICKIT-学习管道

下面是代码片段:

# read dfs
df_1 = pd.read_csv("data1.csv",header=0) # for training, has text, and class columns
df_2 = pd.read_csv("data2.csv",header=0) # for testing,  has text, and class columns

# vectorise
CV1 = CountVectorizer(ngram_range=(1,3), stop_words="english").fit(df_1['text']) 
x_train = CV1.transform(df_1['text'])
y_train = df_1['class']

CV2 = CountVectorizer(ngram_range=(1,3), stop_words="english").fit(df_2['text']) 
x_test = CV2.transform(df_2['text'])
y_test = df_test['class']

## shapes of objects
## x_test (1589, 55229), y_test(1589,)
## x_train (16716, 387964), y_train(16716,)

# build classifier and predict
classifier = DecisionTreeClassifier(random_state=1234)
model = classifier.fit(x_train,y_train)
y_pred = model.predict(x_test)

# error ValueError: X has 55229 features, but DecisionTreeClassifier is expecting 387964 features as input.

推荐答案

every的预处理步骤一样,不适合测试集.您应该有一个CountVectorizer的实例,用于fit_transform个训练集和transform个测试集.

在您的 case 中:

CV = CountVectorizer(ngram_range=(1,3), stop_words="english")
x_train = CV.fit_transform(df_1['text'])
y_train = df_1['class']

x_test = CV.transform(df_2['text'])
y_test = df_test['class']

Python相关问答推荐

根据网格和相机参数渲染深度

计算所有前面行(当前行)中列的值

Pydantic 2.7.0模型接受字符串日期时间或无

如何在solve()之后获得症状上的等式的值

如何获得每个组的时间戳差异?

如何调整QscrollArea以正确显示内部正在变化的Qgridlayout?

幂集,其中每个元素可以是正或负""""

在pandas/python中计数嵌套类别

ConversationalRetrivalChain引发键错误

如何获取Python synsets列表的第一个内容?

如何在Python Pandas中填充外部连接后的列中填充DDL值

在极点中读取、扫描和接收有什么不同?

什么是一种快速而优雅的方式来转换一个包含一串重复的列,而不对同一个值多次运行转换,

提取数组每行的非零元素

如何在验证文本列表时使正则表达式无序?

Autocad使用pyautocad/comtypes将对象从一个图形复制到另一个图形

如何在Airflow执行日期中保留日期并将时间转换为00:00

在pandas中,如何在由两列加上一个值列组成的枢轴期间或之后可靠地设置多级列的索引顺序,

启动线程时,Python键盘模块冻结/不工作

在Pandas 中以十六进制显示/打印列?