在将模型从一些本地虚拟数据切换到使用TF数据集时,我遇到了问题.

很抱歉型号代码太长,我已经尽量缩短了.

以下方法很好:

import tensorflow as tf
import tensorflow_recommenders as tfrs
from transformers import AutoTokenizer, TFAutoModel


MODEL_PATH = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = TFAutoModel.from_pretrained(MODEL_PATH, from_pt=True)


class SBert(tf.keras.layers.Layer):
    def __init__(self, tokenizer, model):
        super(SBert, self).__init__()
        
        self.tokenizer = tokenizer
        self.model = model
        
    def tf_encode(self, inputs):
        def encode(inputs):
            inputs = [x[0].decode("utf-8") for x in inputs.numpy()]
            outputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='tf')
            return outputs['input_ids'], outputs['token_type_ids'], outputs['attention_mask']
        return tf.py_function(func=encode, inp=[inputs], Tout=[tf.int32, tf.int32, tf.int32])
    
    def process(self, i, t, a):
        def __call(i, t, a):
            model_output = self.model(
                {'input_ids': i.numpy(), 'token_type_ids': t.numpy(), 'attention_mask': a.numpy()}
            )
            return model_output[0]
        return tf.py_function(func=__call, inp=[i, t, a], Tout=[tf.float32])

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = tf.squeeze(tf.stack(model_output), axis=0)
        input_mask_expanded = tf.cast(
            tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
            tf.float32
        )
        a = tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1)
        b = tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
        embeddings = a / b
        embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
        return embeddings
    
    def call(self, inputs):
        input_ids, token_type_ids, attention_mask = self.tf_encode(inputs)
        model_output = self.process(input_ids, token_type_ids, attention_mask)
        embeddings = self.mean_pooling(model_output, attention_mask)
        return embeddings


sbert = SBert(tokenizer, model)
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
outputs = sbert(inputs)
model = tf.keras.Model(inputs, outputs)
model(tf.constant(['some text', 'more text']))

对模型的调用输出张量-yipee:)

现在我想在一个更大的双塔模型中使用这一层:

class Encoder(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.text_embedding = self._build_text_embedding()
    
    def _build_text_embedding(self):
        sbert = SBert(tokenizer, model)
        inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
        outputs = sbert(inputs)
        return tf.keras.Model(inputs, outputs)
    
    def call(self, inputs):
        return self.text_embedding(inputs)

    
class RecModel(tfrs.models.Model):
    def __init__(self):
        super().__init__()
        
        self.query_model = tf.keras.Sequential([
            Encoder(),
            tf.keras.layers.Dense(32)
        ])
        
        self.candidate_model = tf.keras.Sequential([
            Encoder(),
            tf.keras.layers.Dense(32)
        ])
    
        self.retrieval_task = tfrs.tasks.Retrieval(
            metrics=tfrs.metrics.FactorizedTopK(
                candidates=tf.data.Dataset.from_tensor_slices(
                    data['text']
                ).batch(1).map(self.candidate_model),
            ),
            batch_metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(k=5)
            ]
        )

    def call(self, features):
        query_embeddings = self.query_model(features['query'])
        candidate_embeddings = self.candidate_model(features['text'])
        return (
            query_embeddings,
            candidate_embeddings,
        )   

    def compute_loss(self, features, training=False):
        query_embeddings, candidate_embeddings = self(features)
        retrieval_loss = self.retrieval_task(query_embeddings, candidate_embeddings)
        return retrieval_loss

创建一个小的虚拟数据集:

data = {
    'query': ['blue', 'cat', 'football'],
    'text': ['a nice colour', 'a type of animal', 'a sport']
}

ds = tf.data.Dataset.from_tensor_slices(data).batch(1)

try 编译:

model = RecModel()
model.compile(optimizer=tf.keras.optimizers.Adagrad())

我们遇到了以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-df4cc46e0307> in <module>
----> 1 model = RecModel()
      2 model.compile(optimizer=tf.keras.optimizers.Adagrad())

<ipython-input-8-a774041744b9> in __init__(self)
     33                 candidates=tf.data.Dataset.from_tensor_slices(
     34                     data['text']
---> 35                 ).batch(1).map(self.candidate_model),
     36             ),
     37             batch_metrics=[

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic, name)
   2014         warnings.warn("The `deterministic` argument has no effect unless the "
   2015                       "`num_parallel_calls` argument is specified.")
-> 2016       return MapDataset(self, map_func, preserve_cardinality=True, name=name)
   2017     else:
   2018       return ParallelMapDataset(

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality, use_legacy_function, name)
   5193         self._transformation_name(),
   5194         dataset=input_dataset,
-> 5195         use_legacy_function=use_legacy_function)
   5196     self._metadata = dataset_metadata_pb2.Metadata()
   5197     if name:

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/data/ops/structured_function.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
    269         fn_factory = trace_tf_function(defun_kwargs)
    270 
--> 271     self._function = fn_factory()
    272     # There is no graph to add in eager mode.
    273     add_to_graph &= not context.executing_eagerly()

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
   3069     """
   3070     graph_function = self._get_concrete_function_garbage_collected(
-> 3071         *args, **kwargs)
   3072     graph_function._garbage_collector.release()  # pylint: disable=protected-access
   3073     return graph_function

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   3034       args, kwargs = None, None
   3035     with self._lock:
-> 3036       graph_function, _ = self._maybe_define_function(args, kwargs)
   3037       seen_names = set()
   3038       captured = object_identity.ObjectIdentitySet(

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3290 
   3291           self._function_cache.add_call_context(cache_key.call_context)
-> 3292           graph_function = self._create_graph_function(args, kwargs)
   3293           self._function_cache.add(cache_key, cache_key_deletion_observer,
   3294                                    graph_function)

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3138             arg_names=arg_names,
   3139             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3140             capture_by_value=self._capture_by_value),
   3141         self._function_attributes,
   3142         function_spec=self.function_spec,

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1159         _, original_func = tf_decorator.unwrap(python_func)
   1160 
-> 1161       func_outputs = python_func(*func_args, **func_kwargs)
   1162 
   1163       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/data/ops/structured_function.py in wrapped_fn(*args)
    246           attributes=defun_kwargs)
    247       def wrapped_fn(*args):  # pylint: disable=missing-docstring
--> 248         ret = wrapper_helper(*args)
    249         ret = structure.to_tensor_list(self._output_structure, ret)
    250         return [ops.convert_to_tensor(t) for t in ret]

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/data/ops/structured_function.py in wrapper_helper(*args)
    175       if not _should_unpack(nested_args):
    176         nested_args = (nested_args,)
--> 177       ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
    178       if _should_pack(ret):
    179         ret = tuple(ret)

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    687       try:
    688         with conversion_ctx:
--> 689           return converted_call(f, args, kwargs, options=options)
    690       except Exception as e:  # pylint:disable=broad-except
    691         if hasattr(e, 'ag_error_metadata'):

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    375 
    376   if not options.user_requested and conversion.is_allowlisted(f):
--> 377     return _call_unconverted(f, args, kwargs, options)
    378 
    379   # internal_convert_user_code is for example turned off when issuing a dynamic

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
    456 
    457   if kwargs is not None:
--> 458     return f(*args, **kwargs)
    459   return f(*args)
    460 

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/keras/layers/core/dense.py in build(self, input_shape)
    137     last_dim = tf.compat.dimension_value(input_shape[-1])
    138     if last_dim is None:
--> 139       raise ValueError('The last dimension of the inputs to a Dense layer '
    140                        'should be defined. Found None. '
    141                        f'Full input shape received: {input_shape}')

ValueError: Exception encountered when calling layer "sequential_5" (type Sequential).

The last dimension of the inputs to a Dense layer should be defined. Found None. Full input shape received: <unknown>

Call arguments received:
  • inputs=tf.Tensor(shape=(None,), dtype=string)
  • training=None
  • mask=None

我不太确定应该在哪里设置形状——因为使用正则张量而不是TF数据集可以.

推荐答案

你必须明确地设置从tf.py_functions开始的张量的形状.使用None将允许可变的输入长度.然而,Bert输出尺寸(384,)是必要的:

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel

MODEL_PATH = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = TFAutoModel.from_pretrained(MODEL_PATH, from_pt=True)

class SBert(tf.keras.layers.Layer):
    def __init__(self, tokenizer, model):
        super(SBert, self).__init__()
        
        self.tokenizer = tokenizer
        self.model = model
        
    def tf_encode(self, inputs):
        def encode(inputs):
            inputs = [x[0].decode("utf-8") for x in inputs.numpy()]
            outputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='tf')
            return outputs['input_ids'], outputs['token_type_ids'], outputs['attention_mask']
        return tf.py_function(func=encode, inp=[inputs], Tout=[tf.int32, tf.int32, tf.int32])
    
    def process(self, i, t, a):
      def __call(i, t, a):
        model_output = self.model({'input_ids': i.numpy(), 'token_type_ids': t.numpy(), 'attention_mask': a.numpy()})
        return model_output[0]
      return tf.py_function(func=__call, inp=[i, t, a], Tout=[tf.float32])

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = tf.squeeze(tf.stack(model_output), axis=0)
        input_mask_expanded = tf.cast(
            tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
            tf.float32
        )
        a = tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1)
        b = tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
        embeddings = a / b
        embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
        return embeddings

    def call(self, inputs):
        input_ids, token_type_ids, attention_mask = self.tf_encode(inputs)
        input_ids.set_shape(tf.TensorShape((None, None)))
        token_type_ids.set_shape(tf.TensorShape((None, None)))
        attention_mask.set_shape(tf.TensorShape((None, None)))

        model_output = self.process(input_ids, token_type_ids, attention_mask)
        model_output[0].set_shape(tf.TensorShape((None, None, 384)))
        embeddings = self.mean_pooling(model_output, attention_mask)
        return embeddings

    
sbert = SBert(tokenizer, model)
inputs = tf.keras.layers.Input((1,), dtype=tf.string)
outputs = sbert(inputs)
outputs = tf.keras.layers.Dense(32)(outputs)
model = tf.keras.Model(inputs, outputs)
print(model(tf.constant(['some text', 'more text'])))
print(model.summary())
tf.Tensor(
[[-0.06719425 -0.02954631 -0.05811356 -0.1456391  -0.13001677  0.00145465
   0.0401044   0.05949172 -0.02589339  0.07255618 -0.00958113  0.01159782
   0.02508018  0.03075579 -0.01910635 -0.03231853  0.00875124  0.01143366
  -0.04365401 -0.02090197  0.07030752 -0.02872834  0.10535908  0.05691438
  -0.017165   -0.02044982  0.02580127 -0.04564123 -0.0631128  -0.00303708
   0.00133517  0.01613527]
 [-0.11922387  0.02304137 -0.02670465 -0.13117084 -0.11492493  0.03961402
   0.08129141 -0.05999354  0.0039564   0.02892766  0.00493046  0.00440936
  -0.07966737  0.11354238  0.03141225  0.00048972  0.04658606 -0.03658888
  -0.05292419 -0.04639702  0.08445395  0.00522146  0.04359548  0.0290177
  -0.02171512 -0.03399373 -0.00418095 -0.04019783 -0.04733383 -0.03972956
   0.01890458 -0.03927581]], shape=(2, 32), dtype=float32)
Model: "model_12"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_18 (InputLayer)       [(None, 1)]               0         
                                                                 
 s_bert_17 (SBert)           (None, 384)               22713216  
                                                                 
 dense_78 (Dense)            (None, 32)                12320     
                                                                 
=================================================================
Total params: 22,725,536
Trainable params: 22,725,536
Non-trainable params: 0
_________________________________________________________________
None

Python相关问答推荐

双情节在单个图上切换-pPython

将C struct 的指针传递给Python中的ioctel

CustomTKinter-向表单添加额外的输入字段

合并同名列,但一列为空,另一列包含值

如何在Python中按组应用简单的线性回归?

Polars Dataframe:如何按组删除交替行?

pyautogui.locateOnScreen在Linux上的工作方式有所不同

如何使用stride_tricks.as_strided逆转NumPy数组

Python中使用时区感知日期时间对象进行时间算术的Incredit

从webhook中的短代码(而不是电话号码)接收Twilio消息

Python多处理:当我在一个巨大的pandas数据框架上启动许多进程时,程序就会陷入困境

如果条件不满足,我如何获得掩码的第一个索引并获得None?

如何从数据库上传数据到html?

连接一个rabrame和另一个1d rabrame不是问题,但当使用[...]'运算符会产生不同的结果

如何在Python中找到线性依赖mod 2

Django RawSQL注释字段

如何在达到end_time时自动将状态字段从1更改为0

Python Tkinter为特定样式调整所有ttkbootstrap或ttk Button填充的大小,适用于所有主题

以逻辑方式获取自己的pyproject.toml依赖项

如何获取Python synsets列表的第一个内容?