Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
ITMO-NSS-team
torch_DE_solver
Commits
ea5f900e
Commit
ea5f900e
authored
1 year ago
by
Damir Aminev
Browse files
Options
Download
Email Patches
Plain Diff
add compiler
parent
c011d7c9
compiler
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
examples/example_wave_paper_autograd.py
+1
-1
examples/example_wave_paper_autograd.py
tedeous/solver.py
+41
-1
tedeous/solver.py
with
42 additions
and
2 deletions
+42
-2
examples/example_wave_paper_autograd.py
View file @
ea5f900e
...
...
@@ -158,7 +158,7 @@ for grid_res in range(20, 110, 10):
model
=
Solver
(
grid
,
equation
,
model
,
'autograd'
).
solve
(
use_cache
=
True
,
verbose
=
True
,
print_every
=
None
,
cache_verbose
=
True
,
abs_loss
=
0.001
,
step_plot_print
=
False
,
step_plot_save
=
True
,
image_save_dir
=
img_dir
)
step_plot_save
=
True
,
image_save_dir
=
img_dir
,
backend
=
True
)
...
...
This diff is collapsed.
Click to expand it.
tedeous/solver.py
View file @
ea5f900e
...
...
@@ -187,6 +187,42 @@ class Solver():
plt
.
show
()
plt
.
close
()
def
compiler
(
self
,
function
,
backend
):
"""
Compile function using torch.compile.
Args:
function: callable, function for compiling
backend: bool or str, backend for compiler
if True the default backends are used.
Returns:
function: callble, compiled function.
"""
torch
.
_dynamo
.
config
.
suppress_errors
=
True
if
self
.
mode
==
'NN'
:
if
device_type
()
==
'cuda'
:
if
backend
==
True
:
backend
=
'inductor'
return
torch
.
compile
(
function
,
backend
=
backend
)
else
:
if
backend
==
True
:
backend
=
'aot_eager'
return
torch
.
compile
(
function
,
backend
=
backend
)
elif
self
.
mode
==
'autograd'
:
if
backend
==
True
:
backend
=
'onnxrt'
print
(
f
'torch.compile with aot_autograd does not currently'
f
'support double backwards, default backend="onnxrt"'
)
torch
.
_dynamo
.
reset
()
return
torch
.
compile
(
function
,
backend
=
backend
)
elif
self
.
mode
==
'mat'
:
if
backend
==
True
:
backend
=
'inductor'
return
torch
.
compile
(
function
,
backend
=
backend
)
def
solve
(
self
,
lambda_bound
:
Union
[
int
,
float
]
=
10
,
verbose
:
bool
=
False
,
learning_rate
:
float
=
1e-4
,
eps
:
float
=
1e-5
,
tmin
:
int
=
1000
,
tmax
:
float
=
1e5
,
nmodels
:
Union
[
int
,
None
]
=
None
,
name
:
Union
[
str
,
None
]
=
None
,
abs_loss
:
Union
[
None
,
float
]
=
None
,
use_cache
:
bool
=
True
,
...
...
@@ -195,7 +231,7 @@ class Solver():
patience
:
int
=
5
,
loss_oscillation_window
:
int
=
100
,
no_improvement_patience
:
int
=
1000
,
model_randomize_parameter
:
Union
[
int
,
float
]
=
0
,
optimizer_mode
:
str
=
'Adam'
,
step_plot_print
:
Union
[
bool
,
int
]
=
False
,
step_plot_save
:
Union
[
bool
,
int
]
=
False
,
image_save_dir
:
Union
[
str
,
None
]
=
None
)
->
Any
:
image_save_dir
:
Union
[
str
,
None
]
=
None
,
backend
:
Union
[
bool
,
str
]
=
False
)
->
Any
:
"""
High-level interface for solving equations.
...
...
@@ -222,6 +258,7 @@ class Solver():
step_plot_print: draws a figure through each given step.
step_plot_save: saves a figure through each given step.
image_save_dir: a directory where saved figure in.
backend: backend for torch.compiler, if is True, the default backends used. Available backends: torch._dynamo.list_backends(). Default False.
Returns:
model.
...
...
@@ -271,6 +308,9 @@ class Solver():
cur_loss
=
loss
.
item
()
return
loss
if
backend
!=
False
:
closure
=
self
.
compiler
(
closure
,
backend
=
backend
)
stop_dings
=
0
t_imp_start
=
0
# to stop train proceduce we fit the line in the loss data
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help