Skip to content

Commit

Permalink
Add write_flush in two loggers, fix argument passing in WandbLogger (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Mar 30, 2022
1 parent 6ab9860 commit f13e415
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tianshou/utils/logger/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger):
:param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
"""

def __init__(
Expand All @@ -26,16 +28,19 @@ def __init__(
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1,
write_flush: bool = True,
) -> None:
super().__init__(train_interval, test_interval, update_interval)
self.save_interval = save_interval
self.write_flush = write_flush
self.last_save_step = -1
self.writer = writer

def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step)
self.writer.flush() # issue #482
if self.write_flush: # issue 580
self.writer.flush() # issue #482

def save_data(
self,
Expand Down
11 changes: 10 additions & 1 deletion tianshou/utils/logger/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class WandbLogger(BaseLogger):
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data().
Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
:param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None.
Expand All @@ -44,6 +48,7 @@ def __init__(
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1000,
write_flush: bool = True,
project: Optional[str] = None,
name: Optional[str] = None,
entity: Optional[str] = None,
Expand All @@ -53,6 +58,7 @@ def __init__(
super().__init__(train_interval, test_interval, update_interval)
self.last_save_step = -1
self.save_interval = save_interval
self.write_flush = write_flush
self.restored = False
if project is None:
project = os.getenv("WANDB_PROJECT", "tianshou")
Expand All @@ -72,7 +78,10 @@ def __init__(

def load(self, writer: SummaryWriter) -> None:
self.writer = writer
self.tensorboard_logger = TensorboardLogger(writer)
self.tensorboard_logger = TensorboardLogger(
writer, self.train_interval, self.test_interval, self.update_interval,
self.save_interval, self.write_flush
)

def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
if self.tensorboard_logger is None:
Expand Down

0 comments on commit f13e415

Please sign in to comment.