Commit d683d60d authored by Yefim Osipov's avatar Yefim Osipov
Browse files

Formmating

No related merge requests found
Showing with 96 additions and 1898 deletions
+96 -1898
# %%
import networkx as nx
import numpy as np
import torch
from tqdm import tqdm
from rostok.graph_grammar.node import GraphGrammar
from rostok.graph_grammar.node_vocabulary import NodeVocabulary
from rostok.graph_grammar.rule_vocabulary import RuleVocabulary
import rule_without_chrono as re
# %%
def get_input_layer(node, dict_id_label_nodes):
input = torch.zeros(len(dict_id_label_nodes)).long()
input[node] = 1
return input
def vocabulary2batch_graph(rule_vocabulary: RuleVocabulary, max_rules: int):
batch_graph = GraphGrammar()
amount_rules = np.random.randint(1, max_rules)
for _ in range(amount_rules):
rules = rule_vocabulary.get_list_of_applicable_rules(batch_graph)
if len(rules) > 0:
rule = rule_vocabulary.get_rule(rules[np.random.choice(len(rules))])
batch_graph.apply_rule(rule)
else:
break
return batch_graph
def random_batch(skip_grams):
random_inputs = []
random_labels = []
random_index = np.random.choice(range(len(skip_grams)), 2, replace=False)
for i in random_index:
random_inputs.append(skip_grams[i][0]) # target
random_labels.append(skip_grams[i][1]) # context word
return random_inputs, random_labels
# %%
class skipgramm_model(torch.nn.Module):
def __init__(self, vocabulary_size: int, embedding_size: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocabulary_size, embedding_size)
self.W = torch.nn.Linear(embedding_size, embedding_size, bias=False)
self.WT = torch.nn.Linear(embedding_size, vocabulary_size, bias=False)
def forward(self, x):
embdedings = self.embedding(x)
hidden_layer = torch.nn.functional.relu(self.W(embdedings))
output_layer = self.WT(hidden_layer)
return output_layer
def get_node_embedding(self, node, sorted_node_labels, dict_label_id_nodes):
input = torch.zeros(len(sorted_node_labels)).float()
input[dict_label_id_nodes[node]] = 1
return self.embedding(input).view(1, -1)
def skipgram(paths, dict_label_id_nodes, window_size=1):
idx_pairs = []
for path in paths:
indices = [dict_label_id_nodes[node_label] for node_label in path]
for pos_center_node, node_index in enumerate(indices):
for i in range(-window_size, window_size + 1):
pos_context_node = pos_center_node + i
if pos_context_node < 0 or pos_context_node >= len(
indices) or pos_center_node == pos_context_node:
continue
context_id_node = indices[pos_context_node]
idx_pairs.append((node_index, context_id_node))
return np.array(idx_pairs)
def create_dict_node_labels(node_vocabulary: NodeVocabulary):
sorted_node_labels = sorted(node_vocabulary.node_dict.keys())
dict_id_label_nodes = dict(enumerate(sorted_node_labels))
dict_label_id_nodes = {w: idx for (idx, w) in enumerate(sorted_node_labels)}
return dict_id_label_nodes, dict_label_id_nodes
# %%
rule_vocab = re.init_extension_rules()
node_vocabulary = rule_vocab.node_vocab
id2label, label2id = create_dict_node_labels(node_vocabulary)
model = skipgramm_model(len(id2label), 2)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
graph = vocabulary2batch_graph(rule_vocab, 15)
pairs = skipgram(graph.get_uniq_representation(),label2id)
for epoch in tqdm(range(150000), total=len(pairs)):
input_batch, target_batch = random_batch(pairs)
input_batch = get_input_layer(input_batch, id2label)
target_batch = get_input_layer(target_batch, id2label)
optimizer.zero_grad()
output = model(input_batch)
# output : [batch_size, voc_size], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 10000 == 0:
print('Epoch:', '%04d' % (epoch + 1), ' cost =', '{:.6f}'.format(loss))
loss.backward(retain_graph=True)
optimizer.step()
# %%
import os
import random
import re
from collections import Counter, OrderedDict
from dataclasses import dataclass
from time import monotonic
from typing import Dict, List, Optional, Union
from typing import List, Union
import numpy as np
import torch
import torch.nn as nn
from scipy.spatial.distance import cosine
from torch.utils.data import DataLoader
from torchtext.data import to_map_style_dataset
from tqdm import tqdm
from rostok.graph_grammar.node import GraphGrammar
......@@ -35,9 +32,11 @@ def vocabulary2batch_graph(rule_vocabulary: RuleVocabulary, max_rules: int):
break
return batch_graph
def create_train_valid_data(rule_vocabulary: RuleVocabulary, amount_graph: int, pseudo_length_graph: int):
def create_train_valid_data(rule_vocabulary: RuleVocabulary, amount_graph: int,
pseudo_length_graph: int):
train_data = []
for __ in range(round(amount_graph*0.8)):
for __ in range(round(amount_graph * 0.8)):
flatted_graph = []
graph = vocabulary2batch_graph(rule_vocabulary, pseudo_length_graph)
df_travels = graph.get_uniq_representation()
......@@ -45,7 +44,7 @@ def create_train_valid_data(rule_vocabulary: RuleVocabulary, amount_graph: int,
flatted_graph = flatted_graph + path
train_data.append(flatted_graph)
valid_data = []
for __ in range(round(amount_graph*0.2)):
for __ in range(round(amount_graph * 0.2)):
flatted_graph = []
graph = vocabulary2batch_graph(rule_vocabulary, pseudo_length_graph)
df_travels = graph.get_uniq_representation()
......@@ -54,6 +53,7 @@ def create_train_valid_data(rule_vocabulary: RuleVocabulary, amount_graph: int,
valid_data.append(flatted_graph)
return train_data, valid_data
# %%
@dataclass
class Word2VecParams:
......@@ -70,13 +70,15 @@ class Word2VecParams:
EMBED_DIM = 10
EMBED_MAX_NORM = None
N_EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CRITERION = nn.BCEWithLogitsLoss()
# %%
class model_vocabulary:
def __init__(self, node_vocabulary: NodeVocabulary):
sorted_node_labels = sorted(node_vocabulary.node_dict.keys())
self.itos = dict(enumerate(sorted_node_labels))
......@@ -86,21 +88,19 @@ class model_vocabulary:
return len(self.stoi) - 1
def get_index(self, label_node: Union[str, List]):
if isinstance(label_node, str):
if label_node in self.stoi:
if label_node in self.stoi:
return self.stoi.get(label_node)
elif isinstance(label_node, list):
res = []
for n in label_node:
if n in self.stoi:
if n in self.stoi:
res.append(self.stoi.get(n))
return res
else:
raise ValueError(
f"Label node {label_node} is not a string or a list of strings."
)
raise ValueError(f"Label node {label_node} is not a string or a list of strings.")
def lookup_token(self, token: Union[int, List]):
if isinstance(token, (int, np.int64)):
......@@ -117,69 +117,66 @@ class model_vocabulary:
raise ValueError(f"Token {t} is not a valid index.")
return res
# %%
def calculate_frequency_nodes(vocab: model_vocabulary, flatted_graphs: list):
frequency_nodes = {label: 0 for label in vocab.stoi.keys()}
for graph in flatted_graphs:
for node in graph:
frequency_nodes[node] = int(frequency_nodes.get(node, 0) + 1)
total_nodes = np.nansum([f for f in frequency_nodes.values()], dtype=int)
return frequency_nodes, total_nodes
# %%
class SkipGrams:
def __init__(self, vocab: model_vocabulary, flatted_graph: list, params: Word2VecParams):
self.vocab = vocab
self.params = params
freq_dict, total_tokens = calculate_frequency_nodes(self.vocab, flatted_graph)
self.t = self._t(freq_dict, total_tokens)
self.discard_probs = self._create_discard_dict(freq_dict, total_tokens)
def _t(self, freq_dict, total_tokens):
freq_list = []
for freq in list(freq_dict.values())[1:]:
freq_list.append(freq/total_tokens)
freq_list.append(freq / total_tokens)
return np.percentile(freq_list, self.params.T)
def _create_discard_dict(self, freq_dict, total_tokens):
discard_dict = {}
for node, freq in freq_dict.items():
dicard_prob = 1-np.sqrt(
self.t / (freq/total_tokens + self.t))
dicard_prob = 1 - np.sqrt(self.t / (freq / total_tokens + self.t))
discard_dict[self.vocab.stoi[node]] = dicard_prob
return discard_dict
def collate_skipgram(self, batch):
batch_input, batch_output = [], []
batch_input, batch_output = [], []
for graph in batch:
node_tokens = self.vocab.get_index(graph)
if len(node_tokens) < self.params.SKIPGRAM_N_WORDS * 2 + 1:
continue
for idx in range(len(node_tokens) - self.params.SKIPGRAM_N_WORDS*2
):
token_id_sequence = node_tokens[
idx : (idx + self.params.SKIPGRAM_N_WORDS * 2 + 1)
]
for idx in range(len(node_tokens) - self.params.SKIPGRAM_N_WORDS * 2):
token_id_sequence = node_tokens[idx:(idx + self.params.SKIPGRAM_N_WORDS * 2 + 1)]
input_ = token_id_sequence.pop(self.params.SKIPGRAM_N_WORDS)
outputs = token_id_sequence
prb = random.random()
del_pair = self.discard_probs.get(input_)
if input_==0 or del_pair >= prb:
if input_ == 0 or del_pair >= prb:
continue
else:
for output in outputs:
prb = random.random()
del_pair = self.discard_probs.get(output)
if output==0 or del_pair >= prb:
if output == 0 or del_pair >= prb:
continue
else:
batch_input.append(input_)
......@@ -187,12 +184,15 @@ class SkipGrams:
batch_input = torch.tensor(batch_input, dtype=torch.long)
batch_output = torch.tensor(batch_output, dtype=torch.long)
return batch_input, batch_output
# %%
class NegativeSampler:
def __init__(self, vocab: model_vocabulary, train_graphs: list, ns_exponent: float, ns_array_len: int):
def __init__(self, vocab: model_vocabulary, train_graphs: list, ns_exponent: float,
ns_array_len: int):
self.vocab = vocab
self.ns_exponent = ns_exponent
self.ns_array_len = ns_array_len
......@@ -205,61 +205,58 @@ class NegativeSampler:
frequency_dict, total_tokens = calculate_frequency_nodes(self.vocab, train_graphs)
frequency_dict_scaled = {
self.vocab.stoi[node]:
max(1,int((freq/total_tokens)*self.ns_array_len))
self.vocab.stoi[node]: max(1, int((freq / total_tokens) * self.ns_array_len))
for node, freq in frequency_dict.items()
}
}
ns_array = []
for node, freq in tqdm(frequency_dict_scaled.items()):
ns_array = ns_array + [node]*freq
ns_array = ns_array + [node] * freq
return ns_array
def sample(self, n_batches: int=1, n_samples: int=1):
def sample(self, n_batches: int = 1, n_samples: int = 1):
samples = []
for _ in range(n_batches):
samples.append(random.sample(self.ns_array, n_samples))
samples = torch.as_tensor(np.array(samples))
return samples
# %%
# Model
class Model(nn.Module):
def __init__(self, vocab: model_vocabulary, params: Word2VecParams):
super().__init__()
self.vocab = vocab
self.t_embeddings = nn.Embedding(
self.vocab.__len__()+1,
params.EMBED_DIM,
max_norm=params.EMBED_MAX_NORM
)
self.c_embeddings = nn.Embedding(
self.vocab.__len__()+1,
params.EMBED_DIM,
max_norm=params.EMBED_MAX_NORM
)
self.t_embeddings = nn.Embedding(self.vocab.__len__() + 1,
params.EMBED_DIM,
max_norm=params.EMBED_MAX_NORM)
self.c_embeddings = nn.Embedding(self.vocab.__len__() + 1,
params.EMBED_DIM,
max_norm=params.EMBED_MAX_NORM)
def forward(self, inputs, context):
# getting embeddings for target & reshaping
# getting embeddings for target & reshaping
target_embeddings = self.t_embeddings(inputs)
n_examples = target_embeddings.shape[0]
n_dimensions = target_embeddings.shape[1]
target_embeddings = target_embeddings.view(n_examples, 1, n_dimensions)
# get embeddings for context labels & reshaping
# get embeddings for context labels & reshaping
# Allows us to do a bunch of matrix multiplications
context_embeddings = self.c_embeddings(context)
# * This transposes each batch
context_embeddings = context_embeddings.permute(0,2,1)
context_embeddings = context_embeddings.permute(0, 2, 1)
# * custom linear layer
dots = target_embeddings.bmm(context_embeddings)
dots = dots.view(dots.shape[0], dots.shape[2])
return dots
return dots
def normalize_embeddings(self):
embeddings = list(self.t_embeddings.parameters())[0]
embeddings = embeddings.cpu().detach().numpy()
norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)
embeddings = embeddings.cpu().detach().numpy()
norms = (embeddings**2).sum(axis=1)**(1 / 2)
norms = norms.reshape(norms.shape[0], 1)
return embeddings / norms
......@@ -270,7 +267,7 @@ class Model(nn.Module):
node_vec = embedding_norms[node_id]
node_vec = np.reshape(node_vec, (node_vec.shape[0], 1))
dists = np.matmul(embedding_norms, node_vec).flatten()
topN_ids = np.argsort(-dists)[1 : n + 1]
topN_ids = np.argsort(-dists)[1:n + 1]
topN_dict = {}
for sim_node_id in topN_ids:
......@@ -284,16 +281,18 @@ class Model(nn.Module):
if idx1 == 0 or idx2 == 0:
print("One or both words are out of vocabulary")
return
embedding_norms = self.normalize_embeddings()
node1_vec, node2_vec = embedding_norms[idx1], embedding_norms[idx2]
return cosine(node1_vec, node2_vec)
# %%
class Trainer:
def __init__(self, model: Model, params: Word2VecParams, optimizer,
vocab: model_vocabulary, train_iter, valid_iter, skipgrams: SkipGrams):
def __init__(self, model: Model, params: Word2VecParams, optimizer, vocab: model_vocabulary,
train_iter, valid_iter, skipgrams: SkipGrams):
self.model = model
self.optimizer = optimizer
self.vocab = vocab
......@@ -309,46 +308,38 @@ class Trainer:
self.model.to(self.params.DEVICE)
self.params.CRITERION.to(self.params.DEVICE)
self.negative_sampler = NegativeSampler(
vocab=self.vocab, ns_exponent=.75,
train_graphs=self.train_iter,
ns_array_len=self.params.NS_ARRAY_LEN
)
self.negative_sampler = NegativeSampler(vocab=self.vocab,
ns_exponent=.75,
train_graphs=self.train_iter,
ns_array_len=self.params.NS_ARRAY_LEN)
self.testnode = ['F1', 'J', 'L1', 'EM']
def train(self):
self.test_testnode()
for epoch in range(self.params.N_EPOCHS):
# Generate Dataloaders
self.train_dataloader = DataLoader(
self.train_iter,
batch_size=self.params.BATCH_SIZE,
shuffle=False,
collate_fn=self.skipgrams.collate_skipgram
)
self.valid_dataloader = DataLoader(
self.valid_iter,
batch_size=self.params.BATCH_SIZE,
shuffle=False,
collate_fn=self.skipgrams.collate_skipgram
)
self.train_dataloader = DataLoader(self.train_iter,
batch_size=self.params.BATCH_SIZE,
shuffle=False,
collate_fn=self.skipgrams.collate_skipgram)
self.valid_dataloader = DataLoader(self.valid_iter,
batch_size=self.params.BATCH_SIZE,
shuffle=False,
collate_fn=self.skipgrams.collate_skipgram)
# training the model
st_time = monotonic()
self._train_epoch()
self.epoch_train_mins[epoch] = round((monotonic()-st_time)/60, 1)
self.epoch_train_mins[epoch] = round((monotonic() - st_time) / 60, 1)
# validating the model
self._validate_epoch()
print(f"""Epoch: {epoch+1}/{self.params.N_EPOCHS}\n""",
f""" Train Loss: {self.loss['train'][-1]:.2}\n""",
f""" Valid Loss: {self.loss['valid'][-1]:.2}\n""",
f""" Training Time (mins): {self.epoch_train_mins.get(epoch)}"""
"""\n"""
)
print(f"""Epoch: {epoch+1}/{self.params.N_EPOCHS}\n""",
f""" Train Loss: {self.loss['train'][-1]:.2}\n""",
f""" Valid Loss: {self.loss['valid'][-1]:.2}\n""",
f""" Training Time (mins): {self.epoch_train_mins.get(epoch)}"""
"""\n""")
self.test_testnode()
def _train_epoch(self):
self.model.train()
running_loss = []
......@@ -358,22 +349,17 @@ class Trainer:
continue
inputs = batch_data[0].to(self.params.DEVICE)
pos_labels = batch_data[1].to(self.params.DEVICE)
neg_labels = self.negative_sampler.sample(
pos_labels.shape[0], self.params.NEG_SAMPLES
)
neg_labels = self.negative_sampler.sample(pos_labels.shape[0], self.params.NEG_SAMPLES)
neg_labels = neg_labels.to(self.params.DEVICE)
context = torch.cat(
[pos_labels.view(pos_labels.shape[0], 1),
neg_labels], dim=1
)
context = torch.cat([pos_labels.view(pos_labels.shape[0], 1), neg_labels], dim=1)
# building the targets tensor
# building the targets tensor
y_pos = torch.ones((pos_labels.shape[0], 1))
y_neg = torch.zeros((neg_labels.shape[0], neg_labels.shape[1]))
y = torch.cat([y_pos, y_neg], dim=1).to(self.params.DEVICE)
self.optimizer.zero_grad()
outputs = self.model(inputs, context)
loss = self.params.CRITERION(outputs, y)
loss.backward()
......@@ -382,7 +368,7 @@ class Trainer:
running_loss.append(loss.item())
epoch_loss = np.mean(running_loss)
self.loss['train'].append(epoch_loss)
def _validate_epoch(self):
......@@ -396,15 +382,10 @@ class Trainer:
inputs = batch_data[0].to(self.params.DEVICE)
pos_labels = batch_data[1].to(self.params.DEVICE)
neg_labels = self.negative_sampler.sample(
pos_labels.shape[0], self.params.NEG_SAMPLES
).to(self.params.DEVICE)
context = torch.cat(
[pos_labels.view(pos_labels.shape[0], 1),
neg_labels], dim=1
)
pos_labels.shape[0], self.params.NEG_SAMPLES).to(self.params.DEVICE)
context = torch.cat([pos_labels.view(pos_labels.shape[0], 1), neg_labels], dim=1)
# building the targets tensor
# building the targets tensor
y_pos = torch.ones((pos_labels.shape[0], 1))
y_neg = torch.zeros((neg_labels.shape[0], neg_labels.shape[1]))
y = torch.cat([y_pos, y_neg], dim=1).to(self.params.DEVICE)
......@@ -425,6 +406,7 @@ class Trainer:
print(f"{v} ({sim:.3})", end=' ')
print('\n')
# %%
rule_vocab = re.init_extension_rules()
......@@ -432,19 +414,17 @@ rule_vocab = re.init_extension_rules()
params = Word2VecParams()
train_data, valid_data = create_train_valid_data(rule_vocab, 100000, 20)
vocab = model_vocabulary(rule_vocab.node_vocab)
skip_gram = SkipGrams(vocab=vocab, flatted_graph = train_data, params=params)
skip_gram = SkipGrams(vocab=vocab, flatted_graph=train_data, params=params)
model = Model(vocab=vocab, params=params).to(params.DEVICE)
optimizer = torch.optim.Adam(params = model.parameters())
optimizer = torch.optim.Adam(params=model.parameters())
# %%
trainer = Trainer(
model=model,
params=params,
optimizer=optimizer,
train_iter=train_data,
valid_iter=valid_data,
vocab=vocab,
skipgrams=skip_gram
)
trainer = Trainer(model=model,
params=params,
optimizer=optimizer,
train_iter=train_data,
valid_iter=valid_data,
vocab=vocab,
skipgrams=skip_gram)
trainer.train()
None
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment