Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
ITMO-NSS-team
GOLEM
Commits
0d3ba3fb
Commit
0d3ba3fb
authored
1 year ago
by
Lyubov Yamshchikova
Browse files
Options
Download
Email Patches
Plain Diff
Experiment code
parent
c6696458
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
examples/molecule_search/experiment.py
+28
-10
examples/molecule_search/experiment.py
with
28 additions
and
10 deletions
+28
-10
examples/molecule_search/experiment.py
View file @
0d3ba3fb
...
...
@@ -9,6 +9,8 @@ from rdkit.Chem.rdchem import BondType
from
examples.molecule_search.mol_adapter
import
MolAdapter
from
examples.molecule_search.mol_advisor
import
MolChangeAdvisor
from
examples.molecule_search.mol_encoders
import
ECFP
,
RDKF
,
atom_pair
,
topological_torsion
,
mol_descriptors
,
Mol2Vec
,
\
MoleculeTransformer
from
examples.molecule_search.mol_graph
import
MolGraph
from
examples.molecule_search.mol_graph_parameters
import
MolGraphRequirements
from
examples.molecule_search.mol_mutations
import
CHEMICAL_MUTATIONS
...
...
@@ -45,9 +47,18 @@ def get_all_mol_metrics() -> Dict[str, Callable]:
'norm_log_p'
:
normalized_logp
}
return
metrics
encoders
=
{
'ECFP'
:
ECFP
,
'RDKF'
:
RDKF
,
'atom_pair'
:
atom_pair
,
'topological_torsion'
:
topological_torsion
,
'mol_descriptors'
:
mol_descriptors
,
'mol2vec'
:
Mol2Vec
(),
'transformer'
:
MoleculeTransformer
()}
def
molecule_search_setup
(
optimizer_cls
:
Type
[
GraphOptimizer
]
=
EvoGraphOptimizer
,
adaptive_kind
:
MutationAgentTypeEnum
=
MutationAgentTypeEnum
.
random
,
encoder
:
Optional
[
Callable
]
=
None
,
max_heavy_atoms
:
int
=
50
,
atom_types
:
Optional
[
List
[
str
]]
=
None
,
bond_types
:
Sequence
[
BondType
]
=
(
BondType
.
SINGLE
,
BondType
.
DOUBLE
,
BondType
.
TRIPLE
),
...
...
@@ -78,6 +89,7 @@ def molecule_search_setup(optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimize
mutation_types
=
CHEMICAL_MUTATIONS
,
crossover_types
=
[
CrossoverTypesEnum
.
none
],
adaptive_mutation_type
=
adaptive_kind
,
context_agent_type
=
encoders
[
encoder
]
)
graph_gen_params
=
GraphGenerationParams
(
adapter
=
MolAdapter
(),
...
...
@@ -145,6 +157,7 @@ def pretrain_agent(optimizer: EvoGraphOptimizer, objective: Objective, results_d
def
run_experiment
(
optimizer_setup
:
Callable
,
optimizer_cls
:
Type
[
GraphOptimizer
]
=
EvoGraphOptimizer
,
adaptive_kind
:
MutationAgentTypeEnum
=
MutationAgentTypeEnum
.
random
,
encoder
:
Optional
[
Callable
]
=
None
,
max_heavy_atoms
:
int
=
50
,
atom_types
:
Optional
[
List
[
str
]]
=
None
,
bond_types
:
Sequence
[
BondType
]
=
(
BondType
.
SINGLE
,
BondType
.
DOUBLE
,
BondType
.
TRIPLE
),
...
...
@@ -161,7 +174,7 @@ def run_experiment(optimizer_setup: Callable,
metrics
=
metrics
or
[
'qed_score'
]
optimizer_id
=
optimizer_cls
.
__name__
.
lower
()[:
3
]
experiment_id
=
f
'Experiment [optimizer=
{
optimizer_id
}
metrics=
{
", "
.
join
(
metrics
)
}
pop_size=
{
pop_size
}
]'
exp_name
=
f
'
{
optimizer_id
}
_
{
adaptive_kind
.
value
}
_popsize
{
pop_size
}
_
min
{
trial_
timeout
}
_
{
"_"
.
join
(
metrics
)
}
'
exp_name
=
f
'
{
adaptive_kind
.
value
}
_
{
encoder
}
_
popsize
{
pop_size
}
_
iter
{
trial_
iterations
}
_
{
"_"
.
join
(
metrics
)
}
'
atom_types
=
atom_types
or
[
'C'
,
'N'
,
'O'
,
'F'
,
'P'
,
'S'
,
'Cl'
,
'Br'
]
trial_results
=
[]
...
...
@@ -171,6 +184,7 @@ def run_experiment(optimizer_setup: Callable,
for
trial
in
range
(
num_trials
):
optimizer
,
objective
=
optimizer_setup
(
optimizer_cls
,
adaptive_kind
,
encoder
,
max_heavy_atoms
,
atom_types
,
bond_types
,
...
...
@@ -224,12 +238,16 @@ def plot_experiment_comparison(experiment_ids: Sequence[str], metric_id: int = 0
if
__name__
==
'__main__'
:
run_experiment
(
molecule_search_setup
,
adaptive_kind
=
MutationAgentTypeEnum
.
bandit
,
max_heavy_atoms
=
38
,
trial_timeout
=
6
,
pop_size
=
50
,
visualize
=
True
,
num_trials
=
5
,
pretrain_dir
=
os
.
path
.
join
(
project_root
(),
'examples'
,
'molecule_search'
,
'histories'
)
)
metrics
=
[
'qed_score'
,
'cl_score'
,
'penalized_logp'
]
encoders
=
encoders
.
keys
()
for
metric
in
metrics
:
for
encoder
in
encoders
:
run_experiment
(
molecule_search_setup
,
adaptive_kind
=
MutationAgentTypeEnum
.
contextual_bandit
,
encoder
=
encoder
,
max_heavy_atoms
=
38
,
trial_iterations
=
150
,
pop_size
=
50
,
visualize
=
False
,
num_trials
=
20
)
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment