Skip to content
Snippets Groups Projects
Commit 39209309 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'assertion-headers' into 'master'

Add assertion that headers follow the pattern /"..."/ or /<...>/

See merge request pycodegen/pystencils!137
parents de3489c4 b962a099
No related branches found
No related tags found
No related merge requests found
import re
from collections import namedtuple from collections import namedtuple
from typing import Set from typing import Set
...@@ -24,6 +25,9 @@ except ImportError: ...@@ -24,6 +25,9 @@ except ImportError:
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
HEADER_REGEX = re.compile(r'^[<"].*[">]$')
KERNCRAFT_NO_TERNARY_MODE = False KERNCRAFT_NO_TERNARY_MODE = False
...@@ -112,6 +116,9 @@ def get_headers(ast_node: Node) -> Set[str]: ...@@ -112,6 +116,9 @@ def get_headers(ast_node: Node) -> Set[str]:
if isinstance(g, Node): if isinstance(g, Node):
headers.update(get_headers(g)) headers.update(get_headers(g))
for h in headers:
assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'
return sorted(headers) return sorted(headers)
......
"""
"""
import pytest
from pystencils.astnodes import Block
from pystencils.backends.cbackend import CustomCodeNode, get_headers
def test_headers_have_quotes_or_brackets():
class ErrorNode1(CustomCodeNode):
def __init__(self):
super().__init__("", [], [])
self.headers = ["iostream"]
class ErrorNode2(CustomCodeNode):
headers = ["<iostream>", "foo"]
def __init__(self):
super().__init__("", [], [])
self.headers = ["<iostream>", "foo"]
class OkNode3(CustomCodeNode):
def __init__(self):
super().__init__("", [], [])
self.headers = ["<iostream>", '"foo"']
with pytest.raises(AssertionError, match='.* does not follow the pattern .*'):
get_headers(Block([ErrorNode1()]))
with pytest.raises(AssertionError, match='.* does not follow the pattern .*'):
get_headers(ErrorNode2())
get_headers(OkNode3())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment