Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
itmo-sai-code
ReDKG
Commits
ab599862
Commit
ab599862
authored
2 years ago
by
Egor
Browse files
Options
Download
Email Patches
Plain Diff
Update model.py
parent
32b09bb2
main
HEAD
RLfix
append-docs
basic-models-tests
basic-models-visualization
bellman-ford
custom-models
eegoro-fixed
eegoro-fixed-1
fix-gitlab-mirror
fix-readme-badges
fixes
gnn-examples
mirror
test_bellman_ford
test_ford_bellman
vis-fixes
visualization
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
redkg/model.py
+19
-19
redkg/model.py
with
19 additions
and
19 deletions
+19
-19
redkg/model.py
View file @
ab599862
...
...
@@ -54,20 +54,19 @@ class GraphConvolution(nn.Module):
class
GCN_GRU
(
nn
.
Module
):
def
__init__
(
self
,
config
,
nfeat
,
entity_vocab
,
relation_vocab
):
super
(
GCN_GRU
,
self
).
__init__
()
entity_max
=
max
(
entity_vocab
.
values
())
relation_max
=
max
(
relation_vocab
.
values
())
self
.
n_hop_kg
=
pickle_load
(
f
'
{
config
.
preprocess_results_dir
}
/n_hop_kg.pkl'
)
# self.entity_emb = torch.nn.Embedding(entity_max, nfeat)
# self.kge_model.entity_embedding = torch.nn.Embedding(entity_max, nfeat)
# uniform_range = 6 / np.sqrt(nfeat)
# self.entity_emb.weight.data.uniform_(-uniform_range, uniform_range)
# self.
kge_model.
entity_emb
edding
.weight.data.uniform_(-uniform_range, uniform_range)
# self.relation_emb = torch.nn.Embedding(relation_max, nfeat)
# uniform_range = 6 / np.sqrt(nfeat)
# self.relation_emb.weight.data.uniform_(-uniform_range, uniform_range)
self
.
kge_model
=
KGEModel
(
entity_max
=
max
(
entity_vocab
.
values
())
relation_max
=
max
(
relation_vocab
.
values
())
self
.
n_hop_kg
=
pickle_load
(
f
'
{
config
.
preprocess_results_dir
}
/n_hop_kg.pkl'
)
self
.
kge_model
=
KGEModel
(
model_name
=
"TransE"
,
nentity
=
entity_max
,
nrelation
=
relation_max
,
...
...
@@ -77,7 +76,6 @@ class GCN_GRU(nn.Module):
double_relation_embedding
=
True
,
evaluator
=
None
)
self
.
gc1
=
GraphConvolution
(
nfeat
,
nfeat
)
self
.
gc2
=
GraphConvolution
(
nfeat
,
nfeat
)
...
...
@@ -94,11 +92,13 @@ class GCN_GRU(nn.Module):
heads
=
triplets
[:,
0
]
relations
=
triplets
[:,
1
]
tails
=
triplets
[:,
2
]
return
(
self
.
entity_emb
.
weight
[
heads
]
+
self
.
relation_emb
.
weight
[
relations
]
-
self
.
entity_emb
.
weight
[
tails
]).
norm
(
p
=
1
,
dim
=
1
)
return
(
self
.
kge_model
.
entity_embedding
.
weight
[
heads
]
+
self
.
relation_emb
.
weight
[
relations
]
-
self
.
kge_model
.
entity_embedding
.
weight
[
tails
]).
norm
(
p
=
1
,
dim
=
1
)
def
TransE_forward
(
self
,
pos_triplet
,
neg_triplet
):
# -1 to avoid nan for OOV vector
self
.
entity_emb
.
weight
.
data
[:
-
1
,
:].
div_
(
self
.
entity_emb
.
weight
.
data
[:
-
1
,
:].
norm
(
p
=
2
,
dim
=
1
,
keepdim
=
True
))
self
.
kge_model
.
entity_emb
edding
.
weight
.
data
[:
-
1
,
:].
div_
(
self
.
kge_model
.
entity_emb
edding
.
weight
.
data
[:
-
1
,
:].
norm
(
p
=
2
,
dim
=
1
,
keepdim
=
True
))
pos_distance
=
self
.
distance
(
pos_triplet
)
neg_distance
=
self
.
distance
(
neg_triplet
)
...
...
@@ -109,14 +109,14 @@ class GCN_GRU(nn.Module):
def
forward_GCN
(
self
,
x
):
# GCN
out
=
F
.
relu
(
self
.
gc1
(
self
.
entity_emb
.
weight
))
out
=
F
.
relu
(
self
.
gc1
(
self
.
kge_model
.
entity_emb
edding
.
weight
))
out
=
self
.
gc2
(
out
)
out
=
F
.
log_softmax
(
out
,
dim
=
1
)[
x
]
return
out
def
forward
(
self
,
x
):
# GCN
out
=
F
.
relu
(
self
.
gc1
(
self
.
entity_emb
.
weight
))
out
=
F
.
relu
(
self
.
gc1
(
self
.
kge_model
.
entity_emb
edding
.
weight
))
out
=
self
.
gc2
(
out
)
out
=
F
.
log_softmax
(
out
,
dim
=
1
)[
x
].
reshape
(
1
,
1
,
-
1
)
...
...
@@ -134,9 +134,9 @@ class GRU(nn.Module):
relation_max
=
max
(
relation_vocab
.
values
())
self
.
n_hop_kg
=
pickle_load
(
f
'
{
config
.
preprocess_results_dir
}
/n_hop_kg.pkl'
)
self
.
entity_emb
=
torch
.
nn
.
Embedding
(
entity_max
,
nfeat
)
self
.
kge_model
.
entity_emb
edding
=
torch
.
nn
.
Embedding
(
entity_max
,
nfeat
)
uniform_range
=
6
/
np
.
sqrt
(
nfeat
)
self
.
entity_emb
.
weight
.
data
.
uniform_
(
-
uniform_range
,
uniform_range
)
self
.
kge_model
.
entity_emb
edding
.
weight
.
data
.
uniform_
(
-
uniform_range
,
uniform_range
)
self
.
relation_emb
=
torch
.
nn
.
Embedding
(
relation_max
,
nfeat
)
uniform_range
=
6
/
np
.
sqrt
(
nfeat
)
...
...
@@ -155,11 +155,11 @@ class GRU(nn.Module):
heads
=
triplets
[:,
0
]
relations
=
triplets
[:,
1
]
tails
=
triplets
[:,
2
]
return
(
self
.
entity_emb
.
weight
[
heads
]
+
self
.
relation_emb
.
weight
[
relations
]
-
self
.
entity_emb
.
weight
[
tails
]).
norm
(
p
=
1
,
dim
=
1
)
return
(
self
.
kge_model
.
entity_emb
edding
.
weight
[
heads
]
+
self
.
relation_emb
.
weight
[
relations
]
-
self
.
kge_model
.
entity_emb
edding
.
weight
[
tails
]).
norm
(
p
=
1
,
dim
=
1
)
def
TransE_forward
(
self
,
pos_triplet
,
neg_triplet
):
# -1 to avoid nan for OOV vector
self
.
entity_emb
.
weight
.
data
[:
-
1
,
:].
div_
(
self
.
entity_emb
.
weight
.
data
[:
-
1
,
:].
norm
(
p
=
2
,
dim
=
1
,
keepdim
=
True
))
self
.
kge_model
.
entity_emb
edding
.
weight
.
data
[:
-
1
,
:].
div_
(
self
.
kge_model
.
entity_emb
edding
.
weight
.
data
[:
-
1
,
:].
norm
(
p
=
2
,
dim
=
1
,
keepdim
=
True
))
pos_distance
=
self
.
distance
(
pos_triplet
)
neg_distance
=
self
.
distance
(
neg_triplet
)
...
...
@@ -170,13 +170,13 @@ class GRU(nn.Module):
def
forward_GCN
(
self
,
x
):
# GCN
out
=
F
.
relu
(
self
.
gc1
(
self
.
entity_emb
.
weight
))
out
=
F
.
relu
(
self
.
gc1
(
self
.
kge_model
.
entity_emb
edding
.
weight
))
out
=
self
.
gc2
(
out
)
out
=
F
.
log_softmax
(
out
,
dim
=
1
)[
x
]
return
out
def
forward
(
self
,
x
):
x
=
self
.
entity_emb
.
weight
[
x
]
x
=
self
.
kge_model
.
entity_emb
edding
.
weight
[
x
]
x
=
x
.
reshape
(
1
,
1
,
-
1
)
# GRU
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help