Skip to content

Commit

Permalink
Attempt to turn max_train_len into a parameter - #1465
Browse files Browse the repository at this point in the history
Preserve existing models with the new config attribute
  • Loading branch information
AngledLuffa committed Mar 1, 2025
1 parent de77f17 commit 730654a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions stanza/models/coref/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me
log_norms: bool
singletons: bool

max_train_len: int
3 changes: 3 additions & 0 deletions stanza/models/coref/coref_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ bce_loss_weight = 0.5
# The directory that will contain conll prediction files
conll_log_dir = "data/conll_logs"

# Skip any documents longer than this length
max_train_len = 5000

# =============================================================================
# Extra keyword arguments to be passed to bert tokenizers of specified models
[DEFAULT.tokenizer_kwargs]
Expand Down
6 changes: 5 additions & 1 deletion stanza/models/coref/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ def load_model(path: str,
config = state_dicts.pop('config', None)
if config is None:
raise ValueError("Cannot load this format model without config in the dicts")
if 'max_train_len' not in config:
# TODO: this is to keep old models working.
# Can get rid of it if those models are rebuilt
config['max_train_len'] = 5000
if isinstance(config, dict):
config = Config(**config)
if config_update:
Expand Down Expand Up @@ -456,7 +460,7 @@ def train(self, log=False):
doc = docs[doc_id]

# skip very long documents during training time
if len(doc["subwords"]) > 5000:
if len(doc["subwords"]) > self.config.max_train_len:
continue

for optim in self.optimizers.values():
Expand Down
8 changes: 8 additions & 0 deletions stanza/models/wl_coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def deterministic() -> None:
argparser.add_argument("--seed", type=int, default=2020,
help="Random seed to set")

argparser.add_argument("--max_train_len", type=int, default=5000,
help="Skip any documents longer than this maximum length")
argparser.add_argument("--no_max_train_len", action="store_const", const=float("inf"), dest="max_train_len",
help="Do not skip any documents for being too long")

argparser.add_argument("--train_data", default=None, help="File to use for train data")
argparser.add_argument("--dev_data", default=None, help="File to use for dev data")
argparser.add_argument("--test_data", default=None, help="File to use for test data")
Expand Down Expand Up @@ -184,6 +189,9 @@ def deterministic() -> None:
if args.test_data:
config.test_data = args.test_data

if args.max_train_len:
config.max_train_len = args.max_train_len

# if wandb, generate wandb configuration
if args.mode == "train":
if args.wandb:
Expand Down

0 comments on commit 730654a

Please sign in to comment.