Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TTM model #20

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

add TTM model #20

wants to merge 8 commits into from

Conversation

frndtls
Copy link

@frndtls frndtls commented Mar 4, 2025

  1. add TTM model (simplify some settings)
  2. delete a print statement in dataloader.py

layers/Mlp.py Outdated
from layers.SelfAttention_Family import TTMGatedAttention


class TTMmlp(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

作为类名,MLP应该大写,文件名也要用大写MLP

layers/Mlp.py Outdated
class TTMmlp(nn.Module):
def __init__(self, in_features, out_features, e_factor, dropout):
super().__init__()
num_hidden = in_features * e_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e_factor用factor即可

layers/Mlp.py Outdated
from layers.SelfAttention_Family import TTMGatedAttention


class TTMmlp(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以加一些注释,例如特有的参数(e_factor)在论文中的含义等

layers/Mlp.py Outdated
residual = x # [B M N P]
x = self.norm(x)

assert self.mode in ["patch", "feature", "channel"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于mode的含义可提供一些注释

layers/Mlp.py Outdated
elif self.mode == "channel":
x = x.permute(0, 3, 2, 1) # [B M N P]
else:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处可rasie NotImplementedError

@@ -81,6 +81,17 @@
parser.add_argument('--kernel_size', type=int, default=25)
parser.add_argument('--stride', type=int, default=8)

# TTM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于这些ttm的参数可以给一些文中具体的--help解释

return y_hat


class TTMAPBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以加一些注释,特别是ttm特有的参数



class Model(nn.Module):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开源再标注一下预训练模型在huggingface的来源

@@ -140,3 +140,14 @@ def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None,

return self.out_projection(out), attn


class TTMGatedAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以改名为TTMGatedLayer,forward流程与attention关联不大

):
super().__init__()
self.adapt_patch_level = adapt_patch_level
adaptive_patch_factor = 2**adapt_patch_level
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块乘2的含义是啥,是否可以优化一下

run.py Outdated
parser.add_argument("--e_factor", type=int, default=2)
parser.add_argument("--mode", type=str, default="mix_channel")
parser.add_argument("--AP_levels", type=int, default=0)
parser.add_argument("--head_dropout", type=float, default=0.2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropout如果影响不大,可以fix为一个默认参数到model的入参中,上面几个参数也类似

@frndtls frndtls closed this Mar 5, 2025
@frndtls frndtls reopened this Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants