It appears we did not support the case when the Bundle contains 'utf-8'

stix2.1
Emmanuelle Vargas-Gonzalez 2018-07-09 15:26:57 -04:00
parent 70a1e9522b
commit edd7148e3c
2 changed files with 12 additions and 10 deletions

View File

@ -3,6 +3,7 @@ import json
from medallion.filters.basic_filter import BasicFilter from medallion.filters.basic_filter import BasicFilter
import pytest import pytest
from requests.models import Response from requests.models import Response
import six
from taxii2client import Collection, _filter_kwargs_to_query_params from taxii2client import Collection, _filter_kwargs_to_query_params
import stix2 import stix2
@ -21,8 +22,8 @@ class MockTAXIICollectionEndpoint(Collection):
def add_objects(self, bundle): def add_objects(self, bundle):
self._verify_can_write() self._verify_can_write()
if isinstance(bundle, str): if isinstance(bundle, six.string_types):
bundle = json.loads(bundle) bundle = json.loads(bundle, encoding='utf-8')
for object in bundle.get("objects", []): for object in bundle.get("objects", []):
self.objects.append(object) self.objects.append(object)
@ -30,7 +31,7 @@ class MockTAXIICollectionEndpoint(Collection):
self._verify_can_read() self._verify_can_read()
query_params = _filter_kwargs_to_query_params(filter_kwargs) query_params = _filter_kwargs_to_query_params(filter_kwargs)
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = json.loads(query_params) query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {}) full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter( objs = full_filter.process_filter(
self.objects, self.objects,
@ -44,13 +45,13 @@ class MockTAXIICollectionEndpoint(Collection):
resp.status_code = 404 resp.status_code = 404
resp.raise_for_status() resp.raise_for_status()
def get_object(self, id, version=None): def get_object(self, id, version=None, accept=''):
self._verify_can_read() self._verify_can_read()
query_params = None query_params = None
if version: if version:
query_params = _filter_kwargs_to_query_params({"version": version}) query_params = _filter_kwargs_to_query_params({"version": version})
if query_params: if query_params:
query_params = json.loads(query_params) query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {}) full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter( objs = full_filter.process_filter(
self.objects, self.objects,

View File

@ -3,6 +3,7 @@ import json
from medallion.filters.basic_filter import BasicFilter from medallion.filters.basic_filter import BasicFilter
import pytest import pytest
from requests.models import Response from requests.models import Response
import six
from taxii2client import Collection, _filter_kwargs_to_query_params from taxii2client import Collection, _filter_kwargs_to_query_params
import stix2 import stix2
@ -21,8 +22,8 @@ class MockTAXIICollectionEndpoint(Collection):
def add_objects(self, bundle): def add_objects(self, bundle):
self._verify_can_write() self._verify_can_write()
if isinstance(bundle, str): if isinstance(bundle, six.string_types):
bundle = json.loads(bundle) bundle = json.loads(bundle, encoding='utf-8')
for object in bundle.get("objects", []): for object in bundle.get("objects", []):
self.objects.append(object) self.objects.append(object)
@ -30,7 +31,7 @@ class MockTAXIICollectionEndpoint(Collection):
self._verify_can_read() self._verify_can_read()
query_params = _filter_kwargs_to_query_params(filter_kwargs) query_params = _filter_kwargs_to_query_params(filter_kwargs)
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = json.loads(query_params) query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {}) full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter( objs = full_filter.process_filter(
self.objects, self.objects,
@ -44,13 +45,13 @@ class MockTAXIICollectionEndpoint(Collection):
resp.status_code = 404 resp.status_code = 404
resp.raise_for_status() resp.raise_for_status()
def get_object(self, id, version=None): def get_object(self, id, version=None, accept=''):
self._verify_can_read() self._verify_can_read()
query_params = None query_params = None
if version: if version:
query_params = _filter_kwargs_to_query_params({"version": version}) query_params = _filter_kwargs_to_query_params({"version": version})
if query_params: if query_params:
query_params = json.loads(query_params) query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {}) full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter( objs = full_filter.process_filter(
self.objects, self.objects,