Skip to content
Snippets Groups Projects
Commit 57c258c0 authored by Frederik Hennig's avatar Frederik Hennig Committed by Markus Holzer
Browse files

Creationfunctions now update LBMConfig. Small fix to `GenericDiscreteEquilibrium`.

parent 2b26d636
No related branches found
No related tags found
1 merge request!122Creationfunctions now update LBMConfig. Small fix to `GenericDiscreteEquilibrium`.
...@@ -39,6 +39,9 @@ of the generated code is specified. ...@@ -39,6 +39,9 @@ of the generated code is specified.
This step compiles the AST into an executable function, either for CPU or GPUs. This function This step compiles the AST into an executable function, either for CPU or GPUs. This function
behaves like a normal Python function and runs one LBM time step. behaves like a normal Python function and runs one LBM time step.
Each stage (apart from *Function*) also adds its result to the given `LBMConfig` object. The `LBMConfig`
thus coalesces all information defining the LBM kernel.
The function :func:`create_lb_function` runs the whole pipeline, the other functions in this module The function :func:`create_lb_function` runs the whole pipeline, the other functions in this module
execute this pipeline only up to a certain step. Each function optionally also takes the result of the previous step. execute this pipeline only up to a certain step. Each function optionally also takes the result of the previous step.
...@@ -525,11 +528,12 @@ def create_lb_ast(update_rule=None, lbm_config=None, lbm_optimisation=None, conf ...@@ -525,11 +528,12 @@ def create_lb_ast(update_rule=None, lbm_config=None, lbm_optimisation=None, conf
field_types = set(fa.field.dtype for fa in update_rule.defined_symbols if isinstance(fa, Field.Access)) field_types = set(fa.field.dtype for fa in update_rule.defined_symbols if isinstance(fa, Field.Access))
config = replace(config, data_type=collate_types(field_types), ghost_layers=1) config = replace(config, data_type=collate_types(field_types), ghost_layers=1)
res = create_kernel(update_rule, config=config) ast = create_kernel(update_rule, config=config)
res.method = update_rule.method ast.method = update_rule.method
res.update_rule = update_rule ast.update_rule = update_rule
return res lbm_config.ast = ast
return ast
@disk_cache_no_fallback @disk_cache_no_fallback
...@@ -568,8 +572,9 @@ def create_lb_update_rule(collision_rule=None, lbm_config=None, lbm_optimisation ...@@ -568,8 +572,9 @@ def create_lb_update_rule(collision_rule=None, lbm_config=None, lbm_optimisation
dst_field = src_field.new_field_with_different_name(lbm_config.temporary_field_name) dst_field = src_field.new_field_with_different_name(lbm_config.temporary_field_name)
kernel_type = lbm_config.kernel_type kernel_type = lbm_config.kernel_type
update_rule = None
if kernel_type == 'stream_pull_only': if kernel_type == 'stream_pull_only':
return create_stream_pull_with_output_kernel(lb_method, src_field, dst_field, lbm_config.output) update_rule = create_stream_pull_with_output_kernel(lb_method, src_field, dst_field, lbm_config.output)
else: else:
if kernel_type == 'default_stream_collide': if kernel_type == 'default_stream_collide':
if lbm_config.streaming_pattern == 'pull' and any(lbm_optimisation.builtin_periodicity): if lbm_config.streaming_pattern == 'pull' and any(lbm_optimisation.builtin_periodicity):
...@@ -582,7 +587,10 @@ def create_lb_update_rule(collision_rule=None, lbm_config=None, lbm_optimisation ...@@ -582,7 +587,10 @@ def create_lb_update_rule(collision_rule=None, lbm_config=None, lbm_optimisation
accessor = kernel_type accessor = kernel_type
else: else:
raise ValueError("Invalid value of parameter 'kernel_type'", lbm_config.kernel_type) raise ValueError("Invalid value of parameter 'kernel_type'", lbm_config.kernel_type)
return create_lbm_kernel(collision_rule, src_field, dst_field, accessor) update_rule = create_lbm_kernel(collision_rule, src_field, dst_field, accessor)
lbm_config.update_rule = update_rule
return update_rule
@disk_cache_no_fallback @disk_cache_no_fallback
...@@ -680,6 +688,7 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N ...@@ -680,6 +688,7 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N
if lbm_optimisation.cse_global: if lbm_optimisation.cse_global:
collision_rule = sympy_cse(collision_rule) collision_rule = sympy_cse(collision_rule)
lbm_config.collision_rule = collision_rule
return collision_rule return collision_rule
...@@ -756,6 +765,7 @@ def create_lb_method(lbm_config=None, **params): ...@@ -756,6 +765,7 @@ def create_lb_method(lbm_config=None, **params):
if lbm_config.entropic: if lbm_config.entropic:
method.set_conserved_moments_relaxation_rate(relaxation_rates[0]) method.set_conserved_moments_relaxation_rate(relaxation_rates[0])
lbm_config.method = method
return method return method
......
...@@ -47,6 +47,10 @@ class GenericDiscreteEquilibrium(AbstractEquilibrium): ...@@ -47,6 +47,10 @@ class GenericDiscreteEquilibrium(AbstractEquilibrium):
deviation_only=False): deviation_only=False):
super().__init__(dim=stencil.D) super().__init__(dim=stencil.D)
if len(equilibrium_pdfs) != stencil.Q:
raise ValueError(f"Wrong number of PDFs."
f"On the {stencil} stencil, exactly {stencil.Q} populations must be passed!")
self._stencil = stencil self._stencil = stencil
self._pdfs = tuple(equilibrium_pdfs) self._pdfs = tuple(equilibrium_pdfs)
self._zeroth_order_moment_symbol = zeroth_order_moment_symbol self._zeroth_order_moment_symbol = zeroth_order_moment_symbol
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment