From e91b33d69b51aeebd401aefdb4275b78d73a5da8 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 28 Feb 2020 17:47:42 +0100 Subject: [PATCH] Allow calling Torch kernels with dynamic shape --- src/pystencils_autodiff/_autodiff.py | 2 +- .../backends/_torch_native.py | 16 +- .../framework_integration/astnodes.py | 170 +++++++++--------- 3 files changed, 101 insertions(+), 87 deletions(-) diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index 2967728..7fb78f8 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -134,7 +134,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset] for index in range(diff_read_field.index_shape[0]): - if forward_read_field in self._time_constant_fields: + if True: # Accumulate in case of time_constant_fields assignment = ps.Assignment( diff_read_field.center_vector[index], diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index d43cb90..848da23 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -59,11 +59,17 @@ def create_autograd_function(autodiff_obj, use_cuda): for field in autodiff_obj.forward_output_fields: if field.name not in kwargs: - kwargs[field.name] = torch.zeros( - field.shape, - dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), - device='cuda' if use_cuda else 'cpu') # use device of tensor - + try: + kwargs[field.name] = torch.zeros( + field.shape, + dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), + device=next(chain(args, kwargs.values())).device) + except: + shape = next(filter(lambda x: isinstance(x, torch.Tensor), chain(args, kwargs.values()))).shape + kwargs[field.name] = torch.zeros( + shape, + dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), + device=next(chain(args, kwargs.values())).device) output_tensors = OrderedDict({f.name: field_to_tensor_dict.get(f, kwargs[f.name]) for f in autodiff_obj.forward_output_fields}) diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 1199e9c..4ace780 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -25,6 +25,85 @@ from pystencils_autodiff.framework_integration.printer import FrameworkIntegrati from pystencils_autodiff.framework_integration.texture_astnodes import NativeTextureBinding +class JinjaCppFile(Node): + TEMPLATE: jinja2.Template = None + NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool) + + def __init__(self, ast_dict={}): + self.ast_dict = pystencils.utils.DotDict(ast_dict) + self.printer = FrameworkIntegrationPrinter() + Node.__init__(self) + + @property + def args(self): + """Returns all arguments/children of this node.""" + ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))] + iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable) + and not isinstance(a, str)] + return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes)) + + @property + def symbols_defined(self): + """Set of symbols which are defined by this node.""" + return set(itertools.chain.from_iterable(a.symbols_defined + for a in self.args + if isinstance(a, Node))) + + @property + def undefined_symbols(self): + """Symbols which are used but are not defined inside this node.""" + return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols + for a in self.args + if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined + + def _print(self, node): + if isinstance(node, (Node, sp.Expr)): + return self.printer(node) + else: + return str(node) + + def atoms(self, arg_type) -> Set[Any]: + """Returns a set of all descendants recursively, which are an instance of the given type.""" + if isinstance(self, arg_type): + result = {self} + else: + result = set() + + for arg in self.args: + if isinstance(arg, arg_type): + result.add(arg) + if hasattr(arg, 'atoms'): + result.update(arg.atoms(arg_type)) + return result + + @property + def is_cuda(self): + return any(f.backend == 'gpucuda' for f in self.atoms(KernelFunction)) + + def __str__(self): + assert self.TEMPLATE, f"Template of {self.__class__} must be set" + render_dict = {k: (self._print(v) + if not isinstance(v, self.NOT_PRINT_TYPES) and v is not None + else v) + if not isinstance(v, Iterable) or isinstance(v, str) + else [(self._print(a) + if not isinstance(a, self.NOT_PRINT_TYPES) and a is not None + else a) + for a in v] + for k, v in self.ast_dict.items()} + + render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)}) + render_dict.update({"globals": sorted({ + self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self) + }, key=str)}) + # self.TEMPLATE.environment = self.ENVIROMENT + + return self.TEMPLATE.render(render_dict) + + def __repr__(self): + return f'{str(self.__class__)}:\n {self.TEMPLATE.render(self.ast_dict)}' + + class DestructuringBindingsForFieldClass(Node): """ Defines all variables needed for describing a field (shape, pointer, strides) @@ -112,7 +191,15 @@ class FunctionCall(Node): @property def undefined_symbols(self) -> Set[sp.Symbol]: - return {p.symbol for p in self.kernel_function.get_parameters()} + rtn = {p.symbol for p in self.kernel_function.get_parameters()} + function = self.kernel_function + if function.backend == "gpucuda": + written_fields = function.fields_written + shape = list(written_fields)[0].spatial_shape + block_and_thread_numbers = function.indexing.call_parameters(shape) + rtn = rtn | set(itertools.chain.from_iterable( + (i.free_symbols for i in block_and_thread_numbers['block'] + block_and_thread_numbers['block']))) + return rtn def subs(self, subs_dict) -> None: for a in self.args: @@ -167,85 +254,6 @@ def generate_kernel_call(kernel_function): return block -class JinjaCppFile(Node): - TEMPLATE: jinja2.Template = None - NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool) - - def __init__(self, ast_dict={}): - self.ast_dict = pystencils.utils.DotDict(ast_dict) - self.printer = FrameworkIntegrationPrinter() - Node.__init__(self) - - @property - def args(self): - """Returns all arguments/children of this node.""" - ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))] - iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable) - and not isinstance(a, str)] - return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes)) - - @property - def symbols_defined(self): - """Set of symbols which are defined by this node.""" - return set(itertools.chain.from_iterable(a.symbols_defined - for a in self.args - if isinstance(a, Node))) - - @property - def undefined_symbols(self): - """Symbols which are used but are not defined inside this node.""" - return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols - for a in self.args - if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined - - def _print(self, node): - if isinstance(node, (Node, sp.Expr)): - return self.printer(node) - else: - return str(node) - - def atoms(self, arg_type) -> Set[Any]: - """Returns a set of all descendants recursively, which are an instance of the given type.""" - if isinstance(self, arg_type): - result = {self} - else: - result = set() - - for arg in self.args: - if isinstance(arg, arg_type): - result.add(arg) - if hasattr(arg, 'atoms'): - result.update(arg.atoms(arg_type)) - return result - - @property - def is_cuda(self): - return any(f.backend == 'gpucuda' for f in self.atoms(KernelFunction)) - - def __str__(self): - assert self.TEMPLATE, f"Template of {self.__class__} must be set" - render_dict = {k: (self._print(v) - if not isinstance(v, self.NOT_PRINT_TYPES) and v is not None - else v) - if not isinstance(v, Iterable) or isinstance(v, str) - else [(self._print(a) - if not isinstance(a, self.NOT_PRINT_TYPES) and a is not None - else a) - for a in v] - for k, v in self.ast_dict.items()} - - render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)}) - render_dict.update({"globals": sorted({ - self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self) - }, key=str)}) - # self.TEMPLATE.environment = self.ENVIROMENT - - return self.TEMPLATE.render(render_dict) - - def __repr__(self): - return f'{str(self.__class__)}:\n {self.TEMPLATE.render(self.ast_dict)}' - - class CudaErrorCheckDefinition(CustomCodeNode): def __init__(self): super().__init__(self.code, [], []) @@ -410,7 +418,7 @@ class CustomFunctionCall(JinjaCppFile): @property def undefined_symbols(self): - return set(self.ast_dict.args) + return set(self.ast_dict.args) def subs(self, subs_dict): self.ast_dict.args = list(map(lambda x: x.subs(subs_dict), self.ast_dict.args)) -- GitLab