diff --git a/examples/molecule_search/experiment.py b/examples/molecule_search/experiment.py index 79cea2bbe9e8fab48111a07e1b6ffb2e6309935c..52b94715b7d4b99021ee4b0f197606a3055e2087 100644 --- a/examples/molecule_search/experiment.py +++ b/examples/molecule_search/experiment.py @@ -9,6 +9,8 @@ from rdkit.Chem.rdchem import BondType from examples.molecule_search.mol_adapter import MolAdapter from examples.molecule_search.mol_advisor import MolChangeAdvisor +from examples.molecule_search.mol_encoders import ECFP, RDKF, atom_pair, topological_torsion, mol_descriptors, Mol2Vec, \ + MoleculeTransformer from examples.molecule_search.mol_graph import MolGraph from examples.molecule_search.mol_graph_parameters import MolGraphRequirements from examples.molecule_search.mol_mutations import CHEMICAL_MUTATIONS @@ -45,9 +47,18 @@ def get_all_mol_metrics() -> Dict[str, Callable]: 'norm_log_p': normalized_logp} return metrics +encoders = {'ECFP': ECFP, + 'RDKF': RDKF, + 'atom_pair': atom_pair, + 'topological_torsion': topological_torsion, + 'mol_descriptors': mol_descriptors, + 'mol2vec': Mol2Vec(), + 'transformer': MoleculeTransformer()} + def molecule_search_setup(optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer, adaptive_kind: MutationAgentTypeEnum = MutationAgentTypeEnum.random, + encoder: Optional[Callable] = None, max_heavy_atoms: int = 50, atom_types: Optional[List[str]] = None, bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE), @@ -78,6 +89,7 @@ def molecule_search_setup(optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimize mutation_types=CHEMICAL_MUTATIONS, crossover_types=[CrossoverTypesEnum.none], adaptive_mutation_type=adaptive_kind, + context_agent_type=encoders[encoder] ) graph_gen_params = GraphGenerationParams( adapter=MolAdapter(), @@ -145,6 +157,7 @@ def pretrain_agent(optimizer: EvoGraphOptimizer, objective: Objective, results_d def run_experiment(optimizer_setup: Callable, optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer, adaptive_kind: MutationAgentTypeEnum = MutationAgentTypeEnum.random, + encoder: Optional[Callable] = None, max_heavy_atoms: int = 50, atom_types: Optional[List[str]] = None, bond_types: Sequence[BondType] = (BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE), @@ -161,7 +174,7 @@ def run_experiment(optimizer_setup: Callable, metrics = metrics or ['qed_score'] optimizer_id = optimizer_cls.__name__.lower()[:3] experiment_id = f'Experiment [optimizer={optimizer_id} metrics={", ".join(metrics)} pop_size={pop_size}]' - exp_name = f'{optimizer_id}_{adaptive_kind.value}_popsize{pop_size}_min{trial_timeout}_{"_".join(metrics)}' + exp_name = f'{adaptive_kind.value}_{encoder}_popsize{pop_size}_iter{trial_iterations}_{"_".join(metrics)}' atom_types = atom_types or ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br'] trial_results = [] @@ -171,6 +184,7 @@ def run_experiment(optimizer_setup: Callable, for trial in range(num_trials): optimizer, objective = optimizer_setup(optimizer_cls, adaptive_kind, + encoder, max_heavy_atoms, atom_types, bond_types, @@ -224,12 +238,16 @@ def plot_experiment_comparison(experiment_ids: Sequence[str], metric_id: int = 0 if __name__ == '__main__': - run_experiment(molecule_search_setup, - adaptive_kind=MutationAgentTypeEnum.bandit, - max_heavy_atoms=38, - trial_timeout=6, - pop_size=50, - visualize=True, - num_trials=5, - pretrain_dir=os.path.join(project_root(), 'examples', 'molecule_search', 'histories') - ) + metrics = ['qed_score', 'cl_score', 'penalized_logp'] + encoders = encoders.keys() + for metric in metrics: + for encoder in encoders: + run_experiment(molecule_search_setup, + adaptive_kind=MutationAgentTypeEnum.contextual_bandit, + encoder=encoder, + max_heavy_atoms=38, + trial_iterations=150, + pop_size=50, + visualize=False, + num_trials=20 + )