From 52337596dc11a28a15b2395749cefa39d62e335b Mon Sep 17 00:00:00 2001 From: Hannah Ward Date: Tue, 13 Feb 2018 11:58:26 +0000 Subject: [PATCH] fix: Add TZInfo to poll times --- scripts/run-taxii-poll.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/scripts/run-taxii-poll.py b/scripts/run-taxii-poll.py index cf05fa2..882c344 100644 --- a/scripts/run-taxii-poll.py +++ b/scripts/run-taxii-poll.py @@ -2,6 +2,7 @@ from cabby import create_client from pyaml import yaml +import pytz import argparse import os import logging @@ -17,7 +18,9 @@ parser.add_argument("-s", "--stdout", action="store_true", help="Log to STDOUT") parser.add_argument("--start", help="Date to poll from (YYYY-MM-DD), Exclusive") parser.add_argument("--end", help="Date to poll to (YYYY-MM-DD), Inclusive") parser.add_argument("--subscription_id", help="The ID of the subscription", default=None) - +parser.add_argument("--tz", help="Your timezone, e.g Europe/London. Default utc", + default="utc") + args = parser.parse_args() # Set up a logger for logging's sake @@ -70,9 +73,28 @@ except Exception as ex: log.info("Connected") -poll_from = datetime.strptime(args.start, "%Y-%m-%d") if args.start else None -poll_to = datetime.strptime(args.end, "%Y-%m-%d") if args.end else datetime.now() subscription_id = args.subscription_id +poll_from = datetime.strptime(args.start, "%Y-%m-%dT%H:%M:%S") if args.start else None +poll_to = datetime.strptime(args.end, "%Y-%m-%dT%H:%M:%S") if args.end else datetime.now() + +timezone = args.tz +# Try to cast to pytz +try: + timezone = pytz.timezone(timezone) +except pytz.exceptions.UnknownTimeZoneError: + log.fatal("Timezone %s unknown", timezone) + log.fatal("Please select one of %s", ", ".join(pytz.all_timezones)) + log.fatal("That's case sensitive!") + sys.exit(1) + +# Add timezone info +if poll_from: + # (may not exist) + poll_from = poll_from.replace(tzinfo=pytz.timezone(args.tz)) + +poll_to = poll_to.replace(tzinfo=pytz.timezone(args.tz)) + +log.info("Set poll time to %s - %s", poll_from, poll_to) for server in config: log.info("== %s ==", server["name"])