Commit 0d3ba3fb authored by Lyubov Yamshchikova's avatar Lyubov Yamshchikova
Browse files

Experiment code

parent c6696458
Showing with 28 additions and 10 deletions
+28 -10
......@@ -9,6 +9,8 @@ from rdkit.Chem.rdchem import BondType
from examples.molecule_search.mol_adapter import MolAdapter
from examples.molecule_search.mol_advisor import MolChangeAdvisor
from examples.molecule_search.mol_encoders import ECFP, RDKF, atom_pair, topological_torsion, mol_descriptors, Mol2Vec, \
MoleculeTransformer
from examples.molecule_search.mol_graph import MolGraph
from examples.molecule_search.mol_graph_parameters import MolGraphRequirements
from examples.molecule_search.mol_mutations import CHEMICAL_MUTATIONS
......@@ -45,9 +47,18 @@ def get_all_mol_metrics() -> Dict[str, Callable]:
'norm_log_p': normalized_logp}
return metrics
encoders = {'ECFP': ECFP,
'RDKF': RDKF,
'atom_pair': atom_pair,
'topological_torsion': topological_torsion,
'mol_descriptors': mol_descriptors,
'mol2vec': Mol2Vec(),
'transformer': MoleculeTransformer()}
def molecule_search_setup(optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer,
adaptive_kind: MutationAgentTypeEnum = MutationAgentTypeEnum.random,
encoder: Optional[Callable] = None,
max_heavy_atoms: int = 50,
atom_types: Optional[List[str]] = None,
bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE),
......@@ -78,6 +89,7 @@ def molecule_search_setup(optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimize
mutation_types=CHEMICAL_MUTATIONS,
crossover_types=[CrossoverTypesEnum.none],
adaptive_mutation_type=adaptive_kind,
context_agent_type=encoders[encoder]
)
graph_gen_params = GraphGenerationParams(
adapter=MolAdapter(),
......@@ -145,6 +157,7 @@ def pretrain_agent(optimizer: EvoGraphOptimizer, objective: Objective, results_d
def run_experiment(optimizer_setup: Callable,
optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer,
adaptive_kind: MutationAgentTypeEnum = MutationAgentTypeEnum.random,
encoder: Optional[Callable] = None,
max_heavy_atoms: int = 50,
atom_types: Optional[List[str]] = None,
bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE),
......@@ -161,7 +174,7 @@ def run_experiment(optimizer_setup: Callable,
metrics = metrics or ['qed_score']
optimizer_id = optimizer_cls.__name__.lower()[:3]
experiment_id = f'Experiment [optimizer={optimizer_id} metrics={", ".join(metrics)} pop_size={pop_size}]'
exp_name = f'{optimizer_id}_{adaptive_kind.value}_popsize{pop_size}_min{trial_timeout}_{"_".join(metrics)}'
exp_name = f'{adaptive_kind.value}_{encoder}_popsize{pop_size}_iter{trial_iterations}_{"_".join(metrics)}'
atom_types = atom_types or ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br']
trial_results = []
......@@ -171,6 +184,7 @@ def run_experiment(optimizer_setup: Callable,
for trial in range(num_trials):
optimizer, objective = optimizer_setup(optimizer_cls,
adaptive_kind,
encoder,
max_heavy_atoms,
atom_types,
bond_types,
......@@ -224,12 +238,16 @@ def plot_experiment_comparison(experiment_ids: Sequence[str], metric_id: int = 0
if __name__ == '__main__':
run_experiment(molecule_search_setup,
adaptive_kind=MutationAgentTypeEnum.bandit,
max_heavy_atoms=38,
trial_timeout=6,
pop_size=50,
visualize=True,
num_trials=5,
pretrain_dir=os.path.join(project_root(), 'examples', 'molecule_search', 'histories')
)
metrics = ['qed_score', 'cl_score', 'penalized_logp']
encoders = encoders.keys()
for metric in metrics:
for encoder in encoders:
run_experiment(molecule_search_setup,
adaptive_kind=MutationAgentTypeEnum.contextual_bandit,
encoder=encoder,
max_heavy_atoms=38,
trial_iterations=150,
pop_size=50,
visualize=False,
num_trials=20
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment