Commit ea5f900e authored by Damir Aminev's avatar Damir Aminev
Browse files

add compiler

No related merge requests found
Showing with 42 additions and 2 deletions
+42 -2
......@@ -158,7 +158,7 @@ for grid_res in range(20, 110, 10):
model=Solver(grid, equation, model, 'autograd').solve(use_cache=True, verbose=True, print_every=None,
cache_verbose=True, abs_loss=0.001, step_plot_print=False,
step_plot_save=True,image_save_dir=img_dir)
step_plot_save=True,image_save_dir=img_dir, backend=True)
......
......@@ -187,6 +187,42 @@ class Solver():
plt.show()
plt.close()
def compiler(self, function, backend):
"""
Compile function using torch.compile.
Args:
function: callable, function for compiling
backend: bool or str, backend for compiler
if True the default backends are used.
Returns:
function: callble, compiled function.
"""
torch._dynamo.config.suppress_errors = True
if self.mode == 'NN':
if device_type() == 'cuda':
if backend ==True:
backend = 'inductor'
return torch.compile(function, backend=backend)
else:
if backend ==True:
backend = 'aot_eager'
return torch.compile(function, backend=backend)
elif self.mode == 'autograd':
if backend==True:
backend = 'onnxrt'
print(f'torch.compile with aot_autograd does not currently'
f'support double backwards, default backend="onnxrt"')
torch._dynamo.reset()
return torch.compile(function, backend=backend)
elif self.mode == 'mat':
if backend==True:
backend = 'inductor'
return torch.compile(function, backend=backend)
def solve(self, lambda_bound: Union[int, float] = 10, verbose: bool = False, learning_rate: float = 1e-4,
eps: float = 1e-5, tmin: int = 1000, tmax: float = 1e5, nmodels: Union[int, None] = None,
name: Union[str, None] = None, abs_loss: Union[None, float] = None, use_cache: bool = True,
......@@ -195,7 +231,7 @@ class Solver():
patience: int = 5, loss_oscillation_window: int = 100, no_improvement_patience: int = 1000,
model_randomize_parameter: Union[int, float] = 0, optimizer_mode: str = 'Adam',
step_plot_print: Union[bool, int] = False, step_plot_save: Union[bool, int] = False,
image_save_dir: Union[str, None] = None) -> Any:
image_save_dir: Union[str, None] = None, backend: Union[bool, str] = False) -> Any:
"""
High-level interface for solving equations.
......@@ -222,6 +258,7 @@ class Solver():
step_plot_print: draws a figure through each given step.
step_plot_save: saves a figure through each given step.
image_save_dir: a directory where saved figure in.
backend: backend for torch.compiler, if is True, the default backends used. Available backends: torch._dynamo.list_backends(). Default False.
Returns:
model.
......@@ -271,6 +308,9 @@ class Solver():
cur_loss = loss.item()
return loss
if backend != False:
closure = self.compiler(closure, backend=backend)
stop_dings = 0
t_imp_start = 0
# to stop train proceduce we fit the line in the loss data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment