Commit af988561 authored by Nikita Demyanchuk's avatar Nikita Demyanchuk
Browse files

Refactor + Navier-Stokes example

No related merge requests found
Showing with 425 additions and 212 deletions
+425 -212
import torch
import math
import matplotlib.pyplot as plt
import scipy
import os
import sys
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
sys.path.pop()
sys.path.append(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
from tedeous.input_preprocessing import Equation
from tedeous.solver import Solver
from tedeous.metrics import Solution
from tedeous.device import solver_device
solver_device('cuda')
grid_res = 30
x = torch.linspace(0, 5, grid_res + 1)
t = torch.linspace(0, 2, grid_res + 1)
h = abs((t[1] - t[0]).item())
grid = torch.cartesian_prod(x, t)
def bconds_1():
# Boundary conditions at x=0
bnd1_u = torch.cartesian_prod(torch.Tensor([0]), t)
# u(0,t) = 2
bndval1_u = 2 * torch.ones_like(bnd1_u[:, 0])
bnd_type_1_u = 'dirichlet'
# Boundary conditions at x=5
bnd2_u = torch.cartesian_prod(torch.Tensor([5]), t)
# u(5,t) = 2
bndval2_u = 2 * torch.ones_like(bnd2_u[:, 0])
bnd_type_2_u = 'dirichlet'
# Boundary conditions at x=0
bnd1_p = torch.cartesian_prod(torch.Tensor([0]), t)
# p(0,t) = 0
bndval1_p = torch.zeros_like(bnd1_p[:, 0])
bnd_type_1_p = 'dirichlet'
# Boundary conditions at x=5
bnd2_p = torch.cartesian_prod(torch.Tensor([5]), t)
# p(5,t) = 0
bndval2_p = torch.zeros_like(bnd2_p[:, 0])
bnd_type_2_p = 'dirichlet'
# Initial condition at t=0
ics_u = torch.cartesian_prod(x, torch.Tensor([0]))
icsval_u = torch.sin((2 * math.pi * x) / 5) + 2
ics_type_u = 'dirichlet'
bconds = [[bnd1_u, bndval1_u, 0, bnd_type_1_u],
[bnd1_p, bndval1_p, 1, bnd_type_1_p],
[bnd2_u, bndval2_u, 0, bnd_type_2_u],
[bnd2_p, bndval2_p, 1, bnd_type_2_p],
[ics_u, icsval_u, 0, ics_type_u, 'initial']]
return bconds
def bconds_2():
# Boundary conditions at x=0
bnd1_u = torch.cartesian_prod(torch.Tensor([0]), t)
bop1_u = {
'du/dt':
{
'coeff': 1,
'term': [1],
'pow': 1,
'var': 0
}
}
# u_t = t*sin(t)
bval1_u = t * torch.sin(t)
bnd_type_1_u = 'operator'
# Boundary conditions at x=5
bnd2_u = torch.cartesian_prod(torch.Tensor([5]), t)
bop2_u = {
'du/dt':
{
'coeff': 1,
'term': [1],
'pow': 1,
'var': 0
}
}
# u_t = t*sin(t)
bval2_u = t * torch.sin(t)
bnd_type_2_u = 'operator'
# Boundary conditions at x=0
bnd1_p = torch.cartesian_prod(torch.Tensor([0]), t)
# p(0,t) = 0
bndval1_p = torch.zeros_like(bnd1_p[:, 0])
bnd_type_1_p = 'dirichlet'
# Boundary conditions at x=5
bnd2_p = torch.cartesian_prod(torch.Tensor([5]), t)
# p(5,t) = 0
bndval2_p = torch.zeros_like(bnd2_p[:, 0])
bnd_type_2_p = 'dirichlet'
# Initial condition at t=0
ics_u = torch.cartesian_prod(x, torch.Tensor([0]))
icsval_u = torch.sin((math.pi * x) / 5) + 1
ics_type_u = 'dirichlet'
bconds = [[bnd1_u, bop1_u,bval1_u, 0, bnd_type_1_u],
[bnd1_p, bndval1_p, 1, bnd_type_1_p],
[bnd2_u,bop2_u, bval2_u, 0, bnd_type_2_u],
[bnd2_p, bndval2_p, 1, bnd_type_2_p],
[ics_u, icsval_u, 0, ics_type_u, 'initial']]
return bconds
bconds = bconds_1()
ro = 1
mu = 1
NS_1 = {
'du/dx':
{
'coeff': 1,
'term': [0],
'pow': 1,
'var': 0
}
}
NS_2 = {
'du/dt':
{
'coeff': 1,
'term': [1],
'pow': 1,
'var': 0
},
'u * du/dx':
{
'coeff': 1,
'term': [[None], [0]],
'pow': [1, 1],
'var': [0, 0]
},
'1/ro * dp/dx':
{
'coeff': 1/ro,
'term': [0],
'pow': 1,
'var': 1
},
'-mu * d2u/dx2':
{'coeff': -mu,
'term': [0, 0],
'pow': 1,
'var': 0}
}
navier_stokes = [NS_1, NS_2]
model = torch.nn.Sequential(
torch.nn.Linear(2, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, 100),
torch.nn.Tanh(),
torch.nn.Linear(100, 2)
)
equation = Equation(grid, navier_stokes, bconds, h=h).set_strategy('autograd')
img_dir = os.path.join(os.path.dirname(__file__), 'navier_stokes_img')
model = Solver(grid, equation, model, 'autograd').solve(lambda_bound=1000, verbose=True, learning_rate=1e-4,
eps=1e-8, tmax=1e6, use_cache=False, cache_verbose=True,
save_always=True, print_every=500, model_randomize_parameter=1e-5,
optimizer_mode='LBFGS', no_improvement_patience=1000,
step_plot_print=500, step_plot_save=False, image_save_dir=img_dir)
......@@ -92,7 +92,7 @@ bop4= {
bconds = [[bnd1, bndval1, 'dirichlet'],
[bnd2, bndval2, 'dirichlet'],
[bnd3, bndval3, 'dirichlet'],
[bnd4, bop4, bndval4, 'operator']]
[bnd4, bop4, bndval4, 'operator', 'initial']]
wave_eq = {
'd2u/dt2**1':
......@@ -124,7 +124,7 @@ equation = Equation(grid, wave_eq, bconds, h=h).set_strategy('autograd')
img_dir = os.path.join(os.path.dirname(__file__), 'wave_example_paper_img')
model = Solver(grid, equation, model, 'autograd').solve(update_every_lambdas=1000, verbose=True, learning_rate=1e-3,
model = Solver(grid, equation, model, 'autograd').solve(loss_term = 3, update_every_lambdas=500, verbose=True, learning_rate=1e-3,
eps=1e-8, tmax=1e6, use_cache=False, cache_verbose=True,lr_decay=True,
save_always=True, print_every=500, model_randomize_parameter=1e-5,
optimizer_mode='Adam', no_improvement_patience=1000,patience= 8,
......
......@@ -114,7 +114,6 @@ class Model_prepare():
continue
model = model.to(device)
if self.lambda_update is int:
l = Solution(self.grid, self.equal_cls,
self.model, self.mode,
weak_form, self.lambda_update).evaluate(lambda_bound)
......
......@@ -38,9 +38,11 @@ class Boundary():
bcond[0] = check_device(bcond[0])
bcond[1] = check_device(bcond[1])
if len(bcond) == 3:
boundary = [bcond[0], None, bcond[1], 0, bcond[2]]
boundary = [bcond[0], None, bcond[1], 0, bcond[2], 'boundary']
elif len(bcond) == 4:
boundary = [bcond[0], None, bcond[1], bcond[2], bcond[3]]
boundary = [bcond[0], None, bcond[1], bcond[2], bcond[3], 'boundary']
elif len(bcond) == 5:
boundary = [bcond[0], None, bcond[1], bcond[2], bcond[3], bcond[4]]
else:
raise NameError('Incorrect Dirichlet condition')
return boundary
......@@ -64,10 +66,16 @@ class Boundary():
bcond[2] = check_device(bcond[2])
if len(bcond) == 4:
bcond[1] = EquationMixin.equation_unify(bcond[1])
boundary = [bcond[0], bcond[1], bcond[2], None, bcond[3]]
boundary = [bcond[0], bcond[1], bcond[2], None, bcond[3], 'boundary']
elif len(bcond) == 5:
bcond[1] = EquationMixin.equation_unify(bcond[1])
boundary = [bcond[0], bcond[1], bcond[2], None, bcond[4]]
if bcond[-1] == 'initial':
boundary = [bcond[0], bcond[1], bcond[2], None, bcond[3], bcond[4]]
else:
boundary = [bcond[0], bcond[1], bcond[2], None, bcond[3], 'boundary']
elif len(bcond) == 6:
bcond[1] = EquationMixin.equation_unify(bcond[1])
boundary = [bcond[0], bcond[1], bcond[2], None, bcond[4], bcond[5]]
else:
raise NameError('Incorrect operator condition')
return boundary
......@@ -91,14 +99,21 @@ class Boundary():
bcond[0][i] = check_device(bcond[0][i])
if len(bcond) == 2:
b_val = torch.zeros(bcond[0][0].shape[0])
boundary = [bcond[0], None, b_val, 0, bcond[1]]
boundary = [bcond[0], None, b_val, 0, bcond[1], 'boundary']
elif len(bcond) == 3 and type(bcond[1]) is int:
b_val = torch.zeros(bcond[0][0].shape[0])
boundary = [bcond[0], None, b_val, bcond[1], bcond[2]]
elif type(bcond[1]) is dict:
boundary = [bcond[0], None, b_val, bcond[1], bcond[2], 'boundary']
elif len(bcond) == 3 and type(bcond[1]) is dict:
b_val = torch.zeros(bcond[0][0].shape[0])
bcond[1] = EquationMixin.equation_unify(bcond[1])
boundary = [bcond[0], bcond[1], b_val, None, bcond[2]]
boundary = [bcond[0], bcond[1], b_val, None, bcond[2], 'boundary']
elif len(bcond) == 4 and type(bcond[1]) is int:
b_val = torch.zeros(bcond[0][0].shape[0])
boundary = [bcond[0], None, b_val, bcond[1], bcond[2], bcond[3]]
elif len(bcond) == 4 and type(bcond[1]) is dict:
b_val = torch.zeros(bcond[0][0].shape[0])
bcond[1] = EquationMixin.equation_unify(bcond[1])
boundary = [bcond[0], bcond[1], b_val, None, bcond[2], bcond[3]]
else:
raise NameError('Incorrect periodic condition')
return boundary
......@@ -114,14 +129,25 @@ class Boundary():
return unified condition.
"""
if bcond[-1] == 'periodic':
bnd = self.periodic(bcond)
elif bcond[-1] == 'dirichlet':
bnd = self.dirichlet(bcond)
elif bcond[-1] == 'operator':
bnd = self.neumann(bcond)
if bcond[-1] == 'initial':
if bcond[-2] == 'periodic':
bnd = self.periodic(bcond)
elif bcond[-2] == 'dirichlet':
bnd = self.dirichlet(bcond)
elif bcond[-2] == 'operator':
bnd = self.neumann(bcond)
else:
raise NameError('TEDEouS can not use ' + bcond[-2] + ' condition type')
else:
raise NameError('TEDEouS can not use ' + bcond[-1] + ' condition type')
if bcond[-1] == 'periodic':
bnd = self.periodic(bcond)
elif bcond[-1] == 'dirichlet':
bnd = self.dirichlet(bcond)
elif bcond[-1] == 'operator':
bnd = self.neumann(bcond)
else:
raise NameError('TEDEouS can not use ' + bcond[-1] + ' condition type')
return bnd
def bnd_unify(self) -> list:
......@@ -136,7 +162,7 @@ class Boundary():
for bcond in self.bconds:
bnd = {}
bnd['bnd'], bnd['bop'], bnd['bval'], bnd['var'], \
bnd['type'] = self.bnd_choose(bcond)
bnd['type'], bnd['condition'] = self.bnd_choose(bcond)
unified_bnd.append(bnd)
return unified_bnd
......@@ -664,7 +690,7 @@ class Equation():
"""
Interface for preparing equations due to chosen calculation method.
"""
def __init__(self, grid: torch.Tensor, operator: dict, bconds: list, h: float = 0.001,
def __init__(self, grid: torch.Tensor, operator: Union[dict, list], bconds: list, h: float = 0.001,
inner_order: str ='1', boundary_order: str ='2'):
"""
Args:
......
......@@ -12,7 +12,8 @@ from tedeous.utils import *
flatten_list = lambda t: [item for sublist in t for item in sublist]
def integration(func: torch.tensor, grid, pow: Union[int, float] = 2) -> Union[Tuple[float, float], Tuple[list, torch.Tensor]]:
def integration(func: torch.tensor, grid, pow: Union[int, float] = 2) -> Union[
Tuple[float, float], Tuple[list, torch.Tensor]]:
"""
Function realize 1-space integrands,
where func=(L(u)-f)*weak_form subintegrands function and
......@@ -52,12 +53,19 @@ def integration(func: torch.tensor, grid, pow: Union[int, float] = 2) -> Union[T
grid = grid[index, :-1]
return result, grid
class Operator():
def __init__(self,operator, grid, model, mode):
def __init__(self, operator, grid, model, mode, weak_form):
self.operator = operator
self.grid = grid
self.model = model
self.mode = mode
self.weak_form = weak_form
if self.mode == 'NN':
self.grid_dict = Points_type(self.grid).grid_sort()
self.sorted_grid = torch.cat(list(self.grid_dict.values()))
elif self.mode == 'autograd' or self.mode == 'mat':
self.sorted_grid = self.grid
def apply_op(self, operator, grid_points) -> torch.Tensor:
"""
......@@ -81,9 +89,66 @@ class Operator():
total = dif
return total
def pde_compute(self) -> torch.Tensor:
"""
Computes PDE residual.
Returns:
PDE residual.
"""
num_of_eq = len(self.operator)
if num_of_eq == 1:
op = self.apply_op(
self.operator[0], self.sorted_grid)
else:
op_list = []
for i in range(num_of_eq):
op_list.append(self.apply_op(
self.operator[i], self.sorted_grid))
op = torch.cat(op_list, 1)
return op
def weak_pde_compute(self, weak_form) -> torch.Tensor:
"""
Computes PDE residual in weak form.
Args:
weak_form: list of basis functions
Returns:
weak PDE residual.
"""
device = device_type()
if self.mode == 'NN':
grid_central = self.grid_dict['central']
elif self.mode == 'autograd':
grid_central = self.grid
op = self.pde_compute()
sol_list = []
for i in range(op.shape[-1]):
sol = op[:, i]
for func in weak_form:
sol = sol * func(grid_central).to(device).reshape(-1)
grid_central1 = torch.clone(grid_central)
for k in range(grid_central.shape[-1]):
sol, grid_central1 = integration(sol, grid_central1)
sol_list.append(sol.reshape(-1, 1))
if len(sol_list) == 1:
return sol_list[0]
else:
return torch.cat(sol_list)
@counter
def operator_compute(self):
if self.weak_form == None or self.weak_form == []:
return self.pde_compute()
else:
return self.weak_pde_compute(self.weak_form)
class Bounds(Operator):
def __init__(self, bconds, grid, model, mode):
super().__init__(bconds, grid, model, mode)
def __init__(self, bconds, grid, model, mode, weak_form):
super().__init__(bconds, grid, model, mode, weak_form)
self.bconds = bconds
def apply_bconds_set(self, operator_set: list) -> torch.Tensor:
......@@ -120,7 +185,7 @@ class Bounds(Operator):
elif self.mode == 'autograd':
b_op_val = self.apply_op(bop, bnd)
elif self.mode == 'mat':
b_op_val = self.apply_op(operator = bop, grid_points=self.grid)
b_op_val = self.apply_op(operator=bop, grid_points=self.grid)
b_val = []
for position in bnd:
if self.grid.dim() == 1 or min(self.grid.shape) == 1:
......@@ -154,6 +219,7 @@ class Bounds(Operator):
Returns:
calculated operator on the boundary.
"""
if bcond['type'] == 'dirichlet':
b_op_val = self.apply_dirichlet(bcond['bnd'], bcond['var'])
elif bcond['type'] == 'operator':
......@@ -167,7 +233,8 @@ class Bounds(Operator):
truebval = bcond['bval'].reshape(-1, 1)
b_op_val = self.b_op_val_calc(bcond)
return b_op_val, truebval
def apply_bconds_default(self) -> Tuple[torch.Tensor, torch.Tensor]:
def bcs(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Auxiliary function. Serves only to evaluate boundary values and true boundary values.
Returns:
......@@ -177,6 +244,7 @@ class Bounds(Operator):
true_b_val_list = []
b_val_list = []
for bcond in self.bconds:
b_op_val, truebval = self.compute_bconds(bcond)
b_val_list.append(b_op_val)
......@@ -186,63 +254,61 @@ class Bounds(Operator):
b_val = torch.cat(b_val_list).reshape(-1, 1)
return b_val, true_b_val
def apply_bconds_separate(self):
true_b_val_list_sep = []
b_val_list_sep = []
def bcs_ics(self):
true_ics_list = []
ics_list = []
bcs_list = []
true_bcs_list = []
true_b_val_list_op = []
b_val_list_op = []
for bcond in self.bconds:
if bcond['type'] == 'operator':
b_val_op, true_b_val_op = self.compute_bconds(bcond)
b_val_list_op.append(b_val_op)
true_b_val_list_op.append(true_b_val_op)
if bcond['condition'] == 'initial':
ics_pred, true_ics = self.compute_bconds(bcond)
ics_list.append(ics_pred)
true_ics_list.append(true_ics)
else:
b_val_sep, true_b_val_sep = self.compute_bconds(bcond)
b_val_list_sep.append(b_val_sep)
true_b_val_list_sep.append(true_b_val_sep)
bcs_pred, true_bcs = self.compute_bconds(bcond)
bcs_list.append(bcs_pred)
true_bcs_list.append(true_bcs)
true_b_val_separate = torch.cat(true_b_val_list_sep)
b_val_op_separate = torch.cat(b_val_list_sep)
true_b_val_operator = torch.cat(true_b_val_list_op)
b_val_op_operator = torch.cat(b_val_list_op)
true_bcs = torch.cat(true_bcs_list)
bcs_pred = torch.cat(bcs_list)
b_val = [b_val_op_separate, b_val_op_operator]
true_b_val = [true_b_val_separate, true_b_val_operator]
true_ics = torch.cat(true_ics_list)
ics_pred = torch.cat(ics_list)
b_val = [bcs_pred, ics_pred]
true_b_val = [true_bcs, true_ics]
return b_val, true_b_val
# def apply_bconds_separate(self):
# loss_bcs = []
# loss_ics = []
# for bcond in self.bconds:
# if bcond['type'] == 'operator':
# b_val_op, true_b_val_op = self.compute_bconds(bcond)
# loss_ics.append(torch.mean(torch.square(b_val_op)))
# true_b_val_list_op.append(true_b_val_op)
# else:
# b_val_sep, true_b_val_sep = self.compute_bconds(bcond)
# b_val_list_sep.append(b_val_sep)
# true_b_val_list_sep.append(true_b_val_sep)
def apply_bnd(self, mode):
if type(mode) is int:
return self.apply_bconds_separate()
def apply_bnd(self, num_of_terms):
if num_of_terms == 3:
return self.bcs_ics()
else:
return self.apply_bconds_default()
return self.bcs()
class Loss():
def __init__(self, *args, operator, bval, true_bval, mode, weak_form):
if len(args) == 3:
self.l_bnd, self.l_bop, self.l_op = args[0], args[1], args[2]
def __init__(self, operator, bval, true_bval, lambda_op, lambda_bcs, lambda_ics, mode, weak_form, num_of_loss_term):
self.operator = operator
self.bval = bval
self.true_bval = true_bval
self.mode = mode
self.weak_form = weak_form
def l2_loss(self, lambda_bound: Union[int, float] = 10) -> torch.Tensor:
if num_of_loss_term == 2:
self.lambda_op = 1
self.lambda_bound = lambda_bcs
self.lambda_initial = 0
self.bval = [bval, torch.zeros_like(bval)]
self.true_bval = [true_bval, torch.zeros_like(true_bval)]
elif num_of_loss_term == 3:
self.lambda_op = lambda_op
self.lambda_bound = lambda_bcs
self.lambda_initial = lambda_ics
self.bval = bval
self.true_bval = true_bval
def l2_loss(self) -> torch.Tensor:
"""
Computes l2 loss.
Args:
......@@ -255,13 +321,14 @@ class Loss():
if self.mode == 'mat':
loss = torch.mean((self.operator) ** 2) + \
lambda_bound * torch.mean((self.bval - self.true_bval) ** 2)
self.lambda_bound * torch.mean((self.bval - self.true_bval) ** 2)
else:
loss = torch.sum(torch.mean((self.operator) ** 2, 0)) + \
lambda_bound * torch.sum(torch.mean((self.bval - self.true_bval) ** 2, 0))
loss = self.lambda_op * torch.sum(torch.mean((self.operator) ** 2), 0) + \
self.lambda_bound * torch.sum(torch.mean((self.bval[0] - self.true_bval[0]) ** 2, 0)) + \
self.lambda_initial * torch.sum(torch.mean((self.bval[1] - self.true_bval[1]) ** 2, 0))
return loss
def weak_loss(self, lambda_bound: Union[int, float] = 10) -> torch.Tensor:
def weak_loss(self) -> torch.Tensor:
"""
Weak solution of O/PDE problem.
Args:
......@@ -275,29 +342,12 @@ class Loss():
# we apply no boundary conditions operators if they are all None
loss = torch.sum(self.operator) + \
lambda_bound * torch.sum(torch.mean((self.bval - self.true_bval) ** 2, 0))
return loss
def l2_loss_sep(self, lambda_bound, lambda_bound_op, lambda_op) -> torch.Tensor:
"""
Args:
Returns:
"""
if self.bval == None:
return torch.sum(torch.mean((self.operator) ** 2, 0))
loss = lambda_op * torch.sum(torch.mean((self.operator) ** 2), 0) + \
lambda_bound * torch.sum(torch.mean((self.bval[0] - self.true_bval[0]) ** 2, 0)) + \
lambda_bound_op * torch.sum(torch.mean((self.bval[1] - self.true_bval[1]) ** 2, 0))
loss = self.lambda_op * torch.sum(self.operator) + \
self.lambda_bound * torch.sum(torch.mean((self.bval - self.true_bval) ** 2, 0)) + \
self.lambda_initial * torch.sum(torch.mean((self.bval[1] - self.true_bval[1]) ** 2, 0))
return loss
def compute(self, *args, lambda_bound = 10) -> Union[l2_loss, weak_loss]:
def compute(self) -> Union[l2_loss, weak_loss]:
"""
Setting the required loss calculation method.
Args:
......@@ -307,113 +357,50 @@ class Loss():
A given calculation method.
"""
if self.mode == 'mat' or self.mode == 'autograd':
if self.bval == None:
print('No bconds is not possible, returning infinite loss')
return np.inf
if self.weak_form == None or self.weak_form == []:
if len(args) == 3:
return self.l2_loss_sep(self.l_bnd, self.l_bop, self.l_op)
else:
return self.l2_loss(lambda_bound=lambda_bound)
return self.l2_loss()
else:
return self.weak_loss(lambda_bound=lambda_bound)
return self.weak_loss()
class Solution():
def __init__(self, grid: torch.Tensor, equal_cls: Union[tedeous.input_preprocessing.Equation_NN,
tedeous.input_preprocessing.Equation_mat, tedeous.input_preprocessing.Equation_autograd],
model: Union[torch.nn.Sequential, torch.Tensor], mode: str, weak_form, update_every_lambdas):
self.grid = check_device(grid)
model: Union[torch.nn.Sequential, torch.Tensor], mode: str, weak_form, update_every_lambdas,
loss_term, lambda_op,lambda_bcs, lambda_ics):
grid = check_device(grid)
equal_copy = deepcopy(equal_cls)
self.prepared_operator = equal_copy.operator_prepare()
self.prepared_bconds = equal_copy.bnd_prepare()
prepared_operator = equal_copy.operator_prepare()
prepared_bconds = equal_copy.bnd_prepare()
self.model = model.to(device_type())
self.mode = mode
self.update_every_lambdas = update_every_lambdas
self.weak_form = weak_form
if self.mode == 'NN':
self.grid_dict = Points_type(self.grid).grid_sort()
self.sorted_grid = torch.cat(list(self.grid_dict.values()))
elif self.mode == 'autograd' or self.mode == 'mat':
self.sorted_grid = self.grid
self.operator = Operator(operator = self.prepared_operator, grid = self.grid, model = self.model, mode = self.mode)
self.l_bnd, self.l_bop, self.l_op = 1, 1, 1
def pde_compute(self) -> torch.Tensor:
"""
Computes PDE residual.
Returns:
PDE residual.
"""
num_of_eq = len(self.prepared_operator)
if num_of_eq == 1:
op = self.operator.apply_op(
self.prepared_operator[0], self.sorted_grid)
else:
op_list = []
for i in range(num_of_eq):
op_list.append(self.operator.apply_op(
self.prepared_operator[i], self.sorted_grid))
op = torch.cat(op_list, 1)
return op
def weak_pde_compute(self, weak_form) -> torch.Tensor:
"""
Computes PDE residual in weak form.
Args:
weak_form: list of basis functions
Returns:
weak PDE residual.
"""
device = device_type()
if self.mode == 'NN':
grid_central = self.grid_dict['central']
elif self.mode == 'autograd':
grid_central = self.grid
op = self.pde_compute()
sol_list = []
for i in range(op.shape[-1]):
sol = op[:, i]
for func in weak_form:
sol = sol * func(grid_central).to(device).reshape(-1)
grid_central1 = torch.clone(grid_central)
for k in range(grid_central.shape[-1]):
sol, grid_central1 = integration(sol, grid_central1)
sol_list.append(sol.reshape(-1, 1))
if len(sol_list) == 1:
return sol_list[0]
else:
return torch.cat(sol_list)
@counter
def operator_compute(self):
if self.weak_form == None or self.weak_form == []:
return self.pde_compute()
else:
return self.weak_pde_compute(self.weak_form)
self.loss_term = loss_term
self.lambda_op = lambda_op
self.lambda_bcs = lambda_bcs
self.lambda_ics = lambda_ics
def evaluate(self, lambda_bound):
it = self.operator_compute.count
op = self.operator_compute()
bval, true_bval = Bounds(self.prepared_bconds,self.grid, self.model,self.mode).apply_bnd(mode=self.update_every_lambdas)
loss = Loss(operator=op,bval=bval,true_bval=true_bval,mode=self.mode, weak_form = self.weak_form)
self.operator = Operator(prepared_operator, grid, self.model, self.mode, self.weak_form)
self.boundary = Bounds(prepared_bconds, grid, self.model, self.mode, self.weak_form)
if type(self.update_every_lambdas) is int:
bnd = bval[0]
true_bnd = true_bval[0]
def evaluate(self, iter):
op = self.operator.operator_compute()
bval, true_bval = self.boundary.apply_bnd(num_of_terms=self.loss_term)
bop = bval[1]
true_bop = true_bval[1]
if self.update_every_lambdas is not None and iter % self.update_every_lambdas == 0:
self.lambda_bcs, self.lambda_ics, self.lambda_op = LambdaCompute(bval, true_bval, op, self.model).update()
lambdas = LambdaCompute(bnd - true_bnd, bop, op, self.model)
loss = Loss(self.l_bnd, self.l_bop, self.l_op, operator=op,bval=bval,true_bval=true_bval,mode=self.mode, weak_form = self.weak_form)
if it % self.update_every_lambdas == 0:
self.l_bnd, self.l_bop, self.l_op = lambdas.update()
return loss.compute(self.l_bnd, self.l_bop, self.l_op)
loss = Loss(operator=op, bval=bval, true_bval=true_bval, mode=self.mode,
weak_form=self.weak_form, lambda_op=self.lambda_op, lambda_bcs=self.lambda_bcs, lambda_ics=self.lambda_ics,
num_of_loss_term=self.loss_term)
return loss.compute(lambda_bound)
\ No newline at end of file
return loss.compute()
......@@ -187,7 +187,8 @@ class Solver():
plt.show()
plt.close()
def solve(self, lambda_bound: Union[int, float] = 10, update_every_lambdas: Union[None, int] = None, verbose: bool = False, learning_rate: float = 1e-4,
def solve(self, lambda_bound: Union[int, float] = 10, lambda_op = 1, lambda_initial = 0,
update_every_lambdas: Union[None, int] = None, verbose: bool = False, learning_rate: float = 1e-4,
eps: float = 1e-5, tmin: int = 1000, tmax: float = 1e5, nmodels: Union[int, None] = None,
name: Union[str, None] = None, abs_loss: Union[None, float] = None, use_cache: bool = True,
cache_dir: str = '../cache/', cache_verbose: bool = False, save_always: bool = False,
......@@ -195,7 +196,7 @@ class Solver():
patience: int = 5, loss_oscillation_window: int = 100, no_improvement_patience: int = 1000,
model_randomize_parameter: Union[int, float] = 0, optimizer_mode: str = 'Adam',
step_plot_print: Union[bool, int] = False, step_plot_save: Union[bool, int] = False,
image_save_dir: Union[str, None] = None, lr_decay = False, decay_rate = 1000) -> Any:
image_save_dir: Union[str, None] = None, lr_decay = False, decay_rate = 1000, loss_term = 2) -> Any:
"""
High-level interface for solving equations.
......@@ -244,13 +245,17 @@ class Solver():
Solution_class = Solution(self.grid, self.equal_cls,
self.model, self.mode,
self.weak_form, update_every_lambdas)
self.weak_form, update_every_lambdas,
loss_term, lambda_op=lambda_op,
lambda_bcs=lambda_bound, lambda_ics=lambda_initial)
else:
Solution_class = Solution(self.grid, self.equal_cls,
self.model, self.mode,
self.weak_form, update_every_lambdas)
self.weak_form, update_every_lambdas,
loss_term, lambda_op=lambda_op,
lambda_bcs=lambda_bound, lambda_ics=lambda_initial)
min_loss = Solution_class.evaluate(lambda_bound)
min_loss = Solution_class.evaluate(-1)
optimizer = self.optimizer_choice(optimizer_mode, learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer = optimizer, gamma = 0.9)
......@@ -267,7 +272,7 @@ class Solver():
def closure():
nonlocal cur_loss
optimizer.zero_grad()
loss = Solution_class.evaluate(lambda_bound)
loss = Solution_class.evaluate(t)
loss.backward()
cur_loss = loss.item()
return loss
......@@ -351,9 +356,9 @@ class Solver():
solution_print=step_plot_print,
solution_save=step_plot_save,
save_dir=image_save_dir)
if type(update_every_lambdas) is int:
l_bnd, l_bop, l_op = Solution_class.l_bnd, Solution_class.l_bop, Solution_class.l_op
print('lambda bound: {:.3e}, lambda boundary operator: {:.3e}, lambda operator: {:.3e}'.format(l_bnd, l_bop, l_op))
l_bnd, l_bop, l_op = Solution_class.lambda_bcs, Solution_class.lambda_ics, Solution_class.lambda_op
print('lambda bound: {:.3e}, lambda boundary operator: {:.3e}, lambda operator: {:.3e}'.format(l_bnd, l_bop, l_op))
t += 1
if t > tmax:
break
......
......@@ -2,22 +2,29 @@
import torch
def list_to_vector(list_):
return torch.cat([x.reshape(-1) for x in list_])
def counter(fu):
def inner(*a,**kw):
inner.count+=1
return fu(*a,**kw)
def inner(*a, **kw):
inner.count += 1
return fu(*a, **kw)
inner.count = 0
return inner
class LambdaCompute():
def __init__(self, bounds, bounds_op, operator, model):
self.bnd = bounds
self.bop = bounds_op
def __init__(self, bnd, true_bnd, operator, model):
self.bnd = bnd[0]
self.ics = bnd[1]
self.true_bnd = true_bnd[0]
self.true_ics = true_bnd[1]
self.op = operator
self.model = model
self.num_of_eq = operator.shape[-1]
def jacobian(self, f):
jac = {}
......@@ -27,12 +34,12 @@ class LambdaCompute():
grad, = torch.autograd.grad(op, param, retain_graph=True, allow_unused=True)
if grad is None:
grad = torch.tensor([0.])
jac1.append(grad.reshape(1,-1))
jac1.append(grad.reshape(1, -1))
jac[name] = torch.cat(jac1)
return jac
def compute_ntk(self, J1_dict, J2_dict):
def ntk(self, J1_dict, J2_dict):
keys = list(J1_dict.keys())
size = J1_dict[keys[0]].shape[0]
Ker = torch.zeros((size, size))
......@@ -42,22 +49,28 @@ class LambdaCompute():
K = J1 @ J2.T
Ker = Ker + K
return Ker
def trace(self, f):
J_f = self.jacobian(f)
ntk = self.ntk(J_f, J_f)
tr = torch.trace(ntk)
return tr
def update(self):
trace_bnd = self.trace(self.bnd)
trace_ics = self.trace(self.ics)
J_bnd = self.jacobian(self.bnd)
J_bop = self.jacobian(self.bop)
J_op = self.jacobian(self.op)
K_bnd = self.compute_ntk(J_bnd, J_bnd)
K_bop = self.compute_ntk(J_bop, J_bop)
K_op = self.compute_ntk(J_op, J_op)
if self.num_of_eq > 1:
trace_op = torch.zeros(self.num_of_eq)
for i in range(self.num_of_eq):
trace_op[i] = self.trace(self.op[:, i: i + 1])
trace_op = torch.mean(trace_op)
else:
trace_op = torch.trace(self.op)
trace_K = torch.trace(K_bnd) + torch.trace(K_bop) + \
torch.trace(K_op)
trace_K = trace_bnd + trace_ics + trace_op
l_bnd = trace_K / torch.trace(K_bnd)
l_bop = trace_K / torch.trace(K_bop)
l_op = trace_K / torch.trace(K_op)
l_bnd = trace_K / trace_bnd
l_ics = trace_K / trace_ics
l_op = trace_K / trace_op
return l_bnd, l_bop, l_op
\ No newline at end of file
return l_bnd, l_ics, l_op
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment