Skip to content
Snippets Groups Projects
Commit 69e8ca3f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Sort fields in tf-mad

parent 6c536747
No related branches found
No related tags found
No related merge requests found
...@@ -206,17 +206,17 @@ Backward: ...@@ -206,17 +206,17 @@ Backward:
forward_assignments = ps.AssignmentCollection( forward_assignments = ps.AssignmentCollection(
forward_assignments, []) forward_assignments, [])
read_field_accesses = [ read_field_accesses = sorted([
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)] a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
write_field_accesses = [a.lhs for a in forward_assignments] write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
read_fields = {s.field for s in read_field_accesses} read_fields = {s.field for s in read_field_accesses}
write_fields = {s.field for s in write_field_accesses} write_fields = {s.field for s in write_field_accesses}
self._forward_assignments = forward_assignments self._forward_assignments = forward_assignments
self._forward_read_accesses = read_field_accesses self._forward_read_accesses = read_field_accesses
self._forward_write_accesses = write_field_accesses self._forward_write_accesses = write_field_accesses
self._forward_input_fields = list(read_fields) self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
self._forward_output_fields = list(write_fields) self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
read_field_accesses = [ read_field_accesses = [
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)] a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)]
...@@ -373,24 +373,28 @@ Backward: ...@@ -373,24 +373,28 @@ Backward:
def forward_ast_cpu(self): def forward_ast_cpu(self):
if not self._forward_ast_cpu: if not self._forward_ast_cpu:
self._forward_ast_cpu = ps.create_kernel(self._forward_assignments, **self._kwargs) self._forward_ast_cpu = ps.create_kernel(self._forward_assignments, **self._kwargs)
self._forward_ast_cpu.function_name = self.op_name + '_forward_cpu'
return self._forward_ast_cpu return self._forward_ast_cpu
@property @property
def forward_ast_gpu(self): def forward_ast_gpu(self):
if not self._forward_ast_gpu: if not self._forward_ast_gpu:
self._forward_ast_gpu = ps.create_kernel(self._forward_assignments, target='gpu', **self._kwargs) self._forward_ast_gpu = ps.create_kernel(self._forward_assignments, target='gpu', **self._kwargs)
self._forward_ast_gpu.function_name = self.op_name + '_forward_gpu'
return self._forward_ast_gpu return self._forward_ast_gpu
@property @property
def backward_ast_cpu(self): def backward_ast_cpu(self):
if not self._backward_ast_cpu: if not self._backward_ast_cpu:
self._backward_ast_cpu = ps.create_kernel(self._backward_assignments, target='cpu', **self._kwargs) self._backward_ast_cpu = ps.create_kernel(self._backward_assignments, target='cpu', **self._kwargs)
self._backward_ast_cpu.function_name = self.op_name + '_backward_cpu'
return self._backward_ast_cpu return self._backward_ast_cpu
@property @property
def backward_ast_gpu(self): def backward_ast_gpu(self):
if not self._backward_ast_gpu: if not self._backward_ast_gpu:
self._backward_ast_gpu = ps.create_kernel(self._backward_assignments, target='gpu', **self._kwargs) self._backward_ast_gpu = ps.create_kernel(self._backward_assignments, target='gpu', **self._kwargs)
self._backward_ast_gpu.function_name = self.op_name + '_backward_gpu'
return self._backward_ast_gpu return self._backward_ast_gpu
@property @property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment