Skip to content
Snippets Groups Projects
Commit 70aee528 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Rename KernelFunctionCall -> FunctionCall and regard fields_accessed in

called functions
parent 1bcdc628
No related branches found
No related tags found
No related merge requests found
......@@ -42,12 +42,10 @@ class DestructuringBindingsForFieldClass(Node):
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
"""Set of Field instances: fields which are accessed inside this kernel function"""
# TODO: remove when texture support is merged into pystencils
try:
from pystencils.interpolation_astnodes import InterpolatorAccess
return set(o.field for o in self.atoms(ResolvedFieldAccess) | self.atoms(InterpolatorAccess))
except ImportError:
return set(o.field for o in self.atoms(ResolvedFieldAccess))
from pystencils.interpolation_astnodes import InterpolatorAccess
return set(o.field for o in self.atoms(ResolvedFieldAccess) | self.atoms(InterpolatorAccess)) \
| set(itertools.chain.from_iterable((k.kernel_function.fields_accessed
for k in self.atoms(FunctionCall))))
def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__()
......@@ -94,7 +92,7 @@ class DestructuringBindingsForFieldClass(Node):
return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
class KernelFunctionCall(Node):
class FunctionCall(Node):
"""
AST nodes representing a call of a :class:`pystencils.astnodes.KernelFunction`
"""
......@@ -104,11 +102,11 @@ class KernelFunctionCall(Node):
@property
def args(self):
return [self.kernel_function]
return [p.symbol for p in self.kernel_function.get_parameters()]
@property
def symbols_defined(self) -> Set[sp.Symbol]:
return set()
return {}
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
......@@ -153,15 +151,15 @@ def generate_kernel_call(kernel_function):
CudaErrorCheck(),
*texture_uploads,
CudaErrorCheck(),
KernelFunctionCall(kernel_function),
FunctionCall(kernel_function),
CudaErrorCheck(),
])
elif kernel_function.backend == 'gpucuda':
return pystencils.astnodes.Block([CudaErrorCheck(),
KernelFunctionCall(kernel_function),
FunctionCall(kernel_function),
CudaErrorCheck()])
else:
return pystencils.astnodes.Block([KernelFunctionCall(kernel_function)])
return pystencils.astnodes.Block([FunctionCall(kernel_function)])
return block
......@@ -238,8 +236,8 @@ class CudaErrorCheckDefinition(CustomCodeNode):
function_name = 'gpuErrchk'
code = """
#ifdef __GNUC__
#define gpuErrchk(ans) { gpuAssert((ans), __PRETTY_FUNCTION__, __FILE__, __LINE__); }
# ifdef __GNUC__
# define gpuErrchk(ans) { gpuAssert((ans), __PRETTY_FUNCTION__, __FILE__, __LINE__); }
inline static void gpuAssert(cudaError_t code, const char* function, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
......@@ -248,8 +246,8 @@ inline static void gpuAssert(cudaError_t code, const char* function, const char
if (abort) exit(code);
}
}
#else
#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
# else
# define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline static void gpuAssert(cudaError_t code, const char* function, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
......@@ -258,7 +256,7 @@ inline static void gpuAssert(cudaError_t code, const char* function, const char
if (abort) exit(code);
}
}
#endif
# endif
"""
headers = ['<cuda.h>']
......
......@@ -37,7 +37,7 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c', with_globals=False)
return prefix + kernel_code
def _print_KernelFunctionCall(self, node):
def _print_FunctionCall(self, node):
function = node.kernel_function
parameters = function.get_parameters()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment