Commit ab599862 authored by Egor's avatar Egor
Browse files

Update model.py

Showing with 19 additions and 19 deletions
+19 -19
......@@ -54,20 +54,19 @@ class GraphConvolution(nn.Module):
class GCN_GRU(nn.Module):
def __init__(self, config, nfeat, entity_vocab, relation_vocab):
super(GCN_GRU, self).__init__()
entity_max = max(entity_vocab.values())
relation_max = max(relation_vocab.values())
self.n_hop_kg = pickle_load(f'{config.preprocess_results_dir}/n_hop_kg.pkl')
# self.entity_emb = torch.nn.Embedding(entity_max, nfeat)
# self.kge_model.entity_embedding = torch.nn.Embedding(entity_max, nfeat)
# uniform_range = 6 / np.sqrt(nfeat)
# self.entity_emb.weight.data.uniform_(-uniform_range, uniform_range)
# self.kge_model.entity_embedding.weight.data.uniform_(-uniform_range, uniform_range)
# self.relation_emb = torch.nn.Embedding(relation_max, nfeat)
# uniform_range = 6 / np.sqrt(nfeat)
# self.relation_emb.weight.data.uniform_(-uniform_range, uniform_range)
self.kge_model = KGEModel(
entity_max = max(entity_vocab.values())
relation_max = max(relation_vocab.values())
self.n_hop_kg = pickle_load(f'{config.preprocess_results_dir}/n_hop_kg.pkl')
self.kge_model = KGEModel(
model_name="TransE",
nentity=entity_max,
nrelation=relation_max,
......@@ -77,7 +76,6 @@ class GCN_GRU(nn.Module):
double_relation_embedding=True,
evaluator=None
)
self.gc1 = GraphConvolution(nfeat, nfeat)
self.gc2 = GraphConvolution(nfeat, nfeat)
......@@ -94,11 +92,13 @@ class GCN_GRU(nn.Module):
heads = triplets[:, 0]
relations = triplets[:, 1]
tails = triplets[:, 2]
return (self.entity_emb.weight[heads] + self.relation_emb.weight[relations] - self.entity_emb.weight[tails]).norm(p=1, dim=1)
return (self.kge_model.entity_embedding.weight[heads] +
self.relation_emb.weight[relations] -
self.kge_model.entity_embedding.weight[tails]).norm(p=1, dim=1)
def TransE_forward(self, pos_triplet, neg_triplet):
# -1 to avoid nan for OOV vector
self.entity_emb.weight.data[:-1, :].div_(self.entity_emb.weight.data[:-1, :].norm(p=2, dim=1, keepdim=True))
self.kge_model.entity_embedding.weight.data[:-1, :].div_(self.kge_model.entity_embedding.weight.data[:-1, :].norm(p=2, dim=1, keepdim=True))
pos_distance = self.distance(pos_triplet)
neg_distance = self.distance(neg_triplet)
......@@ -109,14 +109,14 @@ class GCN_GRU(nn.Module):
def forward_GCN(self, x):
# GCN
out = F.relu(self.gc1(self.entity_emb.weight))
out = F.relu(self.gc1(self.kge_model.entity_embedding.weight))
out = self.gc2(out)
out = F.log_softmax(out, dim=1)[x]
return out
def forward(self, x):
# GCN
out = F.relu(self.gc1(self.entity_emb.weight))
out = F.relu(self.gc1(self.kge_model.entity_embedding.weight))
out = self.gc2(out)
out = F.log_softmax(out, dim=1)[x].reshape(1,1,-1)
......@@ -134,9 +134,9 @@ class GRU(nn.Module):
relation_max = max(relation_vocab.values())
self.n_hop_kg = pickle_load(f'{config.preprocess_results_dir}/n_hop_kg.pkl')
self.entity_emb = torch.nn.Embedding(entity_max, nfeat)
self.kge_model.entity_embedding = torch.nn.Embedding(entity_max, nfeat)
uniform_range = 6 / np.sqrt(nfeat)
self.entity_emb.weight.data.uniform_(-uniform_range, uniform_range)
self.kge_model.entity_embedding.weight.data.uniform_(-uniform_range, uniform_range)
self.relation_emb = torch.nn.Embedding(relation_max, nfeat)
uniform_range = 6 / np.sqrt(nfeat)
......@@ -155,11 +155,11 @@ class GRU(nn.Module):
heads = triplets[:, 0]
relations = triplets[:, 1]
tails = triplets[:, 2]
return (self.entity_emb.weight[heads] + self.relation_emb.weight[relations] - self.entity_emb.weight[tails]).norm(p=1, dim=1)
return (self.kge_model.entity_embedding.weight[heads] + self.relation_emb.weight[relations] - self.kge_model.entity_embedding.weight[tails]).norm(p=1, dim=1)
def TransE_forward(self, pos_triplet, neg_triplet):
# -1 to avoid nan for OOV vector
self.entity_emb.weight.data[:-1, :].div_(self.entity_emb.weight.data[:-1, :].norm(p=2, dim=1, keepdim=True))
self.kge_model.entity_embedding.weight.data[:-1, :].div_(self.kge_model.entity_embedding.weight.data[:-1, :].norm(p=2, dim=1, keepdim=True))
pos_distance = self.distance(pos_triplet)
neg_distance = self.distance(neg_triplet)
......@@ -170,13 +170,13 @@ class GRU(nn.Module):
def forward_GCN(self, x):
# GCN
out = F.relu(self.gc1(self.entity_emb.weight))
out = F.relu(self.gc1(self.kge_model.entity_embedding.weight))
out = self.gc2(out)
out = F.log_softmax(out, dim=1)[x]
return out
def forward(self, x):
x = self.entity_emb.weight[x]
x = self.kge_model.entity_embedding.weight[x]
x = x.reshape(1,1,-1)
# GRU
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment