Merge pull request #322 from emmanvg/321-issue

add encoding option to areas where open() is used
master
Chris Lenk 2020-01-16 10:17:28 -05:00 committed by GitHub
commit 0af1f442c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 15 deletions

View File

@ -282,19 +282,21 @@ def _get_matching_dir_entries(parent_dir, auth_set, st_mode_test=None, ext=""):
return results return results
def _check_object_from_file(query, filepath, allow_custom, version): def _check_object_from_file(query, filepath, allow_custom, version, encoding):
""" """
Read a STIX object from the given file, and check it against the given Read a STIX object from the given file, and check it against the given
filters. filters.
Args: Args:
query: Iterable of filters query: Iterable of filters
filepath: Path to file to read filepath (str): Path to file to read
allow_custom: Whether to allow custom properties as well unknown allow_custom (bool): Whether to allow custom properties as well unknown
custom objects. custom objects.
version (str): If present, it forces the parser to use the version version (str): If present, it forces the parser to use the version
provided. Otherwise, the library will make the best effort based provided. Otherwise, the library will make the best effort based
on checking the "spec_version" property. on checking the "spec_version" property.
encoding (str): The encoding to use when reading a file from the
filesystem.
Returns: Returns:
The (parsed) STIX object, if the object passes the filters. If The (parsed) STIX object, if the object passes the filters. If
@ -308,7 +310,7 @@ def _check_object_from_file(query, filepath, allow_custom, version):
""" """
try: try:
with io.open(filepath, "r") as f: with io.open(filepath, "r", encoding=encoding) as f:
stix_json = json.load(f) stix_json = json.load(f)
except ValueError: # not a JSON file except ValueError: # not a JSON file
raise TypeError( raise TypeError(
@ -327,7 +329,7 @@ def _check_object_from_file(query, filepath, allow_custom, version):
return result return result
def _search_versioned(query, type_path, auth_ids, allow_custom, version): def _search_versioned(query, type_path, auth_ids, allow_custom, version, encoding):
""" """
Searches the given directory, which contains data for STIX objects of a Searches the given directory, which contains data for STIX objects of a
particular versioned type (i.e. not markings), and return any which match particular versioned type (i.e. not markings), and return any which match
@ -337,11 +339,13 @@ def _search_versioned(query, type_path, auth_ids, allow_custom, version):
query: The query to match against query: The query to match against
type_path: The directory with type-specific STIX object files type_path: The directory with type-specific STIX object files
auth_ids: Search optimization based on object ID auth_ids: Search optimization based on object ID
allow_custom: Whether to allow custom properties as well unknown allow_custom (bool): Whether to allow custom properties as well unknown
custom objects. custom objects.
version (str): If present, it forces the parser to use the version version (str): If present, it forces the parser to use the version
provided. Otherwise, the library will make the best effort based provided. Otherwise, the library will make the best effort based
on checking the "spec_version" property. on checking the "spec_version" property.
encoding (str): The encoding to use when reading a file from the
filesystem.
Returns: Returns:
A list of all matching objects A list of all matching objects
@ -375,6 +379,7 @@ def _search_versioned(query, type_path, auth_ids, allow_custom, version):
stix_obj = _check_object_from_file( stix_obj = _check_object_from_file(
query, version_path, query, version_path,
allow_custom, version, allow_custom, version,
encoding,
) )
if stix_obj: if stix_obj:
results.append(stix_obj) results.append(stix_obj)
@ -395,7 +400,7 @@ def _search_versioned(query, type_path, auth_ids, allow_custom, version):
try: try:
stix_obj = _check_object_from_file( stix_obj = _check_object_from_file(
query, id_path, allow_custom, query, id_path, allow_custom,
version, version, encoding,
) )
if stix_obj: if stix_obj:
results.append(stix_obj) results.append(stix_obj)
@ -407,7 +412,7 @@ def _search_versioned(query, type_path, auth_ids, allow_custom, version):
return results return results
def _search_markings(query, markings_path, auth_ids, allow_custom, version): def _search_markings(query, markings_path, auth_ids, allow_custom, version, encoding):
""" """
Searches the given directory, which contains markings data, and return any Searches the given directory, which contains markings data, and return any
which match the query. which match the query.
@ -416,11 +421,13 @@ def _search_markings(query, markings_path, auth_ids, allow_custom, version):
query: The query to match against query: The query to match against
markings_path: The directory with STIX markings files markings_path: The directory with STIX markings files
auth_ids: Search optimization based on object ID auth_ids: Search optimization based on object ID
allow_custom: Whether to allow custom properties as well unknown allow_custom (bool): Whether to allow custom properties as well unknown
custom objects. custom objects.
version (str): If present, it forces the parser to use the version version (str): If present, it forces the parser to use the version
provided. Otherwise, the library will make the best effort based provided. Otherwise, the library will make the best effort based
on checking the "spec_version" property. on checking the "spec_version" property.
encoding (str): The encoding to use when reading a file from the
filesystem.
Returns: Returns:
A list of all matching objects A list of all matching objects
@ -443,7 +450,7 @@ def _search_markings(query, markings_path, auth_ids, allow_custom, version):
try: try:
stix_obj = _check_object_from_file( stix_obj = _check_object_from_file(
query, id_path, allow_custom, query, id_path, allow_custom,
version, version, encoding,
) )
if stix_obj: if stix_obj:
results.append(stix_obj) results.append(stix_obj)
@ -470,13 +477,15 @@ class FileSystemStore(DataStoreMixin):
will be applied to both FileSystemSource and FileSystemSink. will be applied to both FileSystemSource and FileSystemSink.
bundlify (bool): whether to wrap objects in bundles when saving bundlify (bool): whether to wrap objects in bundles when saving
them. Default: False. them. Default: False.
encoding (str): The encoding to use when reading a file from the
filesystem.
Attributes: Attributes:
source (FileSystemSource): FileSystemSource source (FileSystemSource): FileSystemSource
sink (FileSystemSink): FileSystemSink sink (FileSystemSink): FileSystemSink
""" """
def __init__(self, stix_dir, allow_custom=None, bundlify=False): def __init__(self, stix_dir, allow_custom=None, bundlify=False, encoding='utf-8'):
if allow_custom is None: if allow_custom is None:
allow_custom_source = True allow_custom_source = True
allow_custom_sink = False allow_custom_sink = False
@ -484,7 +493,7 @@ class FileSystemStore(DataStoreMixin):
allow_custom_sink = allow_custom_source = allow_custom allow_custom_sink = allow_custom_source = allow_custom
super(FileSystemStore, self).__init__( super(FileSystemStore, self).__init__(
source=FileSystemSource(stix_dir=stix_dir, allow_custom=allow_custom_source), source=FileSystemSource(stix_dir=stix_dir, allow_custom=allow_custom_source, encoding=encoding),
sink=FileSystemSink(stix_dir=stix_dir, allow_custom=allow_custom_sink, bundlify=bundlify), sink=FileSystemSink(stix_dir=stix_dir, allow_custom=allow_custom_sink, bundlify=bundlify),
) )
@ -603,12 +612,15 @@ class FileSystemSource(DataSource):
stix_dir (str): path to directory of STIX objects stix_dir (str): path to directory of STIX objects
allow_custom (bool): Whether to allow custom STIX content to be allow_custom (bool): Whether to allow custom STIX content to be
added to the FileSystemSink. Default: True added to the FileSystemSink. Default: True
encoding (str): The encoding to use when reading a file from the
filesystem.
""" """
def __init__(self, stix_dir, allow_custom=True): def __init__(self, stix_dir, allow_custom=True, encoding='utf-8'):
super(FileSystemSource, self).__init__() super(FileSystemSource, self).__init__()
self._stix_dir = os.path.abspath(stix_dir) self._stix_dir = os.path.abspath(stix_dir)
self.allow_custom = allow_custom self.allow_custom = allow_custom
self.encoding = encoding
if not os.path.exists(self._stix_dir): if not os.path.exists(self._stix_dir):
raise ValueError("directory path for STIX data does not exist: %s" % self._stix_dir) raise ValueError("directory path for STIX data does not exist: %s" % self._stix_dir)
@ -712,11 +724,13 @@ class FileSystemSource(DataSource):
type_results = _search_markings( type_results = _search_markings(
query, type_path, auth_ids, query, type_path, auth_ids,
self.allow_custom, version, self.allow_custom, version,
self.encoding,
) )
else: else:
type_results = _search_versioned( type_results = _search_versioned(
query, type_path, auth_ids, query, type_path, auth_ids,
self.allow_custom, version, self.allow_custom, version,
self.encoding,
) )
all_data.extend(type_results) all_data.extend(type_results)

View File

@ -359,8 +359,8 @@ class MemorySource(DataSource):
return all_data return all_data
def load_from_file(self, file_path, version=None): def load_from_file(self, file_path, version=None, encoding='utf-8'):
with io.open(os.path.abspath(file_path), "r") as f: with io.open(os.path.abspath(file_path), "r", encoding=encoding) as f:
stix_data = json.load(f) stix_data = json.load(f)
_add(self, stix_data, self.allow_custom, version) _add(self, stix_data, self.allow_custom, version)