# -*- coding: utf-8 -*-
"""PhonemeRecognizer.ipynb
This is downloaded directly from Jupyter Notebook. DON'T RUN!

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1U_Ws9T7QrVxaipStWXF94jw3gV7bPJhO
"""

import matplotlib.pyplot as plt
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.autograd import Variable

import time
from os import path
import pickle
import gc

from google.colab import drive
drive.mount('/content/gdrive')

assert torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TIMIT_root = '/content/gdrive/My Drive/data/TIMIT'
file_name = "0.pkl"
train_path = path.join(TIMIT_root, "train", file_name)
val_path = path.join(TIMIT_root, "val", file_name)
with open(train_path, 'rb') as infile:
    X_train, y_train = pickle.load(infile)
with open(val_path, 'rb') as infile:
    X_test, y_test = pickle.load(infile)

audio_path = path.join(TIMIT_root, "train/audios.pkl")
with open(audio_path, 'rb') as infile:
    audios, _ = pickle.load(infile)

plt.plot(audios[0])

class TIMITRawDataset(Dataset):
    def __init__(self, signal, y):
        super().__init__()
        self.data = signal
        self.label = y
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

def custom_collate(batch):
    signal = [torch.tensor(item[0]) for item in batch]
    lengths = [len(item[0]) for item in batch]
    y = [item[1] for item in batch]
    y = [x for _,x in sorted(zip(lengths, y), reverse=True)]
    lengths.sort(reverse=True)
    signal.sort(reverse=True, key=lambda x: len(x))
    signal = pad_sequence(signal, batch_first=True).unsqueeze(1).type(dtype=torch.float)
    return signal, torch.tensor(lengths), torch.tensor(y).type(dtype=torch.LongTensor)

train_lengths = [len(x) for x in X_train]
def generate_segments(batch_size):
    segments = []
    lengths = []
    for _ in range(batch_size):
        audio = np.random.choice(audios)
        s = np.random.randint(0, len(audio) - 1000)
        length = np.random.choice(train_lengths) + int(np.random.normal(scale=100))
        segment = audio[s:s + length]
        segments.append(torch.tensor(segment))
        lengths.append(len(segment))
    segments.sort(reverse=True, key=lambda x: len(x))
    lengths.sort(reverse=True)
    segments = pad_sequence(segments, batch_first=True).unsqueeze(1).type(dtype=torch.float)
    return segments, torch.tensor(lengths)

x, l = generate_segments(128)
l

class PhonemeRecognizer(nn.Module):
    def __init__(self):
        super(PhonemeRecognizer, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, 30, stride=10)
        self.conv1_1 = nn.Conv1d(64, 64, 15, stride=1, padding=7)
        self.conv2 = nn.Conv1d(64, 128, 15, stride=2)
        self.conv2_1 = nn.Conv1d(128, 128, 5, stride=1, padding=2)
        self.conv3 = nn.Conv1d(128, 256, 5, stride=2)
        self.conv3_1 = nn.Conv1d(256, 256, 3, stride=1, padding=1)
        self.lstm = nn.LSTM(256, 256, 3, bidirectional=False, dropout=0.1)
        self.linear = nn.Linear(256*3, 256)
        self.linear2 = nn.Linear(256, 1)
    
    def forward(self, x, l):
        x = F.dropout(F.relu(self.conv1(x)), p=0.1)
        x = F.relu(self.conv1_1(x))
        x = F.dropout(F.relu(self.conv2(x)), p=0.1)
        x = F.relu(self.conv2_1(x))
        x = F.dropout(F.relu(self.conv3(x)), p=0.1)
        x = F.relu(self.conv3_1(x))
        factor = l.max().item() / x.size(2)
        new_lengths = l.type(dtype=torch.float) / factor
        new_lengths = new_lengths.type(dtype=torch.int)
        new_lengths[new_lengths==0] = 1
        x = pack_padded_sequence(x.permute(2, 0, 1), new_lengths)
        out, (h_n, c_n) = self.lstm(x)
        h_n = h_n.permute(1, 0, 2).reshape(-1, 1, 3*256)
        y = F.relu(self.linear(h_n))
        y = self.linear2(y).squeeze(1)
        return y
    
class PhonemeClassifier(nn.Module):
    def __init__(self):
        super(PhonemeClassifier, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, 30, stride=10)
        self.conv1_1 = nn.Conv1d(64, 64, 15, stride=1, padding=7)
        self.conv2 = nn.Conv1d(64, 128, 15, stride=2)
        self.conv2_1 = nn.Conv1d(128, 128, 5, stride=1, padding=2)
        self.conv3 = nn.Conv1d(128, 256, 5, stride=2)
        self.conv3_1 = nn.Conv1d(256, 256, 3, stride=1, padding=1)
        self.lstm = nn.LSTM(256, 256, 3, bidirectional=False, dropout=0.1)
        self.linear = nn.Linear(256*3, 256)
        self.linear2 = nn.Linear(256, 61)
    
    def forward(self, x, l):
        x = F.dropout(F.relu(self.conv1(x)), p=0.1)
        x = F.relu(self.conv1_1(x))
        x = F.dropout(F.relu(self.conv2(x)), p=0.1)
        x = F.relu(self.conv2_1(x))
        x = F.dropout(F.relu(self.conv3(x)), p=0.1)
        x = F.relu(self.conv3_1(x))
        factor = l.max().item() / x.size(2)
        new_lengths = l.type(dtype=torch.float) / factor
        new_lengths = new_lengths.type(dtype=torch.int)
        new_lengths[new_lengths==0] = 1
        x = pack_padded_sequence(x.permute(2, 0, 1), new_lengths)
        out, (h_n, c_n) = self.lstm(x)
        h_n = h_n.permute(1, 0, 2).reshape(-1, 1, 3*256)
        y = F.relu(self.linear(h_n))
        y = self.linear2(y).squeeze(1)
        return y

model = PhonemeRecognizer()
model.cuda()
scores = model(x,l)
scores.mean()

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()

class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="batchmean")
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 1))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        self.true_dist = true_dist
        x = F.log_softmax(x, dim=-1)
        return self.criterion(x, Variable(true_dist, requires_grad=False))

def save_model(model, optimizer, name, root_path=F'/content/gdrive/My Drive/models/'):
    state = {
              'model': model.state_dict(),
              'optimizer': optimizer.state_dict(),
            }
    torch.save(state, path.join(root_path, name))

def load_model(name):
    model = PhonemeRecognizer()
    optimizer = optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))
    model.cuda()
    checkpoint = torch.load('/content/gdrive/My Drive/models/' + name)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer

def load_classifier(name):
    model = PhonemeClassifier()
    noam = NoamOpt(256, 0.5, 2000,
                   optim.Adam(model.parameters(), 
                              lr=0))
    model.cuda()
    checkpoint = torch.load('/content/gdrive/My Drive/models/' + name)
    model.load_state_dict(checkpoint['model'])
    noam.optimizer.load_state_dict(checkpoint['optimizer'])
    noam._step = checkpoint['_step']
    noam.warmup = checkpoint['warmup']
    noam.factor = checkpoint['factor']
    noam.model_size = checkpoint['model_size']
    noam._rate = checkpoint['_rate']
    return model, noam

save_model(model, optimizer, "recognizer.pt")

gc.collect()
batch_size = 128
epochs = 5
print_every = 10
save = True
model = PhonemeRecognizer()
# optimizer = NoamOpt(256, 0.5, 2000,
#                 optim.Adam(model.parameters(), 
#                            lr=0))
optimizer = optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))
model.cuda()
dataset = TIMITRawDataset(X_train, y_train)
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        collate_fn=custom_collate)
ave_real_score = []
ave_fake_score = []
T = len(dataloader)
total_elapsed = 0
begin = time.time()
for e in range(epochs):
    start = time.time()
    model.train()
    for i, (x, l, _) in enumerate(dataloader):
        x, l = x.cuda(), l.cuda()
        r_scores = model(x, l)
        r_clamp = torch.clamp(r_scores, -10, 10)
        r_residual = torch.abs(r_scores - r_clamp)
        
        x_hat, l_hat = generate_segments(batch_size)
        x_hat, l_hat = x_hat.cuda(), l_hat.cuda()
        f_scores = model(x_hat, l_hat)
        f_clamp = torch.clamp(f_scores, -10, 10)
        f_residual = torch.abs(f_scores - f_clamp)
        
        loss = f_clamp.mean() - r_clamp.mean() + \
               r_residual.mean() + f_residual.mean()
        
        ave_real_score.append(r_scores.mean().item())
        ave_fake_score.append(f_scores.mean().item())
        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()
        if i%print_every == 0:
            ave_real_score = np.mean(ave_real_score)
            ave_fake_score = np.mean(ave_fake_score)
            elapsed = time.time() - start
            print("time: {:.2f} sec, progress: {}/{}, real: {:.4f}, fake: {:.4f}"
                  .format(elapsed, i, T, ave_real_score, ave_fake_score))
            ave_real_score = []
            ave_fake_score = []
            
    total_elapsed = time.time() - begin
    remaining = (total_elapsed / (e+1)) * (epochs - e - 1)
    print("epoch: {}, elapsed: {:.2f} min, remaining: {:.2f} min"
          .format(e+1, total_elapsed/60, remaining/60))
    print()
    if save:
        save_model(model, optimizer, "recognizer.pt")

x, l, _ = next(iter(dataloader))
x, l = generate_segments(128)
x, l = x.cuda(), l.cuda()
x.shape, l

audio = audios[0]
model.eval()
model.cuda()
i = 0
k = 500
segments = []
while k + 500 + i < len(audio):
    x = torch.FloatTensor(audio[k - 500:k + 500 + i])
    l = torch.tensor(len(x))
    x, l = x.view(1, 1, -1).cuda(), l.unsqueeze(0).cuda()
    score = model(x,l)
    i += 10
    if score > 5:
        segments.append(k + i)
        k = k + i
segments.append(len(audio) - 1)

phones = np.array([
    'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em',
    'en', 'eng', 'epi', 'er', 'ey', 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm',
    'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w',
    'y', 'z', 'zh'
])
phone_to_idx = {
    phone: i for i, phone in enumerate(phones)
}

classifier, _ = load_classifier("lstm_model_padding.pt")

s0 = 0
for s in segments:
    sample = torch.FloatTensor(audio[s0:s])
    length = torch.tensor(len(sample))
    sample, length = sample.view(1, 1, -1).cuda(), length.unsqueeze(0).cuda()
    p = classifier(sample, length)
    idx = p.topk(1)[1].item()
    print(s0, s, phones[idx])
    s0 = s

segments

