diff --git a/stix2/environment.py b/stix2/environment.py index 3be0bb0..9eae2d6 100644 --- a/stix2/environment.py +++ b/stix2/environment.py @@ -193,7 +193,7 @@ class Environment(DataStoreMixin): return None @staticmethod - def semantically_equivalent(obj1, obj2, **weight_dict): + def semantically_equivalent(obj1, obj2, prop_scores={}, **weight_dict): """This method is meant to verify if two objects of the same type are semantically equivalent. @@ -277,17 +277,16 @@ class Environment(DataStoreMixin): raise ValueError('The objects to compare must be of the same spec version!') try: - method = weights[type1]["method"] + weights[type1] except KeyError: + logger.warning("'%s' type has no 'weights' dict specified in the semantic equivalence method call!", type1) + sum_weights = matching_score = 0 + else: try: - weights[type1] + method = weights[type1]["method"] except KeyError: - logger.warning("'%s' type has no semantic equivalence method to call!", type1) - sum_weights = matching_score = 0 - else: matching_score = 0.0 sum_weights = 0.0 - prop_scores = {} for prop in weights[type1]: if check_property_present(prop, obj1, obj2) or prop == "longitude_latitude": @@ -310,13 +309,15 @@ class Environment(DataStoreMixin): prop_scores["matching_score"] = matching_score prop_scores["sum_weights"] = sum_weights - else: - logger.debug("Starting semantic equivalence process between: '%s' and '%s'", obj1["id"], obj2["id"]) - matching_score, sum_weights = method(obj1, obj2, **weights[type1]) + else: + logger.debug("Starting semantic equivalence process between: '%s' and '%s'", obj1["id"], obj2["id"]) + try: + matching_score, sum_weights = method(obj1, obj2, prop_scores, **weights[type1]) + except TypeError: + matching_score, sum_weights = method(obj1, obj2, **weights[type1]) if sum_weights <= 0: return 0 - equivalence_score = (matching_score / sum_weights) * 100.0 return equivalence_score @@ -503,31 +504,3 @@ def partial_location_distance(lat1, long1, lat2, long2, threshold): (lat1, long1), (lat2, long2), threshold, result, ) return result - - -def _indicator_checks(obj1, obj2, **weights): - matching_score = 0.0 - sum_weights = 0.0 - if check_property_present("indicator_types", obj1, obj2): - w = weights["indicator_types"] - contributing_score = w * partial_list_based(obj1["indicator_types"], obj2["indicator_types"]) - sum_weights += w - matching_score += contributing_score - logger.debug("'indicator_types' check -- weight: %s, contributing score: %s", w, contributing_score) - if check_property_present("pattern", obj1, obj2): - w = weights["pattern"] - contributing_score = w * custom_pattern_based(obj1["pattern"], obj2["pattern"]) - sum_weights += w - matching_score += contributing_score - logger.debug("'pattern' check -- weight: %s, contributing score: %s", w, contributing_score) - if check_property_present("valid_from", obj1, obj2): - w = weights["valid_from"] - contributing_score = ( - w * - partial_timestamp_based(obj1["valid_from"], obj2["valid_from"], weights["tdelta"]) - ) - sum_weights += w - matching_score += contributing_score - logger.debug("'valid_from' check -- weight: %s, contributing score: %s", w, contributing_score) - logger.debug("Matching Score: %s, Sum of Weights: %s", matching_score, sum_weights) - return matching_score, sum_weights diff --git a/stix2/test/v21/test_environment.py b/stix2/test/v21/test_environment.py index d057df5..8432700 100644 --- a/stix2/test/v21/test_environment.py +++ b/stix2/test/v21/test_environment.py @@ -622,11 +622,10 @@ def test_semantic_equivalence_zero_match(): ) weights = { "indicator": { - "indicator_types": 15, - "pattern": 80, - "valid_from": 0, + "indicator_types": (15, stix2.environment.partial_list_based), + "pattern": (80, stix2.environment.custom_pattern_based), + "valid_from": (5, stix2.environment.partial_timestamp_based), "tdelta": 1, # One day interval - "method": stix2.environment._indicator_checks, }, "_internal": { "ignore_spec_version": False, @@ -645,11 +644,10 @@ def test_semantic_equivalence_different_spec_version(): ) weights = { "indicator": { - "indicator_types": 15, - "pattern": 80, - "valid_from": 0, + "indicator_types": (15, stix2.environment.partial_list_based), + "pattern": (80, stix2.environment.custom_pattern_based), + "valid_from": (5, stix2.environment.partial_timestamp_based), "tdelta": 1, # One day interval - "method": stix2.environment._indicator_checks, }, "_internal": { "ignore_spec_version": True, # Disables spec_version check. @@ -750,3 +748,81 @@ def test_non_existent_config_for_object(): r1 = stix2.v21.Report(id=REPORT_ID, **REPORT_KWARGS) r2 = stix2.v21.Report(id=REPORT_ID, **REPORT_KWARGS) assert stix2.Environment().semantically_equivalent(r1, r2) == 0.0 + + +def custom_semantic_equivalence_method(obj1, obj2, **weights): + return 96.0, 100.0 + + +def test_semantic_equivalence_method_provided(): + TOOL2_KWARGS = dict( + name="Random Software", + tool_types=["information-gathering"], + ) + + weights = { + "tool": { + "tool_types": (20, stix2.environment.partial_list_based), + "name": (80, stix2.environment.partial_string_based), + "method": custom_semantic_equivalence_method, + }, + } + + tool1 = stix2.v21.Tool(id=TOOL_ID, **TOOL_KWARGS) + tool2 = stix2.v21.Tool(id=TOOL_ID, **TOOL2_KWARGS) + env = stix2.Environment().semantically_equivalent(tool1, tool2, **weights) + assert round(env) == 96 + + +def test_semantic_equivalence_prop_scores(): + TOOL2_KWARGS = dict( + name="Random Software", + tool_types=["information-gathering"], + ) + + weights = { + "tool": { + "tool_types": (20, stix2.environment.partial_list_based), + "name": (80, stix2.environment.partial_string_based), + }, + } + + prop_scores = {} + + tool1 = stix2.v21.Tool(id=TOOL_ID, **TOOL_KWARGS) + tool2 = stix2.v21.Tool(id=TOOL_ID, **TOOL2_KWARGS) + stix2.Environment().semantically_equivalent(tool1, tool2, prop_scores, **weights) + assert len(prop_scores) == 4 + assert round(prop_scores["matching_score"], 1) == 37.6 + assert round(prop_scores["sum_weights"], 1) == 100.0 + + +def custom_semantic_equivalence_method_prop_scores(obj1, obj2, prop_scores, **weights): + prop_scores["matching_score"] = 96.0 + prop_scores["sum_weights"] = 100.0 + return 96.0, 100.0 + + +def test_semantic_equivalence_prop_scores_method_provided(): + TOOL2_KWARGS = dict( + name="Random Software", + tool_types=["information-gathering"], + ) + + weights = { + "tool": { + "tool_types": (20, stix2.environment.partial_list_based), + "name": (80, stix2.environment.partial_string_based), + "method": custom_semantic_equivalence_method_prop_scores, + }, + } + + prop_scores = {} + + tool1 = stix2.v21.Tool(id=TOOL_ID, **TOOL_KWARGS) + tool2 = stix2.v21.Tool(id=TOOL_ID, **TOOL2_KWARGS) + env = stix2.Environment().semantically_equivalent(tool1, tool2, prop_scores, **weights) + assert round(env) == 96 + assert len(prop_scores) == 2 + assert prop_scores["matching_score"] == 96.0 + assert prop_scores["sum_weights"] == 100.0