Commit 71ddd163 authored by K Z's avatar K Z
Browse files

Merge branch 'article' of https://github.com/aimclub/rostok into article

No related merge requests found
Showing with 213 additions and 101 deletions
+213 -101
......@@ -4,10 +4,13 @@ import pickle
import time
import numpy as np
import torch
from rostok.graph_generators.graph_game import GraphGrammarGame as Game
from rostok.neural_network.wrappers import AlphaZeroWrapper
from MCTS import MCTS
from Coach import Coach
from obj_grasp.objects import get_obj_easy_box
......@@ -17,15 +20,34 @@ from utils import dotdict
import hyperparameters as hp
import optmizers_config
from rule_sets import rule_extention, rule_extention_graph
from rule_sets.ruleset_old_style_graph import create_rules
log = logging.getLogger(__name__)
CURRENT_PLAYER = 1
mcts_args = dotdict({
"numMCTSSims" : 1000,
"cpuct" : 1/np.sqrt(2),
"epochs": 10
coach_args = dotdict({
"numMCTSSims" : 10,
"cpuct" : 5,
"train_offline_epochs": 50,
"train_online_epochs":10,
"num_learn_epochs": 100,
"tempThreshold":4,
"maxlenOfQueue":20000,
"numItersForTrainExamplesHistory":20,
"offline_iters":5,
"online_iters":10,
"update_weights":2,
"checkpoint": "./temp/"
})
args_train = dotdict({
"batch_size": 2,
"cuda": torch.cuda.is_available(),
"nhid": 512,
"pooling_ratio": 0.7,
"dropout_ratio": 0.3
})
def executeEpisode(game, mcts, temp_threshold):
......@@ -42,33 +64,38 @@ def executeEpisode(game, mcts, temp_threshold):
action = np.random.choice(len(pi), p=pi)
graph, __ = game.getNextState(graph, CURRENT_PLAYER, action)
trainExamples.append((graph, pi, None))
r = game.getGameEnded(graph, CURRENT_PLAYER)
trainExamples.append([graph, CURRENT_PLAYER, pi, r])
if r != 0:
return tuple(trainExamples)
return [(x[0],x[1], r) for x in trainExamples]
def main():
def preconfigure():
# List of weights for each criterion (force, time, COG)
WEIGHT = hp.CRITERION_WEIGHTS
# At least 20 iterations are needed for good results
rule_vocabul = deepcopy(rule_extention_graph.rule_vocab)
cfg = optmizers_config.get_cfg_graph(rule_extention_graph.torque_dict)
rule_vocabul, torque_dict = create_rules()
cfg = optmizers_config.get_cfg_graph(torque_dict)
cfg.get_rgab_object_callback = get_obj_easy_box
control_optimizer = ControlOptimizer(cfg)
graph_game = Game(rule_vocabul, control_optimizer, hp.MAX_NUMBER_RULES)
return graph_game
examples = [[] for __ in range(mcts_args.epochs)]
for epoch in range(mcts_args.epochs):
def main():
graph_game = preconfigure()
examples = [[] for __ in range(coach_args.epochs)]
for epoch in range(coach_args.epochs):
log.info('Loading %s ...', Game.__name__)
log.info('Loading %s ...', AlphaZeroWrapper.__name__)
nnet = AlphaZeroWrapper(graph_game)
mcts_searcher = MCTS(graph_game, nnet, mcts_args)
mcts_searcher = MCTS(graph_game, nnet, coach_args)
start = time.time()
examples[epoch] = executeEpisode(graph_game, mcts_searcher, 15)
......@@ -76,21 +103,39 @@ def main():
print(f"epoch: {epoch:3}, time: {ex}")
struct = time.localtime(time.time())
str_time = time.strftime('%d%m%Y%H%M', struct)
name_file = f"train_mcts_data_{mcts_args.epochs}e_{mcts_args.cpuct}c_{mcts_args}mcts_{hp.MAX_NUMBER_RULES}rule_{str_time}t.pickle"
name_file = f"train_mcts_data_{coach_args.epochs}e_{coach_args.cpuct:f}c_{coach_args.numMCTSSims}mcts_{hp.MAX_NUMBER_RULES}rule_{str_time}t.pickle"
with open(name_file, "wb+") as file:
pickle.dump(examples, file)
def load_train(path_to_data):
def load_train_data(path_to_data):
with open(path_to_data, "rb") as input_file:
train = pickle.load(input_file)
print(train)
formatting_train_data = [[]+list(example) for episode in train for example in episode]
formatting_train_data = list(map(lambda x: (x[0], x[2], x[3]), formatting_train_data))
return train
def pretrain_model(path_to_data, list_name_data):
for name_data in list_name_data:
train_examples = load_train_data(path_to_data+name_data+".pickle")
graph_game = preconfigure()
log.info('Loading %s ...', AlphaZeroWrapper.__name__)
nnet = AlphaZeroWrapper(graph_game, args_train)
nnet.train(train_examples)
nnet.save_checkpoint()
if __name__ == "__main__":
initial_time = time.time()
main()
final_ex = time.time() - initial_time
print(f"full_time :{final_ex}")
# for idx in range(10):
# initial_time = time.time()
# main()
# final_ex = time.time() - initial_time
# print(f"train {idx} index, full_time: {final_ex}")
load_train("train_data_10e_1000mcts_2302.pickle")
\ No newline at end of file
# load_train("train_data_10e_1000mcts_2302.pickle")
game_graph = preconfigure()
graph_nnet = AlphaZeroWrapper(game_graph, args_train)
coacher = Coach(game_graph, graph_nnet, coach_args)
coacher.learn()
\ No newline at end of file
MAX_NUMBER_RULES = 20
MAX_NUMBER_RULES = 6
BASE_ITERATION_LIMIT = 500
......@@ -7,9 +7,10 @@ ITERATION_REDUCTION_TIME = 0.7
CRITERION_WEIGHTS = [5, 5, 2, 5]
CONTROL_OPTIMIZATION_ITERATION = 20
TIME_OPTIMIZATION = 100
TIME_STEP_SIMULATION = 0.0025
TIME_SIMULATION = 3
TIME_SIMULATION = 1
FLAG_TIME_NO_CONTACT = 1.5
FLAG_TIME_SLIPOUT = 0.4
\ No newline at end of file
......@@ -26,6 +26,7 @@ def get_cfg_standart():
# Init configuration of control optimizing
cfg = ConfigVectorJoints()
cfg.bound = (0, 15)
cfg.time_optimization = hp.TIME_OPTIMIZATION
cfg.iters = hp.CONTROL_OPTIMIZATION_ITERATION
cfg.time_step = hp.TIME_STEP_SIMULATION
cfg.time_sim = hp.TIME_SIMULATION
......@@ -47,6 +48,7 @@ def get_cfg_graph(torque_dict: dict[Node, float]):
WEIGHT = hp.CRITERION_WEIGHTS
# Init configuration of control optimizing
cfg = ConfigGraphControl()
cfg.time_optimization = hp.TIME_OPTIMIZATION
cfg.time_step = hp.TIME_STEP_SIMULATION
cfg.time_sim = hp.TIME_SIMULATION
cfg.flags = [FlagMaxTime(cfg.time_sim),
......@@ -67,6 +69,7 @@ def get_cfg_standart_anealing():
cfg = ConfigVectorJoints()
cfg.optimizer_scipy = partial(dual_annealing)
cfg.bound = (0, 15)
cfg.time_optimization = hp.TIME_OPTIMIZATION
cfg.iters = hp.CONTROL_OPTIMIZATION_ITERATION
cfg.time_step = hp.TIME_STEP_SIMULATION
cfg.time_sim = hp.TIME_SIMULATION
......@@ -89,6 +92,7 @@ def get_cfg_standart_step():
# Init configuration of control optimizing
CONST_TORQUE = 12
cfg = ConfigVectorJoints()
cfg.time_optimization = hp.TIME_OPTIMIZATION
cfg.bound = (0, hp.TIME_SIMULATION / 3)
cfg.iters = hp.CONTROL_OPTIMIZATION_ITERATION
cfg.time_step = hp.TIME_STEP_SIMULATION
......
......@@ -86,6 +86,7 @@ node_vocab.add_node(ROOT)
node_vocab.create_node("L")
node_vocab.create_node("F")
node_vocab.create_node("M")
node_vocab.create_node("J")
node_vocab.create_node("EF")
node_vocab.create_node("EM")
node_vocab.create_node("SML")
......@@ -231,7 +232,7 @@ rule_vocab.create_rule("InitMechanism_4", ["ROOT"],
rule_vocab.create_rule("InitMechanism_4_A", ["ROOT"],
["F", "SMLPA", "SMLMA", "SMRPA", "SMRMA", "EM", "EM", "EM", "EM"], 0, 0,
[(0, 1), (0, 2), (0, 3), (0, 4), (1, 5), (2, 6), (3, 7), (4, 8)])
rule_vocab.create_rule("FingerUpper", ["EM"], ["J1", "L", "EM"], 0, 2, [(0, 1), (1, 2)])
rule_vocab.create_rule("FingerUpper", ["EM"], ["J", "L", "EM"], 0, 2, [(0, 1), (1, 2)])
rule_vocab.create_rule("TerminalFlat1", ["F"], ["F1"], 0, 0)
rule_vocab.create_rule("TerminalFlat2", ["F"], ["F2"], 0, 0)
......
......@@ -26,7 +26,6 @@ class GraphGrammarGame(Game):
self.id2rule: dict[int, str] = {t[0]: t[1] for t in list(enumerate(sorted_name_rule))}
self.rule2id: dict[str, int] = {t[1]: t[0] for t in self.id2rule.items()}
self.max_no_terminal_rules: int = max_no_terminal_rules
self.counter_actions: int = 0
self.control_optimization = control_optimization
self.terminal_graphs: dict[list[list[str]], tuple(float, list[list[float]])] = {}
......@@ -43,8 +42,6 @@ class GraphGrammarGame(Game):
rule = self.rule_vocabulary.rule_dict[name_rule]
next_state_graph = deepcopy(graph)
next_state_graph.apply_rule(rule)
if name_rule in self.rule_vocabulary.rules_nonterminal_node_set:
self.counter_actions += 1
return (next_state_graph, -player)
......@@ -53,7 +50,7 @@ class GraphGrammarGame(Game):
def getValidMoves(self, graph: GraphGrammar, player: int):
if self.counter_actions < self.max_no_terminal_rules:
if graph.counter_nonterminal_rules < self.max_no_terminal_rules:
possible_rules_name = self.rule_vocabulary.get_list_of_applicable_rules(graph)
else:
possible_rules_name = self.rule_vocabulary.get_list_of_applicable_terminal_rules(graph)
......@@ -68,26 +65,23 @@ class GraphGrammarGame(Game):
def getGameEnded(self, graph: GraphGrammar, player):
if len(self.rule_vocabulary.get_list_of_applicable_terminal_rules(graph)) != 0:
return 0
if (len(self.rule_vocabulary.get_list_of_applicable_nonterminal_rules(graph)) and
self.counter_actions < self.max_no_terminal_rules):
return 0
flatten_graph = self.stringRepresentation(graph)
self.counter_actions = 0
if flatten_graph in set(self.terminal_graphs.keys()):
reward, movments_trajectory = self.terminal_graphs[flatten_graph]
terminal_nodes = [node[1]["Node"].is_terminal for node in graph.nodes.items()]
if sum(terminal_nodes) == len(terminal_nodes):
flatten_graph = self.stringRepresentation(graph)
if flatten_graph in set(self.terminal_graphs.keys()):
reward, movments_trajectory = self.terminal_graphs[flatten_graph]
else:
result_optimizer = self.control_optimization.start_optimisation(graph)
reward = -result_optimizer[0]
if reward == 0:
reward = 0.01
movments_trajectory = result_optimizer[1]
self.terminal_graphs[flatten_graph] = (reward, movments_trajectory)
return reward
else:
result_optimizer = self.control_optimization.start_optimisation(graph)
reward = -result_optimizer[0]
if reward == 0:
reward = 0.01
movments_trajectory = result_optimizer[1]
self.terminal_graphs[flatten_graph] = (reward, movments_trajectory)
return reward
return 0
def stringRepresentation(self, graph: GraphGrammar):
flatten_graph, __ = self._converter.flatting_sorted_graph(graph)
......
......@@ -109,6 +109,7 @@ class GraphGrammar(nx.DiGraph):
super().__init__(**attr)
self.__uniq_id_counter = -1
self.add_node(self._get_uniq_id(), Node=ROOT)
self.counter_nonterminal_rules = 0
def _get_uniq_id(self):
self.__uniq_id_counter += 1
......@@ -208,6 +209,9 @@ class GraphGrammar(nx.DiGraph):
return root_id
def apply_rule(self, rule: Rule):
if not rule.is_terminal:
self.counter_nonterminal_rules += 1
ids = self.find_nodes(rule.replaced_node)
edge_list = list(self.edges)
id_closest = self.closest_node_to_root(ids)
......
......@@ -13,7 +13,7 @@ class SAGPoolToAlphaZero(torch.nn.Module):
self.num_rules = args.num_rules
self.pooling_ratio = args.pooling_ratio
self.dropout_ratio = args.dropout_ratio
self.conv1 = GCNConv(self.num_features, self.nhid)
self.pool1 = SAGPooling(self.nhid, ratio=self.pooling_ratio, GNN=GCNConv)
self.conv2 = GCNConv(self.nhid, self.nhid)
......@@ -25,7 +25,7 @@ class SAGPoolToAlphaZero(torch.nn.Module):
self.fc_bn1 = torch.nn.BatchNorm1d(self.nhid)
self.lin2 = torch.nn.Linear(self.nhid, self.nhid//2)
self.fc_bn2 = torch.nn.BatchNorm1d(self.nhid//2)
self.fc_to_pi = torch.nn.Linear(self.nhid//2, self.num_rules)
self.fc_to_v = torch.nn.Linear(self.nhid//2, 1)
......@@ -46,12 +46,14 @@ class SAGPoolToAlphaZero(torch.nn.Module):
x = x1 + x2 + x3
x = F.relu(self.fc_bn1(self.lin1(x)))
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = F.relu(self.fc_bn2(self.lin2(x)))
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = F.relu(self.lin1(x))
# x = self.fc_bn1(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = F.relu(self.lin2(x))
# x = self.fc_bn2(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
pi = F.log_softmax(self.fc_to_pi(x), dim=1)
v = torch.tanh(self.fc_to_v(x))
v = F.relu(self.fc_to_v(x))
return pi, v
\ No newline at end of file
from typing import Union
import networkx as nx
import numpy as np
import torch
from torch_geometric.data import Data
......@@ -24,21 +26,44 @@ class ConverterToPytorchGeometric:
return dict_id_label_nodes, dict_label_id_nodes
def flatting_sorted_graph(self, graph: GraphGrammar) -> tuple[list[int], list[list[int]]]:
def flatting_sorted_graph(self, graph: Union[GraphGrammar, tuple]):
if isinstance(graph, GraphGrammar):
sorted_id_nodes = list(
nx.lexicographical_topological_sort(graph, key=lambda x: graph.get_node_by_id(x).label))
sorted_name_nodes = list(map(lambda x: graph.get_node_by_id(x).label, sorted_id_nodes))
sorted_id_nodes = list(
nx.lexicographical_topological_sort(graph, key=lambda x: graph.get_node_by_id(x).label))
sorted_name_nodes = list(map(lambda x: graph.get_node_by_id(x).label, sorted_id_nodes))
id_node2list = {id[1]: id[0] for id in enumerate(sorted_id_nodes)}
id_node2list = {id[1]: id[0] for id in enumerate(sorted_id_nodes)}
list_edges_on_id = list(graph.edges)
list_id_edge_links = list(map(lambda x: [id_node2list[n] for n in x], list_edges_on_id))
if not list_id_edge_links:
list_id_edge_links = [[0, 0]]
list_edges_on_id = list(graph.edges)
list_id_edge_links = list(map(lambda x: [id_node2list[n] for n in x], list_edges_on_id))
return sorted_name_nodes, list_id_edge_links
if not list_id_edge_links:
list_id_edge_links = [[0, 0]]
return sorted_name_nodes, list_id_edge_links
if isinstance(graph, tuple):
list_sorted_name_nodes = []
list_list_id_edge_links = []
for g in graph:
sorted_id_nodes = list(
nx.lexicographical_topological_sort(g, key=lambda x: g.get_node_by_id(x).label))
sorted_name_nodes = list(map(lambda x: g.get_node_by_id(x).label, sorted_id_nodes))
id_node2list = {id[1]: id[0] for id in enumerate(sorted_id_nodes)}
list_edges_on_id = list(g.edges)
list_id_edge_links = list(map(lambda x: [id_node2list[n] for n in x], list_edges_on_id))
if not list_id_edge_links:
list_id_edge_links = [[0, 0]]
list_sorted_name_nodes.append(sorted_name_nodes)
list_list_id_edge_links.append(list_id_edge_links)
return list_sorted_name_nodes, list_list_id_edge_links
def one_hot_encodding(self, label_node: str) -> list[int]:
......@@ -48,13 +73,38 @@ class ConverterToPytorchGeometric:
def transform_digraph(self, graph: GraphGrammar):
node_label_list, edge_id_list = self.flatting_sorted_graph(graph)
one_hot_list = list(map(self.one_hot_encodding, node_label_list))
edge_index = torch.t(torch.tensor(edge_id_list, dtype=torch.long))
x = torch.tensor(one_hot_list, dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
return data
\ No newline at end of file
if isinstance(graph, GraphGrammar):
node_label_list, edge_id_list = self.flatting_sorted_graph(graph)
one_hot_list = list(map(self.one_hot_encodding, node_label_list))
edge_index = torch.t(torch.tensor(edge_id_list, dtype=torch.long))
x = torch.tensor(one_hot_list, dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
return data
if isinstance(graph, tuple):
x = []
edge_index = []
batch = []
node_label_list, edge_id_list = self.flatting_sorted_graph(graph)
for id_batch, (n_label_list, e_id_list) in enumerate(zip(node_label_list, edge_id_list)):
one_hot_list = list(map(self.one_hot_encodding, n_label_list))
x += one_hot_list
edge_index += e_id_list
batch += [id_batch for __ in range(len(one_hot_list))]
edge_index = torch.t(torch.tensor(edge_index, dtype=torch.long))
x = torch.tensor(x, dtype=torch.float)
batch = torch.tensor(batch, dtype=torch.long)
data = Data(x=x, edge_index=edge_index, batch=batch)
return data
......@@ -13,57 +13,49 @@ from NeuralNet import NeuralNet
import torch
import torch.optim as optim
from torch_geometric.loader import DataLoader
from rostok.neural_network.SAGPool import SAGPoolToAlphaZero as graph_nnet
from rostok.neural_network.converter import ConverterToPytorchGeometric
args_train = dotdict({
"epochs": 10,
"batch_size": 1,
"cuda": torch.cuda.is_available(),
})
class AlphaZeroWrapper(NeuralNet):
def __init__(self, game: GraphGrammarGame):
def __init__(self, game: GraphGrammarGame , args_train):
self.action_size = game.getActionSize()
self.converter = ConverterToPytorchGeometric(game.rule_vocabulary.node_vocab)
args_network = dotdict({"num_features":len(self.converter.label2id),
"num_rules": game.getActionSize(),
"nhid": 64,
"pooling_ratio": 0.3,
"dropout_ratio": 0.3})
self.args_train = args_train
self.args_train["num_features"] =len(self.converter.label2id)
self.args_train["num_rules"] = game.getActionSize()
self.nnet = graph_nnet(args_network)
self.nnet = graph_nnet(self.args_train)
if args_train.cuda:
if self.args_train.cuda:
self.nnet.cuda()
def train(self, examples):
def train(self, examples, epochs):
"""
examples: list of examples, each example is of form (graph, pi, v)
"""
optimizer = optim.Adam(self.nnet.parameters())
for epoch in range(args_train.epochs):
for epoch in range(epochs):
print('EPOCH ::: ' + str(epoch + 1))
self.nnet.train()
pi_losses = AverageMeter()
v_losses = AverageMeter()
batch_count = int(len(examples) / args_train.batch_size)
batch_count = int(len(examples) / self.args_train.batch_size)
t = tqdm(range(batch_count), desc='Training Net')
for _ in t:
sample_ids = np.random.randint(len(examples), size=args_train.batch_size)
sample_ids = np.random.randint(len(examples), size=self.args_train.batch_size)
graph, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
data_graph = self.converter.transform_digraph(graph)
target_pis = torch.FloatTensor(np.array(pis))
target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))
# predict
if args_train.cuda:
if self.args_train.cuda:
data_graph, target_pis, target_vs = data_graph.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda()
# compute output
......@@ -91,7 +83,7 @@ class AlphaZeroWrapper(NeuralNet):
# preparing input
data_graph = self.converter.transform_digraph(graph)
if args_train.cuda: data_graph = data_graph.contiguous().cuda()
if self.args_train.cuda: data_graph = data_graph.contiguous().cuda()
self.nnet.eval()
with torch.no_grad():
pi, v = self.nnet(data_graph)
......@@ -121,6 +113,6 @@ class AlphaZeroWrapper(NeuralNet):
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
raise ("No model in path {}".format(filepath))
map_location = None if args_train.cuda else 'cpu'
map_location = None if self.args_train.cuda else 'cpu'
checkpoint = torch.load(filepath, map_location=map_location)
self.nnet.load_state_dict(checkpoint['state_dict'])
from dataclasses import dataclass, field
from typing import Callable
from typing import Union
import time
import warnings
import pychrono as chrono
from scipy.optimize import direct, shgo, dual_annealing
......@@ -8,8 +11,22 @@ from rostok.block_builder.blocks_utils import NodeFeatures
from rostok.graph_grammar.node import GraphGrammar
from rostok.virtual_experiment.robot import Robot
from rostok.virtual_experiment.simulation_step import (SimOut, SimulationStepOptimization)
from typing import Union
import time
class TimeOptimizerStopper(object):
def __init__(self, max_sec=0.3):
self.max_sec = max_sec
self.start = time.time()
def __call__(self, xk=None, convergence=None):
elapsed = time.time() - self.start
if elapsed > self.max_sec:
warnings.warn("Terminating optimization: time limit reached")
return True
else:
# you might want to report other stuff here
print("Elapsed: %.3f sec" % elapsed)
return False
@dataclass
class _ConfigRewardFunction:
......@@ -27,6 +44,7 @@ class _ConfigRewardFunction:
sim_config: dict[str, str] = field(default_factory=dict)
time_step: float = 0.001
time_sim: float = 2
time_optimization = 100
flags: list = field(default_factory=list)
criterion_callback: Callable[[SimOut, Robot], float] = None
get_rgab_object_callback: Callable[[], chrono.ChBody] = None
......@@ -42,7 +60,7 @@ class ConfigVectorJoints(_ConfigRewardFunction):
"""
bound: tuple[float, float] = (-1, 1)
iters: int = 10
optimizer_scipy = partial(direct)
optimizer_scipy = partial(shgo)
class ConfigGraphControl(_ConfigRewardFunction):
......@@ -126,7 +144,8 @@ class ControlOptimizer():
multi_bound = create_multidimensional_bounds(generated_graph, self.cfg.bound)
if len(multi_bound) == 0:
return (0, 0)
result = self.cfg.optimizer_scipy(reward_fun, multi_bound, maxiter=self.cfg.iters)
time_stopper = TimeOptimizerStopper(self.cfg.time_optimization)
result = self.cfg.optimizer_scipy(reward_fun, multi_bound, callback=time_stopper)#,maxiter=self.cfg.iters,)
return (result.fun, result.x)
elif isinstance(self.cfg, ConfigGraphControl):
n_joint = num_joints(generated_graph)
......
File deleted
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment