Commit de18d3a2 authored by Egor Shikov's avatar Egor Shikov
Browse files

fixes

parent fdc10cb2
Showing with 11 additions and 13 deletions
+11 -13
%% Cell type:code id:1e7ad0bb-e486-4134-9aab-1dbcf9bb4959 tags:
``` python
%load_ext autoreload
%autoreload 2
%matplotlib inline
from tqdm import tqdm
import torch
import sys
sys.path.append('../')
from redkg.dataloader import TrainDataset, get_info
from redkg.kge import KGEModel
from redkg.config import Config
from redkg.utils import AttributeDict
from redkg.train import train_kge_model
from ogb.linkproppred import LinkPropPredDataset, Evaluator
```
%%%% Output: stream
/data/home/egor/.conda/envs/myenv/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
%% Cell type:markdown id:a15fc0ba-cd99-4326-9047-de9d8f8d8623 tags:
## BioKG
<img src="biokg_logo2.png" width="400">
The **biokg** dataset is a Knowledge Graph (KG), that was created using data from a large number of biomedical data repositories. It contains 5 types of entities: diseases (10,687 nodes), proteins (17,499), drugs (10,533 nodes), side effects (9,969 nodes), and protein functions (45,085 nodes). There are 51 types of directed relations connecting two types of entities, including 38 kinds of drug-drug interactions, 8 kinds of protein-protein interaction, as well as drug-protein, drug-side effect, function-function relations. All relations are modeled as directed edges, among which the relations connecting the same entity types (e.g., protein-protein, drug-drug, function-function) are always symmetric, i.e., the edges are bi-directional.
This dataset is relevant to both biomedical and fundamental ML research. On the biomedical side, the dataset allows us to get better insights into human biology and generate predictions that can guide downstream biomedical research. On the fundamental ML side, the dataset presents challenges in handling a noisy, incomplete KG with possible contradictory observations. This is because the ogbl-biokg dataset involves heterogeneous interactions that span from the molecular scale (e.g., protein-protein interactions within a cell) to whole populations (e.g., reports of unwanted side effects experienced by patients in a particular country). Further, triplets in the KG come from sources with a variety of confidence levels, including experimental readouts, human-curated annotations, and automatically extracted metadata.
%% Cell type:code id:bf1328f3-ce9d-4853-85d7-0b1f2fe49369 tags:
``` python
dataset_name = 'ogbl-biokg'
dataset = LinkPropPredDataset(name = dataset_name, root = '../data')
split_edge = dataset.get_edge_split()
train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]
info = get_info(dataset, train_triples)
```
%% Cell type:code id:0f44c042-455a-4c11-9bac-2c09d96959d3 tags:
``` python
evaluator = Evaluator(name = dataset_name)
kge_model = KGEModel(
model_name="TransE",
nentity=info['nentity'],
nrelation=info['nrelation'],
hidden_dim=128,
gamma=12.0,
double_entity_embedding=True,
double_relation_embedding=True,
evaluator=evaluator
)
```
%% Cell type:markdown id:945f9c89-4035-456b-bcc0-9e6bdb0d5ee9 tags:
## Training model
%% Cell type:code id:3c28f2cf-e03c-464f-8357-6f8296eec1da tags:
``` python
train_pars = AttributeDict()
train_pars.cuda = False
train_pars.uni_weight = True
train_pars.negative_adversarial_sampling = True
train_pars.regularization = 0.0
train_pars.adversarial_temperature = 1.0
train_pars.train_batch_size = 128
train_pars.negative_sample_size = 128
train_pars.learning_rate = 0.001
train_pars.cpu_num = 10
train_pars.negative_mode = "full"
test_params = AttributeDict()
test_params.cuda = False
test_params.neg_size_eval_train = 500
test_params.test_log_steps = 1000
test_params.test_batch_size = 128
test_params.cpu_num = 10
test_params.nentity = info['nentity']
test_params.nrelation = info['nrelation']
train_pars.neg_size_eval_train = 500
train_pars.test_log_steps = 1000
train_pars.test_batch_size = 128
train_pars.nentity = info['nentity']
train_pars.nrelation = info['nrelation']
train_pars.do_test = True
```
%% Cell type:code id:07655037-6675-4393-83bd-aa2761c1b70f tags:
``` python
training_logs, test_logs = train_kge_model(kge_model, train_pars, test_params, info, train_triples, valid_triples, test_triples)
```
%%%% Output: stream
Training...
%% Cell type:markdown id:aa5defab tags:
## Results visualization
%% Cell type:code id:2daa7f04-2207-479f-93d0-a7357f5264ed tags:
``` python
test_logs
```
%%%% Output: execute_result
[{'hits@1_list': 0.0033673858270049095,
'hits@3_list': 0.009205824695527554,
'hits@10_list': 0.029947325587272644,
'mrr_list': 0.0193993728607893},
{'hits@1_list': 0.003305993042886257,
'hits@3_list': 0.009122944436967373,
'hits@10_list': 0.029704824090003967,
'mrr_list': 0.01936676912009716},
{'hits@1_list': 0.00337966438382864,
'hits@3_list': 0.009211963973939419,
'hits@10_list': 0.029950395226478577,
'mrr_list': 0.019408494234085083},
{'hits@1_list': 0.0033213412389159203,
'hits@3_list': 0.009098388254642487,
'hits@10_list': 0.0294254869222641,
'mrr_list': 0.019350754097104073},
{'hits@1_list': 0.0032752968836575747,
'hits@3_list': 0.009138292632997036,
'hits@10_list': 0.02936716563999653,
'mrr_list': 0.01935487426817417},
{'hits@1_list': 0.003315201960504055,
'hits@3_list': 0.009098388254642487,
'hits@10_list': 0.029385583475232124,
'mrr_list': 0.01936601661145687},
{'hits@1_list': 0.0033029234036803246,
'hits@3_list': 0.008905000984668732,
'hits@10_list': 0.029222892597317696,
'mrr_list': 0.01932138204574585},
{'hits@1_list': 0.00330906268209219,
'hits@3_list': 0.008883513510227203,
'hits@10_list': 0.029272006824612617,
'mrr_list': 0.019309179857373238},
{'hits@1_list': 0.0032691576052457094,
'hits@3_list': 0.008951045572757721,
'hits@10_list': 0.029207544401288033,
'mrr_list': 0.019305111840367317},
{'hits@1_list': 0.0032599486876279116,
'hits@3_list': 0.00898481160402298,
'hits@10_list': 0.029222892597317696,
'mrr_list': 0.019316496327519417}]
%% Cell type:code id:a81335f2-7cf8-4b3c-8cf6-2e255234bfed tags:
``` python
training_logs
```
%%%% Output: execute_result
[{'positive_sample_loss': 5.781314849853516,
'negative_sample_loss': 0.14729920029640198,
'loss': 2.9643070697784424},
{'positive_sample_loss': 5.593881607055664,
'negative_sample_loss': 0.04241343215107918,
'loss': 2.8181474208831787},
{'positive_sample_loss': 5.6944122314453125,
'negative_sample_loss': 0.059163790196180344,
'loss': 2.8767879009246826},
{'positive_sample_loss': 5.279773235321045,
'negative_sample_loss': 0.09559981524944305,
'loss': 2.6876864433288574},
{'positive_sample_loss': 5.365677356719971,
'negative_sample_loss': 0.18596382439136505,
'loss': 2.77582049369812},
{'positive_sample_loss': 5.616865634918213,
'negative_sample_loss': 0.24613013863563538,
'loss': 2.931497812271118},
{'positive_sample_loss': 5.645710468292236,
'negative_sample_loss': 0.15487168729305267,
'loss': 2.9002909660339355},
{'positive_sample_loss': 5.437168598175049,
'negative_sample_loss': 0.08542022109031677,
'loss': 2.761294364929199},
{'positive_sample_loss': 5.42050838470459,
'negative_sample_loss': 0.14526036381721497,
'loss': 2.782884359359741},
{'positive_sample_loss': 5.240655899047852,
'negative_sample_loss': 0.22022339701652527,
'loss': 2.7304396629333496}]
%% Cell type:code id:1f9cb02c-fb60-430b-bfd0-de0a70eb75f8 tags:
``` python
```
%% Cell type:code id:8b199d5a-09ac-4284-b0e7-8fd0894b4f11 tags:
``` python
```
%% Cell type:code id:bd925e7e-7014-4183-87bf-93fa154d6b64 tags:
``` python
```
%% Cell type:code id:23bb0d87-746d-4731-b3b5-85159ec45651 tags:
``` python
```
......
File moved
......@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from redkg.kge import KGEModel
from utils import pickle_load
from redkg.utils import pickle_load
class GraphConvolution(nn.Module):
......
......@@ -6,7 +6,7 @@ import torch.optim as optim
from env import Config, Simulator
from model import GCN_GRU, Net
def train_kge_model(kge_model, train_pars, test_params, info, train_triples, valid_triples, test_triples, max_steps = 10):
def train_kge_model(kge_model, train_pars, info, train_triples, valid_triples, test_triples, max_steps = 10):
print('Training...')
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, kge_model.parameters()),
......@@ -44,8 +44,9 @@ def train_kge_model(kge_model, train_pars, test_params, info, train_triples, val
log = kge_model.train_step(kge_model, optimizer, train_iterator, train_pars)
training_logs.append(log)
metrics = kge_model.test_step(kge_model, valid_triples, test_params, info['entity_dict'])
test_logs.append(metrics)
if train_pars.do_test:
metrics = kge_model.test_step(kge_model, valid_triples, train_pars, info['entity_dict'])
test_logs.append(metrics)
return training_logs, test_logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment