Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
itmo-sai-code
rostok
Commits
71ddd163
Commit
71ddd163
authored
2 years ago
by
K Z
Browse files
Options
Download
Plain Diff
Merge branch 'article' of
https://github.com/aimclub/rostok
into article
parents
b8cceb75
950d72bf
article_mor
article
No related merge requests found
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
article/NN_mcts_train.py
+66
-21
article/NN_mcts_train.py
article/hyperparameters.py
+3
-2
article/hyperparameters.py
article/optmizers_config.py
+4
-0
article/optmizers_config.py
article/rule_sets/rule_extention_graph.py
+2
-1
article/rule_sets/rule_extention_graph.py
rostok/graph_generators/graph_game.py
+17
-23
rostok/graph_generators/graph_game.py
rostok/graph_grammar/node.py
+4
-0
rostok/graph_grammar/node.py
rostok/neural_network/SAGPool.py
+9
-7
rostok/neural_network/SAGPool.py
rostok/neural_network/converter.py
+71
-21
rostok/neural_network/converter.py
rostok/neural_network/wrappers.py
+14
-22
rostok/neural_network/wrappers.py
rostok/trajectory_optimizer/control_optimizer.py
+23
-4
rostok/trajectory_optimizer/control_optimizer.py
train_data_10e_1000mcts_2302.pickle
+0
-0
train_data_10e_1000mcts_2302.pickle
with
213 additions
and
101 deletions
+213
-101
article/NN_mcts_train.py
View file @
71ddd163
...
...
@@ -4,10 +4,13 @@ import pickle
import
time
import
numpy
as
np
import
torch
from
rostok.graph_generators.graph_game
import
GraphGrammarGame
as
Game
from
rostok.neural_network.wrappers
import
AlphaZeroWrapper
from
MCTS
import
MCTS
from
Coach
import
Coach
from
obj_grasp.objects
import
get_obj_easy_box
...
...
@@ -17,15 +20,34 @@ from utils import dotdict
import
hyperparameters
as
hp
import
optmizers_config
from
rule_sets
import
rule_extention
,
rule_extention_graph
from
rule_sets.ruleset_old_style_graph
import
create_rules
log
=
logging
.
getLogger
(
__name__
)
CURRENT_PLAYER
=
1
mcts_args
=
dotdict
({
"numMCTSSims"
:
1000
,
"cpuct"
:
1
/
np
.
sqrt
(
2
),
"epochs"
:
10
coach_args
=
dotdict
({
"numMCTSSims"
:
10
,
"cpuct"
:
5
,
"train_offline_epochs"
:
50
,
"train_online_epochs"
:
10
,
"num_learn_epochs"
:
100
,
"tempThreshold"
:
4
,
"maxlenOfQueue"
:
20000
,
"numItersForTrainExamplesHistory"
:
20
,
"offline_iters"
:
5
,
"online_iters"
:
10
,
"update_weights"
:
2
,
"checkpoint"
:
"./temp/"
})
args_train
=
dotdict
({
"batch_size"
:
2
,
"cuda"
:
torch
.
cuda
.
is_available
(),
"nhid"
:
512
,
"pooling_ratio"
:
0.7
,
"dropout_ratio"
:
0.3
})
def
executeEpisode
(
game
,
mcts
,
temp_threshold
):
...
...
@@ -42,33 +64,38 @@ def executeEpisode(game, mcts, temp_threshold):
action
=
np
.
random
.
choice
(
len
(
pi
),
p
=
pi
)
graph
,
__
=
game
.
getNextState
(
graph
,
CURRENT_PLAYER
,
action
)
trainExamples
.
append
((
graph
,
pi
,
None
))
r
=
game
.
getGameEnded
(
graph
,
CURRENT_PLAYER
)
trainExamples
.
append
([
graph
,
CURRENT_PLAYER
,
pi
,
r
])
if
r
!=
0
:
return
tuple
(
trainExamples
)
return
[(
x
[
0
],
x
[
1
],
r
)
for
x
in
trainExamples
]
def
main
():
def
preconfigure
():
# List of weights for each criterion (force, time, COG)
WEIGHT
=
hp
.
CRITERION_WEIGHTS
# At least 20 iterations are needed for good results
rule_vocabul
=
deepcopy
(
rule_extention_graph
.
rule_vocab
)
cfg
=
optmizers_config
.
get_cfg_graph
(
rule_extention_graph
.
torque_dict
)
rule_vocabul
,
torque_dict
=
create_rules
(
)
cfg
=
optmizers_config
.
get_cfg_graph
(
torque_dict
)
cfg
.
get_rgab_object_callback
=
get_obj_easy_box
control_optimizer
=
ControlOptimizer
(
cfg
)
graph_game
=
Game
(
rule_vocabul
,
control_optimizer
,
hp
.
MAX_NUMBER_RULES
)
return
graph_game
examples
=
[[]
for
__
in
range
(
mcts_args
.
epochs
)]
for
epoch
in
range
(
mcts_args
.
epochs
):
def
main
():
graph_game
=
preconfigure
()
examples
=
[[]
for
__
in
range
(
coach_args
.
epochs
)]
for
epoch
in
range
(
coach_args
.
epochs
):
log
.
info
(
'Loading %s ...'
,
Game
.
__name__
)
log
.
info
(
'Loading %s ...'
,
AlphaZeroWrapper
.
__name__
)
nnet
=
AlphaZeroWrapper
(
graph_game
)
mcts_searcher
=
MCTS
(
graph_game
,
nnet
,
mcts
_args
)
mcts_searcher
=
MCTS
(
graph_game
,
nnet
,
coach
_args
)
start
=
time
.
time
()
examples
[
epoch
]
=
executeEpisode
(
graph_game
,
mcts_searcher
,
15
)
...
...
@@ -76,21 +103,39 @@ def main():
print
(
f
"epoch:
{
epoch
:
3
}
, time:
{
ex
}
"
)
struct
=
time
.
localtime
(
time
.
time
())
str_time
=
time
.
strftime
(
'%d%m%Y%H%M'
,
struct
)
name_file
=
f
"train_mcts_data_
{
mcts
_args
.
epochs
}
e_
{
mcts
_args
.
cpuct
}
c_
{
mcts_arg
s
}
mcts_
{
hp
.
MAX_NUMBER_RULES
}
rule_
{
str_time
}
t.pickle"
name_file
=
f
"train_mcts_data_
{
coach
_args
.
epochs
}
e_
{
coach
_args
.
cpuct
:
f
}
c_
{
coach_args
.
numMCTSSim
s
}
mcts_
{
hp
.
MAX_NUMBER_RULES
}
rule_
{
str_time
}
t.pickle"
with
open
(
name_file
,
"wb+"
)
as
file
:
pickle
.
dump
(
examples
,
file
)
def
load_train
(
path_to_data
):
def
load_train_data
(
path_to_data
):
with
open
(
path_to_data
,
"rb"
)
as
input_file
:
train
=
pickle
.
load
(
input_file
)
print
(
train
)
formatting_train_data
=
[[]
+
list
(
example
)
for
episode
in
train
for
example
in
episode
]
formatting_train_data
=
list
(
map
(
lambda
x
:
(
x
[
0
],
x
[
2
],
x
[
3
]),
formatting_train_data
))
return
train
def
pretrain_model
(
path_to_data
,
list_name_data
):
for
name_data
in
list_name_data
:
train_examples
=
load_train_data
(
path_to_data
+
name_data
+
".pickle"
)
graph_game
=
preconfigure
()
log
.
info
(
'Loading %s ...'
,
AlphaZeroWrapper
.
__name__
)
nnet
=
AlphaZeroWrapper
(
graph_game
,
args_train
)
nnet
.
train
(
train_examples
)
nnet
.
save_checkpoint
()
if
__name__
==
"__main__"
:
initial_time
=
time
.
time
()
main
()
final_ex
=
time
.
time
()
-
initial_time
print
(
f
"full_time :
{
final_ex
}
"
)
# for idx in range(10):
# initial_time = time.time()
# main()
# final_ex = time.time() - initial_time
# print(f"train {idx} index, full_time: {final_ex}")
load_train
(
"train_data_10e_1000mcts_2302.pickle"
)
\ No newline at end of file
# load_train("train_data_10e_1000mcts_2302.pickle")
game_graph
=
preconfigure
()
graph_nnet
=
AlphaZeroWrapper
(
game_graph
,
args_train
)
coacher
=
Coach
(
game_graph
,
graph_nnet
,
coach_args
)
coacher
.
learn
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
article/hyperparameters.py
View file @
71ddd163
MAX_NUMBER_RULES
=
20
MAX_NUMBER_RULES
=
6
BASE_ITERATION_LIMIT
=
500
...
...
@@ -7,9 +7,10 @@ ITERATION_REDUCTION_TIME = 0.7
CRITERION_WEIGHTS
=
[
5
,
5
,
2
,
5
]
CONTROL_OPTIMIZATION_ITERATION
=
20
TIME_OPTIMIZATION
=
100
TIME_STEP_SIMULATION
=
0.0025
TIME_SIMULATION
=
3
TIME_SIMULATION
=
1
FLAG_TIME_NO_CONTACT
=
1.5
FLAG_TIME_SLIPOUT
=
0.4
\ No newline at end of file
This diff is collapsed.
Click to expand it.
article/optmizers_config.py
View file @
71ddd163
...
...
@@ -26,6 +26,7 @@ def get_cfg_standart():
# Init configuration of control optimizing
cfg
=
ConfigVectorJoints
()
cfg
.
bound
=
(
0
,
15
)
cfg
.
time_optimization
=
hp
.
TIME_OPTIMIZATION
cfg
.
iters
=
hp
.
CONTROL_OPTIMIZATION_ITERATION
cfg
.
time_step
=
hp
.
TIME_STEP_SIMULATION
cfg
.
time_sim
=
hp
.
TIME_SIMULATION
...
...
@@ -47,6 +48,7 @@ def get_cfg_graph(torque_dict: dict[Node, float]):
WEIGHT
=
hp
.
CRITERION_WEIGHTS
# Init configuration of control optimizing
cfg
=
ConfigGraphControl
()
cfg
.
time_optimization
=
hp
.
TIME_OPTIMIZATION
cfg
.
time_step
=
hp
.
TIME_STEP_SIMULATION
cfg
.
time_sim
=
hp
.
TIME_SIMULATION
cfg
.
flags
=
[
FlagMaxTime
(
cfg
.
time_sim
),
...
...
@@ -67,6 +69,7 @@ def get_cfg_standart_anealing():
cfg
=
ConfigVectorJoints
()
cfg
.
optimizer_scipy
=
partial
(
dual_annealing
)
cfg
.
bound
=
(
0
,
15
)
cfg
.
time_optimization
=
hp
.
TIME_OPTIMIZATION
cfg
.
iters
=
hp
.
CONTROL_OPTIMIZATION_ITERATION
cfg
.
time_step
=
hp
.
TIME_STEP_SIMULATION
cfg
.
time_sim
=
hp
.
TIME_SIMULATION
...
...
@@ -89,6 +92,7 @@ def get_cfg_standart_step():
# Init configuration of control optimizing
CONST_TORQUE
=
12
cfg
=
ConfigVectorJoints
()
cfg
.
time_optimization
=
hp
.
TIME_OPTIMIZATION
cfg
.
bound
=
(
0
,
hp
.
TIME_SIMULATION
/
3
)
cfg
.
iters
=
hp
.
CONTROL_OPTIMIZATION_ITERATION
cfg
.
time_step
=
hp
.
TIME_STEP_SIMULATION
...
...
This diff is collapsed.
Click to expand it.
article/rule_sets/rule_extention_graph.py
View file @
71ddd163
...
...
@@ -86,6 +86,7 @@ node_vocab.add_node(ROOT)
node_vocab
.
create_node
(
"L"
)
node_vocab
.
create_node
(
"F"
)
node_vocab
.
create_node
(
"M"
)
node_vocab
.
create_node
(
"J"
)
node_vocab
.
create_node
(
"EF"
)
node_vocab
.
create_node
(
"EM"
)
node_vocab
.
create_node
(
"SML"
)
...
...
@@ -231,7 +232,7 @@ rule_vocab.create_rule("InitMechanism_4", ["ROOT"],
rule_vocab
.
create_rule
(
"InitMechanism_4_A"
,
[
"ROOT"
],
[
"F"
,
"SMLPA"
,
"SMLMA"
,
"SMRPA"
,
"SMRMA"
,
"EM"
,
"EM"
,
"EM"
,
"EM"
],
0
,
0
,
[(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
0
,
4
),
(
1
,
5
),
(
2
,
6
),
(
3
,
7
),
(
4
,
8
)])
rule_vocab
.
create_rule
(
"FingerUpper"
,
[
"EM"
],
[
"J
1
"
,
"L"
,
"EM"
],
0
,
2
,
[(
0
,
1
),
(
1
,
2
)])
rule_vocab
.
create_rule
(
"FingerUpper"
,
[
"EM"
],
[
"J"
,
"L"
,
"EM"
],
0
,
2
,
[(
0
,
1
),
(
1
,
2
)])
rule_vocab
.
create_rule
(
"TerminalFlat1"
,
[
"F"
],
[
"F1"
],
0
,
0
)
rule_vocab
.
create_rule
(
"TerminalFlat2"
,
[
"F"
],
[
"F2"
],
0
,
0
)
...
...
This diff is collapsed.
Click to expand it.
rostok/graph_generators/graph_game.py
View file @
71ddd163
...
...
@@ -26,7 +26,6 @@ class GraphGrammarGame(Game):
self
.
id2rule
:
dict
[
int
,
str
]
=
{
t
[
0
]:
t
[
1
]
for
t
in
list
(
enumerate
(
sorted_name_rule
))}
self
.
rule2id
:
dict
[
str
,
int
]
=
{
t
[
1
]:
t
[
0
]
for
t
in
self
.
id2rule
.
items
()}
self
.
max_no_terminal_rules
:
int
=
max_no_terminal_rules
self
.
counter_actions
:
int
=
0
self
.
control_optimization
=
control_optimization
self
.
terminal_graphs
:
dict
[
list
[
list
[
str
]],
tuple
(
float
,
list
[
list
[
float
]])]
=
{}
...
...
@@ -43,8 +42,6 @@ class GraphGrammarGame(Game):
rule
=
self
.
rule_vocabulary
.
rule_dict
[
name_rule
]
next_state_graph
=
deepcopy
(
graph
)
next_state_graph
.
apply_rule
(
rule
)
if
name_rule
in
self
.
rule_vocabulary
.
rules_nonterminal_node_set
:
self
.
counter_actions
+=
1
return
(
next_state_graph
,
-
player
)
...
...
@@ -53,7 +50,7 @@ class GraphGrammarGame(Game):
def
getValidMoves
(
self
,
graph
:
GraphGrammar
,
player
:
int
):
if
self
.
counter_
action
s
<
self
.
max_no_terminal_rules
:
if
graph
.
counter_
nonterminal_rule
s
<
self
.
max_no_terminal_rules
:
possible_rules_name
=
self
.
rule_vocabulary
.
get_list_of_applicable_rules
(
graph
)
else
:
possible_rules_name
=
self
.
rule_vocabulary
.
get_list_of_applicable_terminal_rules
(
graph
)
...
...
@@ -68,26 +65,23 @@ class GraphGrammarGame(Game):
def
getGameEnded
(
self
,
graph
:
GraphGrammar
,
player
):
if
len
(
self
.
rule_vocabulary
.
get_list_of_applicable_terminal_rules
(
graph
))
!=
0
:
return
0
if
(
len
(
self
.
rule_vocabulary
.
get_list_of_applicable_nonterminal_rules
(
graph
))
and
self
.
counter_actions
<
self
.
max_no_terminal_rules
):
return
0
flatten_graph
=
self
.
stringRepresentation
(
graph
)
self
.
counter_actions
=
0
if
flatten_graph
in
set
(
self
.
terminal_graphs
.
keys
()):
reward
,
movments_trajectory
=
self
.
terminal_graphs
[
flatten_graph
]
terminal_nodes
=
[
node
[
1
][
"Node"
].
is_terminal
for
node
in
graph
.
nodes
.
items
()]
if
sum
(
terminal_nodes
)
==
len
(
terminal_nodes
):
flatten_graph
=
self
.
stringRepresentation
(
graph
)
if
flatten_graph
in
set
(
self
.
terminal_graphs
.
keys
()):
reward
,
movments_trajectory
=
self
.
terminal_graphs
[
flatten_graph
]
else
:
result_optimizer
=
self
.
control_optimization
.
start_optimisation
(
graph
)
reward
=
-
result_optimizer
[
0
]
if
reward
==
0
:
reward
=
0.01
movments_trajectory
=
result_optimizer
[
1
]
self
.
terminal_graphs
[
flatten_graph
]
=
(
reward
,
movments_trajectory
)
return
reward
else
:
result_optimizer
=
self
.
control_optimization
.
start_optimisation
(
graph
)
reward
=
-
result_optimizer
[
0
]
if
reward
==
0
:
reward
=
0.01
movments_trajectory
=
result_optimizer
[
1
]
self
.
terminal_graphs
[
flatten_graph
]
=
(
reward
,
movments_trajectory
)
return
reward
return
0
def
stringRepresentation
(
self
,
graph
:
GraphGrammar
):
flatten_graph
,
__
=
self
.
_converter
.
flatting_sorted_graph
(
graph
)
...
...
This diff is collapsed.
Click to expand it.
rostok/graph_grammar/node.py
View file @
71ddd163
...
...
@@ -109,6 +109,7 @@ class GraphGrammar(nx.DiGraph):
super
().
__init__
(
**
attr
)
self
.
__uniq_id_counter
=
-
1
self
.
add_node
(
self
.
_get_uniq_id
(),
Node
=
ROOT
)
self
.
counter_nonterminal_rules
=
0
def
_get_uniq_id
(
self
):
self
.
__uniq_id_counter
+=
1
...
...
@@ -208,6 +209,9 @@ class GraphGrammar(nx.DiGraph):
return
root_id
def
apply_rule
(
self
,
rule
:
Rule
):
if
not
rule
.
is_terminal
:
self
.
counter_nonterminal_rules
+=
1
ids
=
self
.
find_nodes
(
rule
.
replaced_node
)
edge_list
=
list
(
self
.
edges
)
id_closest
=
self
.
closest_node_to_root
(
ids
)
...
...
This diff is collapsed.
Click to expand it.
rostok/neural_network/SAGPool.py
View file @
71ddd163
...
...
@@ -13,7 +13,7 @@ class SAGPoolToAlphaZero(torch.nn.Module):
self
.
num_rules
=
args
.
num_rules
self
.
pooling_ratio
=
args
.
pooling_ratio
self
.
dropout_ratio
=
args
.
dropout_ratio
self
.
conv1
=
GCNConv
(
self
.
num_features
,
self
.
nhid
)
self
.
pool1
=
SAGPooling
(
self
.
nhid
,
ratio
=
self
.
pooling_ratio
,
GNN
=
GCNConv
)
self
.
conv2
=
GCNConv
(
self
.
nhid
,
self
.
nhid
)
...
...
@@ -25,7 +25,7 @@ class SAGPoolToAlphaZero(torch.nn.Module):
self
.
fc_bn1
=
torch
.
nn
.
BatchNorm1d
(
self
.
nhid
)
self
.
lin2
=
torch
.
nn
.
Linear
(
self
.
nhid
,
self
.
nhid
//
2
)
self
.
fc_bn2
=
torch
.
nn
.
BatchNorm1d
(
self
.
nhid
//
2
)
self
.
fc_to_pi
=
torch
.
nn
.
Linear
(
self
.
nhid
//
2
,
self
.
num_rules
)
self
.
fc_to_v
=
torch
.
nn
.
Linear
(
self
.
nhid
//
2
,
1
)
...
...
@@ -46,12 +46,14 @@ class SAGPoolToAlphaZero(torch.nn.Module):
x
=
x1
+
x2
+
x3
x
=
F
.
relu
(
self
.
fc_bn1
(
self
.
lin1
(
x
)))
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout_ratio
,
training
=
self
.
training
)
x
=
F
.
relu
(
self
.
fc_bn2
(
self
.
lin2
(
x
)))
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout_ratio
,
training
=
self
.
training
)
x
=
F
.
relu
(
self
.
lin1
(
x
))
# x = self.fc_bn1(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x
=
F
.
relu
(
self
.
lin2
(
x
))
# x = self.fc_bn2(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
pi
=
F
.
log_softmax
(
self
.
fc_to_pi
(
x
),
dim
=
1
)
v
=
torch
.
tanh
(
self
.
fc_to_v
(
x
))
v
=
F
.
relu
(
self
.
fc_to_v
(
x
))
return
pi
,
v
\ No newline at end of file
This diff is collapsed.
Click to expand it.
rostok/neural_network/converter.py
View file @
71ddd163
from
typing
import
Union
import
networkx
as
nx
import
numpy
as
np
import
torch
from
torch_geometric.data
import
Data
...
...
@@ -24,21 +26,44 @@ class ConverterToPytorchGeometric:
return
dict_id_label_nodes
,
dict_label_id_nodes
def
flatting_sorted_graph
(
self
,
graph
:
GraphGrammar
)
->
tuple
[
list
[
int
],
list
[
list
[
int
]]]:
def
flatting_sorted_graph
(
self
,
graph
:
Union
[
GraphGrammar
,
tuple
]):
if
isinstance
(
graph
,
GraphGrammar
):
sorted_id_nodes
=
list
(
nx
.
lexicographical_topological_sort
(
graph
,
key
=
lambda
x
:
graph
.
get_node_by_id
(
x
).
label
))
sorted_name_nodes
=
list
(
map
(
lambda
x
:
graph
.
get_node_by_id
(
x
).
label
,
sorted_id_nodes
))
sorted_id_nodes
=
list
(
nx
.
lexicographical_topological_sort
(
graph
,
key
=
lambda
x
:
graph
.
get_node_by_id
(
x
).
label
))
sorted_name_nodes
=
list
(
map
(
lambda
x
:
graph
.
get_node_by_id
(
x
).
label
,
sorted_id_nodes
))
id_node2list
=
{
id
[
1
]:
id
[
0
]
for
id
in
enumerate
(
sorted_id_nodes
)}
id_node2list
=
{
id
[
1
]:
id
[
0
]
for
id
in
enumerate
(
sorted_id_nodes
)}
list_edges_on_id
=
list
(
graph
.
edges
)
list_id_edge_links
=
list
(
map
(
lambda
x
:
[
id_node2list
[
n
]
for
n
in
x
],
list_edges_on_id
))
if
not
list_id_edge_links
:
list_id_edge_links
=
[[
0
,
0
]]
list_edges_on_id
=
list
(
graph
.
edges
)
list_id_edge_links
=
list
(
map
(
lambda
x
:
[
id_node2list
[
n
]
for
n
in
x
],
list_edges_on_id
))
return
sorted_name_nodes
,
list_id_edge_links
if
not
list_id_edge_links
:
list_id_edge_links
=
[[
0
,
0
]]
return
sorted_name_nodes
,
list_id_edge_links
if
isinstance
(
graph
,
tuple
):
list_sorted_name_nodes
=
[]
list_list_id_edge_links
=
[]
for
g
in
graph
:
sorted_id_nodes
=
list
(
nx
.
lexicographical_topological_sort
(
g
,
key
=
lambda
x
:
g
.
get_node_by_id
(
x
).
label
))
sorted_name_nodes
=
list
(
map
(
lambda
x
:
g
.
get_node_by_id
(
x
).
label
,
sorted_id_nodes
))
id_node2list
=
{
id
[
1
]:
id
[
0
]
for
id
in
enumerate
(
sorted_id_nodes
)}
list_edges_on_id
=
list
(
g
.
edges
)
list_id_edge_links
=
list
(
map
(
lambda
x
:
[
id_node2list
[
n
]
for
n
in
x
],
list_edges_on_id
))
if
not
list_id_edge_links
:
list_id_edge_links
=
[[
0
,
0
]]
list_sorted_name_nodes
.
append
(
sorted_name_nodes
)
list_list_id_edge_links
.
append
(
list_id_edge_links
)
return
list_sorted_name_nodes
,
list_list_id_edge_links
def
one_hot_encodding
(
self
,
label_node
:
str
)
->
list
[
int
]:
...
...
@@ -48,13 +73,38 @@ class ConverterToPytorchGeometric:
def
transform_digraph
(
self
,
graph
:
GraphGrammar
):
node_label_list
,
edge_id_list
=
self
.
flatting_sorted_graph
(
graph
)
one_hot_list
=
list
(
map
(
self
.
one_hot_encodding
,
node_label_list
))
edge_index
=
torch
.
t
(
torch
.
tensor
(
edge_id_list
,
dtype
=
torch
.
long
))
x
=
torch
.
tensor
(
one_hot_list
,
dtype
=
torch
.
float
)
data
=
Data
(
x
=
x
,
edge_index
=
edge_index
)
return
data
\ No newline at end of file
if
isinstance
(
graph
,
GraphGrammar
):
node_label_list
,
edge_id_list
=
self
.
flatting_sorted_graph
(
graph
)
one_hot_list
=
list
(
map
(
self
.
one_hot_encodding
,
node_label_list
))
edge_index
=
torch
.
t
(
torch
.
tensor
(
edge_id_list
,
dtype
=
torch
.
long
))
x
=
torch
.
tensor
(
one_hot_list
,
dtype
=
torch
.
float
)
data
=
Data
(
x
=
x
,
edge_index
=
edge_index
)
return
data
if
isinstance
(
graph
,
tuple
):
x
=
[]
edge_index
=
[]
batch
=
[]
node_label_list
,
edge_id_list
=
self
.
flatting_sorted_graph
(
graph
)
for
id_batch
,
(
n_label_list
,
e_id_list
)
in
enumerate
(
zip
(
node_label_list
,
edge_id_list
)):
one_hot_list
=
list
(
map
(
self
.
one_hot_encodding
,
n_label_list
))
x
+=
one_hot_list
edge_index
+=
e_id_list
batch
+=
[
id_batch
for
__
in
range
(
len
(
one_hot_list
))]
edge_index
=
torch
.
t
(
torch
.
tensor
(
edge_index
,
dtype
=
torch
.
long
))
x
=
torch
.
tensor
(
x
,
dtype
=
torch
.
float
)
batch
=
torch
.
tensor
(
batch
,
dtype
=
torch
.
long
)
data
=
Data
(
x
=
x
,
edge_index
=
edge_index
,
batch
=
batch
)
return
data
This diff is collapsed.
Click to expand it.
rostok/neural_network/wrappers.py
View file @
71ddd163
...
...
@@ -13,57 +13,49 @@ from NeuralNet import NeuralNet
import
torch
import
torch.optim
as
optim
from
torch_geometric.loader
import
DataLoader
from
rostok.neural_network.SAGPool
import
SAGPoolToAlphaZero
as
graph_nnet
from
rostok.neural_network.converter
import
ConverterToPytorchGeometric
args_train
=
dotdict
({
"epochs"
:
10
,
"batch_size"
:
1
,
"cuda"
:
torch
.
cuda
.
is_available
(),
})
class
AlphaZeroWrapper
(
NeuralNet
):
def
__init__
(
self
,
game
:
GraphGrammarGame
):
def
__init__
(
self
,
game
:
GraphGrammarGame
,
args_train
):
self
.
action_size
=
game
.
getActionSize
()
self
.
converter
=
ConverterToPytorchGeometric
(
game
.
rule_vocabulary
.
node_vocab
)
args_network
=
dotdict
({
"num_features"
:
len
(
self
.
converter
.
label2id
),
"num_rules"
:
game
.
getActionSize
(),
"nhid"
:
64
,
"pooling_ratio"
:
0.3
,
"dropout_ratio"
:
0.3
})
self
.
args_train
=
args_train
self
.
args_train
[
"num_features"
]
=
len
(
self
.
converter
.
label2id
)
self
.
args_train
[
"num_rules"
]
=
game
.
getActionSize
()
self
.
nnet
=
graph_nnet
(
args_network
)
self
.
nnet
=
graph_nnet
(
self
.
args_train
)
if
args_train
.
cuda
:
if
self
.
args_train
.
cuda
:
self
.
nnet
.
cuda
()
def
train
(
self
,
examples
):
def
train
(
self
,
examples
,
epochs
):
"""
examples: list of examples, each example is of form (graph, pi, v)
"""
optimizer
=
optim
.
Adam
(
self
.
nnet
.
parameters
())
for
epoch
in
range
(
args_train
.
epochs
):
for
epoch
in
range
(
epochs
):
print
(
'EPOCH ::: '
+
str
(
epoch
+
1
))
self
.
nnet
.
train
()
pi_losses
=
AverageMeter
()
v_losses
=
AverageMeter
()
batch_count
=
int
(
len
(
examples
)
/
args_train
.
batch_size
)
batch_count
=
int
(
len
(
examples
)
/
self
.
args_train
.
batch_size
)
t
=
tqdm
(
range
(
batch_count
),
desc
=
'Training Net'
)
for
_
in
t
:
sample_ids
=
np
.
random
.
randint
(
len
(
examples
),
size
=
args_train
.
batch_size
)
sample_ids
=
np
.
random
.
randint
(
len
(
examples
),
size
=
self
.
args_train
.
batch_size
)
graph
,
pis
,
vs
=
list
(
zip
(
*
[
examples
[
i
]
for
i
in
sample_ids
]))
data_graph
=
self
.
converter
.
transform_digraph
(
graph
)
target_pis
=
torch
.
FloatTensor
(
np
.
array
(
pis
))
target_vs
=
torch
.
FloatTensor
(
np
.
array
(
vs
).
astype
(
np
.
float64
))
# predict
if
args_train
.
cuda
:
if
self
.
args_train
.
cuda
:
data_graph
,
target_pis
,
target_vs
=
data_graph
.
contiguous
().
cuda
(),
target_pis
.
contiguous
().
cuda
(),
target_vs
.
contiguous
().
cuda
()
# compute output
...
...
@@ -91,7 +83,7 @@ class AlphaZeroWrapper(NeuralNet):
# preparing input
data_graph
=
self
.
converter
.
transform_digraph
(
graph
)
if
args_train
.
cuda
:
data_graph
=
data_graph
.
contiguous
().
cuda
()
if
self
.
args_train
.
cuda
:
data_graph
=
data_graph
.
contiguous
().
cuda
()
self
.
nnet
.
eval
()
with
torch
.
no_grad
():
pi
,
v
=
self
.
nnet
(
data_graph
)
...
...
@@ -121,6 +113,6 @@ class AlphaZeroWrapper(NeuralNet):
filepath
=
os
.
path
.
join
(
folder
,
filename
)
if
not
os
.
path
.
exists
(
filepath
):
raise
(
"No model in path {}"
.
format
(
filepath
))
map_location
=
None
if
args_train
.
cuda
else
'cpu'
map_location
=
None
if
self
.
args_train
.
cuda
else
'cpu'
checkpoint
=
torch
.
load
(
filepath
,
map_location
=
map_location
)
self
.
nnet
.
load_state_dict
(
checkpoint
[
'state_dict'
])
This diff is collapsed.
Click to expand it.
rostok/trajectory_optimizer/control_optimizer.py
View file @
71ddd163
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
from
typing
import
Union
import
time
import
warnings
import
pychrono
as
chrono
from
scipy.optimize
import
direct
,
shgo
,
dual_annealing
...
...
@@ -8,8 +11,22 @@ from rostok.block_builder.blocks_utils import NodeFeatures
from
rostok.graph_grammar.node
import
GraphGrammar
from
rostok.virtual_experiment.robot
import
Robot
from
rostok.virtual_experiment.simulation_step
import
(
SimOut
,
SimulationStepOptimization
)
from
typing
import
Union
import
time
class
TimeOptimizerStopper
(
object
):
def
__init__
(
self
,
max_sec
=
0.3
):
self
.
max_sec
=
max_sec
self
.
start
=
time
.
time
()
def
__call__
(
self
,
xk
=
None
,
convergence
=
None
):
elapsed
=
time
.
time
()
-
self
.
start
if
elapsed
>
self
.
max_sec
:
warnings
.
warn
(
"Terminating optimization: time limit reached"
)
return
True
else
:
# you might want to report other stuff here
print
(
"Elapsed: %.3f sec"
%
elapsed
)
return
False
@
dataclass
class
_ConfigRewardFunction
:
...
...
@@ -27,6 +44,7 @@ class _ConfigRewardFunction:
sim_config
:
dict
[
str
,
str
]
=
field
(
default_factory
=
dict
)
time_step
:
float
=
0.001
time_sim
:
float
=
2
time_optimization
=
100
flags
:
list
=
field
(
default_factory
=
list
)
criterion_callback
:
Callable
[[
SimOut
,
Robot
],
float
]
=
None
get_rgab_object_callback
:
Callable
[[],
chrono
.
ChBody
]
=
None
...
...
@@ -42,7 +60,7 @@ class ConfigVectorJoints(_ConfigRewardFunction):
"""
bound
:
tuple
[
float
,
float
]
=
(
-
1
,
1
)
iters
:
int
=
10
optimizer_scipy
=
partial
(
direct
)
optimizer_scipy
=
partial
(
shgo
)
class
ConfigGraphControl
(
_ConfigRewardFunction
):
...
...
@@ -126,7 +144,8 @@ class ControlOptimizer():
multi_bound
=
create_multidimensional_bounds
(
generated_graph
,
self
.
cfg
.
bound
)
if
len
(
multi_bound
)
==
0
:
return
(
0
,
0
)
result
=
self
.
cfg
.
optimizer_scipy
(
reward_fun
,
multi_bound
,
maxiter
=
self
.
cfg
.
iters
)
time_stopper
=
TimeOptimizerStopper
(
self
.
cfg
.
time_optimization
)
result
=
self
.
cfg
.
optimizer_scipy
(
reward_fun
,
multi_bound
,
callback
=
time_stopper
)
#,maxiter=self.cfg.iters,)
return
(
result
.
fun
,
result
.
x
)
elif
isinstance
(
self
.
cfg
,
ConfigGraphControl
):
n_joint
=
num_joints
(
generated_graph
)
...
...
This diff is collapsed.
Click to expand it.
train_data_10e_1000mcts_2302.pickle
deleted
100644 → 0
View file @
b8cceb75
File deleted
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help