diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 00f8928d2ead..377a3666fd45 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -265,8 +265,8 @@ def _collate_data(data, stack_fn=Stack()): tokens = tokens_[:, :-1] return { - "input_ids": tokens, - "labels": labels, + "input_ids": paddle.to_tensor(tokens), + "labels": paddle.to_tensor(labels), } if need_data: