Skip to content
Snippets Groups Projects
Commit fd37cb09 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'mr_json_serializer' into 'master'

JSON Serializer for pystencils config

See merge request pycodegen/pystencils!338
parents 32de591c d948ebdd
Branches
Tags release/1.3.1
No related merge requests found
...@@ -65,6 +65,7 @@ try: ...@@ -65,6 +65,7 @@ try:
except ImportError: except ImportError:
add_path_to_ignore('pystencils/runhelper') add_path_to_ignore('pystencils/runhelper')
collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils_tests/test_parameterstudy.py")] collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils_tests/test_parameterstudy.py")]
collect_ignore += [os.path.join(SCRIPT_FOLDER, "pystencils_tests/test_json_serializer.py")]
try: try:
import islpy import islpy
......
import socket import socket
import time import time
from types import MappingProxyType
from typing import Dict, Iterator, Sequence from typing import Dict, Iterator, Sequence
import blitzdb import blitzdb
import six
from blitzdb.backends.file.backend import serializer_classes
from blitzdb.backends.file.utils import JsonEncoder
from pystencils.cpu.cpujit import get_compiler_config from pystencils.cpu.cpujit import get_compiler_config
from pystencils import CreateKernelConfig, Target, Backend, Field
import json
import sympy as sp
from pystencils.typing import BasicType
class PystencilsJsonEncoder(JsonEncoder):
def default(self, obj):
if isinstance(obj, CreateKernelConfig):
return obj.__dict__
if isinstance(obj, (sp.Float, sp.Rational)):
return float(obj)
if isinstance(obj, sp.Integer):
return int(obj)
if isinstance(obj, (BasicType, MappingProxyType)):
return str(obj)
if isinstance(obj, (Target, Backend, sp.Symbol)):
return obj.name
if isinstance(obj, Field):
return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \
f"dtype = {str(obj.dtype)}, layout = {obj.layout}, shape = {obj.shape}, " \
f"strides = {obj.strides})"
return JsonEncoder.default(self, obj)
class PystencilsJsonSerializer(object):
@classmethod
def serialize(cls, data):
if six.PY3:
if isinstance(data, bytes):
return json.dumps(data.decode('utf-8'), cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
@classmethod
def deserialize(cls, data):
if six.PY3:
return json.loads(data.decode('utf-8'))
else:
return json.loads(data.decode('utf-8'))
class Database: class Database:
...@@ -46,7 +96,7 @@ class Database: ...@@ -46,7 +96,7 @@ class Database:
class SimulationResult(blitzdb.Document): class SimulationResult(blitzdb.Document):
pass pass
def __init__(self, file: str) -> None: def __init__(self, file: str, serializer_info: tuple = None) -> None:
if file.startswith("mongo://"): if file.startswith("mongo://"):
from pymongo import MongoClient from pymongo import MongoClient
db_name = file[len("mongo://"):] db_name = file[len("mongo://"):]
...@@ -57,6 +107,10 @@ class Database: ...@@ -57,6 +107,10 @@ class Database:
self.backend.autocommit = True self.backend.autocommit = True
if serializer_info:
serializer_classes.update({serializer_info[0]: serializer_info[1]})
self.backend.load_config({'serializer_class': serializer_info[0]}, True)
def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None: def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None:
"""Stores a simulation result in the database. """Stores a simulation result in the database.
...@@ -146,10 +200,15 @@ class Database: ...@@ -146,10 +200,15 @@ class Database:
'cpuCompilerConfig': get_compiler_config(), 'cpuCompilerConfig': get_compiler_config(),
} }
try: try:
from git import Repo, InvalidGitRepositoryError from git import Repo
except ImportError:
return result
try:
from git import InvalidGitRepositoryError
repo = Repo(search_parent_directories=True) repo = Repo(search_parent_directories=True)
result['git_hash'] = str(repo.head.commit) result['git_hash'] = str(repo.head.commit)
except (ImportError, InvalidGitRepositoryError): except InvalidGitRepositoryError:
pass pass
return result return result
......
...@@ -9,6 +9,7 @@ from time import sleep ...@@ -9,6 +9,7 @@ from time import sleep
from typing import Any, Callable, Dict, Optional, Sequence, Tuple from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from pystencils.runhelper import Database from pystencils.runhelper import Database
from pystencils.runhelper.db import PystencilsJsonSerializer
from pystencils.utils import DotDict from pystencils.utils import DotDict
ParameterDict = Dict[str, Any] ParameterDict = Dict[str, Any]
...@@ -54,10 +55,11 @@ class ParameterStudy: ...@@ -54,10 +55,11 @@ class ParameterStudy:
Run = namedtuple("Run", ['parameter_dict', 'weight']) Run = namedtuple("Run", ['parameter_dict', 'weight'])
def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (), def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (),
database_connector: str = './db') -> None: database_connector: str = './db',
serializer_info: tuple = ('pystencils_serializer', PystencilsJsonSerializer)) -> None:
self.runs = list(runs) self.runs = list(runs)
self.run_function = run_function self.run_function = run_function
self.db = Database(database_connector) self.db = Database(database_connector, serializer_info)
def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None: def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None:
"""Schedule a dictionary of parameters to run in this parameter study. """Schedule a dictionary of parameters to run in this parameter study.
......
"""
Test the pystencils-specific JSON encoder and serializer as used in the Database class.
"""
import numpy as np
import tempfile
from pystencils.config import CreateKernelConfig
from pystencils import Target, Field
from pystencils.runhelper.db import Database, PystencilsJsonSerializer
def test_json_serializer():
dtype = np.float32
index_arr = np.zeros((3,), dtype=dtype)
indexed_field = Field.create_from_numpy_array('index', index_arr)
# create pystencils config
config = CreateKernelConfig(target=Target.CPU, function_name='dummy_config', data_type=dtype,
index_fields=[indexed_field])
# create dummy database
temp_dir = tempfile.TemporaryDirectory()
db = Database(file=temp_dir.name, serializer_info=('pystencils_serializer', PystencilsJsonSerializer))
db.save(params={'config': config}, result={'test': 'dummy'})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment