-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add TTM model #20
base: main
Are you sure you want to change the base?
add TTM model #20
Conversation
frndtls
commented
Mar 4, 2025
- add TTM model (simplify some settings)
- delete a print statement in dataloader.py
layers/Mlp.py
Outdated
from layers.SelfAttention_Family import TTMGatedAttention | ||
|
||
|
||
class TTMmlp(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
作为类名,MLP应该大写,文件名也要用大写MLP
layers/Mlp.py
Outdated
class TTMmlp(nn.Module): | ||
def __init__(self, in_features, out_features, e_factor, dropout): | ||
super().__init__() | ||
num_hidden = in_features * e_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e_factor用factor即可
layers/Mlp.py
Outdated
from layers.SelfAttention_Family import TTMGatedAttention | ||
|
||
|
||
class TTMmlp(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以加一些注释,例如特有的参数(e_factor)在论文中的含义等
layers/Mlp.py
Outdated
residual = x # [B M N P] | ||
x = self.norm(x) | ||
|
||
assert self.mode in ["patch", "feature", "channel"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于mode的含义可提供一些注释
layers/Mlp.py
Outdated
elif self.mode == "channel": | ||
x = x.permute(0, 3, 2, 1) # [B M N P] | ||
else: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处可rasie NotImplementedError
@@ -81,6 +81,17 @@ | |||
parser.add_argument('--kernel_size', type=int, default=25) | |||
parser.add_argument('--stride', type=int, default=8) | |||
|
|||
# TTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于这些ttm的参数可以给一些文中具体的--help解释
return y_hat | ||
|
||
|
||
class TTMAPBlock(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以加一些注释,特别是ttm特有的参数
|
||
|
||
class Model(nn.Module): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
开源再标注一下预训练模型在huggingface的来源
layers/SelfAttention_Family.py
Outdated
@@ -140,3 +140,14 @@ def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, | |||
|
|||
return self.out_projection(out), attn | |||
|
|||
|
|||
class TTMGatedAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以改名为TTMGatedLayer,forward流程与attention关联不大
): | ||
super().__init__() | ||
self.adapt_patch_level = adapt_patch_level | ||
adaptive_patch_factor = 2**adapt_patch_level |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块乘2的含义是啥,是否可以优化一下
run.py
Outdated
parser.add_argument("--e_factor", type=int, default=2) | ||
parser.add_argument("--mode", type=str, default="mix_channel") | ||
parser.add_argument("--AP_levels", type=int, default=0) | ||
parser.add_argument("--head_dropout", type=float, default=0.2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dropout如果影响不大,可以fix为一个默认参数到model的入参中,上面几个参数也类似