diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 78a54abf13..9444c713d0 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -193,7 +193,7 @@ print(f"Input text: {repr(prompt)}") print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") with torch.no_grad(): - outputs = model(input_ids) + outputs = model(input_ids.to(model.device)) logits = outputs.logits # Extract logits for the last token (next token prediction)