diff --git a/lbmpy/creationfunctions.py b/lbmpy/creationfunctions.py index 83a1ecfcc44777e17a59f88133ce2ce590516908..730a9b34d52ea8a19fe083dc6d1d0b829421cf0b 100644 --- a/lbmpy/creationfunctions.py +++ b/lbmpy/creationfunctions.py @@ -56,11 +56,18 @@ from dataclasses import dataclass, field, replace from typing import Union, List, Tuple, Any, Type, Iterable from warnings import warn, filterwarnings -import lbmpy.moment_transforms -import pystencils.astnodes import sympy as sp import sympy.core.numbers +import pystencils.astnodes +from pystencils import CreateKernelConfig, create_kernel +from pystencils.cache import disk_cache_no_fallback +from pystencils.typing import collate_types +from pystencils.field import Field +from pystencils.simp import sympy_cse, SimplificationStrategy + +import lbmpy.moment_transforms +from lbmpy.advanced_streaming.utility import Timestep, get_accessor from lbmpy.enums import Stencil, Method, ForceModel, CollisionSpace import lbmpy.forcemodels as forcemodels from lbmpy.fieldaccess import CollideOnlyInplaceAccessor, PdfFieldAccessor, PeriodicTwoFieldsAccessor @@ -77,12 +84,8 @@ from lbmpy.simplificationfactory import create_simplification_strategy from lbmpy.stencils import LBStencil from lbmpy.turbulence_models import add_smagorinsky_model from lbmpy.updatekernels import create_lbm_kernel, create_stream_pull_with_output_kernel -from lbmpy.advanced_streaming.utility import Timestep, get_accessor -from pystencils import CreateKernelConfig, create_kernel -from pystencils.cache import disk_cache_no_fallback -from pystencils.typing import collate_types -from pystencils.field import Field -from pystencils.simp import sympy_cse, SimplificationStrategy +from lbmpy.utils import update_dataclass_inplace + # needed for the docstring from lbmpy.methods.abstractlbmethod import LbmCollisionRule, AbstractLbMethod from lbmpy.methods.cumulantbased import CumulantBasedLbMethod @@ -496,8 +499,17 @@ class LBMOptimisation: def create_lb_function(ast=None, lbm_config=None, lbm_optimisation=None, config=None, optimization=None, **kwargs): """Creates a Python function for the LB method""" - lbm_config, lbm_optimisation, config = update_with_default_parameters(kwargs, optimization, - lbm_config, lbm_optimisation, config) + tmp_lc, tmp_lo, tmp_co = update_with_default_parameters(kwargs, optimization, lbm_config, lbm_optimisation, config) + + if lbm_config is None: + lbm_config = tmp_lc + + if lbm_optimisation is None: + lbm_optimisation = tmp_lo + + if config is None: + config = tmp_co + if lbm_config.ast is not None: ast = lbm_config.ast @@ -515,8 +527,16 @@ def create_lb_function(ast=None, lbm_config=None, lbm_optimisation=None, config= def create_lb_ast(update_rule=None, lbm_config=None, lbm_optimisation=None, config=None, optimization=None, **kwargs): """Creates a pystencils AST for the LB method""" - lbm_config, lbm_optimisation, config = update_with_default_parameters(kwargs, optimization, - lbm_config, lbm_optimisation, config) + tmp_lc, tmp_lo, tmp_co = update_with_default_parameters(kwargs, optimization, lbm_config, lbm_optimisation, config) + + if lbm_config is None: + lbm_config = tmp_lc + + if lbm_optimisation is None: + lbm_optimisation = tmp_lo + + if config is None: + config = tmp_co if lbm_config.update_rule is not None: update_rule = lbm_config.update_rule @@ -527,7 +547,8 @@ def create_lb_ast(update_rule=None, lbm_config=None, lbm_optimisation=None, conf field_types = set(fa.field.dtype for fa in update_rule.defined_symbols if isinstance(fa, Field.Access)) - config = replace(config, data_type=collate_types(field_types), ghost_layers=1) + new_config = replace(config, data_type=collate_types(field_types), ghost_layers=1) + update_dataclass_inplace(config, new_config) ast = create_kernel(update_rule, config=config) ast.method = update_rule.method @@ -540,8 +561,16 @@ def create_lb_ast(update_rule=None, lbm_config=None, lbm_optimisation=None, conf def create_lb_update_rule(collision_rule=None, lbm_config=None, lbm_optimisation=None, config=None, optimization=None, **kwargs): """Creates an update rule (list of Assignments) for a LB method that describe a full sweep""" - lbm_config, lbm_optimisation, config = update_with_default_parameters(kwargs, optimization, - lbm_config, lbm_optimisation, config) + tmp_lc, tmp_lo, tmp_co = update_with_default_parameters(kwargs, optimization, lbm_config, lbm_optimisation, config) + + if lbm_config is None: + lbm_config = tmp_lc + + if lbm_optimisation is None: + lbm_optimisation = tmp_lo + + if config is None: + config = tmp_co if lbm_config.collision_rule is not None: collision_rule = lbm_config.collision_rule @@ -596,8 +625,13 @@ def create_lb_update_rule(collision_rule=None, lbm_config=None, lbm_optimisation def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=None, config=None, optimization=None, **kwargs): """Creates a collision rule (list of Assignments) for a LB method describing the collision operator (no stream)""" - lbm_config, lbm_optimisation, config = update_with_default_parameters(kwargs, optimization, - lbm_config, lbm_optimisation, config) + tmp_lc, tmp_lo, tmp_co = update_with_default_parameters(kwargs, optimization, lbm_config, lbm_optimisation, config) + + if lbm_config is None: + lbm_config = tmp_lc + + if lbm_optimisation is None: + lbm_optimisation = tmp_lo if lbm_config.lb_method is not None: lb_method = lbm_config.lb_method @@ -698,7 +732,10 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N def create_lb_method(lbm_config=None, **params): """Creates a LB method, defined by moments/cumulants for collision space, equilibrium and relaxation rates.""" - lbm_config, _, _ = update_with_default_parameters(params, lbm_config=lbm_config) + if not lbm_config: + lbm_config, _, _ = update_with_default_parameters(params, lbm_config=lbm_config) + else: + update_with_default_parameters(params, lbm_config=lbm_config) relaxation_rates = lbm_config.relaxation_rates dim = lbm_config.stencil.D @@ -792,7 +829,8 @@ def update_with_default_parameters(params, opt_params=None, lbm_config=None, lbm for k, v in config_params.items(): if not hasattr(config, k): raise KeyError(f'{v} is not a valid kwarg. Please look in CreateKernelConfig for valid settings') - config = replace(config, **config_params) + new_config = replace(config, **config_params) + update_dataclass_inplace(config, new_config) lbm_opt_params = ['cse_pdfs', 'cse_global', 'simplification', 'pre_simplification', 'split', 'field_size', 'field_layout', 'symbolic_field', 'symbolic_temporary_field', 'builtin_periodicity'] @@ -808,7 +846,8 @@ def update_with_default_parameters(params, opt_params=None, lbm_config=None, lbm for k, v in opt_params_dict.items(): if not hasattr(lbm_optimisation, k): raise KeyError(f'{v} is not a valid kwarg. Please look in LBMOptimisation for valid settings') - lbm_optimisation = replace(lbm_optimisation, **opt_params_dict) + new_lbm_optimisation = replace(lbm_optimisation, **opt_params_dict) + update_dataclass_inplace(lbm_optimisation, new_lbm_optimisation) if params is None: params = {} @@ -819,6 +858,7 @@ def update_with_default_parameters(params, opt_params=None, lbm_config=None, lbm for k, v in params.items(): if not hasattr(lbm_config, k): raise KeyError(f'{v} is not a valid kwarg. Please look in LBMConfig for valid settings') - lbm_config = replace(lbm_config, **params) + new_config = replace(lbm_config, **params) + update_dataclass_inplace(lbm_config, new_config) return lbm_config, lbm_optimisation, config diff --git a/lbmpy/utils.py b/lbmpy/utils.py index 76919e2ad20a52ddae0881407f85b3274e368714..9047826a1db4505f130fc1b161d0cd0fab1a6d18 100644 --- a/lbmpy/utils.py +++ b/lbmpy/utils.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, fields import sympy as sp @@ -26,3 +27,10 @@ def extract_shear_relaxation_rate(collision_rule, shear_relaxation_rate): shear_relaxation_rate = eq.lhs return shear_relaxation_rate, found_symbolic_shear_relaxation + + +def update_dataclass_inplace(dataclass_to_write: dataclass, dataclass_to_read: dataclass): + """Takes to dataclasses and updates the first dataclass with the values of the second dataclass inplace + """ + for f in fields(dataclass_to_write): + setattr(dataclass_to_write, f.name, getattr(dataclass_to_read, f.name))