From 9b8ec60ec4e230789cf846da5797556570b7b336 Mon Sep 17 00:00:00 2001
From: Philipp Suffa <philipp.suffa@fau.de>
Date: Fri, 27 Sep 2024 15:07:28 +0200
Subject: [PATCH] Small change to reduce code duplication

---
 src/lbmpy/macroscopic_value_kernels.py | 54 +++++++++-----------------
 1 file changed, 19 insertions(+), 35 deletions(-)

diff --git a/src/lbmpy/macroscopic_value_kernels.py b/src/lbmpy/macroscopic_value_kernels.py
index 0f86da9a..ae99bbb3 100644
--- a/src/lbmpy/macroscopic_value_kernels.py
+++ b/src/lbmpy/macroscopic_value_kernels.py
@@ -11,25 +11,29 @@ from lbmpy.relaxationrates import get_shear_relaxation_rate
 from lbmpy.utils import second_order_moment_tensor
 
 
-def pdf_initialization_assignments(lb_method, density, velocity, pdfs,
-                                   streaming_pattern='pull', previous_timestep=Timestep.BOTH,
-                                   set_pre_collision_pdfs=False):
-    """Assignments to initialize the pdf field with equilibrium"""
+def get_field_accesses(lb_method, pdfs, streaming_pattern, previous_timestep, pre_collision_pdfs):
     if isinstance(pdfs, Field):
         accessor = get_accessor(streaming_pattern, previous_timestep)
-        if set_pre_collision_pdfs:
+        if pre_collision_pdfs:
             field_accesses = accessor.read(pdfs, lb_method.stencil)
         else:
             field_accesses = accessor.write(pdfs, lb_method.stencil)
-    elif streaming_pattern == 'pull' and not set_pre_collision_pdfs:
+    elif streaming_pattern == 'pull' and not pre_collision_pdfs:
         field_accesses = pdfs
     else:
         raise ValueError("Invalid value of pdfs: A PDF field reference is required to derive "
                          + f"initialization assignments for streaming pattern {streaming_pattern}.")
+    return field_accesses
+
+
+def pdf_initialization_assignments(lb_method, density, velocity, pdfs,
+                                   streaming_pattern='pull', previous_timestep=Timestep.BOTH,
+                                   set_pre_collision_pdfs=False):
+    """Assignments to initialize the pdf field with equilibrium"""
+    field_accesses = get_field_accesses(lb_method, pdfs, streaming_pattern, previous_timestep, set_pre_collision_pdfs)
 
     if isinstance(density, Field):
         density = density.center
-
     if isinstance(velocity, Field):
         velocity = velocity.center_vector
 
@@ -44,17 +48,9 @@ def pdf_initialization_assignments(lb_method, density, velocity, pdfs,
 def macroscopic_values_getter(lb_method, density, velocity, pdfs,
                               streaming_pattern='pull', previous_timestep=Timestep.BOTH,
                               use_pre_collision_pdfs=False):
-    if isinstance(pdfs, Field):
-        accessor = get_accessor(streaming_pattern, previous_timestep)
-        if use_pre_collision_pdfs:
-            field_accesses = accessor.read(pdfs, lb_method.stencil)
-        else:
-            field_accesses = accessor.write(pdfs, lb_method.stencil)
-    elif streaming_pattern == 'pull' and not use_pre_collision_pdfs:
-        field_accesses = pdfs
-    else:
-        raise ValueError("Invalid value of pdfs: A PDF field reference is required to derive "
-                         + f"getter assignments for streaming pattern {streaming_pattern}.")
+
+    field_accesses = get_field_accesses(lb_method, pdfs, streaming_pattern, previous_timestep, use_pre_collision_pdfs)
+
     cqc = lb_method.conserved_quantity_computation
     assert not (velocity is None and density is None)
     output_spec = {}
@@ -68,21 +64,10 @@ def macroscopic_values_getter(lb_method, density, velocity, pdfs,
 macroscopic_values_setter = pdf_initialization_assignments
 
 
-
-def strain_rate_tensor_getter(lb_method, strain_rate_tensor, pdfs,streaming_pattern='pull',
+def strain_rate_tensor_getter(lb_method, strain_rate_tensor, pdfs, streaming_pattern='pull',
                               previous_timestep=Timestep.BOTH, use_pre_collision_pdfs=False):
 
-    if isinstance(pdfs, Field):
-        accessor = get_accessor(streaming_pattern, previous_timestep)
-        if use_pre_collision_pdfs:
-            field_accesses = accessor.read(pdfs, lb_method.stencil)
-        else:
-            field_accesses = accessor.write(pdfs, lb_method.stencil)
-    elif streaming_pattern == 'pull' and not use_pre_collision_pdfs:
-        field_accesses = pdfs
-    else:
-        raise ValueError("Invalid value of pdfs: A PDF field reference is required to derive "
-                         + f"getter assignments for streaming pattern {streaming_pattern}.")
+    field_accesses = get_field_accesses(lb_method, pdfs, streaming_pattern, previous_timestep, use_pre_collision_pdfs)
 
     if isinstance(strain_rate_tensor, Field):
         strain_rate_tensor = strain_rate_tensor.center_vector
@@ -91,16 +76,15 @@ def strain_rate_tensor_getter(lb_method, strain_rate_tensor, pdfs,streaming_patt
     equilibrium = lb_method.equilibrium_distribution
     rho = equilibrium.density if equilibrium.compressible else equilibrium.background_density
 
-
     f_neq = sp.Matrix([field_accesses[i] for i in range(lb_method.stencil.Q)]) - lb_method.get_equilibrium_terms()
     pi = second_order_moment_tensor(f_neq, lb_method.stencil)
-    strain_rate_tensor_equ = - 1.5 * (omega_s/rho) * pi
+    strain_rate_tensor_equ = - 1.5 * (omega_s / rho) * pi
 
-    assignments = [Assignment(strain_rate_tensor[i * lb_method.stencil.D + j], strain_rate_tensor_equ[i , j]) for i in range(lb_method.stencil.D) for j in range(lb_method.stencil.D)]
+    assignments = [Assignment(strain_rate_tensor[i * lb_method.stencil.D + j], strain_rate_tensor_equ[i, j])
+                   for i in range(lb_method.stencil.D) for j in range(lb_method.stencil.D)]
     return assignments
 
 
-
 def compile_macroscopic_values_getter(lb_method, output_quantities, pdf_arr=None,
                                       ghost_layers=1, iteration_slice=None,
                                       field_layout='numpy', target=Target.CPU,
-- 
GitLab