# -*- coding: utf-8 -*-
"""PhonemeClassifier.ipynb
This is downloaded directly from Jupyter Notebook. DON'T RUN!

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1tbUasNvb-cGE52bP0ETewAO6IKUijw8F
"""

import matplotlib.pyplot as plt
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.autograd import Variable

import time
from os import path
import pickle
import gc

from google.colab import drive
drive.mount('/content/gdrive')

assert torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TIMIT_root = '/content/gdrive/My Drive/data/TIMIT'
file_name = "500.pkl"
train_path = path.join(TIMIT_root, "train", file_name)
val_path = path.join(TIMIT_root, "val/0.pkl")
train_path

TIMIT_root = '/content/gdrive/My Drive/data/TIMIT'
file_name = "500.pkl"
train_path = path.join(TIMIT_root, "train", file_name)
val_path = path.join(TIMIT_root, "val", file_name)
with open(train_path, 'rb') as infile:
    X_train, y_train = pickle.load(infile)
with open(val_path, 'rb') as infile:
    X_test, y_test = pickle.load(infile)

class TIMITRawDataset(Dataset):
    def __init__(self, signal, y):
        super().__init__()
        self.data = signal
        self.label = y
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

def custom_collate(batch):
    signal = [torch.tensor(item[0]) for item in batch]
    lengths = [len(item[0]) for item in batch]
    y = [item[1] for item in batch]
    y = [x for _,x in sorted(zip(lengths, y), reverse=True)]
    lengths.sort(reverse=True)
    signal.sort(reverse=True, key=lambda x: len(x))
    signal = pad_sequence(signal, batch_first=True).unsqueeze(1).type(dtype=torch.float)
    return signal, torch.tensor(lengths), torch.tensor(y).type(dtype=torch.LongTensor)

class PhonemeClassifier(nn.Module):
    def __init__(self):
        super(PhonemeClassifier, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, 30, stride=10)
        self.conv1_1 = nn.Conv1d(64, 64, 15, stride=1, padding=7)
        self.conv2 = nn.Conv1d(64, 128, 15, stride=2)
        self.conv2_1 = nn.Conv1d(128, 128, 5, stride=1, padding=2)
        self.conv3 = nn.Conv1d(128, 256, 5, stride=2)
        self.conv3_1 = nn.Conv1d(256, 256, 3, stride=1, padding=1)
        self.lstm = nn.LSTM(256, 256, 3, bidirectional=False, dropout=0.1)
        self.linear = nn.Linear(256*3, 256)
        self.linear2 = nn.Linear(256, 61)
    
    def forward(self, x, l):
        x = F.dropout(F.relu(self.conv1(x)), p=0.1)
        x = F.relu(self.conv1_1(x))
        x = F.dropout(F.relu(self.conv2(x)), p=0.1)
        x = F.relu(self.conv2_1(x))
        x = F.dropout(F.relu(self.conv3(x)), p=0.1)
        x = F.relu(self.conv3_1(x))
        factor = l.max().item() / x.size(2)
        new_lengths = l.type(dtype=torch.float) / factor
        new_lengths = new_lengths.type(dtype=torch.int)
        new_lengths[new_lengths==0] = 1
        x = pack_padded_sequence(x.permute(2, 0, 1), new_lengths)
        out, (h_n, c_n) = self.lstm(x)
        h_n = h_n.permute(1, 0, 2).reshape(-1, 1, 3*256)
        y = F.relu(self.linear(h_n))
        y = self.linear2(y).squeeze(1)
        return y

dataset = TIMITRawDataset(X_train, y_train)
dataloader = DataLoader(dataset,
                        batch_size=3,
                        shuffle=True,
                        collate_fn=custom_collate)
x, l, y = next(iter(dataloader))

model = PhonemeClassifier()
y_hat = model(x, l)
y_hat.shape

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()

class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="batchmean")
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 1))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        self.true_dist = true_dist
        x = F.log_softmax(x, dim=-1)
        return self.criterion(x, Variable(true_dist, requires_grad=False))

criterion = LabelSmoothing(5, smoothing=0.1)
x = torch.rand(3, 5)
y = torch.LongTensor([2, 1, 4])
criterion(x, y)

opt = NoamOpt(256, 0.8, 2000, None)
plt.plot(np.arange(1, 20000), [opt.rate(i) for i in range(1, 20000)])

def save_model(model, noam, name, root_path=F'/content/gdrive/My Drive/models/'):
    state = {
              'model': model.state_dict(),
              'optimizer': noam.optimizer.state_dict(),
              '_step': noam._step,
              'warmup': noam.warmup,
              'factor': noam.factor,
              'model_size': noam.model_size,
              '_rate': noam._rate
            }
    torch.save(state, path.join(root_path, name))

def load_model(name):
    model = PhonemeClassifier()
    noam = NoamOpt(256, 0.5, 2000,
                   optim.Adam(model.parameters(), 
                              lr=0))
    model.cuda()
    checkpoint = torch.load('/content/gdrive/My Drive/models/' + name)
    model.load_state_dict(checkpoint['model'])
    noam.optimizer.load_state_dict(checkpoint['optimizer'])
    noam._step = checkpoint['_step']
    noam.warmup = checkpoint['warmup']
    noam.factor = checkpoint['factor']
    noam.model_size = checkpoint['model_size']
    noam._rate = checkpoint['_rate']
    return model, noam

gc.collect()
batch_size = 128
epochs = 20
print_every = 100
save = True
load = False
if load:
    model, optimizer = load_model()
else:
    model = PhonemeClassifier()
    optimizer = NoamOpt(256, 0.5, 2000,
                    optim.Adam(model.parameters(), 
                               lr=0))
    model.cuda()

criterion = LabelSmoothing(61, smoothing=0.1)
criterion.cuda()
dataset = TIMITRawDataset(X_train, y_train)
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        collate_fn=custom_collate)
val_dataset = TIMITRawDataset(X_test, y_test)
val_loader = DataLoader(val_dataset,
                        batch_size=128,
                        shuffle=True,
                        collate_fn=custom_collate)
ave_loss = []
T = len(dataloader)
total_elapsed = 0
begin = time.time()
for e in range(epochs):
    start = time.time()
    model.train()
    for i, (x, l, y) in enumerate(dataloader):
        x, l, y = x.cuda(), l.cuda(), y.cuda()
        y_hat = model(x, l)
        loss = criterion(y_hat, y)
        ave_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i%print_every == 0:
            ave_loss = np.mean(ave_loss)
            elapsed = time.time() - start
            print("time: {:.2f} sec, progress: {}/{}, loss: {:.4f}"
                  .format(elapsed, i, T, ave_loss))
            ave_loss = []
            
    total_correct = 0
    total_n = 0
    model.eval()
    for x,l,y in val_loader:
        x,l,y = x.cuda(), l.cuda(), y.cuda()
        y_hat = model(x, l)
        scores = F.softmax(y_hat, dim=-1)
        pred = scores.topk(1, dim=-1)[1].squeeze()
        total_correct += sum(pred==y).item()
        total_n += y.size(0)
    accuracy = total_correct / total_n
    total_elapsed = time.time() - begin
    remaining = (total_elapsed / (e+1)) * (epochs - e - 1)
    print("epoch: {}, accuracy: {:.2f}%, elapsed: {:.2f} min, remaining: {:.2f} min"
          .format(e+1, accuracy*100, total_elapsed/60, remaining/60))
    print()
    if save:
        save_model(model, optimizer, "lstm_model_padding.pt")

model, _ = load_model("lstm_model_padding.pt")

total_correct = 0
total_n = 0
model.eval()
val_dataset = TIMITRawDataset(X_test, y_test)
val_loader = DataLoader(val_dataset,
                        batch_size=128,
                        shuffle=True,
                        collate_fn=custom_collate)
for x,l,y in val_loader:
    x,l,y = x.cuda(), l.cuda(), y.cuda()
    y_hat = model(x, l)
    scores = F.softmax(y_hat, dim=-1)
    topm = 10
    pred = scores.topk(topm, dim=-1)[1].squeeze()
    total_correct += sum(sum(pred[:,i]==y).item() for i in range(topm))
    total_n += y.size(0)
accuracy = total_correct / total_n
accuracy

val_dataset = TIMITRawDataset(X_test, y_test)
val_loader = DataLoader(val_dataset,
                        batch_size=1,
                        shuffle=True,
                        collate_fn=custom_collate)

x, l, y = next(iter(val_loader))
x, l = x.cuda(), l.cuda()
x_, l_ = x[:,:, :int(l[0]/2)], l/2

L = [next(iter(val_loader)) for _ in range(4)]
_x = torch.cat([L[i][0] for i in range(4)], dim=2)
_l = torch.tensor([x.size(2)])
_x,_l = _x.cuda(), _l.cuda()

y_hat = model(x, l)
_y_hat = model(_x, _l)
y__hat = model(x_, l_)
y_hat.topk(3), _y_hat.topk(3), y__hat.topk(3)

y

phones[27]

val_dataset = TIMITRawDataset(X_test, y_test)
val_loader = DataLoader(val_dataset,
                        batch_size=128,
                        shuffle=True,
                        collate_fn=custom_collate)
total_occur = np.zeros(61)
total_correct = np.zeros(61)
miss_occur = np.zeros((61,61))
for x,l,y in val_loader:
    x,l,y = x.cuda(), l.cuda(), y.cuda()
    y_hat = model(x, l)
    scores = F.softmax(y_hat, dim=-1)
    p = scores.topk(1, dim=-1)[1].squeeze()
    for i in range(len(y)):
        if y[i] == p[i]:
            total_correct[y[i]] += 1
        else:
            miss_occur[y[i], p[i]] += 1
        total_occur[y[i]] += 1

phones = np.array([
    'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em',
    'en', 'eng', 'epi', 'er', 'ey', 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm',
    'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w',
    'y', 'z', 'zh'
])
phone_to_idx = {
    phone: i for i, phone in enumerate(phones)
}

rate = total_correct / total_occur
p_occur = total_occur / total_occur.sum()
miss_rate = (miss_occur.T / total_occur).T
miss_rate_idx = np.argsort(miss_rate, axis=1)
for i in range(len(phones)):
    miss1 = miss_rate_idx[i, -1]
    miss2 = miss_rate_idx[i, -2]
    miss3 = miss_rate_idx[i, -3]
    print("Phoneme: {}, rate: {:.2f}%, occurrence: {:.2f}%"
          .format(phones[i], 
                  rate[i]*100, 
                  p_occur[i]*100))
    print("\tMisclassification - {}: {:.2f}%, {}: {:.2f}%, {}: {:.2f}%"
          .format(phones[miss1], miss_rate[i, miss1]*100, 
                  phones[miss2], miss_rate[i, miss2]*100, 
                  phones[miss3], miss_rate[i, miss3]*100))
    print()

miss_rate = (miss_occur.T / miss_occur.sum(axis=1)).T

np.argsort(miss_rate)

(miss_occur / miss_occur.sum(axis=1))

