diff --git a/dump.yml b/dump.yml deleted file mode 100644 index 2d93b8be..00000000 --- a/dump.yml +++ /dev/null @@ -1,948 +0,0 @@ -# SCUNet.SCUNet() -m_head - m_head.0.weight: Tensor float32 Size([64, 3, 3, 3]) -m_down1 - 0 - trans_block - ln1 - m_down1.0.trans_block.ln1.weight: Tensor float32 Size([32]) - m_down1.0.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_down1.0.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_down1.0.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_down1.0.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_down1.0.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_down1.0.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_down1.0.trans_block.ln2.weight: Tensor float32 Size([32]) - m_down1.0.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_down1.0.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_down1.0.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_down1.0.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_down1.0.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_down1.0.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.0.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_down1.0.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.0.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_down1.0.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_down1.0.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) - 1 - trans_block - ln1 - m_down1.1.trans_block.ln1.weight: Tensor float32 Size([32]) - m_down1.1.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_down1.1.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_down1.1.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_down1.1.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_down1.1.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_down1.1.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_down1.1.trans_block.ln2.weight: Tensor float32 Size([32]) - m_down1.1.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_down1.1.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_down1.1.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_down1.1.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_down1.1.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_down1.1.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.1.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_down1.1.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.1.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_down1.1.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_down1.1.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) - 2 - trans_block - ln1 - m_down1.2.trans_block.ln1.weight: Tensor float32 Size([32]) - m_down1.2.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_down1.2.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_down1.2.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_down1.2.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_down1.2.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_down1.2.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_down1.2.trans_block.ln2.weight: Tensor float32 Size([32]) - m_down1.2.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_down1.2.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_down1.2.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_down1.2.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_down1.2.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_down1.2.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.2.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_down1.2.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.2.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_down1.2.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_down1.2.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) - 3 - trans_block - ln1 - m_down1.3.trans_block.ln1.weight: Tensor float32 Size([32]) - m_down1.3.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_down1.3.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_down1.3.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_down1.3.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_down1.3.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_down1.3.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_down1.3.trans_block.ln2.weight: Tensor float32 Size([32]) - m_down1.3.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_down1.3.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_down1.3.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_down1.3.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_down1.3.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_down1.3.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.3.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_down1.3.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_down1.3.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_down1.3.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_down1.3.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) - 4 - m_down1.4.weight: Tensor float32 Size([128, 64, 2, 2]) -m_down2 - 0 - trans_block - ln1 - m_down2.0.trans_block.ln1.weight: Tensor float32 Size([64]) - m_down2.0.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_down2.0.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_down2.0.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_down2.0.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_down2.0.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_down2.0.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_down2.0.trans_block.ln2.weight: Tensor float32 Size([64]) - m_down2.0.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_down2.0.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_down2.0.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_down2.0.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_down2.0.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_down2.0.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.0.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_down2.0.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.0.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_down2.0.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_down2.0.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) - 1 - trans_block - ln1 - m_down2.1.trans_block.ln1.weight: Tensor float32 Size([64]) - m_down2.1.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_down2.1.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_down2.1.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_down2.1.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_down2.1.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_down2.1.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_down2.1.trans_block.ln2.weight: Tensor float32 Size([64]) - m_down2.1.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_down2.1.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_down2.1.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_down2.1.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_down2.1.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_down2.1.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.1.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_down2.1.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.1.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_down2.1.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_down2.1.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) - 2 - trans_block - ln1 - m_down2.2.trans_block.ln1.weight: Tensor float32 Size([64]) - m_down2.2.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_down2.2.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_down2.2.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_down2.2.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_down2.2.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_down2.2.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_down2.2.trans_block.ln2.weight: Tensor float32 Size([64]) - m_down2.2.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_down2.2.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_down2.2.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_down2.2.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_down2.2.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_down2.2.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.2.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_down2.2.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.2.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_down2.2.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_down2.2.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) - 3 - trans_block - ln1 - m_down2.3.trans_block.ln1.weight: Tensor float32 Size([64]) - m_down2.3.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_down2.3.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_down2.3.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_down2.3.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_down2.3.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_down2.3.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_down2.3.trans_block.ln2.weight: Tensor float32 Size([64]) - m_down2.3.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_down2.3.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_down2.3.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_down2.3.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_down2.3.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_down2.3.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.3.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_down2.3.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_down2.3.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_down2.3.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_down2.3.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) - 4 - m_down2.4.weight: Tensor float32 Size([256, 128, 2, 2]) -m_down3 - 0 - trans_block - ln1 - m_down3.0.trans_block.ln1.weight: Tensor float32 Size([128]) - m_down3.0.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_down3.0.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_down3.0.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_down3.0.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_down3.0.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_down3.0.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_down3.0.trans_block.ln2.weight: Tensor float32 Size([128]) - m_down3.0.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_down3.0.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_down3.0.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_down3.0.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_down3.0.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_down3.0.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.0.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_down3.0.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.0.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_down3.0.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_down3.0.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) - 1 - trans_block - ln1 - m_down3.1.trans_block.ln1.weight: Tensor float32 Size([128]) - m_down3.1.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_down3.1.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_down3.1.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_down3.1.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_down3.1.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_down3.1.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_down3.1.trans_block.ln2.weight: Tensor float32 Size([128]) - m_down3.1.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_down3.1.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_down3.1.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_down3.1.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_down3.1.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_down3.1.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.1.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_down3.1.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.1.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_down3.1.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_down3.1.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) - 2 - trans_block - ln1 - m_down3.2.trans_block.ln1.weight: Tensor float32 Size([128]) - m_down3.2.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_down3.2.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_down3.2.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_down3.2.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_down3.2.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_down3.2.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_down3.2.trans_block.ln2.weight: Tensor float32 Size([128]) - m_down3.2.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_down3.2.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_down3.2.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_down3.2.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_down3.2.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_down3.2.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.2.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_down3.2.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.2.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_down3.2.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_down3.2.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) - 3 - trans_block - ln1 - m_down3.3.trans_block.ln1.weight: Tensor float32 Size([128]) - m_down3.3.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_down3.3.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_down3.3.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_down3.3.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_down3.3.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_down3.3.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_down3.3.trans_block.ln2.weight: Tensor float32 Size([128]) - m_down3.3.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_down3.3.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_down3.3.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_down3.3.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_down3.3.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_down3.3.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.3.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_down3.3.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_down3.3.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_down3.3.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_down3.3.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) - 4 - m_down3.4.weight: Tensor float32 Size([512, 256, 2, 2]) -m_body - 0 - trans_block - ln1 - m_body.0.trans_block.ln1.weight: Tensor float32 Size([256]) - m_body.0.trans_block.ln1.bias: Tensor float32 Size([256]) - msa - relative_position_params - m_body.0.trans_block.msa.relative_position_params: Tensor float32 Size([8, 15, 15]) - embedding_layer - m_body.0.trans_block.msa.embedding_layer.weight: Tensor float32 Size([768, 256]) - m_body.0.trans_block.msa.embedding_layer.bias: Tensor float32 Size([768]) - linear - m_body.0.trans_block.msa.linear.weight: Tensor float32 Size([256, 256]) - m_body.0.trans_block.msa.linear.bias: Tensor float32 Size([256]) - ln2 - m_body.0.trans_block.ln2.weight: Tensor float32 Size([256]) - m_body.0.trans_block.ln2.bias: Tensor float32 Size([256]) - mlp - 0 - m_body.0.trans_block.mlp.0.weight: Tensor float32 Size([1024, 256]) - m_body.0.trans_block.mlp.0.bias: Tensor float32 Size([1024]) - 2 - m_body.0.trans_block.mlp.2.weight: Tensor float32 Size([256, 1024]) - m_body.0.trans_block.mlp.2.bias: Tensor float32 Size([256]) - conv1_1 - m_body.0.conv1_1.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.0.conv1_1.bias: Tensor float32 Size([512]) - conv1_2 - m_body.0.conv1_2.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.0.conv1_2.bias: Tensor float32 Size([512]) - conv_block - m_body.0.conv_block.0.weight: Tensor float32 Size([256, 256, 3, 3]) - m_body.0.conv_block.2.weight: Tensor float32 Size([256, 256, 3, 3]) - 1 - trans_block - ln1 - m_body.1.trans_block.ln1.weight: Tensor float32 Size([256]) - m_body.1.trans_block.ln1.bias: Tensor float32 Size([256]) - msa - relative_position_params - m_body.1.trans_block.msa.relative_position_params: Tensor float32 Size([8, 15, 15]) - embedding_layer - m_body.1.trans_block.msa.embedding_layer.weight: Tensor float32 Size([768, 256]) - m_body.1.trans_block.msa.embedding_layer.bias: Tensor float32 Size([768]) - linear - m_body.1.trans_block.msa.linear.weight: Tensor float32 Size([256, 256]) - m_body.1.trans_block.msa.linear.bias: Tensor float32 Size([256]) - ln2 - m_body.1.trans_block.ln2.weight: Tensor float32 Size([256]) - m_body.1.trans_block.ln2.bias: Tensor float32 Size([256]) - mlp - 0 - m_body.1.trans_block.mlp.0.weight: Tensor float32 Size([1024, 256]) - m_body.1.trans_block.mlp.0.bias: Tensor float32 Size([1024]) - 2 - m_body.1.trans_block.mlp.2.weight: Tensor float32 Size([256, 1024]) - m_body.1.trans_block.mlp.2.bias: Tensor float32 Size([256]) - conv1_1 - m_body.1.conv1_1.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.1.conv1_1.bias: Tensor float32 Size([512]) - conv1_2 - m_body.1.conv1_2.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.1.conv1_2.bias: Tensor float32 Size([512]) - conv_block - m_body.1.conv_block.0.weight: Tensor float32 Size([256, 256, 3, 3]) - m_body.1.conv_block.2.weight: Tensor float32 Size([256, 256, 3, 3]) - 2 - trans_block - ln1 - m_body.2.trans_block.ln1.weight: Tensor float32 Size([256]) - m_body.2.trans_block.ln1.bias: Tensor float32 Size([256]) - msa - relative_position_params - m_body.2.trans_block.msa.relative_position_params: Tensor float32 Size([8, 15, 15]) - embedding_layer - m_body.2.trans_block.msa.embedding_layer.weight: Tensor float32 Size([768, 256]) - m_body.2.trans_block.msa.embedding_layer.bias: Tensor float32 Size([768]) - linear - m_body.2.trans_block.msa.linear.weight: Tensor float32 Size([256, 256]) - m_body.2.trans_block.msa.linear.bias: Tensor float32 Size([256]) - ln2 - m_body.2.trans_block.ln2.weight: Tensor float32 Size([256]) - m_body.2.trans_block.ln2.bias: Tensor float32 Size([256]) - mlp - 0 - m_body.2.trans_block.mlp.0.weight: Tensor float32 Size([1024, 256]) - m_body.2.trans_block.mlp.0.bias: Tensor float32 Size([1024]) - 2 - m_body.2.trans_block.mlp.2.weight: Tensor float32 Size([256, 1024]) - m_body.2.trans_block.mlp.2.bias: Tensor float32 Size([256]) - conv1_1 - m_body.2.conv1_1.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.2.conv1_1.bias: Tensor float32 Size([512]) - conv1_2 - m_body.2.conv1_2.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.2.conv1_2.bias: Tensor float32 Size([512]) - conv_block - m_body.2.conv_block.0.weight: Tensor float32 Size([256, 256, 3, 3]) - m_body.2.conv_block.2.weight: Tensor float32 Size([256, 256, 3, 3]) - 3 - trans_block - ln1 - m_body.3.trans_block.ln1.weight: Tensor float32 Size([256]) - m_body.3.trans_block.ln1.bias: Tensor float32 Size([256]) - msa - relative_position_params - m_body.3.trans_block.msa.relative_position_params: Tensor float32 Size([8, 15, 15]) - embedding_layer - m_body.3.trans_block.msa.embedding_layer.weight: Tensor float32 Size([768, 256]) - m_body.3.trans_block.msa.embedding_layer.bias: Tensor float32 Size([768]) - linear - m_body.3.trans_block.msa.linear.weight: Tensor float32 Size([256, 256]) - m_body.3.trans_block.msa.linear.bias: Tensor float32 Size([256]) - ln2 - m_body.3.trans_block.ln2.weight: Tensor float32 Size([256]) - m_body.3.trans_block.ln2.bias: Tensor float32 Size([256]) - mlp - 0 - m_body.3.trans_block.mlp.0.weight: Tensor float32 Size([1024, 256]) - m_body.3.trans_block.mlp.0.bias: Tensor float32 Size([1024]) - 2 - m_body.3.trans_block.mlp.2.weight: Tensor float32 Size([256, 1024]) - m_body.3.trans_block.mlp.2.bias: Tensor float32 Size([256]) - conv1_1 - m_body.3.conv1_1.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.3.conv1_1.bias: Tensor float32 Size([512]) - conv1_2 - m_body.3.conv1_2.weight: Tensor float32 Size([512, 512, 1, 1]) - m_body.3.conv1_2.bias: Tensor float32 Size([512]) - conv_block - m_body.3.conv_block.0.weight: Tensor float32 Size([256, 256, 3, 3]) - m_body.3.conv_block.2.weight: Tensor float32 Size([256, 256, 3, 3]) -m_up3 - 0 - m_up3.0.weight: Tensor float32 Size([512, 256, 2, 2]) - 1 - trans_block - ln1 - m_up3.1.trans_block.ln1.weight: Tensor float32 Size([128]) - m_up3.1.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_up3.1.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_up3.1.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_up3.1.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_up3.1.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_up3.1.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_up3.1.trans_block.ln2.weight: Tensor float32 Size([128]) - m_up3.1.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_up3.1.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_up3.1.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_up3.1.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_up3.1.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_up3.1.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.1.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_up3.1.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.1.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_up3.1.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_up3.1.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) - 2 - trans_block - ln1 - m_up3.2.trans_block.ln1.weight: Tensor float32 Size([128]) - m_up3.2.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_up3.2.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_up3.2.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_up3.2.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_up3.2.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_up3.2.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_up3.2.trans_block.ln2.weight: Tensor float32 Size([128]) - m_up3.2.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_up3.2.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_up3.2.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_up3.2.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_up3.2.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_up3.2.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.2.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_up3.2.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.2.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_up3.2.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_up3.2.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) - 3 - trans_block - ln1 - m_up3.3.trans_block.ln1.weight: Tensor float32 Size([128]) - m_up3.3.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_up3.3.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_up3.3.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_up3.3.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_up3.3.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_up3.3.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_up3.3.trans_block.ln2.weight: Tensor float32 Size([128]) - m_up3.3.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_up3.3.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_up3.3.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_up3.3.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_up3.3.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_up3.3.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.3.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_up3.3.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.3.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_up3.3.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_up3.3.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) - 4 - trans_block - ln1 - m_up3.4.trans_block.ln1.weight: Tensor float32 Size([128]) - m_up3.4.trans_block.ln1.bias: Tensor float32 Size([128]) - msa - relative_position_params - m_up3.4.trans_block.msa.relative_position_params: Tensor float32 Size([4, 15, 15]) - embedding_layer - m_up3.4.trans_block.msa.embedding_layer.weight: Tensor float32 Size([384, 128]) - m_up3.4.trans_block.msa.embedding_layer.bias: Tensor float32 Size([384]) - linear - m_up3.4.trans_block.msa.linear.weight: Tensor float32 Size([128, 128]) - m_up3.4.trans_block.msa.linear.bias: Tensor float32 Size([128]) - ln2 - m_up3.4.trans_block.ln2.weight: Tensor float32 Size([128]) - m_up3.4.trans_block.ln2.bias: Tensor float32 Size([128]) - mlp - 0 - m_up3.4.trans_block.mlp.0.weight: Tensor float32 Size([512, 128]) - m_up3.4.trans_block.mlp.0.bias: Tensor float32 Size([512]) - 2 - m_up3.4.trans_block.mlp.2.weight: Tensor float32 Size([128, 512]) - m_up3.4.trans_block.mlp.2.bias: Tensor float32 Size([128]) - conv1_1 - m_up3.4.conv1_1.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.4.conv1_1.bias: Tensor float32 Size([256]) - conv1_2 - m_up3.4.conv1_2.weight: Tensor float32 Size([256, 256, 1, 1]) - m_up3.4.conv1_2.bias: Tensor float32 Size([256]) - conv_block - m_up3.4.conv_block.0.weight: Tensor float32 Size([128, 128, 3, 3]) - m_up3.4.conv_block.2.weight: Tensor float32 Size([128, 128, 3, 3]) -m_up2 - 0 - m_up2.0.weight: Tensor float32 Size([256, 128, 2, 2]) - 1 - trans_block - ln1 - m_up2.1.trans_block.ln1.weight: Tensor float32 Size([64]) - m_up2.1.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_up2.1.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_up2.1.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_up2.1.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_up2.1.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_up2.1.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_up2.1.trans_block.ln2.weight: Tensor float32 Size([64]) - m_up2.1.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_up2.1.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_up2.1.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_up2.1.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_up2.1.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_up2.1.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.1.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_up2.1.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.1.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_up2.1.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_up2.1.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) - 2 - trans_block - ln1 - m_up2.2.trans_block.ln1.weight: Tensor float32 Size([64]) - m_up2.2.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_up2.2.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_up2.2.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_up2.2.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_up2.2.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_up2.2.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_up2.2.trans_block.ln2.weight: Tensor float32 Size([64]) - m_up2.2.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_up2.2.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_up2.2.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_up2.2.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_up2.2.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_up2.2.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.2.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_up2.2.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.2.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_up2.2.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_up2.2.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) - 3 - trans_block - ln1 - m_up2.3.trans_block.ln1.weight: Tensor float32 Size([64]) - m_up2.3.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_up2.3.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_up2.3.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_up2.3.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_up2.3.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_up2.3.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_up2.3.trans_block.ln2.weight: Tensor float32 Size([64]) - m_up2.3.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_up2.3.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_up2.3.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_up2.3.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_up2.3.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_up2.3.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.3.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_up2.3.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.3.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_up2.3.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_up2.3.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) - 4 - trans_block - ln1 - m_up2.4.trans_block.ln1.weight: Tensor float32 Size([64]) - m_up2.4.trans_block.ln1.bias: Tensor float32 Size([64]) - msa - relative_position_params - m_up2.4.trans_block.msa.relative_position_params: Tensor float32 Size([2, 15, 15]) - embedding_layer - m_up2.4.trans_block.msa.embedding_layer.weight: Tensor float32 Size([192, 64]) - m_up2.4.trans_block.msa.embedding_layer.bias: Tensor float32 Size([192]) - linear - m_up2.4.trans_block.msa.linear.weight: Tensor float32 Size([64, 64]) - m_up2.4.trans_block.msa.linear.bias: Tensor float32 Size([64]) - ln2 - m_up2.4.trans_block.ln2.weight: Tensor float32 Size([64]) - m_up2.4.trans_block.ln2.bias: Tensor float32 Size([64]) - mlp - 0 - m_up2.4.trans_block.mlp.0.weight: Tensor float32 Size([256, 64]) - m_up2.4.trans_block.mlp.0.bias: Tensor float32 Size([256]) - 2 - m_up2.4.trans_block.mlp.2.weight: Tensor float32 Size([64, 256]) - m_up2.4.trans_block.mlp.2.bias: Tensor float32 Size([64]) - conv1_1 - m_up2.4.conv1_1.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.4.conv1_1.bias: Tensor float32 Size([128]) - conv1_2 - m_up2.4.conv1_2.weight: Tensor float32 Size([128, 128, 1, 1]) - m_up2.4.conv1_2.bias: Tensor float32 Size([128]) - conv_block - m_up2.4.conv_block.0.weight: Tensor float32 Size([64, 64, 3, 3]) - m_up2.4.conv_block.2.weight: Tensor float32 Size([64, 64, 3, 3]) -m_up1 - 0 - m_up1.0.weight: Tensor float32 Size([128, 64, 2, 2]) - 1 - trans_block - ln1 - m_up1.1.trans_block.ln1.weight: Tensor float32 Size([32]) - m_up1.1.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_up1.1.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_up1.1.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_up1.1.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_up1.1.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_up1.1.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_up1.1.trans_block.ln2.weight: Tensor float32 Size([32]) - m_up1.1.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_up1.1.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_up1.1.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_up1.1.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_up1.1.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_up1.1.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.1.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_up1.1.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.1.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_up1.1.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_up1.1.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) - 2 - trans_block - ln1 - m_up1.2.trans_block.ln1.weight: Tensor float32 Size([32]) - m_up1.2.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_up1.2.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_up1.2.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_up1.2.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_up1.2.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_up1.2.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_up1.2.trans_block.ln2.weight: Tensor float32 Size([32]) - m_up1.2.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_up1.2.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_up1.2.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_up1.2.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_up1.2.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_up1.2.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.2.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_up1.2.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.2.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_up1.2.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_up1.2.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) - 3 - trans_block - ln1 - m_up1.3.trans_block.ln1.weight: Tensor float32 Size([32]) - m_up1.3.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_up1.3.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_up1.3.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_up1.3.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_up1.3.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_up1.3.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_up1.3.trans_block.ln2.weight: Tensor float32 Size([32]) - m_up1.3.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_up1.3.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_up1.3.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_up1.3.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_up1.3.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_up1.3.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.3.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_up1.3.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.3.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_up1.3.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_up1.3.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) - 4 - trans_block - ln1 - m_up1.4.trans_block.ln1.weight: Tensor float32 Size([32]) - m_up1.4.trans_block.ln1.bias: Tensor float32 Size([32]) - msa - relative_position_params - m_up1.4.trans_block.msa.relative_position_params: Tensor float32 Size([1, 15, 15]) - embedding_layer - m_up1.4.trans_block.msa.embedding_layer.weight: Tensor float32 Size([96, 32]) - m_up1.4.trans_block.msa.embedding_layer.bias: Tensor float32 Size([96]) - linear - m_up1.4.trans_block.msa.linear.weight: Tensor float32 Size([32, 32]) - m_up1.4.trans_block.msa.linear.bias: Tensor float32 Size([32]) - ln2 - m_up1.4.trans_block.ln2.weight: Tensor float32 Size([32]) - m_up1.4.trans_block.ln2.bias: Tensor float32 Size([32]) - mlp - 0 - m_up1.4.trans_block.mlp.0.weight: Tensor float32 Size([128, 32]) - m_up1.4.trans_block.mlp.0.bias: Tensor float32 Size([128]) - 2 - m_up1.4.trans_block.mlp.2.weight: Tensor float32 Size([32, 128]) - m_up1.4.trans_block.mlp.2.bias: Tensor float32 Size([32]) - conv1_1 - m_up1.4.conv1_1.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.4.conv1_1.bias: Tensor float32 Size([64]) - conv1_2 - m_up1.4.conv1_2.weight: Tensor float32 Size([64, 64, 1, 1]) - m_up1.4.conv1_2.bias: Tensor float32 Size([64]) - conv_block - m_up1.4.conv_block.0.weight: Tensor float32 Size([32, 32, 3, 3]) - m_up1.4.conv_block.2.weight: Tensor float32 Size([32, 32, 3, 3]) -m_tail - m_tail.0.weight: Tensor float32 Size([3, 64, 3, 3]) \ No newline at end of file diff --git a/src/spandrel/architectures/Compact/__init__.py b/src/spandrel/architectures/Compact/__init__.py index c20af7c7..9fe634d3 100644 --- a/src/spandrel/architectures/Compact/__init__.py +++ b/src/spandrel/architectures/Compact/__init__.py @@ -1,38 +1,8 @@ -from __future__ import annotations - -import math - from ...__helpers.model_descriptor import SRModelDescriptor, StateDict -from ..__arch_helpers.state import get_max_seq_index +from ..__arch_helpers.state import get_max_seq_index, get_scale_and_output_channels from .arch.SRVGG import SRVGGNetCompact -def _get_scale_and_output_channels(x: int, input_channels: int) -> tuple[int, int]: - # Unfortunately, we do not have enough information to determine both the scale and - # number output channels correctly *in general*. However, we can make some - # assumptions to make it good enough. - # - # What we know: - # - x = scale * scale * output_channels - # - output_channels is likely equal to input_channels - # - output_channels and input_channels is likely 1, 3, or 4 - # - scale is likely 1, 2, 4, or 8 - - def is_square(n: int) -> bool: - return math.sqrt(n) == int(math.sqrt(n)) - - # just try out a few candidates and see which ones fulfill the requirements - candidates = [input_channels, 3, 4, 1] - for c in candidates: - if x % c == 0 and is_square(x // c): - return int(math.sqrt(x // c)), c - - raise AssertionError( - f"Expected output channels to be either 1, 3, or 4." - f" Could not find a pair (scale, out_nc) such that `scale**2 * out_nc = {x}`" - ) - - def load(state_dict: StateDict) -> SRModelDescriptor[SRVGGNetCompact]: state = state_dict @@ -43,7 +13,7 @@ def load(state_dict: StateDict) -> SRModelDescriptor[SRVGGNetCompact]: num_conv = (highest_num - 2) // 2 pixelshuffle_shape = state[f"body.{highest_num}.bias"].shape[0] - scale, out_nc = _get_scale_and_output_channels(pixelshuffle_shape, in_nc) + scale, out_nc = get_scale_and_output_channels(pixelshuffle_shape, in_nc) model = SRVGGNetCompact( num_in_ch=in_nc, diff --git a/src/spandrel/architectures/OmniSR/__init__.py b/src/spandrel/architectures/OmniSR/__init__.py index 2c864d00..e181204b 100644 --- a/src/spandrel/architectures/OmniSR/__init__.py +++ b/src/spandrel/architectures/OmniSR/__init__.py @@ -1,66 +1,57 @@ import math from ...__helpers.model_descriptor import SizeRequirements, SRModelDescriptor, StateDict +from ..__arch_helpers.state import ( + get_scale_and_output_channels, + get_seq_len, +) from .arch.OmniSR import OmniSR def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]: - state = state_dict - - block_num = 1 # Fine to assume this for now - ffn_bias = True + num_in_ch = 3 + num_out_ch = 3 + num_feat = 64 + block_num = 1 pe = True + window_size = 8 + res_num = 1 + up_scale = 4 + bias = True - num_feat = state_dict["input.weight"].shape[0] or 64 - num_in_ch = state_dict["input.weight"].shape[1] or 3 - num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh + num_feat = state_dict["input.weight"].shape[0] + num_in_ch = state_dict["input.weight"].shape[1] + bias = "input.bias" in state_dict pixelshuffle_shape = state_dict["up.0.weight"].shape[0] - up_scale = math.sqrt(pixelshuffle_shape / num_out_ch) - if up_scale - int(up_scale) > 0: - print( - "out_nc is probably different than in_nc, scale calculation might be wrong" - ) - up_scale = int(up_scale) - res_num = 0 - for key in state_dict.keys(): - if "residual_layer" in key: - temp_res_num = int(key.split(".")[1]) - if temp_res_num > res_num: - res_num = temp_res_num - res_num = res_num + 1 # zero-indexed + up_scale, num_out_ch = get_scale_and_output_channels(pixelshuffle_shape, num_in_ch) - res_num = res_num + res_num = get_seq_len(state_dict, "residual_layer") + block_num = get_seq_len(state_dict, "residual_layer.0.residual_layer") - 1 - if ( + rel_pos_bias_key = ( "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight" - in state_dict.keys() - ): - rel_pos_bias_weight = state_dict[ - "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight" - ].shape[0] + ) + if rel_pos_bias_key in state_dict: + pe = True + # rel_pos_bias_weight = (2 * window_size - 1) ** 2 + rel_pos_bias_weight = state_dict[rel_pos_bias_key].shape[0] window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2) else: - window_size = 8 + pe = False model = OmniSR( num_in_ch=num_in_ch, num_out_ch=num_out_ch, num_feat=num_feat, block_num=block_num, - ffn_bias=ffn_bias, pe=pe, window_size=window_size, res_num=res_num, up_scale=up_scale, - bias=True, + bias=bias, ) - in_nc = num_in_ch - out_nc = num_out_ch - num_feat = num_feat - scale = up_scale - tags = [ f"{num_feat}nf", f"w{window_size}", @@ -69,13 +60,13 @@ def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]: return SRModelDescriptor( model, - state, + state_dict, architecture="OmniSR", tags=tags, supports_half=True, # TODO: Test this supports_bfloat16=True, - scale=scale, - input_channels=in_nc, - output_channels=out_nc, + scale=up_scale, + input_channels=num_in_ch, + output_channels=num_out_ch, size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/OmniSR/arch/OSA.py b/src/spandrel/architectures/OmniSR/arch/OSA.py index 8641e276..60636fb2 100644 --- a/src/spandrel/architectures/OmniSR/arch/OSA.py +++ b/src/spandrel/architectures/OmniSR/arch/OSA.py @@ -498,8 +498,6 @@ class OSA_Block(nn.Module): def __init__( self, channel_num=64, - bias=True, - ffn_bias=True, window_size=8, with_pe=False, dropout=0.0, diff --git a/src/spandrel/architectures/OmniSR/arch/OSAG.py b/src/spandrel/architectures/OmniSR/arch/OSAG.py index 5d580a02..9921bc9e 100644 --- a/src/spandrel/architectures/OmniSR/arch/OSAG.py +++ b/src/spandrel/architectures/OmniSR/arch/OSAG.py @@ -22,7 +22,6 @@ def __init__( channel_num=64, bias=True, block_num=4, - ffn_bias=False, window_size=0, pe=False, ): @@ -42,8 +41,6 @@ def __init__( for _ in range(block_num): temp_res = block_class( channel_num, - bias, - ffn_bias=ffn_bias, window_size=window_size, with_pe=pe, ) diff --git a/src/spandrel/architectures/OmniSR/arch/OmniSR.py b/src/spandrel/architectures/OmniSR/arch/OmniSR.py index 9ac8ba50..71fed5a1 100644 --- a/src/spandrel/architectures/OmniSR/arch/OmniSR.py +++ b/src/spandrel/architectures/OmniSR/arch/OmniSR.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# type: ignore ############################################################# # File: OmniSR.py # Created Date: Tuesday April 28th 2022 @@ -25,7 +24,6 @@ def __init__( num_out_ch=3, num_feat=64, block_num=1, - ffn_bias=True, pe=True, window_size=8, res_num=1, @@ -45,7 +43,6 @@ def __init__( channel_num=num_feat, bias=bias, block_num=block_num, - ffn_bias=ffn_bias, window_size=self.window_size, pe=pe, ) diff --git a/src/spandrel/architectures/__arch_helpers/state.py b/src/spandrel/architectures/__arch_helpers/state.py index df7454f7..352daf75 100644 --- a/src/spandrel/architectures/__arch_helpers/state.py +++ b/src/spandrel/architectures/__arch_helpers/state.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +import math +from typing import Any + + def get_max_seq_index(state: dict, key_pattern: str, start: int = 0) -> int: """ Returns the maximum number `i` such that `key_pattern.format(str(i))` is in `state`. @@ -18,3 +24,57 @@ def get_max_seq_index(state: dict, key_pattern: str, start: int = 0) -> int: if key not in state: return i - 1 i += 1 + + +def get_seq_len(state: dict[str, Any], seq_key: str) -> int: + """ + Returns the length of a sequence in the state dict. + + The length is detected by finding the maximum index `i` such that + `{seq_key}.{i}.{suffix}` is in `state` for some suffix. + + Example: + get_seq_len(state, "body") -> 5 + """ + prefix = seq_key + "." + + keys: set[int] = set() + for k in state.keys(): + if k.startswith(prefix): + index = k[len(prefix) :].split(".", maxsplit=1)[0] + keys.add(int(index)) + + if len(keys) == 0: + return 0 + return max(keys) + 1 + + +def get_scale_and_output_channels(x: int, input_channels: int) -> tuple[int, int]: + """ + Returns a scale and number of output channels such that `scale**2 * out_nc = x`. + + This is commonly used for pixelshuffel layers. + """ + # Unfortunately, we do not have enough information to determine both the scale and + # number output channels correctly *in general*. However, we can make some + # assumptions to make it good enough. + # + # What we know: + # - x = scale * scale * output_channels + # - output_channels is likely equal to input_channels + # - output_channels and input_channels is likely 1, 3, or 4 + # - scale is likely 1, 2, 4, or 8 + + def is_square(n: int) -> bool: + return math.sqrt(n) == int(math.sqrt(n)) + + # just try out a few candidates and see which ones fulfill the requirements + candidates = [input_channels, 3, 4, 1] + for c in candidates: + if x % c == 0 and is_square(x // c): + return int(math.sqrt(x // c)), c + + raise AssertionError( + f"Expected output channels to be either 1, 3, or 4." + f" Could not find a pair (scale, out_nc) such that `scale**2 * out_nc = {x}`" + ) diff --git a/tests/test_OmniSR.py b/tests/test_OmniSR.py index 4c5d2ca2..3c8a2288 100644 --- a/tests/test_OmniSR.py +++ b/tests/test_OmniSR.py @@ -1,7 +1,36 @@ from spandrel import ModelLoader -from spandrel.architectures.OmniSR import OmniSR +from spandrel.architectures.OmniSR import OmniSR, load -from .util import ModelFile, TestImage, assert_image_inference, disallowed_props +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + disallowed_props, +) + + +def test_OmniSR_load(): + assert_loads_correctly( + load, + lambda: OmniSR(), + lambda: OmniSR(num_in_ch=1, num_out_ch=1), + lambda: OmniSR(num_in_ch=3, num_out_ch=3), + lambda: OmniSR(num_in_ch=4, num_out_ch=4), + lambda: OmniSR(num_in_ch=1, num_out_ch=3), + lambda: OmniSR(num_feat=32), + lambda: OmniSR(block_num=2), + lambda: OmniSR(pe=False), + lambda: OmniSR(bias=False), + lambda: OmniSR(window_size=5), + lambda: OmniSR(res_num=3), + lambda: OmniSR(up_scale=5), + condition=lambda a, b: ( + a.res_num == b.res_num + and a.up_scale == b.up_scale + and a.window_size == b.window_size + ), + ) def test_OmniSR_community1(snapshot):