Add some simple context-sensitive constant canonicalization, used

as part of canonicalizing comparison expressions.  This
required adding a new comparison expression transformer callback
for leaf-node comparison expression objects, and updating all
existing comparison transformers to work (it affected all/most
of them).  The observation expression transformer which actually
does the comparison canonicalization was updated to also perform
this special canonicalization step.
pull/1/head
Michael Chisholm 2020-08-12 19:28:35 -04:00
parent 311fe38cea
commit 5d6c7d8c8a
3 changed files with 310 additions and 47 deletions

View File

@ -4,6 +4,9 @@ Transformation utilities for STIX pattern comparison expressions.
import functools
import itertools
from stix2.equivalence.patterns.transform import Transformer
from stix2.equivalence.patterns.transform.specials import (
windows_reg_key, ipv4_addr, ipv6_addr
)
from stix2.patterns import (
_BooleanExpression, _ComparisonExpression, AndBooleanExpression,
OrBooleanExpression, ParentheticalExpression
@ -57,6 +60,7 @@ class ComparisonExpressionTransformer(Transformer):
Specifically, subclasses can implement methods:
"transform_or" for OR nodes
"transform_and" for AND nodes
"transform_comparison" for plain comparison nodes (<prop> <op> <value>)
"transform_default" for both types of nodes
"transform_default" is a fallback, if a type-specific callback is not
@ -69,8 +73,7 @@ class ComparisonExpressionTransformer(Transformer):
method: a 2-tuple with the transformed AST and a boolean for change
detection. See doc for the superclass' method.
This process currently silently drops parenthetical nodes, and "leaf"
comparison expression nodes are left unchanged.
This process currently silently drops parenthetical nodes.
"""
def transform(self, ast):
@ -88,9 +91,7 @@ class ComparisonExpressionTransformer(Transformer):
changed = True
elif isinstance(ast, _ComparisonExpression):
# Terminates recursion; we don't change these nodes
result = ast
changed = False
result, changed = self.__dispatch_transform(ast)
elif isinstance(ast, ParentheticalExpression):
# Drop these
@ -116,6 +117,11 @@ class ComparisonExpressionTransformer(Transformer):
elif isinstance(ast, OrBooleanExpression):
meth = getattr(self, "transform_or", self.transform_default)
elif isinstance(ast, _ComparisonExpression):
meth = getattr(
self, "transform_comparison", self.transform_default
)
else:
meth = self.transform_default
@ -142,7 +148,7 @@ class OrderDedupeTransformer(
A or A => A
"""
def transform_default(self, ast):
def __transform(self, ast):
"""
Sort/dedupe children. AND and OR can be treated identically.
@ -172,6 +178,12 @@ class OrderDedupeTransformer(
return ast, changed
def transform_or(self, ast):
return self.__transform(ast)
def transform_and(self, ast):
return self.__transform(ast)
class FlattenTransformer(ComparisonExpressionTransformer):
"""
@ -182,7 +194,7 @@ class FlattenTransformer(ComparisonExpressionTransformer):
(A) => A
"""
def transform_default(self, ast):
def __transform(self, ast):
"""
Flatten children. AND and OR can be treated mostly identically. The
little difference is that we can absorb AND children if we're an AND
@ -192,14 +204,14 @@ class FlattenTransformer(ComparisonExpressionTransformer):
:return: The same AST node, but with flattened children
"""
if isinstance(ast, _BooleanExpression) and len(ast.operands) == 1:
changed = False
if len(ast.operands) == 1:
# Replace an AND/OR with one child, with the child itself.
ast = ast.operands[0]
changed = True
else:
flat_operands = []
changed = False
for operand in ast.operands:
if isinstance(operand, _BooleanExpression) \
and ast.operator == operand.operator:
@ -213,6 +225,12 @@ class FlattenTransformer(ComparisonExpressionTransformer):
return ast, changed
def transform_or(self, ast):
return self.__transform(ast)
def transform_and(self, ast):
return self.__transform(ast)
class AbsorptionTransformer(
ComparisonExpressionTransformer
@ -224,57 +242,62 @@ class AbsorptionTransformer(
A or (A and B) = A
"""
def transform_default(self, ast):
def __transform(self, ast):
changed = False
if isinstance(ast, _BooleanExpression):
secondary_op = "AND" if ast.operator == "OR" else "OR"
secondary_op = "AND" if ast.operator == "OR" else "OR"
to_delete = set()
to_delete = set()
# Check i (child1) against j to see if we can delete j.
for i, child1 in enumerate(ast.operands):
if i in to_delete:
# Check i (child1) against j to see if we can delete j.
for i, child1 in enumerate(ast.operands):
if i in to_delete:
continue
for j, child2 in enumerate(ast.operands):
if i == j or j in to_delete:
continue
for j, child2 in enumerate(ast.operands):
if i == j or j in to_delete:
continue
# We're checking if child1 is contained in child2, so
# child2 has to be a compound object, not just a simple
# comparison expression. We also require the right operator
# for child2: "AND" if ast is "OR" and vice versa.
if not isinstance(child2, _BooleanExpression) \
or child2.operator != secondary_op:
continue
# We're checking if child1 is contained in child2, so
# child2 has to be a compound object, not just a simple
# comparison expression. We also require the right operator
# for child2: "AND" if ast is "OR" and vice versa.
if not isinstance(child2, _BooleanExpression) \
or child2.operator != secondary_op:
continue
# The simple check: is child1 contained in child2?
if iter_in(
child1, child2.operands, comparison_expression_cmp
):
to_delete.add(j)
# The simple check: is child1 contained in child2?
if iter_in(
child1, child2.operands, comparison_expression_cmp
# A more complicated check: does child1 occur in child2
# in a "flattened" form?
elif child1.operator == child2.operator:
if all(
iter_in(
child1_operand, child2.operands,
comparison_expression_cmp
)
for child1_operand in child1.operands
):
to_delete.add(j)
# A more complicated check: does child1 occur in child2
# in a "flattened" form?
elif child1.operator == child2.operator:
if all(
iter_in(
child1_operand, child2.operands,
comparison_expression_cmp
)
for child1_operand in child1.operands
):
to_delete.add(j)
if to_delete:
changed = True
if to_delete:
changed = True
for i in reversed(sorted(to_delete)):
del ast.operands[i]
for i in reversed(sorted(to_delete)):
del ast.operands[i]
return ast, changed
def transform_or(self, ast):
return self.__transform(ast)
def transform_and(self, ast):
return self.__transform(ast)
class DNFTransformer(ComparisonExpressionTransformer):
"""
@ -329,3 +352,26 @@ class DNFTransformer(ComparisonExpressionTransformer):
result = ast
return result, changed
class SpecialValueCanonicalization(ComparisonExpressionTransformer):
"""
Try to find particular leaf-node comparison expressions whose rhs (i.e. the
constant) can be canonicalized. This is an idiosyncratic transformation
based on some ideas people had for context-sensitive semantic equivalence
in constant values.
"""
def transform_comparison(self, ast):
if ast.lhs.object_type_name == "windows-registry-key":
windows_reg_key(ast)
elif ast.lhs.object_type_name == "ipv4-addr":
ipv4_addr(ast)
elif ast.lhs.object_type_name == "ipv6-addr":
ipv6_addr(ast)
# Hard-code False here since this particular canonicalization is never
# worth doing more than once. I think it's okay to pretend nothing has
# changed.
return ast, False

View File

@ -15,7 +15,8 @@ from stix2.equivalence.patterns.transform.comparison import (
FlattenTransformer as CFlattenTransformer,
OrderDedupeTransformer as COrderDedupeTransformer,
AbsorptionTransformer as CAbsorptionTransformer,
DNFTransformer as CDNFTransformer
DNFTransformer as CDNFTransformer,
SpecialValueCanonicalization
)
from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in
from stix2.equivalence.patterns.compare.observation import observation_expression_cmp
@ -473,9 +474,10 @@ class CanonicalizeComparisonExpressionsTransformer(
simplify = ChainTransformer(comp_flatten, comp_order, comp_absorb)
settle_simplify = SettleTransformer(simplify)
comp_special = SpecialValueCanonicalization()
comp_dnf = CDNFTransformer()
self.__comp_canonicalize = ChainTransformer(
settle_simplify, comp_dnf, settle_simplify
comp_special, settle_simplify, comp_dnf, settle_simplify
)
def transform_observation(self, ast):

View File

@ -0,0 +1,215 @@
"""
Some simple comparison expression canonicalization functions.
"""
import socket
from stix2.equivalence.patterns.compare.comparison import (
object_path_to_raw_values
)
# Values we can use as wildcards in path patterns
_ANY_IDX = object()
_ANY_KEY = object()
_ANY = object()
def _path_is(object_path, path_pattern):
"""
Compare an object path against a pattern. This enables simple path
recognition based on a pattern, which is slightly more flexible than exact
equality: it supports some simple wildcards.
The path pattern must be an iterable of values: strings for key path steps,
ints or "*" for index path steps, or wildcards. Exact matches are required
for non-wildcards in the pattern. For the wildcards, _ANY_IDX matches any
index path step; _ANY_KEY matches any key path step, and _ANY matches any
path step.
:param object_path: An ObjectPath instance
:param path_pattern: An iterable giving the pattern path steps
:return: True if the path matches the pattern; False if not
"""
path_values = object_path_to_raw_values(object_path)
path_iter = iter(path_values)
patt_iter = iter(path_pattern)
result = True
while True:
path_val = next(path_iter, None)
patt_val = next(patt_iter, None)
if path_val is None and patt_val is None:
# equal length sequences; no differences found
break
elif path_val is None or patt_val is None:
# unequal length sequences
result = False
break
elif patt_val is _ANY_IDX:
if not isinstance(path_val, int) and path_val != "*":
result = False
break
elif patt_val is _ANY_KEY:
if not isinstance(path_val, str):
result = False
break
elif patt_val is not _ANY and patt_val != path_val:
result = False
break
return result
def _mask_bytes(ip_bytes, prefix_size):
"""
Retain the high-order 'prefix_size' bits from ip_bytes, and zero out the
remaining low-order bits. This side-effects ip_bytes.
:param ip_bytes: A mutable byte sequence (e.g. a bytearray)
:param prefix_size: An integer prefix size
"""
addr_size_bytes = len(ip_bytes)
addr_size_bits = 8 * addr_size_bytes
assert 0 <= prefix_size <= addr_size_bits
num_fixed_bytes = prefix_size // 8
num_zero_bytes = (addr_size_bits - prefix_size) // 8
if num_zero_bytes > 0:
ip_bytes[addr_size_bytes - num_zero_bytes:] = b"\x00" * num_zero_bytes
if num_fixed_bytes + num_zero_bytes != addr_size_bytes:
# The address boundary doesn't fall on a byte boundary.
# So we have a byte for which we have to zero out some
# bits.
num_1_bits = prefix_size % 8
mask = ((1 << num_1_bits) - 1) << (8 - num_1_bits)
ip_bytes[num_fixed_bytes] &= mask
def windows_reg_key(comp_expr):
"""
Lower-cases the rhs, depending on the windows-registry-key property
being compared. This enables case-insensitive comparisons between two
patterns, for those values. This side-effects the given AST.
:param comp_expr: A _ComparisonExpression object whose type is
windows-registry-key
"""
if _path_is(comp_expr.lhs, ("key",)) \
or _path_is(comp_expr.lhs, ("values", _ANY_IDX, "name")):
comp_expr.rhs.value = comp_expr.rhs.value.lower()
def ipv4_addr(comp_expr):
"""
Canonicalizes a CIDR IPv4 address by zeroing out low-order bits, according
to the prefix size. This affects the rhs when the "value" property of an
ipv4-addr is being compared. If the prefix size is 32, the size suffix is
simply dropped since it's redundant. If the value is not a valid CIDR
address, then no change is made. This also runs the address through the
platform's IPv4 address processing functions (inet_aton() and inet_ntoa()),
which can adjust the format.
This side-effects the given AST.
:param comp_expr: A _ComparisonExpression object whose type is ipv4-addr.
"""
if _path_is(comp_expr.lhs, ("value",)):
value = comp_expr.rhs.value
slash_idx = value.find("/")
if 0 <= slash_idx < len(value)-1:
ip_str = value[:slash_idx]
try:
ip_bytes = socket.inet_aton(ip_str)
except OSError:
# illegal IPv4 address string
return
try:
prefix_size = int(value[slash_idx+1:])
except ValueError:
# illegal prefix size
return
if prefix_size < 0 or prefix_size > 32:
# illegal prefix size
return
if prefix_size == 32:
# Drop the "32" since it's redundant. Run the address bytes
# through inet_ntoa() in case it would adjust the format (e.g.
# drop leading zeros: 1.2.3.004 => 1.2.3.4).
value = socket.inet_ntoa(ip_bytes)
else:
# inet_aton() gives an immutable 'bytes' value; we need a value
# we can change.
ip_bytes = bytearray(ip_bytes)
_mask_bytes(ip_bytes, prefix_size)
ip_str = socket.inet_ntoa(ip_bytes)
value = ip_str + "/" + str(prefix_size)
comp_expr.rhs.value = value
def ipv6_addr(comp_expr):
"""
Canonicalizes a CIDR IPv6 address by zeroing out low-order bits, according
to the prefix size. This affects the rhs when the "value" property of an
ipv6-addr is being compared. If the prefix size is 128, the size suffix is
simply dropped since it's redundant. If the value is not a valid CIDR
address, then no change is made. This also runs the address through the
platform's IPv6 address processing functions (inet_pton() and inet_ntop()),
which can adjust the format.
This side-effects the given AST.
:param comp_expr: A _ComparisonExpression object whose type is ipv6-addr.
"""
if _path_is(comp_expr.lhs, ("value",)):
value = comp_expr.rhs.value
slash_idx = value.find("/")
if 0 <= slash_idx < len(value)-1:
ip_str = value[:slash_idx]
try:
ip_bytes = socket.inet_pton(socket.AF_INET6, ip_str)
except OSError:
# illegal IPv6 address string
return
try:
prefix_size = int(value[slash_idx+1:])
except ValueError:
# illegal prefix size
return
if prefix_size < 0 or prefix_size > 128:
# illegal prefix size
return
if prefix_size == 128:
# Drop the "128" since it's redundant. Run the IP address
# through inet_ntop() so it can reformat with the double-colons
# (and make any other adjustments) if necessary.
value = socket.inet_ntop(socket.AF_INET6, ip_bytes)
else:
# inet_pton() gives an immutable 'bytes' value; we need a value
# we can change.
ip_bytes = bytearray(ip_bytes)
_mask_bytes(ip_bytes, prefix_size)
ip_str = socket.inet_ntop(socket.AF_INET6, ip_bytes)
value = ip_str + "/" + str(prefix_size)
comp_expr.rhs.value = value