From 4fc117f4b803e094ebf5c365c7710d5ab9f2bd11 Mon Sep 17 00:00:00 2001 From: Viraj Nadkarni Date: Sun, 22 May 2022 11:55:31 -0500 Subject: [PATCH] cnn files --- models.py | 1147 +++++++++++++++++++++++++++++++++ polar.py | 1298 +++++++++++++++++++++++++++++++++++++ utils.py | 534 +++++++++++++++ xformer_all.py | 1678 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 4657 insertions(+) create mode 100644 models.py create mode 100644 polar.py create mode 100644 utils.py create mode 100644 xformer_all.py diff --git a/models.py b/models.py new file mode 100644 index 0000000..af9659d --- /dev/null +++ b/models.py @@ -0,0 +1,1147 @@ +__author__ = 'vivien98' + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from IPython import display + +import imageio +import pickle +import os +import time +from datetime import datetime +import matplotlib +matplotlib.use('AGG') +import matplotlib.pyplot as plt + +from utils import snr_db2sigma, errors_ber, errors_bitwise_ber, errors_bler, min_sum_log_sum_exp, moving_average, extract_block_errors, extract_block_nonerrors + +from polar import * +from pac_code import * + +from sklearn.manifold import TSNE +import math +import random +import numpy as np +from tqdm import tqdm +from collections import namedtuple +import sys +import csv + + + + + +class ScaledDotProductAttention(nn.Module): + ''' Scaled Dot-Product Attention ''' + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None,causal=False): + + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + if mask is not None: + # if args.model == 'gpt': + # attn = attn.masked_fill(mask == 0, -1e9) + # else: + mask=mask.unsqueeze(1) + attn = attn.masked_fill(mask == 0, -1e9) + + attn = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + return output, attn + +class ScalarMult(nn.Module): + '''scalar multiplaication layer''' + + def __init__(self): + super().__init__() + self.alpha = nn.Parameter(1e-10*torch.ones(1)) + + def forward(self, x): + out = self.alpha*x + return out + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) + self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + + self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.scalar = ScalarMult() + + + def forward(self, q, k, v, mask=None,causal=False): + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + residual = q + + # Pass through the pre-attention projection: b x lq x (n*dv) + # Separate different heads: b x lq x n x dv + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + # Transpose for attention dot product: b x n x lq x dv + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) # For head axis broadcasting. + + q, attn = self.attention(q, k, v, mask=mask) + # if len(list(q.size()))==4: + # q = q.view(q.size(0)*sz_b,q.size(2),q.size(3),q.size(4)).transpose(1, 2).contiguous().view(sz_b, len_q, -1) + # else: + # Transpose to move the head dimension back: b x lq x n x dv + # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) + q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + q = self.dropout(self.fc(q)) + #q = self.scalar(q) + q += residual + q = self.layer_norm(q) + + + return q, attn + + +class PositionwiseFeedForward(nn.Module): + ''' A two-feed-forward-layer module ''' + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) # position-wise + self.w_2 = nn.Linear(d_hid, d_in) # position-wise + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout = nn.Dropout(dropout) + self.scalar = ScalarMult() + + def forward(self, x): + + residual = x + + x = self.w_2(F.gelu(self.w_1(x))) #F.gelu + x = self.dropout(x) + #x = self.scalar(x) + x += residual + + x = self.layer_norm(x) + + return x + + +class EncoderLayer(nn.Module): + ''' Compose with two layers ''' + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, slf_attn_mask=None): + enc_output, enc_slf_attn = self.slf_attn( + enc_input, enc_input, enc_input, mask=slf_attn_mask) + enc_output = self.pos_ffn(enc_output) + return enc_output, enc_slf_attn + + +class DecoderLayer(nn.Module): + ''' Compose with three layers ''' + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(DecoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward( + self, dec_input, enc_output, + slf_attn_mask=None, dec_enc_attn_mask=None, cross_attend=True): + dec_enc_attn=[] + dec_output, dec_slf_attn = self.slf_attn( + dec_input, dec_input, dec_input, mask=slf_attn_mask) + if cross_attend: + dec_output, dec_enc_attn = self.enc_attn( + dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) + dec_output = self.pos_ffn(dec_output) + return dec_output, dec_slf_attn, dec_enc_attn + +class PositionalEncoding(nn.Module): + + def __init__(self, d_hid, n_position=200,num=10000): + super(PositionalEncoding, self).__init__() + + # Not a parameter + self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid,num)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid,num): + ''' Sinusoid position encoding table ''' + # TODO: make it with torch instead of numpy + + def get_position_angle_vec(position,num): + return [position / np.power(num, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i,num) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + def forward(self, x): + return x + self.pos_table[:, :x.size(1)].clone().detach() + +class XFormerEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super(XFormerEncoder,self).__init__() + self.embed_dim = config.embed_dim + self.block_len = config.max_len + self.pos_emb = nn.Embedding(config.N+1, config.embed_dim,padding_idx=0) + self.position_enc = PositionalEncoding(self.embed_dim, n_position=self.block_len) + self.dropout = nn.Dropout(p=config.dropout) + self.layer_stack = nn.ModuleList([ + EncoderLayer(config.embed_dim, config.embed_dim*4, config.n_head, config.embed_dim//config.n_head, config.embed_dim//config.n_head, dropout=config.dropout) + for _ in range(config.n_layers)]) + self.layer_norm = nn.LayerNorm(config.embed_dim, eps=1e-6) + + def forward(self,noisy_enc,src_mask,device,return_attns=False): + position_indices = torch.arange(1,self.block_len+1, device=device) + pos_enc = self.pos_emb(position_indices) + enc_output = noisy_enc*pos_enc #<---- addition instead of multiplication? + enc_output = self.position_enc(enc_output) + + enc_output = self.dropout(enc_output) + enc_output = self.layer_norm(enc_output) + enc_slf_attn_list = [] + for enc_layer in self.layer_stack: + enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) + enc_slf_attn_list += [enc_slf_attn] if return_attns else [] + if return_attns: + return enc_output, enc_slf_attn_list + return enc_output # [b_size,block_len,embed_dim] + +class XFormerDecoder(nn.Module): + def __init__(self, config, layer_idx=None): + super(XFormerDecoder,self).__init__() + self.embed_dim = config.embed_dim + self.block_len = config.max_len + self.emb_auto = nn.Embedding(config.N+1, config.embed_dim,padding_idx=0) + self.emb_cross = nn.Embedding(config.N+1, config.embed_dim,padding_idx=0) + self.emb_inputs = nn.Embedding(4, config.embed_dim,padding_idx=3) + self.position_enc_auto = PositionalEncoding(self.embed_dim, n_position=self.block_len) + self.position_enc_cross = PositionalEncoding(self.embed_dim, n_position=self.block_len,num=5000) + self.dropout = nn.Dropout(p=config.dropout) + self.dropout_cross = nn.Dropout(p=config.dropout) + self.layer_stack = nn.ModuleList([ + DecoderLayer(config.embed_dim, config.embed_dim*4, config.n_head, config.embed_dim//config.n_head, config.embed_dim//config.n_head, dropout=config.dropout) + for _ in range(config.n_layers)]) + self.layer_norm = nn.LayerNorm(config.embed_dim, eps=1e-6) + self.layer_norm_cross = nn.LayerNorm(config.embed_dim, eps=1e-6) + + + def forward(self,noisy_enc,src_mask,trg_seq,trg_mask,device,return_attns=False): + dec_slf_attn_list, dec_enc_attn_list = [], [] + position_indices = torch.arange(1,self.block_len+1, device=device) + emb_self = self.emb_auto(position_indices) + emb_cross = self.emb_cross(position_indices) + enc_output = noisy_enc*emb_cross #<---- addition instead of multiplication? + dec_output = self.emb_inputs(trg_seq) + enc_output = self.position_enc_cross(enc_output) + dec_output = self.position_enc_auto(dec_output) + + dec_output = self.dropout(dec_output) + dec_output = self.layer_norm(dec_output) + + enc_output = self.dropout_cross(enc_output) + enc_output = self.layer_norm_cross(enc_output) + + cross_attend = [False for _ in self.layer_stack] + cross_attend[0] = True + for dec_layer in self.layer_stack: + dec_output, dec_slf_attn, dec_enc_attn = dec_layer( + dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask) + dec_slf_attn_list += [dec_slf_attn] if return_attns else [] + dec_enc_attn_list += [dec_enc_attn] if return_attns else [] + + if return_attns: + return dec_output, dec_slf_attn_list + return dec_output # [b_size,block_len,embed_dim] + +class XFormerGPT(nn.Module): + def __init__(self, config, layer_idx=None): + super(XFormerGPT,self).__init__() + self.embed_dim = config.embed_dim + self.block_len = config.max_len + self.position_enc_auto = PositionalEncoding(self.embed_dim, n_position=self.block_len) + self.dropout = nn.Dropout(p=config.dropout) + #self.pos_emb = nn.Embedding(config.N, config.embed_dim) + #self.dropout_cross = nn.Dropout(p=config.dropout) + self.layer_stack = nn.ModuleList([ + EncoderLayer(config.embed_dim, config.embed_dim*4, config.n_head, config.embed_dim//config.n_head, config.embed_dim//config.n_head, dropout=config.dropout) + for _ in range(config.n_layers)]) + self.layer_norm = nn.LayerNorm(config.embed_dim, eps=1e-6) + self.layer_norm_cross = nn.LayerNorm(config.embed_dim, eps=1e-6) + + + def forward(self,trg_seq,trg_mask,device,return_attns=False,return_layer=None): + #position_indices = torch.arange(1,self.block_len+1, device=device) + #pos_enc = self.pos_emb(position_indices) + dec_slf_attn_list, dec_enc_attn_list = [], [] + dec_output = self.position_enc_auto(trg_seq) + dec_output = self.dropout(dec_output) + #dec_output = self.layer_norm(dec_output) + layer=1 + intermediate_layer_out = None + for dec_layer in self.layer_stack: + dec_output, dec_slf_attn = dec_layer( + dec_output, slf_attn_mask=trg_mask) + dec_slf_attn_list += [dec_slf_attn] if return_attns else [] + if return_layer is not None: + if layer == return_layer: + intermediate_layer_out = dec_output + layer += 1 + if return_attns: + return dec_output, dec_slf_attn_list + if return_layer is not None: + return dec_output, intermediate_layer_out + return dec_output # [b_size,block_len,embed_dim] + + + +class XFormerEndToEndGPT(nn.Module): + def __init__(self,config): + super(XFormerEndToEndGPT,self).__init__() + self.embed_dim = config.embed_dim + self.block_len = config.max_len + self.trg_pad_idx = 2 + self.start_embed_layer = nn.Sequential( + nn.Linear(config.N,self.embed_dim), + nn.GELU(), + nn.Linear(self.embed_dim,self.embed_dim), + nn.GELU(), + nn.Linear(self.embed_dim,self.embed_dim), + ) + self.learnt_pos = True + if not self.learnt_pos: + self.emb_inputs = nn.Embedding(2, self.embed_dim) + #self.emb_inputs = nn.Embedding(4, self.embed_dim,padding_idx=3) + else: + self.pos_emb = nn.Embedding(self.block_len, config.embed_dim) + self.layer_norm_inp = nn.LayerNorm(self.embed_dim, eps=1e-6) + self.layer_norm_out = nn.LayerNorm(self.embed_dim, eps=1e-6) + self.Decoder = XFormerGPT(config) + self.Lin_Decoder = nn.Linear(config.embed_dim,1) + + def forward(self,noisy_enc,mask,trg_seq,device,return_layer = None): + src_mask = mask + trg_seq = trg_seq[:,:-1] + if not self.learnt_pos: + trg_seq = torch.cat((torch.ones((trg_seq.size(0),1),device=device).long(),(trg_seq==-1).long()),-1) # shift inputs forward by one token + trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) # batch_size x max_len x max_len + trg_seq = self.emb_inputs(trg_seq) + else: + trg_seq = torch.cat((torch.ones((trg_seq.size(0),1),device=device),trg_seq),-1) # shift inputs forward by one token + trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) # batch_size x max_len x max_len + trg_seq = torch.ones(self.embed_dim,device=device)*trg_seq.unsqueeze(-1) + position_indices = torch.arange(self.block_len, device=device) + pos_enc = self.pos_emb(position_indices) + trg_seq = trg_seq*pos_enc + + start_emb = self.start_embed_layer(noisy_enc) + trg_seq[:,0] = start_emb + if return_layer is not None: + output,intermediate_layer_out = self.Decoder(trg_seq,trg_mask,device,return_layer=return_layer) + else: + output = self.Decoder(trg_seq,trg_mask,device) + logits = self.Lin_Decoder(output) + + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + + if return_layer is not None: + return output,decoded_msg_bits,out_mask,logits,intermediate_layer_out + + return output,decoded_msg_bits,out_mask,logits # [b_size,block_len,2] + + def decode(self,noisy_enc,info_positions,mask,device): + start_emb = self.start_embed_layer(noisy_enc) + inp_seq = torch.ones((noisy_enc.size(0),self.block_len,self.embed_dim),device=device) + inp_seq[:,0] = start_emb + inp_mask = mask.unsqueeze(1) & get_subsequent_mask(noisy_enc) + output_bits = torch.ones((noisy_enc.size(0),self.block_len),device=device) + for i in range(noisy_enc.size(1)): + if i in info_positions: + mask_i = inp_mask[:,i,:].unsqueeze(1) + output = self.Decoder(inp_seq,mask_i,device) + output = self.Lin_Decoder(output) + next_bit = output[:,i].sign() + else: + next_bit = torch.ones((noisy_enc.size(0),1),device=device) + output_bits[:,i] = next_bit[:,0] + #print(next_bit) + if i < noisy_enc.size(1)-1: + if not self.learnt_pos: + embed_next_bit = self.emb_inputs((next_bit==1).long()) + inp_seq[:,i+1] = embed_next_bit[:,0] + else: + embed_next_bit = next_bit*self.pos_emb(torch.tensor(i+1,device=device)).unsqueeze(0) + inp_seq[:,i+1] = embed_next_bit + + out_mask = mask + return output_bits,out_mask + +class StartEmbedder(nn.Module): + def __init__(self,inp_dim,hidden_dim,num_layers): + super(StartEmbedder,self).__init__() + self.inp_dim = inp_dim + self.hidden_dim = hidden_dim + self.layers = nn.ModuleList([nn.Linear(self.inp_dim,self.hidden_dim)]+[nn.Linear(hidden_dim,hidden_dim) for i in range(num_layers-1)]) + + def forward(self,x): + out = self.layers[0](x) + res = out + out = F.gelu(out) + for layer in self.layers[1:-1]: + out = layer(out) + out = F.gelu(out) + out = self.layers[-1](out) + out = out + res + return out + +class rnnAttn(nn.Module): + def __init__(self, args): + super(rnnAttn, self).__init__() + + #self.vocab_size = params['vocab_size'] + self.d_emb = 1#args.embed_dim#params['d_emb'] + self.d_hid = args.embed_dim#params['d_hid'] + self.block_len = args.N + self.n_layer = 2 + self.btz = args.batch_size + self.feature1 = multiplyFeature(args.mat) + #self.encoder = nn.Embedding(self.vocab_size, self.d_emb) + self.attn = Attention(self.d_hid) + self.rnn = nn.GRU(self.d_emb, self.d_hid, self.n_layer, batch_first=True) + self.startEmbedder1 = StartEmbedder(args.N,self.d_hid,3) + self.startEmbedder2 = StartEmbedder(args.N,self.d_hid,3) + # the combined_W maps the combined hidden states and context vectors to d_hid + self.combined_W = nn.Linear(self.d_hid * 3, self.d_hid) + self.decoder = nn.Sequential( + nn.Linear(self.d_hid,self.d_hid), + nn.GELU(), + nn.Linear(self.d_hid,self.d_hid), + nn.GELU(), + nn.Linear(self.d_hid,1), + ) + + + def forward(self,noisy_enc,mask,trg_seq,device,return_layer = None, return_attn_weights=False): + + """ + IMPLEMENT ME! + Copy your implementation of RNNLM, make sure it passes the RNNLM check + In addition to that, you need to add the following 3 things + 1. pass rnn output to attention module, get context vectors and attention weights + 2. concatenate the context vec and rnn output, pass the combined + vector to the layer dealing with the combined vectors (self.combined_W) + 3. if return_attn_weights, instead of return the [N, L, V] + matrix, return the attention weight matrix + of dimension [N, L, L] which returned from the forrward function of Attnetion module + """ + batch_size, seq_len= noisy_enc.shape + #multFeat = self.feature1(noisy_enc,device) + trg_seq = trg_seq[:,:-1] + trg_seq = torch.cat((torch.ones((trg_seq.size(0),1),device=device).long(),trg_seq),-1) + start_hidden = self.startEmbedder1(noisy_enc) + #dumb_decode = self.startEmbedder2(multFeat) + hidden = torch.cat((start_hidden.unsqueeze(1),start_hidden.unsqueeze(1)),1) + hidden = torch.transpose(hidden,0,1) + hidden = hidden.contiguous() + start_hidden = (torch.ones((batch_size,seq_len,1),device=device)*start_hidden.unsqueeze(1)) + #dumb_decode = (torch.ones((batch_size,seq_len,1),device=device)*dumb_decode.unsqueeze(1)) + #init=torch.zeros(self.n_layer, batch_size, self.d_hid).to(device) + #wordvecs = self.encoder(batch) + #print(hidden.size()) + outs,last_hidden = self.rnn(trg_seq.unsqueeze(-1),hidden) + context_vec,attn_weights = self.attn(outs) + + cat_vec = torch.cat((context_vec,outs,start_hidden),dim = -1) + dec = self.combined_W(cat_vec) + logits = self.decoder(torch.tanh(dec)) + + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + + return output,decoded_msg_bits,out_mask,logits + + def decode(self,noisy_enc,info_positions,mask,device): + batch_size, seq_len= noisy_enc.shape + inp_seq = torch.ones((noisy_enc.size(0),self.block_len),device=device) + inp_seq[:,0] = 1 + output_bits = torch.ones((noisy_enc.size(0),self.block_len),device=device) + + #multFeat = self.feature1(noisy_enc,device) + start_hidden = self.startEmbedder1(noisy_enc) + #dumb_decode = self.startEmbedder2(multFeat) + + hidden = torch.cat((start_hidden.unsqueeze(1),start_hidden.unsqueeze(1)),1) + hidden = torch.transpose(hidden,0,1) + hidden = hidden.contiguous() + + outs_arr = torch.ones((noisy_enc.size(0),self.block_len,self.d_hid),device=device) + + start_hidden = (torch.ones((batch_size,seq_len,1),device=device)*start_hidden.unsqueeze(1)) + #dumb_decode = (torch.ones((batch_size,seq_len,1),device=device)*dumb_decode.unsqueeze(1)) + for i in range(noisy_enc.size(1)): + if i in info_positions: + outs,last_hidden = self.rnn(inp_seq[:,i].unsqueeze(-1).unsqueeze(-1),hidden) + outs_arr[:,i,:] = outs.squeeze() + context_vec,_ = self.attn(outs_arr) + + cat_vec = torch.cat((context_vec,outs_arr,start_hidden),dim = -1) + dec = self.combined_W(cat_vec) + logits = self.decoder(torch.tanh(dec)) + hidden = last_hidden + next_bit = logits[:,i].sign().squeeze() + else: + outs,last_hidden = self.rnn(inp_seq[:,i].unsqueeze(-1).unsqueeze(-1),hidden) + outs_arr[:,i,:] = outs.squeeze() + hidden = last_hidden + next_bit = torch.ones((noisy_enc.size(0)),device=device) + #print(output_bits[:,i].size()) + #print(next_bit.size()) + output_bits[:,i] = next_bit + #print(next_bit) + if i < noisy_enc.size(1)-1: + inp_seq[:,i+1] = next_bit + out_mask = mask + return output_bits,out_mask + +class Attention(nn.Module): + def __init__(self, d_hidden): + super(Attention, self).__init__() + self.linear_w1 = nn.Linear(d_hidden, d_hidden) + self.linear_w2 = nn.Linear(d_hidden, 1) + + + def forward(self, x): + + """ + IMPLEMENT ME! + For each time step t + 1. Obtain attention scores for step 0 to (t-1) + This should be a dot product between current hidden state (x[:,t:t+1,:]) + and all previous states x[:, :t, :]. While t=0, since there is not + previous context, the context vector and attention weights should be of zeros. + You might find torch.bmm useful for computing over the whole batch. + 2. Turn the scores you get for 0 to (t-1) steps to a distribution. + You might find F.softmax to be helpful. + 3. Obtain the sum of hidden states weighted by the attention distribution + Concat the context vector you get in step 3. to a matrix. + + Also remember to store the attention weights, the attention matrix + for each training instance should be a lower triangular matrix. Specifically, + each row, element 0 to t-1 should sum to 1, the rest should be padded with 0. + e.g. + [ [0.0000, 0.0000, 0.0000, 0.0000], + [1.0000, 0.0000, 0.0000, 0.0000], + [0.4246, 0.5754, 0.0000, 0.0000], + [0.2798, 0.3792, 0.3409, 0.0000] ] + + Return the context vector matrix and the attention weight matrix + + """ + batch_seq_len = x.shape[1] + modif_hidden = self.linear_w1(x) + attn_logits = torch.bmm(x,modif_hidden.transpose(1,2)) + mask = torch.triu(-10000000000000000.0*torch.ones((batch_seq_len,batch_seq_len),device=x.device)) + attn_weights = nn.functional.softmax(attn_logits + mask,-1) + mult_mask = torch.ones(attn_weights.shape,device=x.device) + mult_mask[:,0,:]=0 + attn_weights = attn_weights*mult_mask + context_vecs = torch.bmm(attn_weights,x) + return context_vecs, attn_weights + +class XFormerEndToEndDecoder(nn.Module): + def __init__(self,config): + super(XFormerEndToEndDecoder,self).__init__() + self.embed_dim = config.embed_dim + self.block_len = config.max_len + self.trg_pad_idx = 3 + self.start_idx = 2 + self.Decoder = XFormerDecoder(config) + self.Lin_Decoder = nn.Linear(config.embed_dim,1) + + def forward(self,noisy_enc,mask,trg_seq,device): + src_mask = mask + trg_seq = trg_seq[:,:-1] + trg_seq = torch.cat((2*torch.ones((trg_seq.size(0),1),device=device).long(),(trg_seq==1).long()),-1) + trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) + batch_size = trg_mask.size(0) + max_len = trg_mask.size(1) + trg_mask = trg_mask.view(batch_size*max_len,max_len) + trg_maskh = torch.cat((trg_mask[:,1:],torch.zeros((trg_mask.size(0),1),device=device)),-1).float() + trg_seq = (trg_seq*torch.ones((max_len,batch_size,max_len),device=device)).long().permute((1,0,2)).reshape(batch_size*max_len,max_len) + noisy_enc = (noisy_enc*torch.ones((max_len,batch_size,max_len),device=device)).permute((1,0,2)).reshape(batch_size*max_len,max_len) + src_mask = (src_mask*torch.ones((max_len,batch_size,max_len),device=device)).permute((1,0,2)).reshape(batch_size*max_len,max_len) + + #noisy_enc : [b_size, block_len] + output = torch.ones(self.embed_dim,device=device)*noisy_enc.unsqueeze(-1) + #print(trg_mask.size()) + #noisy_enc : [b_size,block_len,embed_dim] + + output = self.Decoder(output,src_mask,trg_seq,trg_mask,device) + logits = self.Lin_Decoder(output) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = trg_mask.float() - trg_maskh + return output,decoded_msg_bits,out_mask,logits # [b_size,block_len,2] + + def decode(self,noisy_enc,info_positions, mask,device): + enc_input = torch.ones(self.embed_dim,device=device)*noisy_enc.unsqueeze(-1) + inp_seq = torch.ones((noisy_enc.size(0),self.block_len),device=device).long() + inp_seq[:,0] = 2 + inp_mask = mask.unsqueeze(1) & get_subsequent_mask(noisy_enc) + output_bits = torch.ones((noisy_enc.size(0),self.block_len),device=device) + for i in range(noisy_enc.size(1)): + if i in info_positions: + mask_i = inp_mask[:,i,:] + output = self.Decoder(enc_input,mask,inp_seq,mask_i,device) + output = self.Lin_Decoder(output) + next_bit = output[:,i].sign() + else: + next_bit = torch.ones((noisy_enc.size(0),1),device=device) + output_bits[:,i] = next_bit[:,0] + embed_next_bit = ((next_bit==1).long()) + if i < noisy_enc.size(1)-1: + inp_seq[:,i+1] = embed_next_bit[:,0] + out_mask = mask + return output_bits,out_mask + + + + + + + +class XFormerEndToEndEncoder(nn.Module): + def __init__(self,config): + super(XFormerEndToEndEncoder,self).__init__() + self.embed_dim = config.embed_dim + self.block_len = config.max_len + self.Encoder = XFormerEncoder(config) + self.Lin_Decoder = nn.Linear(config.embed_dim,1) + + def forward(self,noisy_enc,mask,trg_seq,device): + #noisy_enc : [b_size, block_len] + output = torch.ones(self.embed_dim,device=device)*noisy_enc.unsqueeze(-1) + #noisy_enc : [b_size,block_len,embed_dim] + output = self.Encoder(output,mask,device) + logits = self.Lin_Decoder(output) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + return output,decoded_msg_bits,out_mask,logits # [b_size,block_len,2] + + def decode(self,noisy_enc,info_positions,mask,device,trg_seq=None): + _,decoded_msg_bits,out_mask,_ = self.forward(noisy_enc,mask,trg_seq,device) + #decoded_msg_bits = (decoded_msg_bits==1).long() + return decoded_msg_bits,out_mask +class smallNet(nn.Module): + def __init__(self,config): + super(smallNet,self).__init__() + self.hidden_dim = config.embed_dim + self.output_len = config.N + self.layers = nn.Sequential( + nn.Linear(self.output_len , self.hidden_dim), + nn.GELU(), + nn.Linear(self.hidden_dim , self.hidden_dim), + nn.GELU(), + nn.Linear(self.hidden_dim , self.output_len) + ) + def forward(self,noisy_enc,mask,trg_seq,device): + + output = self.layers(noisy_enc) + #print(output.size()) + logits = output.squeeze().unsqueeze(-1) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + return output,decoded_msg_bits,out_mask,logits # [b_size,block_len,2] + + def decode(self,noisy_enc,info_positions,mask,device,trg_seq=None): + _,decoded_msg_bits,out_mask,_ = self.forward(noisy_enc,mask,trg_seq,device) + #decoded_msg_bits = (decoded_msg_bits==1).long() + return decoded_msg_bits,out_mask + +class multiplyFeature(nn.Module): + def __init__(self,mat): + super(multiplyFeature,self).__init__() + self.register_buffer('mat',mat) + def forward(self,x,device): + matOut = x.unsqueeze(1)*self.mat + matOut = matOut.masked_fill((matOut==0),1.) + return torch.prod(matOut,-1) + +class multiplyLayer(nn.Module): + def __init__(self,dim): + super(multiplyLayer,self).__init__() + self.linearOut = nn.Linear(dim,dim) + def forward(self,x1,x2): + prod = x1*x2 + linOut = self.linearOut(prod) + return linOut + +class convUnit(nn.Module): + def __init__(self,config): + super(convUnit,self).__init__() + self.hidden_dim = int(config.embed_dim/8) + self.input_len = config.max_len + self.output_len = 1 + bias = not config.dont_use_bias + self.kernel = 7 + self.padding = int((self.kernel-1)/2) + + self.layers1 = nn.Sequential( + nn.Conv1d(1,int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + nn.GELU(), + ) + self.layers2 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + self.layers3 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + nn.GELU(), + ) + self.layers4 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + nn.GELU(), + ) + self.layers5 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + nn.GELU(), + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + self.layersFin = nn.Sequential( + nn.Linear(self.input_len*self.hidden_dim , 1), + ) + + self.layer_norm = nn.LayerNorm(1, eps=1e-6) + self.dropout = nn.Dropout(config.dropout) + + def forward(self,noisy_enc): + input1 = noisy_enc.unsqueeze(1) + + input2 = self.layers1(input1) + + residual2 = input2 + input3 = self.layers2(input2) + residual2 + + residual3 = input3 + input4 = self.layers3(input3)+ residual3 + + residual4 = input4 + input5 = self.layers4(input4) + residual4 + + residual5 = input5 + input6 = self.layers5(input5) + + + output = self.layer_norm(self.dropout(self.layersFin(torch.flatten(input6,start_dim=1)))).squeeze() + + return output + +class bitConvNet(nn.Module): + def __init__(self,config): + super(bitConvNet,self).__init__() + self.hidden_dim = int(config.embed_dim/2) + self.input_len = config.max_len + self.output_len = config.N + self.nets = nn.ModuleList([convUnit(config) for _ in range(self.output_len)]) + + + def forward(self,noisy_enc,mask,trg_seq,device): + output = torch.ones((noisy_enc.shape[0],self.output_len)) + for i in range(self.output_len): + output[:,i] = self.nets[i](noisy_enc) + + logits = output.squeeze().unsqueeze(-1) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + return output,decoded_msg_bits,out_mask,logits,input4 + + def decode(self,noisy_enc,info_positions,mask,device,trg_seq=None): + _,decoded_msg_bits,out_mask,_,_ = self.forward(noisy_enc,mask,trg_seq,device) + #decoded_msg_bits = (decoded_msg_bits==1).long() + return decoded_msg_bits,out_mask + + +class multConvNet(nn.Module): + def __init__(self,config,mat): + super(multConvNet,self).__init__() + self.hidden_dim = int(config.embed_dim/2) + self.input_len = config.max_len + self.output_len = config.N + self.mat = mat + bias = not config.dont_use_bias + self.kernel = 65 + self.padding = int((self.kernel-1)/2) + + self.feature1 = multiplyFeature(mat) + self.layers1 = nn.Sequential( + nn.Conv1d(1,int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + #self.mult1 = multiplyLayer(config.N) + #self.layer_norm1 = nn.LayerNorm(self.output_len, eps=1e-6) + self.layers2 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + #self.mult2 = multiplyLayer(config.N) + #self.layer_norm2 = nn.LayerNorm(self.output_len, eps=1e-6) + self.layers3 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + #self.mult3 = multiplyLayer(config.N) + #self.layer_norm3 = nn.LayerNorm(self.output_len, eps=1e-6) + self.layers4 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + #self.mult4 = multiplyLayer(config.N) + #self.layer_norm4 = nn.LayerNorm(self.output_len, eps=1e-6) + self.layers5 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + #self.mult5 = multiplyLayer(config.N) + #self.layer_norm5 = nn.LayerNorm(self.output_len, eps=1e-6) + self.layersFin = nn.Sequential( + nn.Linear(self.hidden_dim*self.output_len , 4*self.output_len), + nn.GELU(), + nn.Linear(4*self.output_len , self.output_len), + nn.GELU(), + nn.Linear(self.output_len , self.output_len) + ) + + self.layer_norm = nn.LayerNorm(self.output_len, eps=1e-6) + self.dropout = nn.Dropout(config.dropout) + + def forward(self,noisy_enc,mask,trg_seq,device): + input1 = noisy_enc.unsqueeze(1) + + #inputFeat = self.feature1(input1,device) + input2 = self.layers1(input1) #+ inputFeat.unsqueeze(1) + #input2 = self.layer_norm1(self.mult1(input2[:,:int(input2.shape[1]/2)],input2[:,int(input2.shape[1]/2):])) + + residual2 = input2 + input3 = self.layers2(input2) + residual2 + #input3 = self.layer_norm2(self.mult2(input3[:,:int(input3.shape[1]/2)],input3[:,int(input3.shape[1]/2):])) + residual2 + + residual3 = input3 + input4 = self.layers3(input3) + residual3 + #input4 = self.layer_norm3(self.mult3(input4[:,:int(input4.shape[1]/2)],input4[:,int(input4.shape[1]/2):])) + residual3 + + residual4 = input4 + input5 = self.layers4(input4) + residual4 + #input5 = self.layer_norm4(self.mult4(input5[:,:int(input5.shape[1]/2)],input5[:,int(input5.shape[1]/2):])) + residual4 + + residual5 = input5 + input6 = self.layers5(input5) + + + output = self.layer_norm(self.dropout(self.layersFin(torch.flatten(input6,start_dim=1)))) + #print(output[:,[61,62,63]]) + logits = output.squeeze().unsqueeze(-1) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + return output,decoded_msg_bits,out_mask,logits,input4 + + def decode(self,noisy_enc,info_positions,mask,device,trg_seq=None): + _,decoded_msg_bits,out_mask,_,_ = self.forward(noisy_enc,mask,trg_seq,device) + #decoded_msg_bits = (decoded_msg_bits==1).long() + return decoded_msg_bits,out_mask + +class simpleNet(nn.Module): + def __init__(self,config): + super(simpleNet,self).__init__() + self.hidden_dim = config.embed_dim + self.input_len = config.max_len + self.output_len = config.N + self.kernel = 19 + self.padding = int((self.kernel-1)/2) + + self.layers1 = nn.Sequential( + nn.Conv1d(1,self.hidden_dim,self.kernel,padding=self.padding), + nn.GELU(), + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=2*self.padding,dilation=2), + nn.GELU(), + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=self.padding), + nn.GELU(), + ) + self.layers2 = nn.Sequential( + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=self.padding), + nn.GELU(), + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=2*self.padding,dilation=2), + nn.GELU(), + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=self.padding), + nn.GELU(), + ) + #self.layers3 = nn.Conv1d(self.hidden_dim,1,self.kernel,padding=self.padding) + self.layers3 = nn.Sequential( + nn.Linear(self.hidden_dim*self.output_len , self.output_len), + nn.GELU(), + nn.Linear(self.output_len , self.output_len), + nn.GELU(), + nn.Linear(self.output_len , self.output_len) + ) + def forward(self,noisy_enc,mask,trg_seq,device): + #noisy_enc : [b_size, block_len] + input1 = noisy_enc.unsqueeze(1) + residual1 = input1 + #noisy_enc : [b_size,block_len,embed_dim] + + input2 = self.layers1(input1) + residual1 + residual2 = input2 + output2 = self.layers2(input2) + residual2 + + #output = self.layers3(output2) + output = self.layers3(torch.flatten(output2,start_dim=1)) + #print(output.size()) + logits = output.squeeze().unsqueeze(-1) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + return output,decoded_msg_bits,out_mask,logits # [b_size,block_len,2] + + def decode(self,noisy_enc,info_positions,mask,device,trg_seq=None): + _,decoded_msg_bits,out_mask,_ = self.forward(noisy_enc,mask,trg_seq,device) + #decoded_msg_bits = (decoded_msg_bits==1).long() + return decoded_msg_bits,out_mask + +class convNet(nn.Module): + def __init__(self,config): + super(convNet,self).__init__() + self.hidden_dim = config.embed_dim + self.input_len = config.max_len + self.output_len = config.N + bias = not config.dont_use_bias + self.kernel = 7 + self.padding = int((self.kernel-1)/2) + + self.layers1 = nn.Sequential( + nn.Conv1d(1,int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + nn.GELU(), + ) + self.layers2 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + self.layers3 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + nn.GELU(), + ) + self.layers4 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim/2),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + nn.GELU(), + ) + self.layers5 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim/2),int(self.hidden_dim),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + nn.GELU(), + nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + ) + self.layersFin = nn.Sequential( + nn.Linear(self.hidden_dim*self.output_len , 4*self.output_len), + nn.GELU(), + nn.Linear(4*self.output_len , self.output_len), + nn.GELU(), + nn.Linear(self.output_len , self.output_len) + ) + + self.layer_norm = nn.LayerNorm(self.output_len, eps=1e-6) + self.dropout = nn.Dropout(config.dropout) + + def forward(self,noisy_enc,mask,trg_seq,device): + input1 = noisy_enc.unsqueeze(1) + + input2 = self.layers1(input1) + + residual2 = input2 + input3 = self.layers2(input2) + residual2 + + residual3 = input3 + input4 = self.layers3(input3)+ residual3 + + residual4 = input4 + input5 = self.layers4(input4) + residual4 + + residual5 = input5 + input6 = self.layers5(input5) + + + output = self.layer_norm(self.dropout(self.layersFin(torch.flatten(input6,start_dim=1)))) + #print(output.size()) + logits = output.squeeze().unsqueeze(-1) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + return output,decoded_msg_bits,out_mask,logits,input4 # [b_size,block_len,2] + + def decode(self,noisy_enc,info_positions,mask,device,trg_seq=None): + _,decoded_msg_bits,out_mask,_,_ = self.forward(noisy_enc,mask,trg_seq,device) + #decoded_msg_bits = (decoded_msg_bits==1).long() + return decoded_msg_bits,out_mask + +class bigConvNet(nn.Module): + def __init__(self,config): + super(bigConvNet,self).__init__() + self.hidden_dim = config.embed_dim + self.input_len = config.max_len + self.output_len = config.N + bias = not config.dont_use_bias + self.kernel = 33 + self.padding = int((self.kernel-1)/2) + + self.layers1 = nn.Sequential( + nn.Conv1d(1,int(self.hidden_dim),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + nn.GELU(), + nn.MaxPool1d(5,stride=1, padding=2) + ) + self.layers2 = nn.Sequential( + nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + nn.GELU(), + nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=self.padding,bias=bias), + nn.GELU(), + nn.MaxPool1d(5,stride=1, padding=2) + ) + # self.layers3 = nn.Sequential( + # nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + # nn.GELU(), + # nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + # nn.GELU(), + # ) + # self.layers4 = nn.Sequential( + # nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=self.padding,bias=bias), + # nn.GELU(), + # nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=2*self.padding,dilation=2,bias=bias), + # nn.GELU(), + # ) + # self.layers5 = nn.Sequential( + # nn.Conv1d(int(self.hidden_dim),int(self.hidden_dim),self.kernel,padding=4*self.padding,dilation=4,bias=bias), + # nn.GELU(), + # nn.Conv1d(self.hidden_dim,self.hidden_dim,self.kernel,padding=self.padding,bias=bias), + # nn.GELU(), + # ) + self.layersFin = nn.Sequential( + nn.Linear(self.hidden_dim*self.output_len , 4*self.output_len), + nn.GELU(), + nn.Linear(4*self.output_len , self.output_len), + nn.GELU(), + nn.Linear(self.output_len , self.output_len) + ) + + self.layer_norm = nn.LayerNorm(self.output_len, eps=1e-6) + self.dropout = nn.Dropout(config.dropout) + + def forward(self,noisy_enc,mask,trg_seq,device): + input1 = noisy_enc.unsqueeze(1) + + residual1 = input1 + input2 = (self.layers1(input1))+residual1 + + residual2 = input2 + input3 = (self.layers2(input2)) + residual2 + + # residual3 = input3 + # input4 = self.layers3(input3)+ residual3 + + # residual4 = input4 + # input5 = self.layers4(input4) + residual4 + + # residual5 = input5 + # input6 = self.layers5(input5) + + + output = (self.layer_norm(self.layersFin(torch.flatten(input3,start_dim=1))))#self.layer_norm(self.dropout(self.layersFin(torch.flatten(input6,start_dim=1)))) + #print(output.size()) + logits = output.squeeze().unsqueeze(-1) + decoded_msg_bits = logits.sign() + output = torch.sigmoid(logits) + output = torch.cat((1-output,output),-1) + out_mask = mask + return output,decoded_msg_bits,out_mask,logits,input2 # [b_size,block_len,2] + + def decode(self,noisy_enc,info_positions,mask,device,trg_seq=None): + _,decoded_msg_bits,out_mask,_,_ = self.forward(noisy_enc,mask,trg_seq,device) + #decoded_msg_bits = (decoded_msg_bits==1).long() + return decoded_msg_bits,out_mask \ No newline at end of file diff --git a/polar.py b/polar.py new file mode 100644 index 0000000..63fbc87 --- /dev/null +++ b/polar.py @@ -0,0 +1,1298 @@ +__author__ = 'hebbarashwin' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import matplotlib +matplotlib.use('AGG') +import matplotlib.pyplot as plt +plt.rcParams["font.family"] = "Times New Roman" +plt.rcParams.update({'font.size': 15}) +import pickle +import os +import argparse +import sys +from collections import namedtuple + +from utils import log_sum_exp, log_sum_avoid_zero_NaN, snr_db2sigma, STEQuantize, Clamp, min_sum_log_sum_exp, errors_ber, errors_bler +#from xformer_all import dec2bitarray + +def dec2bitarray(in_number, bit_width): + """ + Converts a positive integer to NumPy array of the specified size containing + bits (0 and 1). + Parameters + ---------- + in_number : int + Positive integer to be converted to a bit array. + bit_width : int + Size of the output bit array. + Returns + ------- + bitarray : 1D ndarray of ints + Array containing the binary representation of the input decimal. + """ + + binary_string = bin(in_number) + length = len(binary_string) + bitarray = np.zeros(bit_width, 'int') + for i in range(length-2): + bitarray[bit_width-i-1] = int(binary_string[length-i-1]) + + return bitarray +def get_args(): + parser = argparse.ArgumentParser(description='(N,K) Polar code') + + parser.add_argument('--N', type=int, default=4, help='Polar code parameter N') + parser.add_argument('--K', type=int, default=3, help='Polar code parameter K') + parser.add_argument('--rate_profile', type=str, default='polar', choices=['RM', 'polar', 'sorted', 'sorted_last', 'rev_polar'], help='Polar rate profiling') + parser.add_argument('--hard_decision', dest = 'hard_decision', default=False, action='store_true') + parser.add_argument('--only_args', dest = 'only_args', default=False, action='store_true') + parser.add_argument('--list_size', type=int, default=1, help='SC List size') + parser.add_argument('--crc_len', type=int, default='0', choices=[0, 3, 8, 16], help='CRC length') + + parser.add_argument('--batch_size', type=int, default=10000, help='size of the batches') + parser.add_argument('--test_ratio', type = float, default = 1, help = 'Number of test samples x batch_size') + parser.add_argument('--test_snr_start', type=float, default=-2., help='testing snr start') + parser.add_argument('--test_snr_end', type=float, default=4., help='testing snr end') + parser.add_argument('--snr_points', type=int, default=7, help='testing snr num points') + args = parser.parse_args() + + return args + +class PolarCode: + + def __init__(self, n, K, args, F = None, rs = None, use_cuda = True, infty = 1000.): + + assert n>=1 + self.args = args + self.n = n + self.N = 2**n + self.K = K + self.G2 = np.array([[1,0],[1,1]]) + self.G = np.array([1]) + for i in range(n): + self.G = np.kron(self.G, self.G2) + self.G = torch.from_numpy(self.G).float() + self.device = torch.device("cuda" if use_cuda else "cpu") + clamp_class = Clamp() + self.clamp = clamp_class.apply + self.infty = infty + + if F is not None: + assert len(F) == self.N - self.K + self.frozen_positions = F + self.unsorted_frozen_positions = self.frozen_positions + self.frozen_positions.sort() + + self.info_positions = np.array(list(set(self.frozen_positions) ^ set(np.arange(self.N)))) + self.unsorted_info_positions = self.info_positions + self.info_positions.sort() + else: + if rs is None: + # in increasing order of reliability + self.reliability_seq = np.arange(1023, -1, -1) + self.rs = self.reliability_seq[self.reliability_seq +1, 1 -> -1 + # Therefore, xor(a, b) = a*b + if custom_info_positions is not None: + info_positions = custom_info_positions + else: + info_positions = self.info_positions + u = torch.ones(message.shape[0], self.N, dtype=torch.float).to(message.device) + u[:, info_positions] = message + + for d in range(0, self.n): + num_bits = 2**d + for i in np.arange(0, self.N, 2*num_bits): + # [u v] encoded to [u xor(u,v)] + u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1) + # u[:, i:i+num_bits] = u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits].clone + if scaling is not None: + u = (scaling * np.sqrt(self.N)*u)/torch.norm(scaling) + return u + + def neural_encode_plotkin(self, message, power_constraint_type = 'hard_power_block'): + + # message shape is (batch, k) + # BPSK convention : 0 -> +1, 1 -> -1 + # Therefore, xor(a, b) = a*b + + u = torch.ones(message.shape[0], self.N, dtype=torch.float).to(self.device) + u[:, self.info_positions] = message.to(self.device) + + for d in range(0, self.n): + depth = self.n - d + num_bits = 2**d + for i in np.arange(0, self.N, 2*num_bits): + # [u v] encoded to [u xor(u,v)] + + u = torch.cat((u[:, :i], self.gnet_dict[depth-1](u[:, i:i+2*num_bits]), u[:, i+num_bits:]), dim=1) + # u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1) + # u[:, i:i+num_bits] = u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits].clone + return self.power_constraint(u, None, power_constraint_type, 'train') + + def power_constraint(self, codewords, gnet_top, power_constraint_type, training_mode): + + + if power_constraint_type in ['soft_power_block','soft_power_bit']: + + this_mean = codewords.mean(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.mean() + this_std = codewords.std(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.std() + + if training_mode == 'train': # Training + power_constrained_codewords = (codewords - this_mean)*1.0 / this_std + + gnet_top.update_normstats_for_test(this_mean, this_std) + + elif training_mode == 'test': # For inference + power_constrained_codewords = (codewords - gnet_top.mean_scalar)*1.0/gnet_top.std_scalar + + # else: # When updating the stat parameters of g2net. Just don't do anything + # power_constrained_codewords = _ + + return power_constrained_codewords + + + elif power_constraint_type == 'hard_power_block': + + return F.normalize(codewords, p=2, dim=1)*np.sqrt(self.N) + + + else: # 'hard_power_bit' + + return codewords/codewords.abs() + + def channel(self, code, snr): + sigma = snr_db2sigma(snr) + + noise = (sigma* torch.randn(code.shape, dtype = torch.float)).to(code.device) + r = code + noise + + return r + + def sc_decode(self, noisy_code, snr): + # Successive cancellation decoder for polar codes + + noise_sigma = snr_db2sigma(snr) + llrs = (2/noise_sigma**2)*noisy_code + assert noisy_code.shape[1] == self.N + decoded_bits = torch.zeros(noisy_code.shape[0], self.N) + + depth = 0 + + # function is recursively called (DFS) + # arguments: Beliefs at the input of node (LLRs at top node), depth of children, bit_position (zero at top node) + decoded_codeword, decoded_bits = self.decode(llrs, depth, 0, decoded_bits) + decoded_message = torch.sign(decoded_bits)[:, self.info_positions] + + return decoded_message + + def decode(self, llrs, depth, bit_position, decoded_bits=None): + # Function to call recursively, for SC decoder + + # print("DEPTH = {}, bit_position = {}".format(depth, bit_position)) + half_index = 2 ** (self.n - depth - 1) + + # n = 2 tree case + if depth == self.n - 1: + # Left child + left_bit_position = 2*bit_position + if left_bit_position in self.frozen_positions: + # If frozen decoded bit is 0 + u_hat = torch.ones_like(llrs[:, :half_index], dtype=torch.float) + else: + # Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) + Lu = log_sum_avoid_zero_NaN(llrs[:, :half_index], llrs[:, half_index:]).sum(dim=1, keepdim=True) + if self.args.hard_decision: + u_hat = torch.sign(Lu) + else: + u_hat = torch.tanh(Lu/2) + + # Right child + right_bit_position = 2*bit_position + 1 + if right_bit_position in self.frozen_positions: + # If frozen decoded bit is 0 + v_hat = torch.ones_like(llrs[:, :half_index], dtype = torch.float) + else: + Lv = u_hat * llrs[:, :half_index] + llrs[:, half_index:] + if self.args.hard_decision: + v_hat = torch.sign(Lv) + else: + v_hat = torch.tanh(Lv/2) + + #print("DECODED: Bit positions {} : {} and {} : {}".format(left_bit_position, u_hat, right_bit postion, v_hat)) + + decoded_bits[:, left_bit_position] = u_hat.squeeze(1) + decoded_bits[:, right_bit_position] = v_hat.squeeze(1) + + return torch.cat((u_hat * v_hat, v_hat), dim = 1).float(), decoded_bits + + # General case + else: + # LEFT CHILD + # Find likelihood of (u xor v) xor (v) = u + # Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)) + Lu = log_sum_avoid_zero_NaN(llrs[:, :half_index], llrs[:, half_index:]) + + u_hat, decoded_bits = self.decode(Lu, depth+1, bit_position*2, decoded_bits) + + # RIGHT CHILD + Lv = u_hat * llrs[:, :half_index] + llrs[:, half_index:] + v_hat, decoded_bits = self.decode(Lv, depth+1, bit_position*2 + 1, decoded_bits) + + return torch.cat((u_hat * v_hat, v_hat), dim=1), decoded_bits + + def sc_decode_soft(self, noisy_code, snr, priors=None): + # Soft successive cancellation decoder for polar codes + # Left subtree : L_u^ = LSE(L_1, L_2) + prior (like normal) + # Right subtree : L_v^ = LSE(L_u^, L_1) + L_2 + # Return up: L_1^, L_2^ = LSE(L_u^, L_v^), L_v^ + + + noise_sigma = snr_db2sigma(snr) + llrs = (2/noise_sigma**2)*noisy_code + assert noisy_code.shape[1] == self.N + decoded_bits = torch.zeros(noisy_code.shape[0], self.N) + + if priors is None: + priors = torch.zeros(self.N) + + depth = 0 + + # function is recursively called (DFS) + # arguments: Beliefs at the input of node (LLRs at top node), depth of children, bit_position (zero at top node) + decoded_codeword, decoded_bits = self.decode_soft(llrs, depth, 0, priors, decoded_bits) + decoded_message = torch.sign(decoded_bits)[:, self.info_positions] + + return decoded_message + + def decode_soft(self, llrs, depth, bit_position, prior, decoded_bits=None): + # Function to call recursively, for soft SC decoder + + # print("DEPTH = {}, bit_position = {}".format(depth, bit_position)) + half_index = 2 ** (self.n - depth - 1) + + # n = 2 tree case + if depth == self.n - 1: + # Left child + left_bit_position = 2*bit_position + + # Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) + Lu = log_sum_avoid_zero_NaN(llrs[:, :half_index], llrs[:, half_index:]).sum(dim=1, keepdim=True) + Lu = self.clamp(Lu + prior[left_bit_position]*torch.ones_like(Lu), -1000, 1000) + if self.args.hard_decision: + u_hat = torch.sign(Lu) + else: + u_hat = torch.tanh(Lu/2) + L_uv = log_sum_avoid_zero_NaN(Lu, llrs[:, :half_index]).sum(dim=1, keepdim=True) + + # Right child + right_bit_position = 2*bit_position + 1 + + Lv = L_uv + llrs[:, half_index:] + Lv = self.clamp(Lv + prior[right_bit_position]*torch.ones_like(Lv), -1000, 1000) + if self.args.hard_decision: + v_hat = torch.sign(Lv) + else: + v_hat = torch.tanh(Lv/2) + + #print("DECODED: Bit positions {} : {} and {} : {}".format(left_bit_position, u_hat, right_bit postion, v_hat)) + + decoded_bits[:, left_bit_position] = u_hat.squeeze(1) + decoded_bits[:, right_bit_position] = v_hat.squeeze(1) + + # print(depth, Lu.shape, Lv.shape, log_sum_avoid_zero_NaN(Lu, Lv).shape, torch.cat((log_sum_avoid_zero_NaN(Lu, Lv).sum(dim=1, keepdim=True), Lv), dim = 1).shape) + return torch.cat((log_sum_avoid_zero_NaN(Lu, Lv).sum(dim=1, keepdim=True), Lv), dim = 1).float(), decoded_bits + + # General case + else: + # LEFT CHILD + # Find likelihood of (u xor v) xor (v) = u + # Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)) + Lu = log_sum_avoid_zero_NaN(llrs[:, :half_index], llrs[:, half_index:]) + + L_u, decoded_bits = self.decode_soft(Lu, depth+1, bit_position*2, prior, decoded_bits) + L_uv = log_sum_avoid_zero_NaN(L_u, llrs[:, :half_index]) + + # RIGHT CHILD + Lv = L_uv + llrs[:, half_index:] + L_v, decoded_bits = self.decode_soft(Lv, depth+1, bit_position*2 + 1, prior, decoded_bits) + # print(depth, L_u.shape, L_v.shape, log_sum_avoid_zero_NaN(L_u, L_v).shape, torch.cat((log_sum_avoid_zero_NaN(L_u, L_v).sum(dim=1, keepdim=True), L_v), dim = 1).shape) + + return torch.cat((log_sum_avoid_zero_NaN(L_u, L_v), L_v), dim = 1).float(), decoded_bits + + + def define_partial_arrays(self, llrs): + # Initialize arrays to store llrs and partial_sums useful to compute the partial successive cancellation process. + llr_array = torch.zeros(llrs.shape[0], self.n+1, self.N, device=llrs.device) + llr_array[:, self.n] = llrs + partial_sums = torch.zeros(llrs.shape[0], self.n+1, self.N, device=llrs.device) + return llr_array, partial_sums + + + def updateLLR(self, leaf_position, llrs, partial_llrs = None, prior = None): + + #START + depth = self.n + decoded_bits = partial_llrs[:,0].clone() + if prior is None: + prior = torch.zeros(self.N) #priors + llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth, 0, leaf_position, prior, decoded_bits) + return llrs, decoded_bits + + + def partial_decode(self, llrs, partial_llrs, depth, bit_position, leaf_position, prior, decoded_bits=None): + # Function to call recursively, for partial SC decoder. + # We are assuming that u_0, u_1, .... , u_{leaf_position -1} bits are known. + # Partial sums computes the sums got through Plotkin encoding operations of known bits, to avoid recomputation. + # this function is implemented for rate 1 (not accounting for frozen bits in polar SC decoding) + + # print("DEPTH = {}, bit_position = {}".format(depth, bit_position)) + half_index = 2 ** (depth - 1) + leaf_position_at_depth = leaf_position // 2**(depth-1) # will tell us whether left_child or right_child + + # n = 2 tree case + if depth == 1: + # Left child + left_bit_position = 2*bit_position + if leaf_position_at_depth > left_bit_position: + u_hat = partial_llrs[:, depth-1, left_bit_position:left_bit_position+1] + elif leaf_position_at_depth == left_bit_position: + Lu = min_sum_log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True) + # Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True) + llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu + prior[left_bit_position]*torch.ones_like(Lu) + if self.args.hard_decision: + u_hat = torch.sign(Lu) + else: + u_hat = torch.tanh(Lu/2) + + decoded_bits[:, left_bit_position] = u_hat.squeeze(1) + + return llrs, partial_llrs, decoded_bits + + # Right child + right_bit_position = 2*bit_position + 1 + if leaf_position_at_depth > right_bit_position: + pass + elif leaf_position_at_depth == right_bit_position: + Lv = u_hat * llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index] + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index] + llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv + prior[right_bit_position] * torch.ones_like(Lv) + if self.args.hard_decision: + v_hat = torch.sign(Lv) + else: + v_hat = torch.tanh(Lv/2) + decoded_bits[:, right_bit_position] = v_hat.squeeze(1) + return llrs, partial_llrs, decoded_bits + + # General case + else: + # LEFT CHILD + # Find likelihood of (u xor v) xor (v) = u + # Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)) + + left_bit_position = 2*bit_position + if leaf_position_at_depth > left_bit_position: + Lu = llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + u_hat = partial_llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + else: + + Lu = min_sum_log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]) + # Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]) + llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu + llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth-1, left_bit_position, leaf_position, prior, decoded_bits) + + return llrs, partial_llrs, decoded_bits + + # RIGHT CHILD + right_bit_position = 2*bit_position + 1 + + Lv = u_hat * llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index] + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index] + llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv + llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth-1, right_bit_position, leaf_position, prior, decoded_bits) + + return llrs, partial_llrs, decoded_bits + + def updatePartialSums(self, leaf_position, decoded_bits, partial_llrs): + + u = decoded_bits.clone() + u[:, leaf_position+1:] = 0 + + for d in range(0, self.n): + partial_llrs[:, d] = u + num_bits = 2**d + for i in np.arange(0, self.N, 2*num_bits): + # [u v] encoded to [u xor(u,v)] + u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1) + partial_llrs[:, self.n] = u + return partial_llrs + + def sc_decode_new(self, corrupted_codewords, snr, use_gt = None): + + # step-wise implementation using updateLLR and updatePartialSums + sigma = snr_db2sigma(snr) + llrs = (2/sigma**2)*corrupted_codewords + + priors = torch.zeros(self.N) + priors[self.frozen_positions] = self.infty + + u_hat = torch.zeros(corrupted_codewords.shape[0], self.N, device=corrupted_codewords.device) + llr_array, partial_llrs = self.define_partial_arrays(llrs) + for ii in range(self.N): + llr_array , decoded_bits = self.updateLLR(ii, llr_array.clone(), partial_llrs, priors) + if use_gt is None: + u_hat[:, ii] = torch.sign(llr_array[:, 0, ii]) + else: + u_hat[:, ii] = use_gt[:, ii] + partial_llrs = self.updatePartialSums(ii, u_hat, partial_llrs) + decoded_bits = u_hat[:, self.info_positions] + return llr_array[:, 0, :].clone(), decoded_bits + + def updateLLR_soft(self, leaf_position, llrs, partial_llrs, prior = None): + + #START + depth = self.n + decoded_bits = partial_llrs[:,0].clone() + if prior is None: + prior = torch.zeros(self.N) #priors + llrs, partial_llrs, decoded_bits = self.partial_decode_soft(llrs, partial_llrs, depth, 0, leaf_position, prior, decoded_bits) + return llrs, decoded_bits + + + def partial_decode_soft(self, llrs, partial_llrs, depth, bit_position, leaf_position, prior, decoded_bits=None): + # Function to call recursively, for partial SC decoder. + # We are assuming that u_0, u_1, .... , u_{leaf_position -1} bits are known. + # Partial sums computes the sums got through Plotkin encoding operations of known bits, to avoid recomputation. + # this function is implemented for rate 1 (not accounting for frozen bits in polar SC decoding) + + # print("DEPTH = {}, bit_position = {}".format(depth, bit_position)) + half_index = 2 ** (depth - 1) + leaf_position_at_depth = leaf_position // 2**(depth-1) # will tell us whether left_child or right_child + + # n = 2 tree case + if depth == 1: + # Left child + left_bit_position = 2*bit_position + if leaf_position_at_depth > left_bit_position: + Lu = llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + L_u = partial_llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + #L_uv = log_sum_avoid_zero_NaN(L_u, llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index]) + elif leaf_position_at_depth == left_bit_position: + Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True) + Lu = self.clamp(Lu + prior[left_bit_position]*torch.ones_like(Lu), -1000, 1000) + + llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu + prior[left_bit_position]*torch.ones_like(Lu) + if self.args.hard_decision: + u_hat = torch.sign(Lu) + else: + u_hat = torch.tanh(Lu/2) + + decoded_bits[:, left_bit_position] = u_hat.squeeze(1) + + return llrs, partial_llrs, decoded_bits + + # Right child + right_bit_position = 2*bit_position + 1 + if leaf_position_at_depth > right_bit_position: + pass + elif leaf_position_at_depth == right_bit_position: + + L_uv = log_sum_avoid_zero_NaN(L_u, llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index]) + Lv = L_uv + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index] + Lv = self.clamp(Lv + prior[right_bit_position]*torch.ones_like(Lv), -1000, 1000) + + llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv + prior[right_bit_position] * torch.ones_like(Lv) + if self.args.hard_decision: + v_hat = torch.sign(Lv) + else: + v_hat = torch.tanh(Lv/2) + decoded_bits[:, right_bit_position] = v_hat.squeeze(1) + + return llrs, partial_llrs, decoded_bits + + # General case + else: + # LEFT CHILD + # Find likelihood of (u xor v) xor (v) = u + # Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)) + + left_bit_position = 2*bit_position + if leaf_position_at_depth > left_bit_position: + Lu = llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + L_u = partial_llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + # L_uv = log_sum_avoid_zero_NaN(L_u, llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index]) + else: + + Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]) + llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu + llrs, partial_llrs, decoded_bits = self.partial_decode_soft(llrs, partial_llrs, depth-1, left_bit_position, leaf_position, prior, decoded_bits) + + return llrs, partial_llrs, decoded_bits + + # RIGHT CHILD + right_bit_position = 2*bit_position + 1 + L_uv = log_sum_avoid_zero_NaN(L_u, llrs[:,depth, (left_bit_position)*half_index:(left_bit_position+1)*half_index]) + Lv = L_uv + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index] + llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv + llrs, partial_llrs, decoded_bits = self.partial_decode_soft(llrs, partial_llrs, depth-1, right_bit_position, leaf_position, prior, decoded_bits) + + return llrs, partial_llrs, decoded_bits + + def updatePartialSums_soft(self, leaf_position, leaf_llrs, partial_llrs): + # In the partial sum array, we store the L^ of the decoded positions. + # LLR for (u^ xor v^, v^) will be (LSE(L_u^, L_v^), L_v^) + u = leaf_llrs.clone() + u[:, leaf_position+1:] = 0 + + for d in range(0, self.n): + partial_llrs[:, d] = u + num_bits = 2**d + for i in np.arange(0, self.N, 2*num_bits): + # [Lu Lv] encoded to [lse(Lu, Lv) Lv] + u = torch.cat((u[:, :i], log_sum_avoid_zero_NaN(u[:, i:i+num_bits].clone(), u[:, i+num_bits: i+2*num_bits]).float(), u[:, i+num_bits:]), dim=1) + partial_llrs[:, self.n] = u + + return partial_llrs + + def sc_decode_soft_new(self, corrupted_codewords, snr, priors=None): + # uses updateLLR_soft and updatePartialSums_soft + + sigma = snr_db2sigma(snr) + llrs = (2/sigma**2)*corrupted_codewords + if priors is None: + priors = torch.zeros(self.N) + + u_hat = torch.zeros(corrupted_codewords.shape[0], self.N, device=corrupted_codewords.device) + llr_array, partial_llrs = self.define_partial_arrays(llrs) + for ii in range(self.N): + llr_array , decoded_bits = self.updateLLR_soft(ii, llr_array.clone(), partial_llrs, priors) + u_hat[:, ii] = torch.sign(llr_array[:, 0, ii]) + partial_llrs = self.updatePartialSums_soft(ii, llr_array[:, 0, :], partial_llrs) + decoded_bits = u_hat[:, self.info_positions] + return decoded_bits + + def neural_sc_decode(self, noisy_code, snr, p = None): + + noise_sigma = snr_db2sigma(snr) + llrs = ((2/noise_sigma**2)*noisy_code).to(self.device) + + assert noisy_code.shape[1] == self.N + # if frozen bit, llr = very large (high likelihood of 0) (P.S.: after BPSK, 0 -> +1 , 1 -> -1) + decoded_llrs = 1000*torch.ones(noisy_code.shape[0], self.N).to(self.device) + + depth = 0 + if p is None: + p = 0.5*torch.ones(self.N) + # function is recursively called (DFS) + # arguments: Beliefs at the input of node (LLRs at top node), depth of children, bit_position (zero at top node) + # depth of root node = 0, => depth of leaves will be n + + decoded_codeword, decoded_llrs = self.neural_decode(llrs, depth, 0, decoded_llrs, p) + # decoded_message = torch.sign(decoded_bits)[:, self.info_positions] + + return decoded_llrs[:, self.info_positions] + + def neural_decode(self, llrs, depth, bit_position, decoded_llrs=None, p=None): + + # print("DEPTH = {}, bit_position = {}".format(depth, bit_position)) + half_index = 2 ** (self.n - depth - 1) # helper variable: half of length of belief (LLR) vector + + + if depth == self.n - 1: # n = 2 tree case - penultimate layer of tree + # Left child + left_bit_position = 2*bit_position + if left_bit_position in self.frozen_positions: + # If frozen decoded bit is 0 + u_hat = torch.ones_like(llrs[:, :half_index], dtype=torch.float) + else: + if self.args.no_sharing_weights: + Lu = self.fnet_dict[depth+1][2*bit_position](llrs) + else: + Lu = self.fnet_dict[depth+1]['left'](llrs) + if self.args.augment: + Lu = Lu + log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) + #Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) + + prior = torch.log((p[left_bit_position])/(1 - p[left_bit_position])) + Lu = self.clamp(Lu - torch.ones_like(Lu)*prior.item(), -1000, 1000) + decoded_llrs[:, left_bit_position] = Lu.squeeze(1) + + if self.args.hard_decision: + u_hat = torch.sign(Lu) + else: + u_hat = torch.tanh(Lu/2) + + # Right child + right_bit_position = 2*bit_position + 1 + if right_bit_position in self.frozen_positions: + # If frozen decoded bit is 0 + v_hat = torch.ones_like(llrs[:, :half_index], dtype = torch.float) + else: + if self.args.no_sharing_weights: + Lv = self.fnet_dict[depth+1][2*bit_position+1](torch.cat((llrs, u_hat), dim=1)) + else: + Lv = self.fnet_dict[depth+1]['right'](torch.cat((llrs, u_hat), dim=1)) + if self.args.augment: + Lv = Lv + u_hat * llrs[:, :half_index] + llrs[:, half_index:] + prior = torch.log((p[right_bit_position])/(1 - p[right_bit_position])) + Lv = self.clamp(Lv - torch.ones_like(Lv)*prior.item(), -1000, 1000) + + decoded_llrs[:, right_bit_position] = Lv.squeeze(1) + # Lv = u_hat * llrs[:, :half_index] + llrs[:, half_index:] + if self.args.hard_decision: + v_hat = torch.sign(Lv) + else: + v_hat = torch.tanh(Lv/2) + + #print("DECODED: Bit positions {} : {} and {} : {}".format(left_bit_position, u_hat, right_bit postion, v_hat)) + + + + + if self.args.no_sharing_weights: + num_positions_on_level = 2**depth + if bit_position == num_positions_on_level - 1: + return torch.cat((u_hat * v_hat, v_hat), dim = 1).float(), decoded_llrs + else: + p0 = self.gnet_dict[depth][bit_position](torch.cat((u_hat, v_hat), dim = 1)) + return torch.cat((p0, v_hat), dim=1), decoded_llrs + else: + p0 = self.gnet_dict[depth](torch.cat((u_hat, v_hat), dim = 1)) + return torch.cat((p0, v_hat), dim=1), decoded_llrs + # return torch.cat((u_hat * v_hat, v_hat), dim = 1).float(), decoded_bits + + else: + # LEFT CHILD + # Find likelihood of (u xor v) xor (v) = u + + #Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)) + if self.args.no_sharing_weights: + Lu = self.fnet_dict[depth+1][2*bit_position](llrs) + else: + # print('LLRs device: ', llrs.device) + Lu = self.fnet_dict[depth+1]['left'](llrs.to(self.device)) + if self.args.augment: + Lu = Lu + log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) + + u_hat, decoded_llrs = self.neural_decode(Lu, depth+1, bit_position*2, decoded_llrs, p) + + # RIGHT CHILD + #Lv = u_hat * llrs[:, :half_index] + llrs[:, half_index:] + # need to verify dimensions + if self.args.no_sharing_weights: + Lv = self.fnet_dict[depth+1][2*bit_position+1](torch.cat((llrs, u_hat), dim=1)) + else: + Lv = self.fnet_dict[depth+1]['right'](torch.cat((llrs, u_hat), dim=1)) + if self.args.augment: + Lv = Lv + u_hat * llrs[:, :half_index] + llrs[:, half_index:] + v_hat, decoded_llrs = self.neural_decode(Lv, depth+1, bit_position*2 + 1, decoded_llrs, p) + + if self.args.no_sharing_weights: + num_positions_on_level = 2**depth + if bit_position == num_positions_on_level - 1: # no need to learn reconstruction of codeword + return torch.cat((u_hat * v_hat, v_hat), dim=1), decoded_llrs + else: + #reconstruct parent llr, p0 + p0 = self.gnet_dict[depth][bit_position](torch.cat((u_hat, v_hat), dim = 1)) + return torch.cat((p0, v_hat), dim=1), decoded_llrs + + else: + p0 = self.gnet_dict[depth](torch.cat((u_hat, v_hat), dim = 1)) + return torch.cat((p0, v_hat), dim=1), decoded_llrs + + def get_CRC(self, message): + + # need to optimize. + # inout message should be int + + padded_bits = torch.cat([message, torch.zeros(polar.CRC_len).int()]) + while len(padded_bits[0:polar.K_minus_CRC].nonzero()): + cur_shift = (padded_bits != 0).int().argmax(0) + padded_bits[cur_shift: cur_shift + polar.CRC_len + 1] ^= polar.CRC_polynomials[polar.CRC_len] + + return padded_bits[self.K_minus_CRC:] + + def CRC_check(self, message): + + # need to optimize. + # input message should be int + + padded_bits = message + while len(padded_bits[0:polar.K_minus_CRC].nonzero()): + cur_shift = (padded_bits != 0).int().argmax(0) + padded_bits[cur_shift: cur_shift + polar.CRC_len + 1] ^= polar.CRC_polynomials[polar.CRC_len] + + if padded_bits[polar.K_minus_CRC:].sum()>0: + return 0 + else: + return 1 + + def encode_with_crc(self, message, CRC_len): + self.CRC_len = CRC_len + self.K_minus_CRC = self.K - CRC_len + + if CRC_len == 0: + return self.encode_plotkin(message) + else: + crcs = 1-2*torch.vstack([self.get_CRC((0.5+0.5*message[jj]).int()) for jj in range(message.shape[0])]) + encoded = self.encode_plotkin(torch.cat([message, crcs], 1)) + + return encoded + + def pruneLists(self, llr_array_list, partial_llrs_list, u_hat_list, metric_list, L): + _, inds = torch.topk(-1*metric_list, L, 0) # select L gratest indices in every row + sorted_inds, _ = torch.sort(inds, 0) + batch_size = partial_llrs_list.shape[1] + + # llr_array_list = torch.gather(llr_array_list, 0, sorted_inds.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, llr_array_list.shape[2], llr_array_list.shape[3])) + # partial_llrs_list = torch.gather(partial_llrs_list, 0, sorted_inds.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, partial_llrs_list.shape[2], partial_llrs_list.shape[3])) + # metric_list = torch.gather(metric_list, 0, sorted_inds) + # u_hat_list = torch.gather(u_hat_list, 0, sorted_inds.unsqueeze(-1).repeat(1, 1, u_hat_list.shape[2])) + llr_array_list = llr_array_list[sorted_inds, torch.arange(batch_size)] + partial_llrs_list = partial_llrs_list[sorted_inds, torch.arange(batch_size)] + metric_list = metric_list[sorted_inds, torch.arange(batch_size)] + u_hat_list = u_hat_list[sorted_inds, torch.arange(batch_size)] + + return llr_array_list, partial_llrs_list, u_hat_list, metric_list + + def scl_decode(self, corrupted_codewords, snr, L=1, use_CRC = False): + + # step-wise implementation using updateLLR and updatePartialSums + sigma = snr_db2sigma(snr) + llrs = (2/sigma**2)*corrupted_codewords + batch_size = corrupted_codewords.shape[0] + + priors = torch.zeros(self.N) + # add frozen priors later only + #priors[self.frozen_positions] = self.infty + + u_hat_list = torch.zeros(1, corrupted_codewords.shape[0], self.N, device=corrupted_codewords.device) + llr_array, partial_llrs = self.define_partial_arrays(llrs) + llr_array_list = llr_array.unsqueeze(0) + partial_llrs_list = partial_llrs.unsqueeze(0) + metric_list = torch.zeros(1, llrs.shape[0]) + for ii in range(self.N): + list_size = llr_array_list.shape[0] + if ii in self.frozen_positions: + llr_array , decoded_bits = self.updateLLR(ii, llr_array_list.reshape(-1, self.n+1, self.N).clone(), partial_llrs_list.reshape(-1, self.n+1, self.N), priors) + metric = torch.abs(llr_array[:, 0, ii])*(llr_array[:, 0, ii].sign() != 1*torch.ones(llr_array.shape[0])).float() + # add the infty prior only later, since metric uses |LLR| + llr_array[:, 0, ii] = llr_array[:, 0, ii] + self.infty * torch.ones_like(llr_array[:, 0, ii]) + + u_hat_list[:, :, ii] = torch.ones(list_size, batch_size, device=corrupted_codewords.device) + partial_llrs = self.updatePartialSums(ii, u_hat_list.reshape(-1, self.N), partial_llrs_list.reshape(-1, self.n+1, self.N).clone()) + + + llr_array_list = llr_array.reshape(list_size, batch_size, self.n+1, self.N) + partial_llrs_list = partial_llrs.reshape(list_size, batch_size, self.n+1, self.N) + metric_list = metric_list + metric.reshape(list_size, batch_size) + + assert llr_array_list.shape[0] == partial_llrs_list.shape[0] == metric_list.shape[0] == u_hat_list.shape[0] + + else: + llr_array , decoded_bits = self.updateLLR(ii, llr_array_list.reshape(-1, self.n+1, self.N).clone(), partial_llrs_list.reshape(-1, self.n+1, self.N), priors) + metric = torch.abs(llr_array[:, 0, ii]) + + #Duplicate lists + u_hat_list = torch.vstack([u_hat_list, u_hat_list]) + u_hat_list[:list_size, :, ii] = torch.sign(llr_array[:, 0, ii]).reshape(list_size, batch_size) + u_hat_list[list_size:, :, ii] = -1* torch.sign(llr_array[:, 0, ii]).reshape(list_size, batch_size) + + # same LLRs for both decisions + llr_array_list = torch.vstack([llr_array.reshape(list_size, batch_size, self.n+1, self.N), llr_array.reshape(list_size, batch_size, self.n+1, self.N)]) + llr_array_list = torch.vstack([llr_array.reshape(list_size, batch_size, self.n+1, self.N), llr_array.reshape(list_size, batch_size, self.n+1, self.N)]) + + # update partial sums for both decisions + partial_llrs_list = self.updatePartialSums(ii, u_hat_list.reshape(-1, self.N), torch.vstack([partial_llrs_list, partial_llrs_list]).reshape(-1, self.n+1, self.N).clone()).reshape(2*list_size, batch_size, self.n+1, self.N) + # no additional penalty for SC path + metric_list = torch.vstack([metric_list, metric_list + metric.reshape(list_size, batch_size)]) + + if llr_array_list.shape[0] > L: # prune list + llr_array_list, partial_llrs_list, u_hat_list, metric_list = self.pruneLists(llr_array_list, partial_llrs_list, u_hat_list, metric_list, L) + + list_size = llr_array_list.shape[0] + if use_CRC: + u_hat = u_hat_list[:, :, self.info_positions] + decoded_bits = torch.zeros(batch_size, self.K_minus_CRC) + llr_array = torch.zeros(batch_size, self.N) + + # optimize this later + crc_checked = torch.zeros(list_size).int() + for jj in range(batch_size): + for kk in range(list_size): + crc_checked[kk] = self.CRC_check((0.5+0.5*u_hat[kk, jj]).int()) + + if crc_checked.sum() == 0: #no code in list passes. pick lowest metric + decoded_bits[jj] = u_hat[metric_list[:, jj].argmin(), jj, :self.K_minus_CRC] + llr_array[jj] = llr_array_list[metric_list[:, jj].argmin(), jj, 0, :] + else: # pick code that has lowest metric among ones that passed crc + inds = crc_checked.nonzero() + decoded_bits[jj] = u_hat[inds[metric_list[inds, jj].argmin()], jj, :self.K_minus_CRC] + llr_array[jj] = llr_array_list[inds[metric_list[inds, jj].argmin()], jj, 0, :] + + else: # do ML decision among the L messages in the list + u_hat = u_hat_list[:, :, self.info_positions] + codeword_list = self.encode_plotkin(u_hat.reshape(-1, self.K)).reshape(list_size, batch_size, self.N) + inds = ((codeword_list - corrupted_codewords.unsqueeze(0))**2).sum(2).argmin(0) + # get ML decision for each sample. + decoded_bits = u_hat[inds, torch.arange(batch_size)] + llr_array = llr_array_list[inds, torch.arange(batch_size), 0, :] + + return llr_array, decoded_bits + + + def bitwise_MAP(self,noisy_enc,device,snr): # take bitwise independent map decisions and return output, this does not use the approximation -> log sum exp = max + sigma = snr_db2sigma(snr) + noisy_enc=(2/sigma**2)*noisy_enc + all_msg_bits = [] + for i in range(2**self.K): + d = dec2bitarray(i, self.K) + all_msg_bits.append(d) + all_message_bits = torch.from_numpy(np.array(all_msg_bits)).to(device) + all_message_bits = 1 - 2*all_message_bits.float() + #codebooks = [] + outputs = torch.ones(noisy_enc.shape[0],self.K,device=device) + for bit in range(self.K): + codebook1 = self.encode_plotkin(all_message_bits[all_message_bits[:,bit]==1.]) + codebook2 = self.encode_plotkin(all_message_bits[all_message_bits[:,bit]==-1.]) + dec1 = torch.logsumexp(torch.matmul(codebook1,noisy_enc.T).T,-1).unsqueeze(0) + dec2 = torch.logsumexp(torch.matmul(codebook2,noisy_enc.T).T,-1).unsqueeze(0) + dec = torch.cat((dec1,dec2),0) + bit_dec = 1.-2.*torch.max(dec,0).indices + outputs[:,bit] = bit_dec.T + + return outputs + + + def get_generator_matrix(self,custom_info_positions=None): + if custom_info_positions is not None: + info_inds = custom_info_positions + else: + info_inds = self.info_positions + msg = 1-2*torch.eye(self.K) + code = 1.*(self.encode_plotkin(msg)==-1.) + mat = torch.zeros((self.N,self.N)) + mat[info_inds,:] = code + mat = mat.T + return mat + + def get_min_xor_matrix(self): + gen_mat = self.get_generator_matrix() + xor_mat = gen_mat[polar.info_positions,:] + return xor_mat + + def get_difficulty_seq(self,unrolling_seq): + difficulty_seq = torch.zeros((self.N,self.K)) + gen_mat = self.get_generator_matrix() + count = 0 + for bit in unrolling_seq: + u = unrolling_seq[0:count+1] + u.sort() + difficulty = torch.sum(gen_mat[:,u],1) + difficulty_seq[u,count] = difficulty[u] + count += 1 + fin = difficulty_seq[self.info_positions,:] + shifted = fin.clone() + transfer = fin.clone() + transfer[:,0] = 0 + shifted[:,:-1] = shifted[:,1:]-shifted[:,:-1] + transfer[:,1:] = shifted[:,:-1] + return fin,transfer + + def calculate_transfer_metric(self,unrolling_seq): + _,deltas = self.get_difficulty_seq(unrolling_seq) + avg = torch.sum(deltas)/torch.sum(1.0*(deltas > 0)) + return torch.max(deltas).item(),avg.item() + + + def plot_standard_schemes(self,path='data'): + h2e = self.unsorted_info_positions.tolist() + e2h = self.unsorted_info_positions.tolist() + e2h.reverse() + l2r = self.info_positions.tolist() + r2l = self.info_positions.tolist() + r2l.reverse() + bottom = -1 + top = 10 + + path = path + '/polar_transfer_{0}_{1}'.format(self.K,self.N) + os.makedirs(path, exist_ok=True) + + diff_seq_h2e1,diff_seq_h2e_transfer1 = self.get_difficulty_seq(h2e) + diff_seq_h2e = diff_seq_h2e1.tolist() + diff_seq_h2e_transfer = diff_seq_h2e_transfer1.tolist() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(h2e)))], diff_seq_h2e[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.ylabel("Learning Difficulty") + plt.xlabel("Progressive training") + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.savefig(path +'/polar_h2e_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("H2E plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(h2e)))], diff_seq_h2e_transfer[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.ylabel("Transfer Difficulty") + plt.xlabel("Progressive training") + plt.savefig(path +'/polar_transfer_h2e_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("H2E plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + + diff_seq_e2h1,diff_seq_e2h_transfer1 = self.get_difficulty_seq(e2h) + diff_seq_e2h = diff_seq_e2h1.tolist() + diff_seq_e2h_transfer = diff_seq_e2h_transfer1.tolist() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(e2h)))], diff_seq_e2h[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.ylabel("Learning Difficulty") + plt.xlabel("Progressive training") + plt.savefig(path +'/polar_e2h_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("e2h plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(h2e)))], diff_seq_e2h_transfer[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.savefig(path +'/polar_transfer_e2h_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("e2h plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + + diff_seq_l2r1,diff_seq_l2r_transfer1 = self.get_difficulty_seq(l2r) + diff_seq_l2r = diff_seq_l2r1.tolist() + diff_seq_l2r_transfer = diff_seq_l2r_transfer1.tolist() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(l2r)))], diff_seq_l2r[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.ylabel("Learning Difficulty") + plt.xlabel("Progressive training") + plt.savefig(path +'/polar_l2r_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("l2r plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(h2e)))], diff_seq_l2r_transfer[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.savefig(path +'/polar_transfer_l2r_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("l2r plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + + diff_seq_r2l1,diff_seq_r2l_transfer1 = self.get_difficulty_seq(r2l) + diff_seq_r2l = diff_seq_r2l1.tolist() + diff_seq_r2l_transfer = diff_seq_r2l_transfer1.tolist() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(r2l)))], diff_seq_r2l[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.ylabel("Learning Difficulty") + plt.xlabel("Progressive training") + plt.savefig(path +'/polar_r2l_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("r2l plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + plt.figure(figsize = (20,10)) + for i in range((self.K)): + plt.step([float(elem) for elem in list(range(len(h2e)))], diff_seq_r2l_transfer[i], label = 'Bit {0}'.format(i)) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.savefig(path +'/polar_transfer_r2l_all_{0}_{1}.pdf'.format(self.K,self.N)) + plt.title("r2l plot , Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.close() + + max1,avg1 = self.calculate_transfer_metric(h2e) + print("Max Transfer Difficulty for H2E : {0}".format(max1)) + print("Avg Transfer Difficulty for H2E : {0}".format(avg1)) + print("\n") + max1,avg1 = self.calculate_transfer_metric(e2h) + print("Max Transfer Difficulty for E2H : {0}".format(max1)) + print("Avg Transfer Difficulty for E2h : {0}".format(avg1)) + print("\n") + max1,avg1 = self.calculate_transfer_metric(l2r) + print("Max Transfer Difficulty for L2R : {0}".format(max1)) + print("Avg Transfer Difficulty for L2R : {0}".format(avg1)) + print("\n") + max1,avg1 = self.calculate_transfer_metric(r2l) + print("Max Transfer Difficulty for R2L : {0}".format(max1)) + print("Avg Transfer Difficulty for R2L : {0}".format(avg1)) + print("\n") + + + for i in range(self.K): + fig = plt.figure(figsize = (20,10)) + ax = fig.add_subplot(1, 1, 1) + #plt.step([float(elem) for elem in list(range(len(h2e)))], diff_seq_h2e[i], label = 'H2E Bit {0}'.format(i), marker='*', linewidth=1.5) + #plt.step([float(elem) for elem in list(range(len(e2h)))], diff_seq_e2h[i], label = 'E2H Bit {0}'.format(i), marker='^', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(l2r)))], diff_seq_l2r[i], label = ' ', where='post', color='tab:orange', marker='^', linewidth=2.5) + plt.step([float(elem) for elem in list(range(len(r2l)))], diff_seq_r2l[i], label = ' ', where='post', color='g',marker='v', linewidth=2.5) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + #plt.title("Plot for bit {0}, Hardest to easiest order : {1}".format(i,np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.legend(prop={'size': 15},loc='lower left', bbox_to_anchor=(0.85,0.9)) + major_ticks = np.arange(0, self.K+1, 1) + majory_ticks = np.arange(0, 10, 1) + ax.set_xticks(major_ticks) + ax.set_yticks(majory_ticks) + ax.grid(which='major') + #plt.ylabel("Learning Difficulty",fontsize=20) + #plt.xlabel("Progressive training steps",fontsize=20) + plt.savefig(path +'/polar_all_{0}_{1}_bit_{2}.pdf'.format(self.K,self.N,i)) + plt.close() + + fig = plt.figure(figsize = (20,10)) + ax = fig.add_subplot(1, 1, 1) + plt.step([float(elem) for elem in list(range(len(h2e)))], diff_seq_h2e_transfer[i], label = 'H2E Bit {0}'.format(i), marker='*', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(e2h)))], diff_seq_e2h_transfer[i], label = 'E2H Bit {0}'.format(i), marker='^', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(l2r)))], diff_seq_l2r_transfer[i], label = 'L2R Bit {0}'.format(i), marker='o', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(r2l)))], diff_seq_r2l_transfer[i], label = 'R2L Bit {0}'.format(i), marker='x', linewidth=1.5) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Transfer Plot for bit {0}, Hardest to easiest order : {1}".format(i,np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.legend(prop={'size': 15},loc='upper right', bbox_to_anchor=(1.1, 1)) + major_ticks = np.arange(0, self.K+1, 1) + ax.set_xticks(major_ticks) + ax.set_yticks(majory_ticks) + ax.grid(which='major') + plt.ylabel("Transfer Difficulty",fontsize=20) + plt.xlabel("Progressive training steps",fontsize=20) + plt.savefig(path +'/polar_transfer_all_{0}_{1}_bit_{2}.pdf'.format(self.K,self.N,i)) + plt.close() + + fig = plt.figure(figsize = (20,10)) + ax = fig.add_subplot(1, 1, 1) + plt.step([float(elem) for elem in list(range(len(h2e)))], torch.sum(diff_seq_h2e1,0), label = 'H2E sum', marker='*', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(e2h)))], torch.sum(diff_seq_e2h1,0), label = 'E2H sum', marker='^', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(l2r)))], torch.sum(diff_seq_l2r1,0), label = 'L2R sum', marker='o', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(r2l)))], torch.sum(diff_seq_r2l1,0), label = 'R2L sum', marker='x', linewidth=1.5) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Sum plot, Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.legend(prop={'size': 15},loc='upper right', bbox_to_anchor=(1.1, 1)) + major_ticks = np.arange(0, self.K+1, 1) + majory_ticks = np.arange(0, 180, 10) + ax.set_xticks(major_ticks) + ax.set_yticks(majory_ticks) + ax.grid(which='major') + plt.ylabel("Learning Difficulty") + plt.xlabel("Progressive training steps") + plt.savefig(path +'/all_polar_transfer_sum.pdf') + plt.close() + + fig = plt.figure(figsize = (20,10)) + ax = fig.add_subplot(1, 1, 1) + plt.step([float(elem) for elem in list(range(len(h2e)))], torch.max(diff_seq_h2e1,0).values, label = 'H2E max', marker='*', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(e2h)))], torch.max(diff_seq_e2h1,0).values, label = 'E2H max', marker='^', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(l2r)))], torch.max(diff_seq_l2r1,0).values, label = 'L2R max', marker='o', linewidth=1.5) + plt.step([float(elem) for elem in list(range(len(r2l)))], torch.max(diff_seq_r2l1,0).values, label = 'R2L max', marker='x', linewidth=1.5) + plt.ylim(bottom=bottom) + plt.ylim(top=top) + plt.title("Max plot, Hardest to easiest order : {0}".format(np.argsort(np.argsort(self.unsorted_info_positions.copy())))) + plt.legend(prop={'size': 15},loc='upper right', bbox_to_anchor=(1.1, 1)) + major_ticks = np.arange(0, self.K+1, 1) + majory_ticks = np.arange(0, 10, 1) + ax.set_xticks(major_ticks) + ax.set_yticks(majory_ticks) + ax.grid(which='major') + plt.ylabel("Learning Difficulty") + plt.xlabel("Progressive training steps") + plt.savefig(path +'/all_polar_transfer_max.pdf') + plt.close() + + + + +if __name__ == '__main__': + args = get_args() + + n = int(np.log2(args.N)) + + + if args.ratel2rofile == 'polar': + # computed for SNR = 0 + if n == 5: + rs = np.array([31, 30, 29, 27, 23, 15, 28, 26, 25, 22, 21, 14, 19, 13, 11, 24, 7, 20, 18, 12, 17, 10, 9, 6, 5, 3, 16, 8, 4, 2, 1, 0]) + + elif n == 4: + rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 5, 3, 8, 4, 2, 1, 0]) + elif n == 3: + rs = np.array([7, 6, 5, 3, 4, 2, 1, 0]) + elif n == 2: + rs = np.array([3, 2, 1, 0]) + + rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1 + rs = rs[rs 0.5).float() + # msg_bits = 1 - 2*torch.zeros(args.batch_size, args.K).float() + # scl_msg_bits = 1 - 2*(torch.rand(args.batch_size, args.K - args.crc_len) > 0.5).float() + codes = polar.encode_plotkin(msg_bits) + # scl_codes = polar.encode_with_crc(scl_msg_bits, args.crc_len) + + + for snr_ind, snr in enumerate(snr_range): + + # codes_G = polar.encode_G(msg_bits_bpsk) + noisy_code = polar.channel(codes, snr) + noise = noisy_code - codes + # scl_noisy_code = scl_codes + noise + + SC_llrs, decoded_SC_msg_bits = polar.sc_decode_new(noisy_code, snr) + ber_SC = errors_ber(msg_bits, decoded_SC_msg_bits.sign()).item() + bler_SC = errors_bler(msg_bits, decoded_SC_msg_bits.sign()).item() + + # print("SNR = {}, BER = {}, BLER = {}".format(snr, bit_error_rate, bler)) + bers_SC[snr_ind] += ber_SC/args.test_ratio + blers_SC[snr_ind] += bler_SC/args.test_ratio + + # SCL_llrs, decoded_SCL_msg_bits = polar.scl_decode(scl_noisy_code, snr, args.list_size) + # ber_SCL = errors_ber(scl_msg_bits, decoded_SCL_msg_bits.sign()).item() + # bler_SCL = errors_bler(scl_msg_bits, decoded_SCL_msg_bits.sign()).item() + # print("SNR = {}, BER = {}, BLER = {}".format(snr, bit_error_rate, bler)) + + SCL_llrs, decoded_SCL_msg_bits = polar.scl_decode(noisy_code, snr, args.list_size, use_CRC = False) + ber_SCL = errors_ber(msg_bits, decoded_SCL_msg_bits.sign()).item() + bler_SCL = errors_bler(msg_bits, decoded_SCL_msg_bits.sign()).item() + + bers_SCL[snr_ind] += ber_SCL/args.test_ratio + blers_SCL[snr_ind] += bler_SCL/args.test_ratio + + print("Test SNRs : ", snr_range) + print("BERs of SC: {0}".format(bers_SC)) + print("BERs of SCL: {0}".format(bers_SCL)) + print("BLERs of SC: {0}".format(blers_SC)) + print("BLERs of SCL: {0}".format(blers_SCL)) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..0274b9c --- /dev/null +++ b/utils.py @@ -0,0 +1,534 @@ +import torch +import numpy as np +from collections import Counter, OrderedDict + +def snr_db2sigma(train_snr): + return 10**(-train_snr*1.0/20) + +def get_msg_bits_batch(data_generator): + msg_bits_batch = next(data_generator) + return msg_bits_batch + +def moving_average(a, n=3) : + ret = np.cumsum(a, dtype=float) + ret[n:] = ret[n:] - ret[:-n] + return ret[n - 1:] / n + +def errors_ber(y_true, y_pred, mask=None): + if mask == None: + mask=torch.ones(y_true.size(),device=y_true.device) + y_true = y_true.view(y_true.shape[0], -1, 1) + y_pred = y_pred.view(y_pred.shape[0], -1, 1) + mask = mask.view(mask.shape[0], -1, 1) + myOtherTensor = (mask*torch.ne(torch.round(y_true), torch.round(y_pred))).float() + res = sum(sum(myOtherTensor))/(torch.sum(mask)) + return res + +def errors_bitwise_ber(y_true, y_pred, mask=None): + if mask == None: + mask=torch.ones(y_true.size(),device=y_true.device) + y_true = y_true.view(y_true.shape[0], -1, 1) + y_pred = y_pred.view(y_pred.shape[0], -1, 1) + mask = mask.view(mask.shape[0], -1, 1) + myOtherTensor = (mask*torch.ne(torch.round(y_true), torch.round(y_pred))).float() + res = torch.sum(myOtherTensor,0)/torch.sum(mask,0) + return res + +def errors_bler(y_true, y_pred, get_pos = False): + y_true = y_true.view(y_true.shape[0], -1, 1) + y_pred = y_pred.view(y_pred.shape[0], -1, 1) + + decoded_bits = torch.round(y_pred).cpu() + X_test = torch.round(y_true).cpu() + tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) + tp0 = tp0.detach().cpu().numpy() + bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) + + if not get_pos: + return bler_err_rate + else: + err_pos = list(np.nonzero((np.sum(tp0,axis=1)>0).astype(int))[0]) + return bler_err_rate, err_pos + +def extract_block_errors(y_true, y_pred, thresh=0): + y_true_out = y_true.clone() + y_pred_out = y_pred.clone() + y_true = y_true.view(y_true.shape[0], -1, 1) + y_pred = y_pred.view(y_pred.shape[0], -1, 1) + + decoded_bits = torch.round(y_pred).cpu() + X_test = torch.round(y_true).cpu() + tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) + tp0 = tp0.detach().cpu().numpy() + bler_err_rate = (np.sum(tp0,axis=1)>thresh)*1.0 + return np.where(bler_err_rate > 0) + +def extract_block_nonerrors(y_true, y_pred, thresh=1): + y_true_out = y_true.clone() + y_pred_out = y_pred.clone() + y_true = y_true.view(y_true.shape[0], -1, 1) + y_pred = y_pred.view(y_pred.shape[0], -1, 1) + + decoded_bits = torch.round(y_pred).cpu() + X_test = torch.round(y_true).cpu() + tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) + tp0 = tp0.detach().cpu().numpy() + bler_correct_rate = (np.sum(tp0,axis=1) 0) + +def get_epos(k1, k2): + # return counter for bit ocations of first-errors + bb = torch.ne(k1.cpu().sign(), k2.cpu().sign()) + # inds = torch.nonzero(bb)[:, 1].numpy() + idx = [] + for ii in range(bb.shape[0]): + try: + iii = list(bb.cpu().float().numpy()[ii]).index(1) + idx.append(iii) + except: + pass + counter = Counter(idx) + ordered_counter = OrderedDict(sorted(counter.items())) + return ordered_counter + +def countSetBits(n): + count = 0 + while (n): + n &= (n-1) + count+= 1 + return count + +def get_minD(code): + all_msg_bits = [] + + for i in range(2**code.K): + d = dec2bitarray(i, code.K) + all_msg_bits.append(d) + all_message_bits = torch.from_numpy(np.array(all_msg_bits)) + all_message_bits = 1 - 2*all_message_bits.float() + + codebook = 0.5*code.encode(all_message_bits)+0.5 + b_codebook = codebook.unsqueeze(0) + dist = 1000 + for ii in range(codebook.shape[0]): + a = ((b_codebook[:, ii] - codebook)**2).sum(1) + a[ii] = 1000 + m = torch.min(a) + if m < dist: + dist = m + return dist + +def get_pairwiseD(code, size = None): + if size is None: + all_msg_bits = [] + + for i in range(2**code.K): + d = dec2bitarray(i, code.K) + all_msg_bits.append(d) + all_message_bits = torch.from_numpy(np.array(all_msg_bits)) + all_message_bits = 1 - 2*all_message_bits.float() + else: + all_message_bits = 1 - 2 *(torch.rand(size, code.K) < 0.5).float() + codebook = 0.5*code.encode(all_message_bits)+0.5 + b_codebook = codebook.unsqueeze(0) + + dist_counts = {} + for ii in range(codebook.shape[0]): + a = ((b_codebook[:, ii] - codebook)**2).sum(1) + counts = Counter(a.int().numpy()) + for key in counts: + if key not in dist_counts.keys(): + dist_counts[key] = counts[key] + else: + dist_counts[key] += counts[key] + + # minimum distance : np.sqrt(min(dist_counts.keys())) + # average distance : np.array([np.sqrt(key)*value for (key, value) in dist_counts.items()]).sum()/np.array(list(dist_counts.values())).sum() + return {key:value//2 for (key, value) in dist_counts.items() if key != 0} + +def get_pairwiseD_weight(code, size = None): + if size is None: + all_msg_bits = [] + + for i in range(2**code.K): + d = dec2bitarray(i, code.K) + all_msg_bits.append(d) + all_message_bits = torch.from_numpy(np.array(all_msg_bits)) + all_message_bits = 1 - 2*all_message_bits.float() + else: + all_message_bits = 1 - 2 *(torch.rand(size, code.K) < 0.5).float() + codebook = 1 - (0.5*code.encode(all_message_bits)+0.5) + b_codebook = codebook.unsqueeze(0) + + a = codebook.sum(1) + dist_counts = Counter(a.int().numpy()) + # minimum distance : np.sqrt(min(dist_counts.keys())) + # average distance : np.array([np.sqrt(key)*value for (key, value) in dist_counts.items()]).sum()/np.array(list(dist_counts.values())).sum() + return dist_counts + + +def dec2bitarray(in_number, bit_width): + """ + Converts a positive integer to NumPy array of the specified size containing + bits (0 and 1). + Parameters + ---------- + in_number : int + Positive integer to be converted to a bit array. + bit_width : int + Size of the output bit array. + Returns + ------- + bitarray : 1D ndarray of ints + Array containing the binary representation of the input decimal. + """ + + binary_string = bin(in_number) + length = len(binary_string) + bitarray = np.zeros(bit_width, 'int') + for i in range(length-2): + bitarray[bit_width-i-1] = int(binary_string[length-i-1]) + + return bitarray + +def bitarray2dec(in_bitarray): + """ + Converts an input NumPy array of bits (0 and 1) to a decimal integer. + Parameters + ---------- + in_bitarray : 1D ndarray of ints + Input NumPy array of bits. + Returns + ------- + number : int + Integer representation of input bit array. + """ + + number = 0 + + for i in range(len(in_bitarray)): + number = number + in_bitarray[i]*pow(2, len(in_bitarray)-1-i) + + return number + +class STEQuantize(torch.autograd.Function): + #self.args.fb_quantize_limit, self.args.fb_quantize_level + @staticmethod + def forward(ctx, inputs, quant_limit=1, quant_level=2): + + ctx.save_for_backward(inputs) + + x_lim_abs = quant_limit + x_lim_range = 2.0 * x_lim_abs + x_input_norm = torch.clamp(inputs, -x_lim_abs, x_lim_abs) + + if quant_level == 2: + outputs_int = torch.sign(x_input_norm) + else: + outputs_int = torch.round((x_input_norm +x_lim_abs) * ((quant_level - 1.0)/x_lim_range)) * x_lim_range/(quant_level - 1.0) - x_lim_abs + + return outputs_int + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + + # let's see what happens.... + # grad_output[torch.abs(input)>1.5]=0 + # grad_output[torch.abs(input)<0.5]=0 + + # grad_output[input>1.0]=0 + # grad_output[input<-1.0]=0 + + grad_output = torch.clamp(grad_output, -0.25, +0.25) + + grad_input = grad_output.clone() + + return grad_input, None, None, None + +class STESign(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + return torch.sign(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.clamp_(-1, 1) + +class Clamp(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, min=0, max=1): + return input.clamp(min=min+(1e-10), max=max-(1e-10)) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.clone(), None, None + +def snr_db2sigma(train_snr): + return 10**(-train_snr*1.0/20) + +def min_sum_log_sum_exp(x, y): + + log_sum_ms = torch.min(torch.abs(x), torch.abs(y))*torch.sign(x)*torch.sign(y) + return log_sum_ms + +def log_sum_exp_diff(x, y): + + c1 = torch.max(x+y,torch.zeros_like(x)) + c2 = torch.max(x, y) + + # log_sum_standard = torch.log(1 + (x+y).exp()) - x - torch.log(1 + (y-x).exp() ) + log_sum_standard = c1 + torch.log((-c1).exp() + (x+y-c1).exp()) - c2 - torch.log((x-c2).exp() + (y-c2).exp()) + + # log_sum_standard = torch.min(torch.abs(x), torch.abs(y))*torch.sign(x)*torch.sign(y) + return log_sum_standard + +def log_sum_exp(LLR_vector): + + sum_vector = LLR_vector.sum(dim=1, keepdim=True) + sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) + + return torch.logsumexp(sum_concat, dim=1)- torch.logsumexp(LLR_vector, dim=1) + +def log_sum_avoid_NaN(x, y): + + a = torch.max(x, y) + b = torch.min(x, y) + + log_sum_standard = torch.log(1 + (x+y).exp()) - x - torch.log(1 + (y-x).exp() ) + + # print("Original one:", log_sum_standard) + + ## Check for NaN or infty or -infty once here. + if (torch.isnan(log_sum_standard).sum() > 0) | ((log_sum_standard == float('-inf')).sum() > 0 )| ( (log_sum_standard == float('inf')).sum() > 0) : + + # print("Had to avoid NaNs!") + # 80 for float32 and 707 for float64. + #big_threshold = 80. if log_sum_standard.dtype == torch.float32 else 700. + big_threshold = 200. if log_sum_standard.dtype == torch.float32 else 700. + idx_1 = (x + y > big_threshold) + subset_1 = idx_1 & ((x-y).abs() < big_threshold) + + idx_2 = (x + y < -big_threshold) + subset_2 = idx_2 & ((x-y).abs() < big_threshold) + + idx_3 = ((x - y).abs() > big_threshold) & ( (x+y).abs() < big_threshold ) + + # Can be fastened + if idx_1.sum() > 0 : + + if subset_1.sum() > 0: + log_sum_standard[subset_1] = y[subset_1]- torch.log(1 + (y[subset_1] - x[subset_1]).exp() ) + # print("After 11 modification", log_sum_standard) + + if (idx_1 ^ subset_1).sum() > 0: + log_sum_standard[idx_1 ^ subset_1] = b[idx_1 ^ subset_1] + # print("After 12 modification", log_sum_standard) + + if idx_2.sum() > 0: + + if subset_2.sum() > 0: + log_sum_standard[subset_2] = -x[subset_2]- torch.log(1 + (y[subset_2] - x[subset_2]).exp() ) + # print("After 21 modification", log_sum_standard) + + if (idx_2 ^ subset_2).sum() > 0: + log_sum_standard[idx_2 ^ subset_2] = -a[idx_2 ^ subset_2] + # print("After 22 modification", log_sum_standard) + + if idx_3.sum() > 0: + + log_sum_standard[idx_3] = torch.log(1 + (x[idx_3]+ y[idx_3]).exp() ) - a[idx_3] + # print("After 3 modification", log_sum_standard) + + return log_sum_standard + + +def log_sum_avoid_zero_NaN(x, y): + + avoided_NaN = log_sum_avoid_NaN(x,y) + + zero_idx = (avoided_NaN == 0.) + + data_type = x.dtype + + if zero_idx.sum() > 0: + + # print("Had to avoid zeros!") + + x_subzero = x[zero_idx] + y_subzero = y[zero_idx] + + nume = torch.relu(x_subzero + y_subzero) + denom = torch.max(x_subzero , y_subzero) + delta = 1e-7 if data_type == torch.float32 else 1e-16 + + term_1 = 0.5 *( (-nume).exp() + (x_subzero + y_subzero - nume).exp() ) + term_2 = 0.5 * ( (x_subzero - denom).exp() + (y_subzero - denom).exp() ) + + # close_1 = torch.tensor( (term_1 - 1).abs() < delta, dtype= data_type) + close_1 = ((term_1 - 1).abs() < delta).clone().float() + T_1 = (term_1 - 1.) * close_1 + torch.log(term_1) * (1-close_1) + + # close_2 = torch.tensor( (term_2 - 1).abs() < delta, dtype= data_type) + close_2 = ((term_2 - 1).abs() < delta).clone().float() + T_2 = (term_2 - 1.) * close_2 + torch.log(term_2) * (1-close_2) + + corrected_ans = nume - denom + T_1 - T_2 + + further_zero = (corrected_ans == 0.) + + if further_zero.sum() > 0: + + x_sub_subzero = x_subzero[further_zero] + y_sub_subzero = y_subzero[further_zero] + + positive_idx = ( x_sub_subzero + y_sub_subzero > 0.) + + spoiled_brat = torch.min(- x_sub_subzero, - y_sub_subzero) + + spoiled_brat[positive_idx] = torch.min(x_sub_subzero[positive_idx], y_sub_subzero[positive_idx]) + + corrected_ans[further_zero] = spoiled_brat + + avoided_NaN[zero_idx] = corrected_ans + + return avoided_NaN + + +def new_log_sum(x, y): + # log_sum_standard = torch.nan_to_num(torch.nan_to_num(torch.log(torch.abs((1 - (x+y).exp()))) - torch.log(torch.abs(((y-x).exp() - 1) ))) - x) + # log_sum_standard = torch.log(torch.abs((1 - (x+y).exp())/((y-x).exp() - 1) )) - x + # log_sum_standard = torch.nan_to_num(torch.nan_to_num(torch.log(torch.abs(((x-y).exp() - (2*x).exp()))) - torch.log(torch.abs(((x-y).exp() - 1) ))) - x) + # log_sum_standard = torch.log(torch.abs(((x-y).exp() - (2*x).exp()))) - torch.log(torch.abs(((x-y).exp() - 1) )) - x + + x = x.clone() + torch.isclose(x,y).float()* 1e-5 #otherwise we get inf + c1 = torch.max(x+y,torch.zeros_like(x)) + c2 = torch.max(x, y) + log_sum_standard = torch.log(torch.abs((-1*c1).exp() - (x+y-c1).exp())) + c1 - torch.log(torch.abs((x-c2).exp() - (y-c2).exp())) - c2 + + if torch.isnan(log_sum_standard).any(): + print(c1, c2, (-1*c1).exp(), (x+y-c1).exp(), (x-c2).exp(), (y-c2).exp()) + id1 = np.random.randint(0,100) + + torch.save([x, y], 'errors_' + str(id1)+'.pt') + print('Nan detected! Saved at {}'.format('errors_' + str(id1)+'.pt')) + + if torch.isinf(log_sum_standard).any(): + print(c1, c2, (-1*c1).exp(), (x+y-c1).exp(), (x-c2).exp(), (y-c2).exp()) + id1 = np.random.randint(0,100) + + torch.save([x, y], 'errors_inf_' + str(id1)+'.pt') + print('Inf detected! Saved at {}'.format('errors_inf_' + str(id1)+'.pt')) + + return log_sum_standard + +def new_log_sum_avoid_NaN(x, y): + + a = torch.max(x, y) + b = torch.min(x, y) + + x = x.clone() + torch.isclose(x,y).float()* 1e-5 #otherwise we get inf + + # log_sum_standard = torch.log(torch.abs((1 - (x+y).exp())/(y.exp() - x.exp())) ) + log_sum_standard = torch.log(torch.abs((1 - (x+y).exp()))) - x - torch.log(torch.abs(((y-x).exp() - 1) )) + + # log_sum_standard = torch.log(torch.abs((-1*x).exp() - (-1*y).exp()) ) + + # print("Original one:", log_sum_standard) + + ## Check for NaN or infty or -infty once here. + if (torch.isnan(log_sum_standard).sum() > 0) | ((log_sum_standard == float('-inf')).sum() > 0 )| ( (log_sum_standard == float('inf')).sum() > 0) : + + # print("Had to avoid NaNs!") + # 80 for float32 and 707 for float64. + #big_threshold = 80. if log_sum_standard.dtype == torch.float32 else 700. + big_threshold = 200. if log_sum_standard.dtype == torch.float32 else 700. + idx_1 = (x + y > big_threshold) + subset_1 = idx_1 & ((x-y).abs() < big_threshold) + + idx_2 = (x + y < -big_threshold) + subset_2 = idx_2 & ((x-y).abs() < big_threshold) + + idx_3 = ((x - y).abs() > big_threshold) & ( (x+y).abs() < big_threshold ) + + # Can be fastened + if idx_1.sum() > 0 : + + if subset_1.sum() > 0: + log_sum_standard[subset_1] = y[subset_1]- torch.log(1 + (y[subset_1] - x[subset_1]).exp() ) + # print("After 11 modification", log_sum_standard) + + if (idx_1 ^ subset_1).sum() > 0: + log_sum_standard[idx_1 ^ subset_1] = b[idx_1 ^ subset_1] + # print("After 12 modification", log_sum_standard) + + if idx_2.sum() > 0: + + if subset_2.sum() > 0: + log_sum_standard[subset_2] = -x[subset_2]- torch.log(1 + (y[subset_2] - x[subset_2]).exp() ) + # print("After 21 modification", log_sum_standard) + + if (idx_2 ^ subset_2).sum() > 0: + log_sum_standard[idx_2 ^ subset_2] = -a[idx_2 ^ subset_2] + # print("After 22 modification", log_sum_standard) + + if idx_3.sum() > 0: + + log_sum_standard[idx_3] = torch.log(1 + (x[idx_3]+ y[idx_3]).exp() ) - a[idx_3] + # print("After 3 modification", log_sum_standard) + + return log_sum_standard + + +def new_log_sum_avoid_zero_NaN(x, y): + + avoided_NaN = new_log_sum_avoid_NaN(x,y) + + zero_idx = (avoided_NaN == 0.) + + data_type = x.dtype + + if zero_idx.sum() > 0: + + # print("Had to avoid zeros!") + + x_subzero = x[zero_idx] + y_subzero = y[zero_idx] + + nume = torch.relu(x_subzero + y_subzero) + denom = torch.max(x_subzero , y_subzero) + delta = 1e-7 if data_type == torch.float32 else 1e-16 + + term_1 = 0.5 *( (-nume).exp() + (x_subzero + y_subzero - nume).exp() ) + term_2 = 0.5 * ( (x_subzero - denom).exp() + (y_subzero - denom).exp() ) + + # close_1 = torch.tensor( (term_1 - 1).abs() < delta, dtype= data_type) + close_1 = ((term_1 - 1).abs() < delta).clone().float() + T_1 = (term_1 - 1.) * close_1 + torch.log(term_1) * (1-close_1) + + # close_2 = torch.tensor( (term_2 - 1).abs() < delta, dtype= data_type) + close_2 = ((term_2 - 1).abs() < delta).clone().float() + T_2 = (term_2 - 1.) * close_2 + torch.log(term_2) * (1-close_2) + + corrected_ans = nume - denom + T_1 - T_2 + + further_zero = (corrected_ans == 0.) + + if further_zero.sum() > 0: + + x_sub_subzero = x_subzero[further_zero] + y_sub_subzero = y_subzero[further_zero] + + positive_idx = ( x_sub_subzero + y_sub_subzero > 0.) + + spoiled_brat = torch.min(- x_sub_subzero, - y_sub_subzero) + + spoiled_brat[positive_idx] = torch.min(x_sub_subzero[positive_idx], y_sub_subzero[positive_idx]) + + corrected_ans[further_zero] = spoiled_brat + + avoided_NaN[zero_idx] = corrected_ans + + return avoided_NaN diff --git a/xformer_all.py b/xformer_all.py new file mode 100644 index 0000000..d05c941 --- /dev/null +++ b/xformer_all.py @@ -0,0 +1,1678 @@ +__author__ = 'vivien98' + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from IPython import display + +import imageio +import pickle +import os +import time +from datetime import datetime +import matplotlib +matplotlib.use('AGG') +import matplotlib.pyplot as plt + +from utils import snr_db2sigma, errors_ber, errors_bitwise_ber, errors_bler, min_sum_log_sum_exp, moving_average, extract_block_errors, extract_block_nonerrors +from models import convNet,XFormerEndToEndGPT,XFormerEndToEndDecoder,XFormerEndToEndEncoder,simpleNet,bigConvNet,smallNet,multConvNet,rnnAttn,bitConvNet +from polar import * +from pac_code import * + +from sklearn.manifold import TSNE +import math +import random +import numpy as np +from tqdm import tqdm +from collections import namedtuple +import sys +import csv + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def get_args(): + parser = argparse.ArgumentParser(description='Polar code - xformer decoder') + + parser.add_argument('--id', type=str, default=None, help='ID: optional, to run multiple runs of same hyperparameters') #Will make a folder like init_932 , etc. + + parser.add_argument('--previous_id', type=str, default=None, help='ID: optional, to run multiple runs of same hyperparameters') #Will make a folder like init_932 , etc. + + parser.add_argument('--code', type=str, default='pac',choices=['pac', 'polar'], help='code to be tested/trained on') + + parser.add_argument('--previous_code', type=str, default=None,choices=[None,'pac', 'polar'], help='code to load model from') + + parser.add_argument('--N', type=int, default=32)#, choices=[4, 8, 16, 32, 64, 128], help='Polar code parameter N') + + parser.add_argument('--previous_N', type=int, default=32)#, choices=[4, 8, 16, 32, 64, 128], help='Polar code parameter N') + + parser.add_argument('--max_len', type=int, default=32)#, choices=[4, 8, 16, 32, 64, 128], help='Polar code parameter N') + + parser.add_argument('--K', type=int, default=8)#, choices= [3, 4, 8, 16, 32, 64], help='Polar code parameter K') + + parser.add_argument('--previous_K', type=int, default=8)#, choices= [3, 4, 8, 16, 32, 64], help='Polar code parameter K') + + parser.add_argument('--test', dest = 'test', default=False, action='store_true', help='Testing?') + + parser.add_argument('--plot_progressive', dest = 'plot_progressive', default=False, action='store_true', help='plot merged progressive ber vs time') + + parser.add_argument('--do_range_training', dest = 'do_range_training', default=False, action='store_true', help="training on dec_train_snr + 1 and + 2 also?") + + parser.add_argument('--rate_profile', type=str, default='RM', choices=['RM', 'polar', 'sorted', 'last', 'custom'], help='PAC rate profiling') + + parser.add_argument('--previous_rate_profile', type=str, default=None, choices=[None,'RM', 'polar', 'sorted', 'last', 'custom'], help='PAC rate profiling') + + parser.add_argument('--embed_dim', type=int, default=64)# embedding size / hidden size of input vectors/hidden outputs between layers + + parser.add_argument('--dropout', type=int, default=0.1)# dropout + + parser.add_argument('--n_head', type=int, default=8)# number of attention heads + + parser.add_argument('--n_layers', type=int, default=6)# number of transformer layers + + parser.add_argument('--num_devices', type=int, default=2)# number of transformer layers + + parser.add_argument('--load_previous', dest = 'load_previous', default=False, action='store_true', help='load previous model at step --model_iters') + + parser.add_argument('--parallel', dest = 'parallel', default=False, action='store_true', help='gpu parallel') + + parser.add_argument('--dont_use_bias', dest = 'dont_use_bias', default=False, action='store_true', help='dont use bias in neural net')# load previous while training? + + parser.add_argument('--include_previous_block_errors', dest = 'include_previous_block_errors', default=False, action='store_true', help='train again on block errors of the previous step') + + parser.add_argument('--dec_train_snr', type=float, default=-1., help='SNR at which decoder is trained') + + parser.add_argument('--test_snr_start', type=float, default=-2., help='testing snr start') + + parser.add_argument('--test_snr_end', type=float, default=4., help='testing snr end') + + parser.add_argument('--model_iters', type=int, default=None, help='by default load final model, option to load a model of x episodes') + + parser.add_argument('--run', type=int, default=None)#, choices= [3, 4, 8, 16, 32, 64], help='Polar code parameter K') + + parser.add_argument('--num_steps', type=int, default=400000)#, choices=[100, 20000, 40000], help='number of blocks') + + parser.add_argument('--batch_size', type=int, default=128)#, choices=[64, 128, 256, 1024], help='number of blocks') + + parser.add_argument('--mult', type=int, default=1)#, multiplying factor to increase effective batch size + + parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') + + parser.add_argument('--cosine', dest = 'cosine', default=False, action='store_true', help='cosine annealing') + + parser.add_argument('--num_restarts',type=int, default=200, help='number of restarts while cosine annealing') + + parser.add_argument('--print_freq', type=int, default=1000, help='validation every x steps') + + parser.add_argument('--activation', type=str, default='selu', choices=['selu', 'relu', 'elu', 'tanh', 'sigmoid'], help='activation function') + + parser.add_argument('--prog_mode', type=str, default='e2h', choices=['e2h', 'h2e', 'r2l', 'l2r','random'], help='hard 2 easy progressive training, etc.') + + parser.add_argument('--target_K', type=int, default=16, help='target K while training progressively') + + # TRAINING parameters + parser.add_argument('--model', type=str, default='gpt', choices=['simple','conv','encoder', 'decoder', 'gpt','denoiser','bigConv','small','multConv','rnnAttn','bitConv'], help='model to be trained') + + parser.add_argument('--initialization', type=str, default='Xavier', choices=['Dontknow', 'He', 'Xavier'], help='initialization') + + parser.add_argument('--optimizer_type', type=str, default='AdamW', choices=['Adam', 'RMS', 'AdamW','SGD'], help='optimizer type') + + parser.add_argument('--loss', type=str, default='MSE', choices=['Huber', 'MSE','NLL','Block'], help='loss function') + + parser.add_argument('--loss_on_all', dest = 'loss_on_all', default=False, action='store_true', help='loss on all bits or only info bits') + + parser.add_argument('--split_batch', dest = 'split_batch', default=False, action='store_true', help='split batch - for teacher forcing') + + + + parser.add_argument('--lr_decay', type=int, default=None, help='learning rate decay frequency (in episodes)') + + parser.add_argument('--T_anneal', type=int, default=None, help='Number of iterations to midway in cosine lr') + + + parser.add_argument('--lr_decay_gamma', type=float, default=None, help='learning rate decay factor') + + parser.add_argument('--clip', type=float, default=0.25, help='gradient clipping factor') + + parser.add_argument('--validation_snr', type=float, default=None, help='snr at validation') + + parser.add_argument('--no_detach', dest = 'no_detach', default=False, action='store_true', help='detach previous output during rnn training?') + + # TEACHER forcing + # if only tfr_max is given assume no annealing + parser.add_argument('--tfr_min', type=float, default=None, help='teacher forcing ratio minimum') + + parser.add_argument('--tfr_max', type=float, default=0., help='teacher forcing ratio maximum') + + parser.add_argument('--tfr_decay', type=float, default=10000, help='teacher forcing ratio decay parameter') + + parser.add_argument('--teacher_steps', type=int, default=-10000, help='initial number of steps to do teacher forcing only') + + # TESTING parameters + + parser.add_argument('--model_save_per', type=int, default=5000, help='num of episodes after which model is saved') + + parser.add_argument('--snr_points', type=int, default=7, help='testing snr num points') + + parser.add_argument('--test_batch_size', type=int, default=1000, help='number of blocks') + + parser.add_argument('--test_size', type=int, default=50000, help='size of the batches') + + + + parser.add_argument('--test_load_path', type=str, default=None, help='load test model given path') + + parser.add_argument('--run_fano', dest = 'run_fano', default=False, action='store_true', help='run fano decoding') + + parser.add_argument('--random_test', dest = 'random_test', default=False, action='store_true', help='run test on random data (default action is to test on same samples as Fano did)') + + parser.add_argument('--save_path', type=str, default=None, help='save name') + + parser.add_argument('--load_path', type=str, default=None, help='load name') + + parser.add_argument("--run_dumer", type=str2bool, nargs='?', const=True, default=True, help="run dumer during test?") + # parser.add_argument('-id', type=int, default=100000) + parser.add_argument('--hard_decision', dest = 'hard_decision', default=False, action='store_true', help='polar code sc decoding hard decision?') + + parser.add_argument('--gpu', type=int, default= -1, help='gpus used for training - e.g 0,1,3') # -1 if run on any available gpu + + parser.add_argument('--anomaly', dest = 'anomaly', default=False, action='store_true', help='enable anomaly detection') + + parser.add_argument('--only_args', dest = 'only_args', default=False, action='store_true') + + args = parser.parse_args() + + if args.N == 4: + args.g = 7 # Convolutional coefficients are [1,1, 0, 1] + # args.M = 2 # log N + + + elif args.N == 8: + args.g = 13 # Convolutional coefficients are [1, 0, 1, 1] + # args.M = 3 # log N + + elif args.N == 16: + args.g = 21 # [1, 0, 1, 0, 1] + + + elif args.N == 32: + args.g = 53 # [1, 1, 0, 1, 0, 1] + + else: + args.g = 91 + + args.are_we_doing_ML = True if args.K <=16 and args.N <= 32 else False + + # args.hard_decision = True # use hard-SC + return args + +def get_pad_mask(seq, pad_idx): + return (seq != pad_idx).unsqueeze(-2) + + +def get_subsequent_mask(seq): + ''' For masking out the subsequent info. ''' + sz_b, len_s = seq.size() + subsequent_mask = (1 - torch.triu( + torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() + return subsequent_mask + +def dec2bitarray(in_number, bit_width): + """ + Converts a positive integer to NumPy array of the specified size containing + bits (0 and 1). + Parameters + ---------- + in_number : int + Positive integer to be converted to a bit array. + bit_width : int + Size of the output bit array. + Returns + ------- + bitarray : 1D ndarray of ints + Array containing the binary representation of the input decimal. + """ + + binary_string = bin(in_number) + length = len(binary_string) + bitarray = np.zeros(bit_width, 'int') + for i in range(length-2): + bitarray[bit_width-i-1] = int(binary_string[length-i-1]) + + return bitarray + +def countSetBits(n): + + count = 0 + while (n): + n &= (n-1) + count+= 1 + + return count + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def testXformer(net, polar, snr_range, Test_Data_Generator,device,Test_Data_Mask=None, run_ML=False, bitwise_snr_idx = -1): + num_test_batches = len(Test_Data_Generator) + + bers_Xformer_test = [0. for ii in snr_range] + bers_bitwise_Xformer_test = torch.zeros((1,polar.K),device=device) + blers_Xformer_test = [0. for ii in snr_range] + + bers_SC_test = [0. for ii in snr_range] + blers_SC_test = [0. for ii in snr_range] + + bers_SCL_test = [0. for ii in snr_range] + blers_SCL_test = [0. for ii in snr_range] + + bers_ML_test = [0. for ii in snr_range] + blers_ML_test = [0. for ii in snr_range] + + bers_bitwise_MAP_test = [0. for ii in snr_range] + blers_bitwise_MAP_test = [0. for ii in snr_range] + + for (k, msg_bits) in tqdm(enumerate(Test_Data_Generator)): + + msg_bits = msg_bits.to(device) + polar_code = polar.encode_plotkin(msg_bits) + + + for snr_ind, snr in enumerate(snr_range): + noisy_code = polar.channel(polar_code, snr) + noise = noisy_code - polar_code + if Test_Data_Mask == None: + mask = torch.ones(noisy_code.size(),device=device).long() + SC_llrs, decoded_SC_msg_bits = polar.sc_decode_new(noisy_code, snr) + if not run_ML: + SCL_llrs, decoded_SCL_msg_bits = polar.scl_decode(noisy_code.cpu(), snr, 4, use_CRC = False) + ber_SCL = errors_ber(msg_bits.cpu(), decoded_SCL_msg_bits.sign().cpu()).item() + bler_SCL = errors_bler(msg_bits.cpu(), decoded_SCL_msg_bits.sign().cpu()).item() + bers_SCL_test[snr_ind] += ber_SCL/num_test_batches + blers_SCL_test[snr_ind] += bler_SCL/num_test_batches + + ber_SC = errors_ber(msg_bits.cpu(), decoded_SC_msg_bits.sign().cpu()).item() + bler_SC = errors_bler(msg_bits.cpu(), decoded_SC_msg_bits.sign().cpu()).item() + + decoded_bits,out_mask = net.decode(noisy_code,polar.info_positions, mask,device) + decoded_Xformer_msg_bits = decoded_bits[:, polar.info_positions].sign() + + ber_Xformer = errors_ber(msg_bits, decoded_Xformer_msg_bits.sign(), mask = mask[:, polar.info_positions]).item() + if snr_ind==bitwise_snr_idx: + ber_bitwise_Xformer = errors_bitwise_ber(msg_bits, decoded_Xformer_msg_bits.sign(), mask = mask[:, polar.info_positions]).squeeze() + bers_bitwise_Xformer_test += ber_bitwise_Xformer/num_test_batches + print(ber_bitwise_Xformer) + bler_Xformer = errors_bler(msg_bits, decoded_Xformer_msg_bits.sign()).item() + if run_ML: + b_noisy = noisy_code.unsqueeze(1).repeat(1, 2**args.K, 1) + diff = (b_noisy - b_codebook).pow(2).sum(dim=2) + idx = diff.argmin(dim=1) + decoded = all_message_bits[idx, :] + decoded_bitwiseMAP_msg_bits = polar.bitwise_MAP(noisy_code,device,snr) + + ber_ML = errors_ber(msg_bits.to(decoded.device), decoded.sign()).item() + bler_ML = errors_bler(msg_bits.to(decoded.device), decoded.sign()).item() + ber_bitwiseMAP = errors_ber(msg_bits.cpu(), decoded_bitwiseMAP_msg_bits.sign().cpu()).item() + bler_bitwiseMAP = errors_bler(msg_bits.cpu(), decoded_bitwiseMAP_msg_bits.sign().cpu()).item() + bers_ML_test[snr_ind] += ber_ML/num_test_batches + blers_ML_test[snr_ind] += bler_ML/num_test_batches + bers_bitwise_MAP_test[snr_ind] += ber_bitwiseMAP/num_test_batches + blers_bitwise_MAP_test[snr_ind] += bler_bitwiseMAP/num_test_batches + + bers_Xformer_test[snr_ind] += ber_Xformer/num_test_batches + bers_SC_test[snr_ind] += ber_SC/num_test_batches + + blers_Xformer_test[snr_ind] += bler_Xformer/num_test_batches + blers_SC_test[snr_ind] += bler_SC/num_test_batches + + + print(bers_bitwise_Xformer_test) + return bers_Xformer_test, blers_Xformer_test, bers_SC_test, blers_SC_test,bers_SCL_test, blers_SCL_test, bers_ML_test, blers_ML_test,bers_bitwise_Xformer_test,bers_bitwise_MAP_test,blers_bitwise_MAP_test + +def PAC_MAP_decode(noisy_codes, b_codebook): + + b_noisy = noisy_codes.unsqueeze(1).repeat(1, 2**args.K, 1) + + diff = (b_noisy - b_codebook).pow(2).sum(dim=2) + + idx = diff.argmin(dim=1) + + MAP_decoded_bits = all_message_bits[idx, :] + + return MAP_decoded_bits + + +def test_RNN_and_Dumer_batch(net, pac, msg_bits, corrupted_codewords, snr, run_dumer=True,Test_Data_Mask =None,bitwise_snr = 1): + + state = corrupted_codewords + + ### DQN decoding + info_inds = pac.B + + if Test_Data_Mask == None: + mask = torch.ones(corrupted_codewords.size(),device=device).long() + else: + mask = Test_Data_Mask + + decoded_bits,out_mask = net.decode(corrupted_codewords,info_inds, mask,device) + decoded_Xformer_msg_bits = decoded_bits[:, info_inds].sign() + + ber_Xformer = errors_ber(msg_bits, decoded_Xformer_msg_bits.sign(), mask = mask[:, info_inds]).item() + bler_Xformer = errors_bler(msg_bits, decoded_Xformer_msg_bits.sign()).item() + ber_bitwise_Xformer = -1 + if snr==bitwise_snr: + ber_bitwise_Xformer = errors_bitwise_ber(msg_bits, decoded_Xformer_msg_bits.sign(), mask = mask[:, info_inds]).squeeze() + + + if run_dumer: + _, decoded_Dumer_msg_bits, _ = pac.pac_sc_decode(corrupted_codewords, snr) + ber_Dumer = errors_ber(msg_bits, decoded_Dumer_msg_bits.sign()).item() + bler_Dumer = errors_bler(msg_bits, decoded_Dumer_msg_bits.sign()).item() + else: + ber_Dumer = 0. + bler_Dumer = 0. + + if args.are_we_doing_ML: + MAP_decoded_bits = PAC_MAP_decode(corrupted_codewords, b_codebook) + + ber_ML = errors_ber(msg_bits, MAP_decoded_bits).item() + bler_ML = errors_bler(msg_bits, MAP_decoded_bits).item() + + return ber_Xformer, bler_Xformer, ber_Dumer, bler_Dumer, ber_ML, bler_ML, ber_bitwise_Xformer + + else: + + return ber_Xformer, bler_Xformer, ber_Dumer, bler_Dumer, ber_bitwise_Xformer + +def test_fano(pac,msg_bits, noisy_code, snr): + + msg_bits = msg_bits.to('cpu') # run fano on cpu. required? + sigma = snr_db2sigma(snr) + noisy_code = noisy_code.to('cpu') + llrs = (2/sigma**2)*noisy_code + + decoded_bits = torch.empty_like(msg_bits) + for ii, vv in enumerate(llrs): + v_hat, pm = pac.fano_decode(vv.unsqueeze(0), delta = 2, verbose = 0, maxDiversions = 1000, bias_type = 'p_e') + decoded_bits[ii] = pac.extract(v_hat) + + ber_fano = errors_ber(msg_bits, decoded_bits).item() + bler_fano = errors_bler(msg_bits, decoded_bits).item() + + return ber_fano, bler_fano + +def test_full_data(net, pac, snr_range, Test_Data_Generator, run_fano = False, run_dumer = True, Test_Data_Mask=None): + + num_test_batches = len(Test_Data_Generator) + + bers_RNN_test = [0. for ii in snr_range] + blers_RNN_test = [0. for ii in snr_range] + + bers_Dumer_test = [0. for ii in snr_range] + blers_Dumer_test = [0. for ii in snr_range] + + bers_ML_test = [0. for ii in snr_range] + blers_ML_test = [0. for ii in snr_range] + + bers_fano_test = [0. for ii in snr_range] + blers_fano_test = [0. for ii in snr_range] + + for (k, msg_bits) in tqdm(enumerate(Test_Data_Generator)): + + msg_bits = msg_bits.to(device) + pac_code = pac.pac_encode(msg_bits, scheme = args.rate_profile) + + for snr_ind, snr in enumerate(snr_range): + noisy_code = pac.channel(pac_code, snr) + if Test_Data_Mask == None: + mask = torch.ones(noisy_code.size(),device=device).long() + if args.are_we_doing_ML: + + ber_RNN, bler_RNN, ber_Dumer, bler_Dumer, ber_ML, bler_ML = test_RNN_and_Dumer_batch(net, pac, msg_bits, noisy_code, snr, Test_Data_Mask=mask) + + else: + + ber_RNN, bler_RNN, ber_Dumer, bler_Dumer,_ = test_RNN_and_Dumer_batch(net, pac, msg_bits, noisy_code, snr, run_dumer, Test_Data_Mask=mask) + + bers_RNN_test[snr_ind] += ber_RNN/num_test_batches + bers_Dumer_test[snr_ind] += ber_Dumer/num_test_batches + + blers_RNN_test[snr_ind] += bler_RNN/num_test_batches + blers_Dumer_test[snr_ind] += bler_Dumer/num_test_batches + + if args.are_we_doing_ML: + bers_ML_test[snr_ind] += ber_ML/num_test_batches + blers_ML_test[snr_ind] += bler_ML/num_test_batches + + if run_fano: + ber_fano, bler_fano = test_fano(msg_bits, noisy_code, snr) + bers_fano_test[snr_ind] += ber_fano/num_test_batches + blers_fano_test[snr_ind] += bler_fano/num_test_batches + + return bers_RNN_test, blers_RNN_test, bers_Dumer_test, blers_Dumer_test, bers_ML_test, blers_ML_test, bers_fano_test, blers_fano_test + +def test_standard(net, pac, msg_bits_all, received, run_fano = False, run_dumer = True, Test_Data_Mask=None,bitwise_snr_idx = 3): + + snr_range = list(received.keys()) + bers_RNN_test = [0. for ii in snr_range] + blers_RNN_test = [0. for ii in snr_range] + + bers_Dumer_test = [0. for ii in snr_range] + blers_Dumer_test = [0. for ii in snr_range] + + bers_ML_test = [0. for ii in snr_range] + blers_ML_test = [0. for ii in snr_range] + + bers_fano_test = [0. for ii in snr_range] + blers_fano_test = [0. for ii in snr_range] + + bers_bitwise_Xformer_test = torch.zeros((1,pac.K),device=device) + + msg_bits_all = msg_bits_all.to(device) + # quick fix to get this running. need to modify to support other test batch sizes ig + num_test_batches = msg_bits_all.shape[0]//args.test_batch_size + for snr_ind, (snr, noisy_code_all) in enumerate(received.items()): + noisy_code_all = noisy_code_all.to(device) + if snr_ind == bitwise_snr_idx: + bitwise_snr = snr + else: + bitwise_snr = -100 + for ii in range(num_test_batches): + msg_bits = msg_bits_all[ii*args.test_batch_size: (ii+1)*args.test_batch_size] + noisy_code = noisy_code_all[ii*args.test_batch_size: (ii+1)*args.test_batch_size] + if Test_Data_Mask == None: + mask = torch.ones(noisy_code.size(),device=device).long() + if args.are_we_doing_ML: + + ber_RNN, bler_RNN, ber_Dumer, bler_Dumer, ber_ML, bler_ML, ber_bitwise_Xformer = test_RNN_and_Dumer_batch(net, pac, msg_bits, noisy_code, snr, Test_Data_Mask=mask, bitwise_snr = bitwise_snr) + + else: + + ber_RNN, bler_RNN, ber_Dumer, bler_Dumer, ber_bitwise_Xformer = test_RNN_and_Dumer_batch(net, pac, msg_bits, noisy_code, snr, Test_Data_Mask=mask, bitwise_snr = bitwise_snr) + if snr_ind==bitwise_snr_idx: + #ber_bitwise_Xformer = errors_bitwise_ber(msg_bits, decoded_Xformer_msg_bits.sign(), mask = mask[:, polar.info_positions]).squeeze() + bers_bitwise_Xformer_test += ber_bitwise_Xformer/num_test_batches + bers_RNN_test[snr_ind] += ber_RNN/num_test_batches + bers_Dumer_test[snr_ind] += ber_Dumer/num_test_batches + + blers_RNN_test[snr_ind] += bler_RNN/num_test_batches + blers_Dumer_test[snr_ind] += bler_Dumer/num_test_batches + + if args.are_we_doing_ML: + bers_ML_test[snr_ind] += ber_ML/num_test_batches + blers_ML_test[snr_ind] += bler_ML/num_test_batches + + if run_fano: + ber_fano, bler_fano = test_fano(msg_bits_all, noisy_code_all, snr) + bers_fano_test[snr_ind] += ber_fano + blers_fano_test[snr_ind] += bler_fano + + return bers_RNN_test, blers_RNN_test, bers_Dumer_test, blers_Dumer_test, bers_ML_test, blers_ML_test, bers_fano_test, blers_fano_test,bers_bitwise_Xformer_test + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def plot_tsne(args,net,device,focus_on_bit=1,num_samples=100,num_neighbours=5,small_net=False): + K = args.K + N = args.N + n = int(np.log2(N)) + snr = args.dec_train_snr + rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1 + polar = PolarCode(n, K, None, rs=rs) + + images = [] + # for filename in filenames: + # images.append(imageio.imread(filename)) + # imageio.mimsave('/path/to/movie.gif', images) + all_msg_bits = [] + for i in range(2**K): + d = dec2bitarray(i,K ) + all_msg_bits.append(d) + all_message_bits = torch.from_numpy(np.array(all_msg_bits)) + all_message_bits = 1 - 2*all_message_bits.float() + codebook = polar.encode_plotkin(all_message_bits) + b_codebook = codebook.repeat(1, 1, 1) + + info_inds = polar.info_positions + msg_bits1 = torch.ones((num_samples,K))#1 - 2 * (torch.rand(num_samples, K) < 0.5).float() + msg_bits2 = torch.ones((num_samples,K)) + msg_bits2[:,focus_on_bit] = -1 + msg_bits = torch.cat((msg_bits1,msg_bits2),0) + gt = torch.ones(2*num_samples, N) + gt[:, info_inds] = msg_bits + polar_code = polar.encode_plotkin(msg_bits,custom_info_positions = info_inds) + + # noisy_code = polar_code#polar.channel(polar_code, snr)#a + # b_noisy = noisy_code.unsqueeze(1).repeat(1, 2**K, 1) + # diff = -(b_noisy - b_codebook).pow(2).sum(dim=2) + # distList = torch.topk(diff,1024) + # print(torch.sum(1.0*(distList.values == -32.))) + # all_neighbours = b_codebook[0][distList.indices.tolist()[0][0:num_neighbours],:] + if small_net: + codes = 4*np.random.rand(num_samples,2)-2 + samples = torch.from_numpy(codes).float() + x = samples[:,0] + y = samples[:,1] + num_models = int(args.model_iters/args.model_save_per) + model_iters_arr = [ args.model_save_per*elem for elem in list(range(num_models))] + model_iters_arr[0] = 1 + for iteration in tqdm(model_iters_arr): + if args.previous_code == 'polar': + previous_save_path = './Supervised_Xformer_decoder_Polar_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.previous_K, args.previous_N, args.previous_rate_profile, args.model, args.n_head,args.n_layers) + elif args.previous_code == 'pac': + previous_save_path = './Supervised_Xformer_decoder_PAC_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.previous_K, args.previous_N, args.previous_rate_profile, args.model, args.n_head,args.n_layers) + if args.previous_id is not None: + previous_save_path = previous_save_path + '/' + args.previous_id + ID = args.previous_id + else: + previous_save_path = previous_save_path #+ '/' + ID + ID = 'scratch' + if args.run is not None: + previous_save_path = previous_save_path + '/' + '{0}'.format(args.run) + os.makedirs(previous_save_path+ '/'+'decision_boundaries/', exist_ok=True) + checkpoint1 = torch.load(previous_save_path +'/Models/model_{0}.pt'.format(iteration), map_location=lambda storage, loc: storage) + #xformer.load_state_dict(torch.load(PATH)) + loaded_step = checkpoint1['step'] + net.load_state_dict(checkpoint1['xformer']) + print("Loaded Model for {0},{1} loaded at step {2} from previous model {3},{4}".format(args.K,args.N,loaded_step,args.previous_K,args.previous_N)) + decoded_bits,out_mask = net.decode(samples,[0,1],None,device) + decoded_msg_bits = 1.0 * (decoded_bits > 0) + colors = [2*elem[0] + elem[1] for elem in decoded_msg_bits.squeeze().tolist()] + plt.figure(figsize = (20,10)) + scatt = plt.scatter(x, y,c=colors) + plt.title(ID + ' at Step {0}'.format(iteration)) + plt.savefig(previous_save_path + '/'+'decision_boundaries/tsne_at_step_{0}_bit_{1}.png'.format(iteration,focus_on_bit)) + plt.close() + images.append(imageio.imread(previous_save_path + '/'+'decision_boundaries/tsne_at_step_{0}_bit_{1}.png'.format(iteration,focus_on_bit))) + imageio.mimsave(previous_save_path + '/'+'decision_boundaries/all_movie_bit_{0}.gif'.format(focus_on_bit), images) + return + + samples = polar.channel(polar_code, snr) + + samplesNumpy = samples.numpy() + #print(decoded_msg_bits[:,focus_on_bit]) + + + embedder = TSNE(n_components=2,init='random') + X_embedded = embedder.fit_transform(samplesNumpy) + + x = X_embedded[:,0] + y = X_embedded[:,1] + + num_models = int(args.model_iters/args.model_save_per) + model_iters_arr = [ args.model_save_per*elem for elem in list(range(num_models))] + model_iters_arr[0] = 10 + for iteration in tqdm(model_iters_arr): + if args.previous_code == 'polar': + previous_save_path = './Supervised_Xformer_decoder_Polar_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.previous_K, args.previous_N, args.previous_rate_profile, args.model, args.n_head,args.n_layers) + elif args.previous_code == 'pac': + previous_save_path = './Supervised_Xformer_decoder_PAC_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.previous_K, args.previous_N, args.previous_rate_profile, args.model, args.n_head,args.n_layers) + if args.previous_id is not None: + previous_save_path = previous_save_path + '/' + args.previous_id + else: + previous_save_path = previous_save_path #+ '/' + ID + if args.run is not None: + previous_save_path = previous_save_path + '/' + '{0}'.format(args.run) + os.makedirs(previous_save_path+ '/'+'decision_boundaries/', exist_ok=True) + checkpoint1 = torch.load(previous_save_path +'/Models/model_{0}.pt'.format(iteration), map_location=lambda storage, loc: storage) + #xformer.load_state_dict(torch.load(PATH)) + loaded_step = checkpoint1['step'] + net.load_state_dict(checkpoint1['xformer']) + print("Loaded Model for {0},{1} loaded at step {2} from previous model {3},{4}".format(args.K,args.N,loaded_step,args.previous_K,args.previous_N)) + decoded_bits,out_mask = net.decode(samples,info_inds,None,device) + decoded_msg_bits = decoded_bits[:, info_inds] + colors = decoded_msg_bits[:,focus_on_bit].squeeze().tolist() + plt.figure(figsize = (20,10)) + scatt = plt.scatter(x, y,c=colors) + plt.title('Step {0}'.format(iteration)) + plt.savefig(previous_save_path + '/'+'decision_boundaries/tsne_at_step_{0}_bit_{1}.png'.format(iteration,focus_on_bit)) + plt.close() + images.append(imageio.imread(previous_save_path + '/'+'decision_boundaries/tsne_at_step_{0}_bit_{1}.png'.format(iteration,focus_on_bit))) + imageio.mimsave(previous_save_path + '/'+'decision_boundaries/all_movie_bit_{0}.gif'.format(focus_on_bit), images) + +if __name__ == '__main__': + args = get_args() + if args.anomaly: + torch.autograd.set_detect_anomaly(True) + + if args.gpu == -1: #run on any available device + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + else: #run on specified gpu + device = torch.device("cuda:{0}".format(args.gpu)) if torch.cuda.is_available() else torch.device("cpu") + + #torch.manual_seed(37) + kwargs = {'num_workers': 4, 'pin_memory': False} if torch.cuda.is_available() else {} + if args.previous_code is None: + args.previous_code = args.code + if args.previous_rate_profile is None: + args.previous_rate_profile = args.rate_profile + ID = '' if args.id is None else args.id + lr_ = args.lr if args.lr_decay is None else str(args.lr)+'_decay_{}_{}'.format(args.lr_decay, args.lr_decay_gamma) + if args.code == 'polar': + results_save_path = './Supervised_Xformer_decoder_Polar_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.K, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + if args.save_path is None: + final_save_path = './Supervised_Xformer_decoder_Polar_Results/final_nets/Scheme_{2}/N{1}_K{0}_{3}_{4}_depth_{5}.pt'\ + .format(args.K, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + else: + final_save_path = args.save_path + elif args.code== 'pac': + results_save_path = './Supervised_Xformer_decoder_PAC_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.K, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + if args.save_path is None: + final_save_path = './Supervised_Xformer_decoder_PAC_Results/final_nets/Scheme_{2}/N{1}_K{0}_{3}_{4}_depth_{5}.pt'\ + .format(args.K, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + else: + final_save_path = args.save_path + if ID != '': + results_save_path = results_save_path + '/' + ID + final_save_path = final_save_path + '/' + ID + + if args.previous_code == 'polar': + previous_save_path = './Supervised_Xformer_decoder_Polar_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.previous_K, args.previous_N, args.previous_rate_profile, args.model, args.n_head,args.n_layers) + elif args.previous_code == 'pac': + previous_save_path = './Supervised_Xformer_decoder_PAC_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(args.previous_K, args.previous_N, args.previous_rate_profile, args.model, args.n_head,args.n_layers) + if args.previous_id is not None: + previous_save_path = previous_save_path + '/' + args.previous_id + else: + previous_save_path = previous_save_path #+ '/' + ID + if args.run is not None: + results_save_path = results_save_path + '/' + '{0}'.format(args.run) + final_save_path = final_save_path + '/' + '{0}'.format(args.run) + previous_save_path = previous_save_path + '/' + '{0}'.format(args.run) + + ############ + ## Polar Code parameters + ############ + K = args.K + N = args.N + n = int(np.log2(args.N)) + target_K = args.target_K + + if args.rate_profile == 'polar': + # computed for SNR = 0 + if n == 5: + rs = np.array([31, 30, 29, 27, 23, 15, 28, 26, 25, 22, 21, 14, 19, 13, 11, 24, 7, 20, 18, 12, 17, 10, 9, 6, 5, 3, 16, 8, 4, 2, 1, 0]) + + elif n == 4: + rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 5, 3, 8, 4, 2, 1, 0]) + elif n == 3: + rs = np.array([7, 6, 5, 3, 4, 2, 1, 0]) + elif n == 2: + rs = np.array([3, 2, 1, 0]) + + rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1 + # Multiple SNRs: + + ############### + ### Polar code + ############## + + ### Encoder + if args.code=='polar': + polar = PolarCode(n, args.K, args, rs=rs) + polarTarget = PolarCode(n, args.target_K, args, rs=rs) + elif args.code=='pac': + polar = PAC(args, args.N, args.K, args.g) + polarTarget = PAC(args, args.N, args.target_K, args.g) + elif args.rate_profile == 'RM': + rmweight = np.array([countSetBits(i) for i in range(args.N)]) + Fr = np.argsort(rmweight)[:-args.K] + Fr.sort() + if args.code=='polar': + polar = PolarCode(n, args.K, args, F=Fr) + rmweight = np.array([countSetBits(i) for i in range(args.N)]) + Fr = np.argsort(rmweight)[:-args.target_K] + Fr.sort() + polarTarget = PolarCode(n, args.target_K, args, F=Fr) + elif args.code=='pac': + polar = PAC(args, args.N, args.K, args.g) + polarTarget = PAC(args, args.N, args.target_K, args.g) + + if args.prog_mode == 'e2h': + if args.code == 'polar': + info_inds = polar.info_positions + frozen_inds = polar.frozen_positions + elif args.code == 'pac': + frozen_levels = (polar.rate_profiler(-torch.ones(1, args.K), scheme = args.rate_profile) == 1.)[0].numpy() + info_inds = polar.B + frozen_inds = np.array(list(set(np.arange(args.N))^set(polar.B))) + elif args.prog_mode == 'h2e': + if args.code == 'polar': + info_inds = polarTarget.unsorted_info_positions[:args.K].copy() + frozen_inds = polarTarget.frozen_positions + elif args.code == 'pac': + frozen_levels = (polar.rate_profiler(-torch.ones(1, args.K), scheme = args.rate_profile) == 1.)[0].numpy() + info_inds = polarTarget.unsorted_info_positions[:args.K].copy() + frozen_inds = np.array(list(set(np.arange(args.N))^set(polar.B))) + elif args.prog_mode == 'l2r': + if args.code == 'polar': + info_inds = polarTarget.info_positions[:args.K].copy() + frozen_inds = polarTarget.frozen_positions + elif args.code == 'pac': + frozen_levels = (polar.rate_profiler(-torch.ones(1, args.K), scheme = args.rate_profile) == 1.)[0].numpy() + info_inds = polarTarget.B[:args.K].copy() + frozen_inds = np.array(list(set(np.arange(args.N))^set(polar.B))) + elif args.prog_mode == 'r2l': + if args.code == 'polar': + info_inds = polarTarget.info_positions[-args.K:].copy() + frozen_inds = polarTarget.frozen_positions + elif args.code == 'pac': + frozen_levels = (polar.rate_profiler(-torch.ones(1, args.K), scheme = args.rate_profile) == 1.)[0].numpy() + info_inds = polarTarget.B[-args.K:].copy() + frozen_inds = np.array(list(set(np.arange(args.N))^set(polar.B))) + elif args.prog_mode == 'random': + if args.code == 'polar': + random_info = polarTarget.info_positions.copy() + random.Random(42).shuffle(random_info) + info_inds = random_info[:args.K].copy() + frozen_inds = polarTarget.frozen_positions + elif args.code == 'pac': + frozen_levels = (polar.rate_profiler(-torch.ones(1, args.K), scheme = args.rate_profile) == 1.)[0].numpy() + info_inds = polarTarget.B[-args.K:].copy() + frozen_inds = np.array(list(set(np.arange(args.N))^set(polar.B))) + + info_inds.sort() + if args.code == 'polar': + target_info_inds = polarTarget.info_positions + elif args.code == 'pac': + target_info_inds = polarTarget.B + target_info_inds.sort() + print("Info positions : {}".format(info_inds)) + print("Target Info positions : {}".format(target_info_inds)) + print("Frozen positions : {}".format(frozen_inds)) + print("Code : {0} ".format(args.code)) + print("Type of training : {0}".format(args.prog_mode)) + print("Rate Profile : {0}".format(args.rate_profile)) + print("Validation SNR : {0}".format(args.validation_snr)) + + #___________________Model Definition___________________________________________________# + + gen_mat = torch.eye(args.N,args.N) + #if args.code == 'polar': + gen_mat = PolarCode(n, args.N, args).get_generator_matrix() + gen_mat.to(device) + args.mat = gen_mat + if args.model == 'gpt': + xformer = XFormerEndToEndGPT(args) + elif args.model == 'decoder': + xformer = XFormerEndToEndDecoder(args) + elif args.model == 'encoder' or args.model == 'denoiser': + xformer = XFormerEndToEndEncoder(args) + elif args.model == 'simple': + xformer = simpleNet(args) + elif args.model == 'conv': + xformer = convNet(args) + elif args.model == 'bigConv': + xformer = bigConvNet(args) + elif args.model == 'small': + xformer = smallNet(args) + elif args.model == 'multConv': + gen_mat = torch.eye(args.N,args.N) + #if args.code == 'polar': + gen_mat = PolarCode(n, args.K, args).get_generator_matrix(custom_info_positions=info_inds) + gen_mat.to(device) + xformer = multConvNet(args,gen_mat) + elif args.model == 'rnnAttn': + xformer = rnnAttn(args) + elif args.model == 'bitConv': + xformer = bitConvNet(args) + #device = 'cpu' + + + + + + if not args.test: + os.makedirs(results_save_path, exist_ok=True) + os.makedirs(results_save_path +'/Models', exist_ok=True) + os.makedirs(final_save_path , exist_ok=True) + os.makedirs(final_save_path +'/Models', exist_ok=True) + + if args.model_iters is not None and args.load_previous : + checkpoint1 = torch.load(previous_save_path +'/Models/model_{0}.pt'.format(args.model_iters), map_location=lambda storage, loc: storage) + #xformer.load_state_dict(torch.load(PATH)) + loaded_step = checkpoint1['step'] + xformer.load_state_dict(checkpoint1['xformer']) + print("Training Model for {0},{1} loaded at step {2} from previous model {3},{4}".format(args.K,args.N,loaded_step,args.previous_K,args.previous_N)) + else: + print("Training Model for {0},{1} anew".format(args.K,args.N)) + device_ids = range(args.num_devices) + if args.parallel: + xformer = torch.nn.DataParallel(xformer, device_ids=device_ids) + + xformer.to(device) + print("Number of parameters :",count_parameters(xformer)) + + if args.only_args: + print("Loaded args. Exiting") + sys.exit() + ############## + ### Optimizers + ############## + if args.optimizer_type == 'Adam': + optimizer = optim.Adam(xformer.parameters(), lr = args.lr) + elif args.optimizer_type == 'AdamW': + optimizer = optim.AdamW(xformer.parameters(), lr = args.lr) + elif args.optimizer_type == 'RMS': + optimizer = optim.RMSprop(xformer.parameters(), lr = args.lr) + elif args.optimizer_type == 'SGD': + optimizer = optim.SGD(xformer.parameters(), lr = args.lr,momentum=1e-4, dampening=0,nesterov = True) + else: + raise Exception("Optimizer not supported yet!") + + if args.lr_decay is None: + scheduler = None + else: + if args.T_anneal is None: + scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay*args.K , args.lr_decay_gamma) + else: + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.T_anneal, eta_min=5e-5) + + if args.cosine: + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer,2200,args.num_steps,num_cycles=args.num_restarts) + + if args.loss == 'Huber': + loss_fn = F.smooth_l1_loss + elif args.loss == 'MSE': + loss_fn = nn.MSELoss(reduction='mean') + elif args.loss == 'NLL': + loss_fn = nn.NLLLoss() + elif args.loss == 'Block': + loss_fn = None + + training_losses = [] + training_bers = [] + valid_bers = [] + valid_bitwise_bers= [] + valid_tgt_bers = [] + valid_blers = [] + valid_tgt_blers = [] + valid_steps = [] + + test_data_path = './data/polar/test/test_N{0}_K{1}.p'.format(args.N, args.K) + try: + test_dict = torch.load(test_data_path) + valid_msg_bits = test_dict['msg'] + valid_received = test_dict['rec'] + print(valid_received.size()) + except: + print("Did not find standard validation data") + + + mavg_steps = 25 + + print("Need to save for:", args.model_save_per) + xformer.train() + first = [info_inds[0]] + range_snr = [args.dec_train_snr,args.dec_train_snr+1,args.dec_train_snr+2] + if args.validation_snr is not None: + valid_snr = args.validation_snr + else: + valid_snr = args.dec_train_snr + info_inds = info_inds.copy() + #acc_error_egs = + kernel = 7 + padding = int((kernel-1)/2) + layersint = nn.Sequential( + nn.Conv1d(64,1,kernel,padding=padding,dilation=1), + ) + layersint.to(device) + try: + for i_step in range(args.num_steps): ## Each episode is like a sample now until memory size is reached. + randperm = torch.randperm(args.batch_size) + if args.do_range_training: + train_snr = args.dec_train_snr + if args.code == 'polar':# and (i_step < 2000 or args.model=='multConv'): + range_snr = [args.dec_train_snr-1,args.dec_train_snr,args.dec_train_snr+5] + train_snr = range_snr[i_step%3] + if args.code == 'pac':# and i_step < 15000: + range_snr = [args.dec_train_snr,args.dec_train_snr+1,args.dec_train_snr+2] + train_snr = range_snr[i_step%3] + else: + train_snr = args.dec_train_snr + start_time = time.time() + #torch.cuda.empty_cache() + msg_bits = 1 - 2 * (torch.rand(args.batch_size, args.K, device=device) < 0.5).float() + gt = torch.ones(args.batch_size, args.N, device = device) + gt[:, info_inds] = msg_bits + gt_valid = gt.clone() + if args.code == 'polar': + polar_code = polar.encode_plotkin(msg_bits,custom_info_positions = info_inds) + corrupted_codewords = polar.channel(polar_code, train_snr)#args.dec_train_snr) + elif args.code == 'pac': + polar_code = polar.pac_encode(msg_bits, scheme = args.rate_profile,custom_info_positions = info_inds) + corrupted_codewords = polar.channel(polar_code, train_snr)#args.dec_train_snr) + mask = torch.cat((torch.ones((args.batch_size,args.N),device=device),torch.zeros((args.batch_size,args.max_len-args.N),device=device)),1).long() + + if args.include_previous_block_errors and i_step%100 not in [0,1,2,3,4,5,6,7,8]: + #print(error_egs_corrupted.size()) + corrupted_codewords = error_egs_corrupted + gt = error_egs_true + # corrupted_codewords = corrupted_codewords[randperm] + # gt = gt[randperm] + if args.model == 'conv' or args.model == 'bigConv' or args.model == 'multConv': + model_out,decoded_vhat,out_mask,logits,int_layer = xformer(corrupted_codewords,mask,gt,device) + else: + model_out,decoded_vhat,out_mask,logits = xformer(corrupted_codewords,mask,gt,device) + + batch_size = gt.size(0) + max_len = gt.size(1) + + if args.model == 'gpt' or args.model == 'decoder': + pass#gt = (gt*torch.ones((max_len,batch_size,max_len),device=device)).permute((1,0,2)).reshape(batch_size*max_len,max_len) + elif args.model == 'denoiser': + gt = polar_code + #print(decoded_vhat.size()) + decoded_msg_bits = decoded_vhat[:,info_inds] + if args.loss == 'NLL': + loss = loss_fn(torch.log(model_out[:,info_inds,:]).transpose(1,2),(gt[:, info_inds]==1).long()) + elif args.loss == 'MSE': + #out_mask[:,0]=100 + loss = loss_fn(out_mask[:,info_inds]*logits[:,info_inds,0],out_mask[:,info_inds]*gt[:, info_inds])#+0.5*loss_fn(layersint(int_layer).squeeze(),polar_code)#*args.N + #print(logits.size()) + #loss = loss_fn(out_mask[:,first]*logits[:,first,0],out_mask[:,first]*gt[:, first])#*args.N + #out_mask[:,0]=1 + elif args.loss == 'Block': + loss = torch.mean(torch.max(out_mask[:,info_inds]*(logits[:,info_inds,0]-gt[:, info_inds])**2,-1).values) + else: + loss = torch.sum(out_mask[:,info_inds]*(model_out[:,info_inds,1]-(gt[:, info_inds]==1).float())**2)/torch.sum(out_mask[:,info_inds]) + # OLD LOSS: on all bits + # if args.loss_on_all: + # loss = loss_fn(decoded_vhat, gt) + # else: + # # NEW LOSS : only on info bits + # loss = loss_fn(msg_bits, decoded_msg_bits) + ber = errors_ber(gt[:,info_inds].cpu(), decoded_msg_bits.cpu(),out_mask[:,info_inds].cpu()).item() + + if args.include_previous_block_errors and i_step%100 in [0,1,2,3,4,5,6,7,8]: + if i_step == 0: + error_egs_corrupted = corrupted_codewords.clone() + error_egs_true = gt.clone() + error_inds, = extract_block_errors(gt[:,info_inds].cpu(), decoded_msg_bits.cpu(),thresh=5) + + # print(error_inds.size) + _, decoded_SCL_msg_bits = polar.scl_decode(corrupted_codewords[error_inds,:].clone().cpu(), train_snr, 4, use_CRC = False) + correct_inds, = extract_block_nonerrors(gt[error_inds,:][:,info_inds].cpu(), decoded_SCL_msg_bits.cpu(),thresh=1) + #print(correct_inds.size) + error_egs_corrupted = torch.cat((corrupted_codewords[correct_inds,:].clone(),error_egs_corrupted),0)[:args.batch_size,:] + error_egs_true = torch.cat((gt[correct_inds,:].clone(),error_egs_true),0)[:args.batch_size,:] + + # print(error_egs_corrupted.size()) + # print('\n') + (loss/args.mult).backward() + torch.nn.utils.clip_grad_norm_(xformer.parameters(), args.clip) # gradient clipping to avoid exploding gradient + + if i_step%args.mult == 0: + optimizer.step() + optimizer.zero_grad() + + if scheduler is not None: + scheduler.step() + + training_losses.append(round(loss.item(),5)) + training_bers.append(round(ber, 5)) + + if i_step % args.print_freq == 0: + xformer.eval() + with torch.no_grad(): + corrupted_codewords_valid = polar.channel(polar_code, valid_snr) + decoded_no_noise,_ = xformer.decode(polar_code,info_inds,mask,device) + decoded_bits,out_mask = xformer.decode(corrupted_codewords_valid,info_inds,mask,device) + decoded_Xformer_msg_bits = decoded_bits[:, info_inds] + decoded_Xformer_msg_bits_no_noise = decoded_no_noise[:, info_inds] + if args.model == 'denoiser': + ber_Xformer = errors_ber(gt_valid[:,info_inds], decoded_Xformer_msg_bits, mask = out_mask[:,info_inds]).item() + else: + ber_Xformer = errors_ber(gt_valid[:,info_inds], decoded_Xformer_msg_bits, mask = out_mask[:,info_inds]).item() + ber_Xformer_noiseless = errors_ber(gt_valid[:,info_inds], decoded_Xformer_msg_bits_no_noise, mask = out_mask[:,info_inds]).item() + bler_Xformer = errors_bler(gt_valid[:,info_inds], decoded_Xformer_msg_bits).item() + bler_Xformer_noiseless = errors_bler(gt_valid[:,info_inds], decoded_Xformer_msg_bits_no_noise).item() + #ber_Xformer = errors_ber(gt[:,first], decoded_bits[:,first], mask = out_mask[:,first]).item() + if args.K < args.target_K: + msg_bits = 1 - 2 * (torch.rand(args.batch_size, args.target_K, device=device) < 0.5).float() + gt = torch.ones(args.batch_size, args.N, device = device) + gt[:, target_info_inds] = msg_bits + + if args.code == 'polar': + polar_code = polarTarget.encode_plotkin(msg_bits) + corrupted_codewords = polarTarget.channel(polar_code, valid_snr)#args.dec_train_snr) + elif args.code == 'pac': + polar_code = polarTarget.pac_encode(msg_bits, scheme = args.rate_profile) + corrupted_codewords = polarTarget.channel(polar_code, valid_snr)#args.dec_train_snr) + decoded_bits,out_mask = xformer.decode(corrupted_codewords,target_info_inds,mask,device) + decoded_Xformer_msg_bits = decoded_bits[:, target_info_inds] + ber_Xformer_tgt = errors_ber(msg_bits, decoded_Xformer_msg_bits, mask = out_mask[:,target_info_inds]).item() + bler_Xformer_tgt = errors_bler(msg_bits, decoded_Xformer_msg_bits).item() + bitwise_ber_Xformer_tgt = errors_bitwise_ber(msg_bits, decoded_Xformer_msg_bits, mask = out_mask[:,target_info_inds]).squeeze().cpu().tolist() + else: + bitwise_ber_Xformer_tgt = errors_bitwise_ber(msg_bits, decoded_Xformer_msg_bits, mask = out_mask[:,target_info_inds]).squeeze().cpu().tolist() + bler_Xformer_tgt = errors_bler(msg_bits, decoded_Xformer_msg_bits).item() + #print(bitwise_ber_Xformer_tgt) + valid_bers.append(round(ber_Xformer, 5)) + valid_blers.append(round(bler_Xformer, 5)) + valid_tgt_blers.append(round(bler_Xformer_tgt, 5)) + if args.K < args.target_K: + valid_tgt_bers.append(round(ber_Xformer_tgt, 5)) + else: + valid_tgt_bers.append(round(ber_Xformer, 5)) + valid_steps.append(i_step) + valid_bitwise_bers.append(bitwise_ber_Xformer_tgt) + xformer.train() + try: + print('[%d/%d] At %d dB, Loss: %.7f, Train BER (%d dB) : %.7f, Valid BER: %.7f, Tgt BER: %.7f, Noiseless BER %.7f, Valid BLER : %.7f' + % (i_step, args.num_steps, valid_snr, loss,train_snr,ber, ber_Xformer,ber_Xformer_tgt,ber_Xformer_noiseless,bler_Xformer)) + except: + print('[%d/%d] At %d dB, Loss: %.7f, Train BER (%d dB) : %.7f, Valid BER: %.7f, Tgt BER: %.7f, Noiseless BER %.7f, Valid BLER : %.7f' + % (i_step, args.num_steps, valid_snr, loss,train_snr,ber, ber_Xformer,ber_Xformer,ber_Xformer_noiseless,bler_Xformer)) + if i_step == 10: + print("Time for one step is {0:.4f} minutes".format((time.time() - start_time)/60)) + + # Save the model for safety + + if ((i_step+1) % args.model_save_per == 0) or (i_step+1 == 10) or ((i_step+1) % args.num_steps == 0): + + # print(i_episode +1 ) + torch.save({'xformer': xformer.state_dict(), 'step':i_step+1, 'args':args} ,\ + results_save_path+'/Models/model_{0}.pt'.format(i_step+1)) + torch.save({'xformer': xformer.state_dict(), 'step':i_step+1, 'args':args} ,\ + final_save_path+'/Models/model_final.pt') + # torch.save({'xformer': xformer.state_dict(), 'step':i_step+1, 'args':args} ,\ + # final_save_path) + + + episode_x = np.arange(1, 1+len(training_losses)) + episode_x_mavg = np.arange(1+len(training_losses)-len(moving_average(training_losses, n=mavg_steps)), 1+len(training_losses)) + + plt.figure() + plt.plot(episode_x, training_losses) + plt.plot(episode_x_mavg, moving_average(training_losses, n=mavg_steps)) + plt.savefig(results_save_path +'/training_losses.png') + plt.close() + + plt.figure() + plt.plot(episode_x, training_losses) + plt.plot(episode_x_mavg, moving_average(training_losses, n=mavg_steps)) + plt.yscale('log') + plt.savefig(results_save_path +'/training_losses_log.png') + plt.close() + + plt.figure() + plt.plot(episode_x, training_bers) + plt.plot(episode_x_mavg, moving_average(training_bers, n=mavg_steps)) + plt.savefig(results_save_path +'/training_bers.png') + plt.close() + + plt.figure() + plt.plot(episode_x, training_bers) + plt.plot(episode_x_mavg, moving_average(training_bers, n=mavg_steps)) + plt.yscale('log') + plt.savefig(results_save_path +'/training_bers_log.png') + plt.close() + + + with open(os.path.join(results_save_path, 'values_training.csv'), 'w') as f: + + # using csv.writer method from CSV package + write = csv.writer(f) + + write.writerow(episode_x) + write.writerow(training_losses) + write.writerow(training_bers) + + with open(os.path.join(results_save_path, 'values_validation.csv'), 'w') as f: + + # using csv.writer method from CSV package + write = csv.writer(f) + + write.writerow(valid_steps) + write.writerow(valid_bers) + write.writerow(valid_tgt_bers) + + for i in range(target_K): + write.writerow([bitwise_bers[i] for bitwise_bers in valid_bitwise_bers]) + + write.writerow(valid_blers) + write.writerow(valid_tgt_blers) + + print('Complete') + + except KeyboardInterrupt: + torch.save({'xformer': xformer.state_dict(), 'step':i_step+1, 'args':args} ,\ + results_save_path+'/Models/model_{0}.pt'.format(i_step+1)) + torch.save({'xformer': xformer.state_dict(), 'step':i_step+1, 'args':args} ,\ + final_save_path+'/Models/model_final.pt') + # torch.save({'net': xformer.state_dict(), 'step':i_step+1, 'args':args} ,\ + # final_save_path) + + episode_x = np.arange(1, 1+len(training_losses)) + episode_x_mavg = np.arange(1+len(training_losses)-len(moving_average(training_losses, n=mavg_steps)), 1+len(training_losses)) + + plt.figure() + plt.plot(episode_x, training_losses) + plt.plot(episode_x_mavg, moving_average(training_losses, n=mavg_steps)) + plt.savefig(results_save_path +'/training_losses.png') + plt.close() + + plt.figure() + plt.plot(episode_x, training_losses) + plt.plot(episode_x_mavg, moving_average(training_losses, n=mavg_steps)) + plt.yscale('log') + plt.savefig(results_save_path +'/training_losses_log.png') + plt.close() + + plt.figure() + plt.plot(episode_x, training_bers) + plt.plot(episode_x_mavg, moving_average(training_bers, n=mavg_steps)) + plt.savefig(results_save_path +'/training_bers.png') + plt.close() + + plt.figure() + plt.plot(episode_x, training_bers) + plt.plot(episode_x_mavg, moving_average(training_bers, n=mavg_steps)) + plt.yscale('log') + plt.savefig(results_save_path +'/training_bers_log.png') + plt.close() + + + print("Exited and saved") + + with open(os.path.join(results_save_path, 'values_training.csv'), 'w') as f: + + # using csv.writer method from CSV package + write = csv.writer(f) + + write.writerow(episode_x) + write.writerow(training_losses) + write.writerow(training_bers) + + with open(os.path.join(results_save_path, 'values_validation.csv'), 'w') as f: + + # using csv.writer method from CSV package + write = csv.writer(f) + + write.writerow(valid_steps) + write.writerow(valid_bers) + write.writerow(valid_tgt_bers) + + for i in range(target_K): + write.writerow([bitwise_bers[i] for bitwise_bers in valid_bitwise_bers]) + else: + print("TESTING :") + + if args.plot_progressive: + k = args.K + plt.figure(figsize = (20,10)) + ber_tgt = [] + net_iters = [0] + snr = args.validation_snr + bers_SC_test = 0. + bers_SCL_test = 0. + bers_SC_test_bitwise = torch.zeros((1,args.target_K),device=device) + num_batches = 10 + batch=1000 + tot = batch*num_batches + for _ in tqdm(range(num_batches)): + msg_bits = 1 - 2 * (torch.rand(batch, args.target_K, device=device) < 0.5).float() + polar_code = polarTarget.encode_plotkin(msg_bits) + noisy_code = polarTarget.channel(polar_code, snr) + noise = noisy_code - polar_code + SC_llrs, decoded_SC_msg_bits = polarTarget.sc_decode_new(noisy_code, snr) + SCL_llrs, decoded_SCL_msg_bits = polarTarget.scl_decode(noisy_code.cpu(), snr, 4, use_CRC = False) + ber_SC = errors_ber(msg_bits.cpu(), decoded_SC_msg_bits.sign().cpu()).item() + ber_SC_bitwise = errors_bitwise_ber(msg_bits.cpu(), decoded_SC_msg_bits.sign().cpu()).squeeze() + ber_SCL = errors_ber(msg_bits.cpu(), decoded_SCL_msg_bits.sign().cpu()).item() + bers_SC_test += ber_SC/num_batches + bers_SC_test_bitwise += ber_SC_bitwise/num_batches + bers_SCL_test += ber_SCL/num_batches + ber_SC_bitwise = ber_SC_bitwise.squeeze() + while k <= args.target_K: + if args.code == 'polar': + results_save_path = './Supervised_Xformer_decoder_Polar_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(k, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + elif args.code== 'pac': + results_save_path = './Supervised_Xformer_decoder_PAC_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(k, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + if ID != '': + results_scratch = results_save_path + '/' + 'scratch' + results_save_path = results_save_path + '/' + ID + if args.run is not None: + results_save_path = results_save_path + '/' + '{0}'.format(args.run) + rows = [] + with open(os.path.join(results_save_path, 'values_validation.csv')) as f: + csvRead = csv.reader(f) + for row in csvRead: + rows.append(list(map(float,row))) + + iterations = [it+net_iters[-1]+1 for it in rows[0]] + net_iters = net_iters + iterations + plt.axvline(x = net_iters[-1], color = 'grey', linestyle='dashed') + ber = rows[1] + ber_tgt = ber_tgt + rows[2] + label = '{0},{1}'.format(k,args.N) + plt.semilogy(iterations, ber, label=label) + if k == target_K: + sc = np.ones(len(iterations)) * bers_SC_test + scl = np.ones(len(iterations)) * bers_SCL_test + + plt.semilogy(iterations,sc, label='SC'.format(k,args.N),linestyle='dashed') + plt.semilogy(iterations,scl, label='SCL'.format(k,args.N),linestyle='dashed') + k += 1 + k-=1 + net_iters = net_iters[1:] + print(len(net_iters)) + plt.semilogy(net_iters,ber_tgt, label='{0},{1} prog'.format(k,args.N)) + + try: + rows = [] + with open(os.path.join(results_scratch, 'values_validation.csv')) as f: + csvRead = csv.reader(f) + for row in csvRead: + rows.append(list(map(float,row))) + #print(len(rows[2])) + ber_scratch = rows[2] + if len(ber_scratch) < len(net_iters): + plt.semilogy(net_iters[:len(ber_scratch)],ber_scratch, label='{0},{1} scr'.format(k,args.N)) + else: + plt.semilogy(net_iters,ber_scratch[:len(net_iters)], label='{0},{1} scr'.format(k,args.N)) + except: + print("Did not find model trained from scratch") + plt.legend(prop={'size': 7},loc='upper right', bbox_to_anchor=(1.1, 1)) + plt.ylim(bottom=1e-3) + plt.ylim(top=0.6) + plt.savefig(results_save_path +'/valid_progressive_log.pdf') + plt.close() + + k = args.K + plt.figure(figsize = (20,10)) + bitwise_ber = [[] for _ in range(args.target_K)] + net_iters = [0] + while k <= args.target_K: + if args.code == 'polar': + results_save_path = './Supervised_Xformer_decoder_Polar_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(k, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + info_inds1 = polarTarget.unsorted_info_positions.copy() + elif args.code== 'pac': + results_save_path = './Supervised_Xformer_decoder_PAC_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(k, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + info_inds1 = polarTarget.unsorted_info_positions.copy() + if ID != '': + results_scratch = results_save_path + '/' + 'scratch' + results_save_path = results_save_path + '/' + ID + if args.run is not None: + results_save_path = results_save_path + '/' + '{0}'.format(args.run) + rows = [] + with open(os.path.join(results_save_path, 'values_validation.csv')) as f: + csvRead = csv.reader(f) + for row in csvRead: + rows.append(list(map(float,row))) + for i in range(len(rows)-3): + bitwise_ber[i] = bitwise_ber[i] + rows[i+3] + iterations = [it+net_iters[-1]+1 for it in rows[0]] + net_iters = net_iters + iterations + plt.axvline(x = net_iters[-1], color = 'grey', linestyle='dashed') + + k += 1 + k-=1 + net_iters = net_iters[1:]#net_iters[101:] + for i in range(len(bitwise_ber)): + plt.semilogy(net_iters,bitwise_ber[i], label='Bit {0}'.format(i)) + plt.annotate('{0}'.format(i), (net_iters[-1], bitwise_ber[i][-1]*0.99)) + sc = np.ones(len(net_iters)) * bers_SC_test_bitwise[0][i].item() + plt.semilogy(net_iters,sc, label='SC bit {0}'.format(i),linestyle='dashed') + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(info_inds1)))) + plt.legend(prop={'size': 7},loc='upper right', bbox_to_anchor=(1.1, 1)) + plt.savefig(results_save_path +'/valid_progressive_bitwise_log.pdf') + plt.close() + + if args.id == 'h2e' or args.id == 'e2h' or args.id == None: + pass + else: + sys.exit() + + for i in range(len(bitwise_ber)): + plt.figure(figsize = (20,10)) + bitwise_ber_h2e = [] + bitwise_ber_e2h = [] + bitwise_ber_scr = [] + net_iters = [0] + k = args.K#+1 + + while k <= args.target_K: + if args.code == 'polar': + results_save_path = './Supervised_Xformer_decoder_Polar_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(k, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + info_inds1 = polarTarget.unsorted_info_positions.copy() + elif args.code== 'pac': + results_save_path = './Supervised_Xformer_decoder_PAC_Results/Polar_{0}_{1}/Scheme_{2}/{3}/{4}_depth_{5}'\ + .format(k, args.N, args.rate_profile, args.model, args.n_head,args.n_layers) + info_inds1 = polarTarget.unsorted_info_positions.copy() + try: + if k == args.target_K: + rows = [] + with open(os.path.join(results_save_path + '/scratch', 'values_validation.csv')) as f: + csvRead = csv.reader(f) + for row in csvRead: + rows.append(list(map(float,row))) + bitwise_ber_scr = bitwise_ber_scr + rows[i+3] + except: + print("Did not find model trained from scratch") + try: + rows = [] + with open(os.path.join(results_save_path + '/h2e', 'values_validation.csv')) as f: + csvRead = csv.reader(f) + for row in csvRead: + rows.append(list(map(float,row))) + bitwise_ber_h2e = bitwise_ber_h2e + rows[i+3] + rows = [] + with open(os.path.join(results_save_path + '/e2h', 'values_validation.csv')) as f: + csvRead = csv.reader(f) + for row in csvRead: + rows.append(list(map(float,row))) + except: + print("Did not find h2e and e2h") + rows = [] + with open(os.path.join(results_save_path, 'values_validation.csv')) as f: + csvRead = csv.reader(f) + for row in csvRead: + rows.append(list(map(float,row))) + + + bitwise_ber_e2h = bitwise_ber_e2h + rows[i+3] + iterations = [it+net_iters[-1]+1 for it in rows[0]] + net_iters = net_iters + iterations + plt.axvline(x = net_iters[-1], color = 'grey', linestyle='dashed') + k += 1 + net_iters = net_iters[1:]#net_iters[101:] + plt.semilogy(net_iters,bitwise_ber_h2e, label='Bit {0} H2E'.format(i)) + plt.semilogy(net_iters,bitwise_ber_e2h, label='Bit {0} E2H'.format(i)) + if len(ber_scratch) < len(net_iters): + plt.semilogy(net_iters[:len(bitwise_ber_scr)],bitwise_ber_scr, label='Bit {0} scr'.format(i)) + else: + plt.semilogy(net_iters,bitwise_ber_scr[:len(net_iters)], label='Bit {0} scr'.format(i)) + sc = np.ones(len(net_iters)) * bers_SC_test_bitwise[0][i].item() + plt.semilogy(net_iters,sc, label='SC bit {0}'.format(i),linestyle='dashed') + plt.legend(prop={'size': 7},loc='upper right', bbox_to_anchor=(1.1, 1)) + plt.ylim(bottom=1e-3) + plt.ylim(top=0.6) + plt.title("Hardest to easiest order : {0}".format(np.argsort(np.argsort(info_inds1)))) + plt.savefig(results_save_path +'/z_progressive_bitwise_{0}.pdf'.format(i)) + + + plt.close() + + + + sys.exit() + + # if args.scatter_plot: + # msgs = torch.ones((4,args.K)).float() + # msgs[1][-1] = -1. + # msgs[2][-2] = -1. + # msgs[3][-1] = -1. + # msgs[3][-2] = -1. + + # polar_code = polarTarget.encode_plotkin(msgs) + + times = [] + results_load_path = final_save_path + #print(results_load_path) + if args.model_iters is not None: + checkpoint1 = torch.load(results_save_path +'/Models/model_{0}.pt'.format(args.model_iters), map_location=lambda storage, loc: storage) + elif args.test_load_path is not None: + checkpoint1 = torch.load(args.test_load_path , map_location=lambda storage, loc: storage) + else: + checkpoint1 = torch.load(results_load_path +'/Models/model_final.pt', map_location=lambda storage, loc: storage) + try: + args.model_iters = i_step + 1 + except: + pass + + #print(checkpoint1) + loaded_step = checkpoint1['step'] + xformer.load_state_dict(checkpoint1['xformer']) + xformer.to(device) + print("Model loaded at step {}".format(loaded_step)) + + xformer.eval() + + if args.snr_points == 1 and args.test_snr_start == args.test_snr_end: + snr_range = [args.test_snr_start] + else: + snrs_interval = (args.test_snr_end - args.test_snr_start)* 1.0 / (args.snr_points-1) + snr_range = [snrs_interval* item + args.test_snr_start for item in range(args.snr_points)] + + Test_msg_bits = 2 * (torch.rand(args.test_size, args.K) < 0.5).float() - 1 + Test_Data_Mask = torch.ones(Test_msg_bits.size(),device=device).long() + + Test_Data_Generator = torch.utils.data.DataLoader(Test_msg_bits, batch_size=args.test_batch_size , shuffle=False, **kwargs) + Test_Data_Mask = torch.utils.data.DataLoader(Test_Data_Mask, batch_size=args.test_batch_size , shuffle=False, **kwargs) + num_test_batches = len(Test_Data_Generator) + + + + ###### + ### MAP decoding stuff + ###### + + if args.are_we_doing_ML and args.code=='polar': + all_msg_bits = [] + for i in range(2**args.K): + d = dec2bitarray(i, args.K) + all_msg_bits.append(d) + all_message_bits = torch.from_numpy(np.array(all_msg_bits)) + all_message_bits = 1 - 2*all_message_bits.float() + codebook = polar.encode_plotkin(all_message_bits) + b_codebook = codebook.repeat(args.test_batch_size, 1, 1).to(device) + if args.are_we_doing_ML and args.code=='pac': + all_msg_bits = [] + for i in range(2**args.K): + d = dec2bitarray(i, args.K) + all_msg_bits.append(d) + all_message_bits = torch.from_numpy(np.array(all_msg_bits)).to(device) + all_message_bits = 1 - 2*all_message_bits.float() + codebook = polar.pac_encode(all_message_bits, scheme = args.rate_profile) + b_codebook = codebook.repeat(args.test_batch_size, 1, 1) + + start_time = time.time() + + if args.code == 'polar': + bers_Xformer_test, blers_Xformer_test, bers_SC_test, blers_SC_test,bers_SCL_test, blers_SCL_test, bers_ML_test, blers_ML_test, bers_bitwise_Xformer_test,bers_bitwise_MAP_test, blers_bitwise_MAP_test = testXformer(xformer, polar, snr_range, Test_Data_Generator, device, run_ML=args.are_we_doing_ML) + print("Test SNRs : ", snr_range) + print("BERs of Xformer: {0}".format(bers_Xformer_test)) + print("BERs of SC decoding: {0}".format(bers_SC_test)) + print("BERs of ML: {0}".format(bers_ML_test)) + print("BLERs of ML: {0}".format(blers_ML_test)) + print("BERs of bitML: {0}".format(bers_bitwise_MAP_test)) + print("BLERs of bitML: {0}".format(blers_bitwise_MAP_test)) + print("BLERs of Xformer: {0}".format(blers_Xformer_test)) + print("Time taken = {} seconds".format(time.time() - start_time)) + ## BER + plt.figure(figsize = (12,8)) + print(bers_bitwise_Xformer_test) + ok = 0 + plt.semilogy(snr_range, bers_Xformer_test, label="Xformer decoder", marker='*', linewidth=1.5) + plt.semilogy(snr_range, bers_SC_test, label="SC decoder", marker='^', linewidth=1.5) + plt.semilogy(snr_range, bers_SCL_test, label="SCL decoder", marker='^', linewidth=1.5) + + if args.are_we_doing_ML: + plt.semilogy(snr_range, bers_ML_test, label="ML decoder", marker='o', linewidth=1.5) + plt.semilogy(snr_range, bers_bitwise_MAP_test, label="Bitwise ML decoder", marker='o', linewidth=1.5) + # if args.run_fano: + ## BLER + plt.semilogy(snr_range, blers_Xformer_test, label="Xformer decoder (BLER)", marker='*', linewidth=1.5, linestyle='dashed') + plt.semilogy(snr_range, blers_SC_test, label="SC decoder (BLER)", marker='^', linewidth=1.5, linestyle='dashed') + plt.semilogy(snr_range, blers_SCL_test, label="SCL decoder (BLER)", marker='^', linewidth=1.5, linestyle='dashed') + + if args.are_we_doing_ML: + plt.semilogy(snr_range, blers_ML_test, label="ML decoder", marker='o', linewidth=1.5, linestyle='dashed') + plt.semilogy(snr_range, blers_bitwise_MAP_test, label="Bitwise ML decoder", marker='o', linewidth=1.5, linestyle='dashed') + # if args.run_fano: + + plt.grid() + plt.xlabel("SNR (dB)", fontsize=16) + plt.ylabel("Error Rate", fontsize=16) + if args.rate_profile == 'polar': + plt.title("Polar({1}, {2}): Xformer trained at Dec_SNR = {0} dB".format(args.dec_train_snr, args.K,args.N)) + elif args.rate_profile == 'RM': + plt.title("RM({1}, {2}): Xformer trained at Dec_SNR = {0} dB".format(args.dec_train_snr, args.K,args.N)) + + plt.legend(prop={'size': 15}) + if args.test_load_path is not None: + os.makedirs('Xformer_Polar_Results/figures', exist_ok=True) + fig_save_path = 'Xformer_Polar_Results/figures/new_plot.pdf' + else: + fig_save_path = results_load_path + "/step_{}.pdf".format(args.model_iters if args.model_iters is not None else '_final') + plt.savefig(fig_save_path) + + plt.close() + elif args.code == 'pac': + plot_fano = False + fano_path = './data/pac/fano/Scheme_{3}/N{0}_K{1}_g{2}.p'.format(args.N, args.K, args.g, args.rate_profile) + test_data_path = './data/pac/test/Scheme_{3}/test_N{0}_K{1}_g{2}.p'.format(args.N, args.K, args.g, args.rate_profile) + if os.path.exists(fano_path): + fanos = pickle.load(open(fano_path, 'rb')) + snr_range_fano = fanos[0] + bers_fano_test = fanos[1] + blers_fano_test = fanos[2] + run_fano = False + plot_fano = True + else: + snr_range_fano = snr_range + bers_fano_test = [] + blers_fano_test = [] + if not args.random_test: + try: + test_dict = torch.load(test_data_path) + random_test = False + except: + random_test = True + else: + random_test = True + if random_test: + print("Testing on random data") + bers_RNN_test, blers_RNN_test, bers_Dumer_test, blers_Dumer_test, bers_ML_test, blers_ML_test, bers_fano_temp, blers_fano_temp = test_full_data(xformer,polar, snr_range, Test_Data_Generator, run_fano = args.run_fano, run_dumer = args.run_dumer) + else: + print("Testing on the standard data") + msg_bits = test_dict['msg'] + received = test_dict['rec'] + snr_range = list(received.keys()) + print(snr_range) + bers_RNN_test, blers_RNN_test, bers_Dumer_test, blers_Dumer_test, bers_ML_test, blers_ML_test, bers_fano_temp, blers_fano_temp,bers_bitwise_Xformer_test = test_standard(xformer,polar, msg_bits, received, run_fano = args.run_fano, run_dumer = args.run_dumer) + + if not os.path.exists(fano_path): + bers_fano_test = bers_fano_temp + blers_fano_test = blers_fano_temp + snr_range_fano = snr_range + + # if args.run_fano: + # if not os.path.exists(fano_path): + # os.makedirs('./data/pac/fano/Scheme_{}'.format(args.rate_profile), exist_ok=True) + # print("Saving fano error rates at: {}".format(fano_path)) + # pickle.dump([snr_range, bers_fano_test, blers_fano_test], open(fano_path, 'wb')) + try: + print(bers_bitwise_Xformer_test) + except: + pass + print("Test SNRs : ", snr_range) + print("BERs of Xformer: {0}".format(bers_RNN_test)) + + print("BERs of SC decoding: {0}".format(bers_Dumer_test)) + print("BERs of ML: {0}".format(bers_ML_test)) + print("BERs of Fano: {0}".format(bers_fano_test)) + print("BLERs of Xformer: {0}".format(blers_RNN_test)) + print("Time taken = {} seconds".format(time.time() - start_time)) + ## BER + plt.figure(figsize = (12,8)) + + ok = 0 + plt.semilogy(snr_range, bers_RNN_test, label="Xformer decoder", marker='*', linewidth=1.5) + + if args.run_dumer: + plt.semilogy(snr_range, bers_Dumer_test, label="SC decoder", marker='^', linewidth=1.5) + + if args.are_we_doing_ML: + plt.semilogy(snr_range, bers_ML_test, label="ML decoder", marker='o', linewidth=1.5) + if plot_fano: + plt.semilogy(snr_range_fano, bers_fano_test, label="Fano decoder", marker='P', linewidth=1.5) + + ## BLER + plt.semilogy(snr_range, blers_RNN_test, label="Xformer decoder (BLER)", marker='*', linewidth=1.5, linestyle='dashed') + if args.run_dumer: + plt.semilogy(snr_range, blers_Dumer_test, label="SC decoder (BLER)", marker='^', linewidth=1.5, linestyle='dashed') + + if args.are_we_doing_ML: + plt.semilogy(snr_range, blers_ML_test, label="ML decoder", marker='o', linewidth=1.5, linestyle='dashed') + if plot_fano: + plt.semilogy(snr_range_fano, blers_fano_test, label="Fano decoder", marker='P', linewidth=1.5, linestyle='dashed') + + plt.grid() + plt.xlabel("SNR (dB)", fontsize=16) + plt.ylabel("Error Rate", fontsize=16) + plt.title("PAC({1}, {2}): Xformer trained at Dec_SNR = {0} dB".format(args.dec_train_snr, args.K,args.N)) + plt.legend(prop={'size': 15}) + if args.test_load_path is not None: + os.makedirs('Xformer_PAC_Results/figures', exist_ok=True) + fig_save_path = 'Xformer_PAC_Results/figures/new_plot.pdf' + else: + fig_save_path = results_load_path + "/step_{}.pdf".format(args.model_iters if args.model_iters is not None else '_final') + plt.savefig(fig_save_path) + + plt.close() + + +# BERs of ML: [0.28868999987840654, 0.21227812469005586, 0.13184125021100043, 0.06560687508434059, 0.02297749994322657, 0.005485000021290035, 0.0006450000058976001] +# BLERs of ML: [0.6675300000000001, 0.4999700000000001, 0.3155800000000001, 0.15975999999999999, 0.056729999999999996, 0.013350000000000008, 0.0017400000000000022] +# BERs of bitML: [0.2826343762874604, 0.20813374981284144, 0.12979624971747403, 0.06462874997407196, 0.022758750058710565, 0.005464999999385329, 0.0006493750043591714] +# BLERs of bitML: [0.6889200000000001, 0.51716, 0.3261600000000001, 0.16444, 0.05816999999999999, 0.013600000000000004, 0.0017700000000000023]