From dc08470c5aaeaadcf9e052c176cc0224ec342291 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Mon, 10 Apr 2023 15:50:08 +0000 Subject: [PATCH] add arg num_workers for ernie3.0 --- .../benchmark/modules/ernie3_for_sequence_classification.py | 2 +- tests/test_tipc/benchmark/options.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_tipc/benchmark/modules/ernie3_for_sequence_classification.py b/tests/test_tipc/benchmark/modules/ernie3_for_sequence_classification.py index 034b4ee861fb..90f8574380a4 100644 --- a/tests/test_tipc/benchmark/modules/ernie3_for_sequence_classification.py +++ b/tests/test_tipc/benchmark/modules/ernie3_for_sequence_classification.py @@ -145,7 +145,7 @@ def create_data_loader(self, args, **kwargs): train_loader = DataLoader( dataset=train_ds, batch_sampler=train_batch_sampler, - num_workers=4, # when paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0 + num_workers=args.num_workers, # when paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0 ) self.num_batch = len(train_loader) diff --git a/tests/test_tipc/benchmark/options.py b/tests/test_tipc/benchmark/options.py index 54ff08294330..690364debf74 100644 --- a/tests/test_tipc/benchmark/options.py +++ b/tests/test_tipc/benchmark/options.py @@ -136,6 +136,12 @@ def get_parser(): parser.add_argument("--epoch", type=int, default=10, help="Number of epochs. ") parser.add_argument("--generated_inputs", action="store_true", help="Use generated inputs. ") + parser.add_argument( + "--num_workers", + type=int, + default=4, + help="num_workers of dataloader. When paddlepaddle<=2.4.1, if we use dynamicTostatic mode, we need set num_workeks > 0 ", + ) # For benchmark. parser.add_argument(