Commit 69c6eb1a authored by Mikhail Chaikovskii's avatar Mikhail Chaikovskii
Browse files

fix merge

Showing with 31 additions and 9 deletions
+31 -9
......@@ -58,7 +58,7 @@ control_optimizer = ControlOptimizer(cfg)
# %% Init mcts parameters
# Hyperparameters mctss
iteration_limit = 20
iteration_limit = 50
# Initialize MCTS
searcher = mcts.mcts(iterationLimit=iteration_limit)
......
......@@ -185,8 +185,8 @@ def init_extension_rules():
rule_vocab.create_rule("InitMechanism_3_R_A", ["ROOT"], ["F", "SML", "SMRPA","SMRMA","EM","EM","EM"], 0 , 0,[(0,1),(0,2),(0,3),(1,4),(2,5),(3,6)])
rule_vocab.create_rule("InitMechanism_3_L", ["ROOT"], ["F", "SMLP","SMLM", "SMR","EM","EM","EM"], 0 , 0, [(0,1),(0,2),(0,3),(1,4),(2,5),(3,6)])
rule_vocab.create_rule("InitMechanism_3_L_A", ["ROOT"], ["F", "SMLPA","SMLMA", "SMR","EM","EM","EM"], 0 , 0, [(0,1),(0,2),(0,3),(1,4),(2,5),(3,6)])
#rule_vocab.create_rule("InitMechanism_4", ["ROOT"], ["F", "SMLP","SMLM", "SMRP","SMRM","EM","EM","EM","EM"], 0 , 0, [(0,1),(0,2),(0,3),(0,4),(1,5),(2,6),(3,7),(4,8)])
#rule_vocab.create_rule("InitMechanism_4_A", ["ROOT"], ["F", "SMLPA","SMLMA", "SMRPA","SMRMA","EM","EM","EM","EM"], 0 , 0, [(0,1),(0,2),(0,3),(0,4),(1,5),(2,6),(3,7),(4,8)])
rule_vocab.create_rule("InitMechanism_4", ["ROOT"], ["F", "SMLP","SMLM", "SMRP","SMRM","EM","EM","EM","EM"], 0 , 0, [(0,1),(0,2),(0,3),(0,4),(1,5),(2,6),(3,7),(4,8)])
rule_vocab.create_rule("InitMechanism_4_A", ["ROOT"], ["F", "SMLPA","SMLMA", "SMRPA","SMRMA","EM","EM","EM","EM"], 0 , 0, [(0,1),(0,2),(0,3),(0,4),(1,5),(2,6),(3,7),(4,8)])
rule_vocab.create_rule("FingerUpper", ["EM"], ["J", "L","EM"], 0 , 2, [(0,1),(1, 2)])
rule_vocab.create_rule("TerminalFlat1", ["F"], ["F1"], 0 , 0)
......
......@@ -214,16 +214,28 @@ class GraphVocabularyEnvironment(GraphEnvironment):
return list(possible_actions)
def getReward(self):
if len(seen_graphs)>0:
seen_graphs_t = list(zip(*seen_graphs))
i = 0
for graph in seen_graphs_t[0]:
if graph == self.graph:
self.reward = seen_graphs_t[1][i]
self.movments_trajectory = seen_graphs_t[2][i]
reporter.add_reward(self.state, self.reward, self.movments_trajectory)
print('seen reward:', self.reward)
return self.reward
i += 1
result_optimizer = self.optimizer.start_optimisation(self.graph)
self.reward = -result_optimizer[0]
self.movments_trajectory = result_optimizer[1]
seen_graphs.append([deepcopy(self.graph), self.reward, deepcopy(self.movments_trajectory)])
reporter.add_reward(self.state, self.reward, self.movments_trajectory)
if self.reward > reporter.best_reward:
reporter.best_reward = self.reward
reporter.best_control = self.movments_trajectory
reporter.best_state = self.state
print(self.reward)
return self.reward
......
......@@ -323,9 +323,19 @@ class GraphGrammar(nx.DiGraph):
return list(dfs_preorder_nodes(self, self.get_root_id()))
def __eq__(self, __o) -> bool:
if isinstance(__o, GraphGrammar):
is_node_eq = __o.nodes == self.nodes
is_edge_eq = __o.edges == self.edges
return is_edge_eq and is_node_eq
def __eq__(self, __rhs) -> bool:
if isinstance(__rhs, GraphGrammar):
self_dfs_paths = self.graph_partition_dfs()
self_dfs_paths_lbl = []
for path in self_dfs_paths:
self_dfs_paths_lbl.append([self.get_node_by_id(x).label for x in path])
self_dfs_paths_lbl.sort(key = lambda x: "".join(x))
rhs_dfs_paths = __rhs.graph_partition_dfs()
rhs_dfs_paths_lbl = []
for path in rhs_dfs_paths:
rhs_dfs_paths_lbl.append([__rhs.get_node_by_id(x).label for x in path])
rhs_dfs_paths_lbl.sort(key = lambda x: "".join(x))
return self_dfs_paths_lbl == rhs_dfs_paths_lbl
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment