Skip to content
Snippets Groups Projects
Commit 69e8ca3f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Sort fields in tf-mad

parent 6c536747
Branches
Tags
No related merge requests found
......@@ -206,17 +206,17 @@ Backward:
forward_assignments = ps.AssignmentCollection(
forward_assignments, [])
read_field_accesses = [
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)]
write_field_accesses = [a.lhs for a in forward_assignments]
read_field_accesses = sorted([
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
read_fields = {s.field for s in read_field_accesses}
write_fields = {s.field for s in write_field_accesses}
self._forward_assignments = forward_assignments
self._forward_read_accesses = read_field_accesses
self._forward_write_accesses = write_field_accesses
self._forward_input_fields = list(read_fields)
self._forward_output_fields = list(write_fields)
self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
read_field_accesses = [
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)]
......@@ -373,24 +373,28 @@ Backward:
def forward_ast_cpu(self):
if not self._forward_ast_cpu:
self._forward_ast_cpu = ps.create_kernel(self._forward_assignments, **self._kwargs)
self._forward_ast_cpu.function_name = self.op_name + '_forward_cpu'
return self._forward_ast_cpu
@property
def forward_ast_gpu(self):
if not self._forward_ast_gpu:
self._forward_ast_gpu = ps.create_kernel(self._forward_assignments, target='gpu', **self._kwargs)
self._forward_ast_gpu.function_name = self.op_name + '_forward_gpu'
return self._forward_ast_gpu
@property
def backward_ast_cpu(self):
if not self._backward_ast_cpu:
self._backward_ast_cpu = ps.create_kernel(self._backward_assignments, target='cpu', **self._kwargs)
self._backward_ast_cpu.function_name = self.op_name + '_backward_cpu'
return self._backward_ast_cpu
@property
def backward_ast_gpu(self):
if not self._backward_ast_gpu:
self._backward_ast_gpu = ps.create_kernel(self._backward_assignments, target='gpu', **self._kwargs)
self._backward_ast_gpu.function_name = self.op_name + '_backward_gpu'
return self._backward_ast_gpu
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment