Skip to content
Snippets Groups Projects
Commit e91b33d6 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow calling Torch kernels with dynamic shape

parent 32ee72e4
Branches
Tags
No related merge requests found
Pipeline #22363 failed
......@@ -134,7 +134,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
for index in range(diff_read_field.index_shape[0]):
if forward_read_field in self._time_constant_fields:
if True:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center_vector[index],
......
......@@ -59,11 +59,17 @@ def create_autograd_function(autodiff_obj, use_cuda):
for field in autodiff_obj.forward_output_fields:
if field.name not in kwargs:
kwargs[field.name] = torch.zeros(
field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device='cuda' if use_cuda else 'cpu') # use device of tensor
try:
kwargs[field.name] = torch.zeros(
field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=next(chain(args, kwargs.values())).device)
except:
shape = next(filter(lambda x: isinstance(x, torch.Tensor), chain(args, kwargs.values()))).shape
kwargs[field.name] = torch.zeros(
shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=next(chain(args, kwargs.values())).device)
output_tensors = OrderedDict({f.name:
field_to_tensor_dict.get(f, kwargs[f.name])
for f in autodiff_obj.forward_output_fields})
......
......@@ -25,6 +25,85 @@ from pystencils_autodiff.framework_integration.printer import FrameworkIntegrati
from pystencils_autodiff.framework_integration.texture_astnodes import NativeTextureBinding
class JinjaCppFile(Node):
TEMPLATE: jinja2.Template = None
NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool)
def __init__(self, ast_dict={}):
self.ast_dict = pystencils.utils.DotDict(ast_dict)
self.printer = FrameworkIntegrationPrinter()
Node.__init__(self)
@property
def args(self):
"""Returns all arguments/children of this node."""
ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))]
iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable)
and not isinstance(a, str)]
return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes))
@property
def symbols_defined(self):
"""Set of symbols which are defined by this node."""
return set(itertools.chain.from_iterable(a.symbols_defined
for a in self.args
if isinstance(a, Node)))
@property
def undefined_symbols(self):
"""Symbols which are used but are not defined inside this node."""
return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols
for a in self.args
if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined
def _print(self, node):
if isinstance(node, (Node, sp.Expr)):
return self.printer(node)
else:
return str(node)
def atoms(self, arg_type) -> Set[Any]:
"""Returns a set of all descendants recursively, which are an instance of the given type."""
if isinstance(self, arg_type):
result = {self}
else:
result = set()
for arg in self.args:
if isinstance(arg, arg_type):
result.add(arg)
if hasattr(arg, 'atoms'):
result.update(arg.atoms(arg_type))
return result
@property
def is_cuda(self):
return any(f.backend == 'gpucuda' for f in self.atoms(KernelFunction))
def __str__(self):
assert self.TEMPLATE, f"Template of {self.__class__} must be set"
render_dict = {k: (self._print(v)
if not isinstance(v, self.NOT_PRINT_TYPES) and v is not None
else v)
if not isinstance(v, Iterable) or isinstance(v, str)
else [(self._print(a)
if not isinstance(a, self.NOT_PRINT_TYPES) and a is not None
else a)
for a in v]
for k, v in self.ast_dict.items()}
render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
render_dict.update({"globals": sorted({
self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)
}, key=str)})
# self.TEMPLATE.environment = self.ENVIROMENT
return self.TEMPLATE.render(render_dict)
def __repr__(self):
return f'{str(self.__class__)}:\n {self.TEMPLATE.render(self.ast_dict)}'
class DestructuringBindingsForFieldClass(Node):
"""
Defines all variables needed for describing a field (shape, pointer, strides)
......@@ -112,7 +191,15 @@ class FunctionCall(Node):
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
return {p.symbol for p in self.kernel_function.get_parameters()}
rtn = {p.symbol for p in self.kernel_function.get_parameters()}
function = self.kernel_function
if function.backend == "gpucuda":
written_fields = function.fields_written
shape = list(written_fields)[0].spatial_shape
block_and_thread_numbers = function.indexing.call_parameters(shape)
rtn = rtn | set(itertools.chain.from_iterable(
(i.free_symbols for i in block_and_thread_numbers['block'] + block_and_thread_numbers['block'])))
return rtn
def subs(self, subs_dict) -> None:
for a in self.args:
......@@ -167,85 +254,6 @@ def generate_kernel_call(kernel_function):
return block
class JinjaCppFile(Node):
TEMPLATE: jinja2.Template = None
NOT_PRINT_TYPES = (pystencils.Field, pystencils.TypedSymbol, bool)
def __init__(self, ast_dict={}):
self.ast_dict = pystencils.utils.DotDict(ast_dict)
self.printer = FrameworkIntegrationPrinter()
Node.__init__(self)
@property
def args(self):
"""Returns all arguments/children of this node."""
ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))]
iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable)
and not isinstance(a, str)]
return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes))
@property
def symbols_defined(self):
"""Set of symbols which are defined by this node."""
return set(itertools.chain.from_iterable(a.symbols_defined
for a in self.args
if isinstance(a, Node)))
@property
def undefined_symbols(self):
"""Symbols which are used but are not defined inside this node."""
return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols
for a in self.args
if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined
def _print(self, node):
if isinstance(node, (Node, sp.Expr)):
return self.printer(node)
else:
return str(node)
def atoms(self, arg_type) -> Set[Any]:
"""Returns a set of all descendants recursively, which are an instance of the given type."""
if isinstance(self, arg_type):
result = {self}
else:
result = set()
for arg in self.args:
if isinstance(arg, arg_type):
result.add(arg)
if hasattr(arg, 'atoms'):
result.update(arg.atoms(arg_type))
return result
@property
def is_cuda(self):
return any(f.backend == 'gpucuda' for f in self.atoms(KernelFunction))
def __str__(self):
assert self.TEMPLATE, f"Template of {self.__class__} must be set"
render_dict = {k: (self._print(v)
if not isinstance(v, self.NOT_PRINT_TYPES) and v is not None
else v)
if not isinstance(v, Iterable) or isinstance(v, str)
else [(self._print(a)
if not isinstance(a, self.NOT_PRINT_TYPES) and a is not None
else a)
for a in v]
for k, v in self.ast_dict.items()}
render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
render_dict.update({"globals": sorted({
self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)
}, key=str)})
# self.TEMPLATE.environment = self.ENVIROMENT
return self.TEMPLATE.render(render_dict)
def __repr__(self):
return f'{str(self.__class__)}:\n {self.TEMPLATE.render(self.ast_dict)}'
class CudaErrorCheckDefinition(CustomCodeNode):
def __init__(self):
super().__init__(self.code, [], [])
......@@ -410,7 +418,7 @@ class CustomFunctionCall(JinjaCppFile):
@property
def undefined_symbols(self):
return set(self.ast_dict.args)
return set(self.ast_dict.args)
def subs(self, subs_dict):
self.ast_dict.args = list(map(lambda x: x.subs(subs_dict), self.ast_dict.args))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment