mirror of https://github.com/CIRCL/AIL-framework
329 lines
10 KiB
Python
Executable File
329 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*-coding:UTF-8 -*
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import uuid
|
|
|
|
import asyncio
|
|
import http
|
|
import ssl
|
|
import websockets
|
|
|
|
sys.path.append(os.environ['AIL_BIN'])
|
|
##################################
|
|
# Import Project packages
|
|
##################################
|
|
from pubsublogger import publisher
|
|
from core import ail_2_ail
|
|
|
|
# # TODO: refactor logging
|
|
#### LOGS ####
|
|
redis_logger = publisher
|
|
redis_logger.port = 6380
|
|
redis_logger.channel = 'Sync'
|
|
|
|
#############################
|
|
|
|
CONNECTED_CLIENTS = {}
|
|
# # TODO: Store in redis
|
|
|
|
#############################
|
|
|
|
# # # # # # #
|
|
# #
|
|
# UTILS #
|
|
# #
|
|
# # # # # # #
|
|
|
|
def is_valid_uuid_v4(UUID):
|
|
if not UUID:
|
|
return False
|
|
UUID = UUID.replace('-', '')
|
|
try:
|
|
uuid_test = uuid.UUID(hex=UUID, version=4)
|
|
return uuid_test.hex == UUID
|
|
except:
|
|
return False
|
|
|
|
def unpack_path(path):
|
|
dict_path = {}
|
|
path = path.split('/')
|
|
if len(path) < 3:
|
|
raise Exception('Invalid url path')
|
|
if not len(path[-1]):
|
|
path = path[:-1]
|
|
|
|
dict_path['sync_mode'] = path[1]
|
|
dict_path['ail_uuid'] = path[-1]
|
|
dict_path['api'] = path[2:-1]
|
|
|
|
return dict_path
|
|
|
|
# # # # # # #
|
|
|
|
# # TODO: ADD more commands
|
|
async def server_controller():
|
|
while True:
|
|
command_dict = ail_2_ail.get_server_controller_command()
|
|
if command_dict:
|
|
command = command_dict.get('command')
|
|
if command == 'kill':
|
|
ail_uuid = command_dict.get('ail_uuid')
|
|
connected_clients = CONNECTED_CLIENTS[ail_uuid].copy()
|
|
for c_websocket in connected_clients:
|
|
await c_websocket.close(code=1000)
|
|
redis_logger.info(f'Server Command Connection closed: {ail_uuid}')
|
|
print(f'Server Command Connection closed: {ail_uuid}')
|
|
|
|
await asyncio.sleep(10)
|
|
|
|
# # # # # # #
|
|
|
|
async def register(websocket):
|
|
ail_uuid = websocket.ail_uuid
|
|
remote_address = websocket.remote_address
|
|
sync_mode = websocket.sync_mode
|
|
redis_logger.info(f'Client Connected: {ail_uuid} {remote_address}')
|
|
print(f'Client Connected: {ail_uuid} {remote_address}')
|
|
|
|
if not ail_uuid in CONNECTED_CLIENTS:
|
|
CONNECTED_CLIENTS[ail_uuid] = set()
|
|
CONNECTED_CLIENTS[ail_uuid].add(websocket)
|
|
ail_2_ail.add_server_connected_client(ail_uuid, sync_mode)
|
|
|
|
print('Register client')
|
|
print(CONNECTED_CLIENTS)
|
|
print()
|
|
|
|
async def unregister(websocket):
|
|
ail_uuid = websocket.ail_uuid
|
|
sync_mode = websocket.sync_mode
|
|
CONNECTED_CLIENTS[ail_uuid].remove(websocket)
|
|
connected_clients = CONNECTED_CLIENTS[ail_uuid].copy()
|
|
for c_websocket in connected_clients:
|
|
if c_websocket.sync_mode == sync_mode:
|
|
sync_mode = None
|
|
break
|
|
if not CONNECTED_CLIENTS[ail_uuid]:
|
|
is_connected = False
|
|
CONNECTED_CLIENTS.pop(ail_uuid)
|
|
else:
|
|
is_connected = True
|
|
ail_2_ail.remove_server_connected_client(ail_uuid, sync_mode=sync_mode, is_connected=is_connected)
|
|
|
|
print('Unregister client')
|
|
print(CONNECTED_CLIENTS)
|
|
print()
|
|
|
|
# PULL: Send data to client
|
|
# # TODO: ADD TIMEOUT ???
|
|
async def pull(websocket, ail_uuid):
|
|
try:
|
|
for queue_uuid in ail_2_ail.get_ail_instance_all_sync_queue(ail_uuid):
|
|
while True:
|
|
# get elem to send
|
|
Obj = ail_2_ail.get_sync_queue_object_by_queue_uuid(queue_uuid, ail_uuid, push=False)
|
|
if Obj:
|
|
obj_ail_stream = ail_2_ail.create_ail_stream(Obj)
|
|
Obj = json.dumps(obj_ail_stream)
|
|
#print(Obj)
|
|
|
|
# send objects
|
|
await websocket.send(Obj)
|
|
await asyncio.sleep(0.1)
|
|
# END PULL
|
|
else:
|
|
break
|
|
except websockets.exceptions.ConnectionClosedError as err:
|
|
# resend object in queue on Connection Error
|
|
ail_2_ail.resend_object_to_sync_queue(ail_uuid, queue_uuid, Obj, push=False)
|
|
raise err
|
|
|
|
# END PULL
|
|
return None
|
|
|
|
|
|
# PUSH: receive data from client
|
|
# # TODO: optional queue_uuid
|
|
async def push(websocket, ail_uuid):
|
|
#print(ail_uuid)
|
|
while True:
|
|
ail_stream = await websocket.recv()
|
|
|
|
# # TODO: CHECK ail_stream
|
|
ail_stream = json.loads(ail_stream)
|
|
#print(ail_stream)
|
|
|
|
# # TODO: Close connection on junk
|
|
ail_2_ail.add_ail_stream_to_sync_importer(ail_stream)
|
|
|
|
# API: server API
|
|
# # TODO: ADD TIMEOUT ???
|
|
async def api(websocket, ail_uuid, api):
|
|
api = api[0]
|
|
if api == 'ping':
|
|
message = {'message':'pong'}
|
|
message = json.dumps(message)
|
|
await websocket.send(message)
|
|
elif api == 'version':
|
|
sync_version = ail_2_ail.get_sync_server_version()
|
|
message = {'version': sync_version}
|
|
message = json.dumps(message)
|
|
await websocket.send(message)
|
|
|
|
# END API
|
|
return
|
|
|
|
async def ail_to_ail_serv(websocket, path):
|
|
|
|
# # TODO: save in class
|
|
ail_uuid = websocket.ail_uuid
|
|
remote_address = websocket.remote_address
|
|
path = unpack_path(path)
|
|
sync_mode = path['sync_mode']
|
|
|
|
# # TODO: check if it works
|
|
# # DEBUG:
|
|
# print(websocket.ail_uuid)
|
|
# print(websocket.remote_address)
|
|
# print(f'sync mode: {sync_mode}')
|
|
|
|
await register(websocket)
|
|
try:
|
|
if sync_mode == 'pull':
|
|
await pull(websocket, websocket.ail_uuid)
|
|
await websocket.close()
|
|
redis_logger.info(f'Connection closed: {ail_uuid} {remote_address}')
|
|
print(f'Connection closed: {ail_uuid} {remote_address}')
|
|
|
|
elif sync_mode == 'push':
|
|
await push(websocket, websocket.ail_uuid)
|
|
|
|
elif sync_mode == 'api':
|
|
await api(websocket, websocket.ail_uuid, path['api'])
|
|
await websocket.close()
|
|
redis_logger.info(f'Connection closed: {ail_uuid} {remote_address}')
|
|
print(f'Connection closed: {ail_uuid} {remote_address}')
|
|
|
|
finally:
|
|
await unregister(websocket)
|
|
|
|
|
|
###########################################
|
|
# CHECK Authorization HEADER and URL PATH #
|
|
|
|
# # TODO: check AIL UUID (optional header)
|
|
|
|
class AIL_2_AIL_Protocol(websockets.WebSocketServerProtocol):
|
|
"""AIL_2_AIL_Protocol websockets server."""
|
|
|
|
async def process_request(self, path, request_headers):
|
|
|
|
# DEBUG:
|
|
# print(self.remote_address)
|
|
# print(request_headers)
|
|
|
|
# API TOKEN
|
|
api_key = request_headers.get('Authorization', '')
|
|
if api_key is None:
|
|
redis_logger.warning(f'Missing token: {self.remote_address}')
|
|
print(f'Missing token: {self.remote_address}')
|
|
return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n"
|
|
|
|
if not ail_2_ail.is_allowed_ail_instance_key(api_key):
|
|
redis_logger.warning(f'Invalid token: {self.remote_address}')
|
|
print(f'Invalid token: {self.remote_address}')
|
|
return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n"
|
|
|
|
# PATH
|
|
try:
|
|
dict_path = unpack_path(path)
|
|
except Exception as e:
|
|
redis_logger.warning(f'Invalid path: {self.remote_address}')
|
|
print(f'Invalid path: {self.remote_address}')
|
|
return http.HTTPStatus.BAD_REQUEST, [], b"Invalid path\n"
|
|
|
|
|
|
ail_uuid = ail_2_ail.get_ail_instance_by_key(api_key)
|
|
if ail_uuid != dict_path['ail_uuid']:
|
|
redis_logger.warning(f'Invalid token: {self.remote_address} {ail_uuid}')
|
|
print(f'Invalid token: {self.remote_address} {ail_uuid}')
|
|
return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n"
|
|
|
|
|
|
if not api_key != ail_2_ail.get_ail_instance_key(api_key):
|
|
redis_logger.warning(f'Invalid token: {self.remote_address} {ail_uuid}')
|
|
print(f'Invalid token: {self.remote_address} {ail_uuid}')
|
|
return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n"
|
|
|
|
self.ail_key = api_key
|
|
self.ail_uuid = ail_uuid
|
|
self.sync_mode = dict_path['sync_mode']
|
|
|
|
if self.sync_mode == 'pull' or self.sync_mode == 'push':
|
|
|
|
# QUEUE UUID
|
|
# if dict_path['queue_uuid']:
|
|
#
|
|
# if not is_valid_uuid_v4(dict_path['queue_uuid']):
|
|
# print('Invalid UUID')
|
|
# return http.HTTPStatus.BAD_REQUEST, [], b"Invalid UUID\n"
|
|
#
|
|
# self.queue_uuid = dict_path['queue_uuid']
|
|
# else:
|
|
# self.queue_uuid = None
|
|
#
|
|
# if not ail_2_ail.is_ail_instance_queue(ail_uuid, dict_path['queue_uuid']):
|
|
# print('UUID not found')
|
|
# return http.HTTPStatus.FORBIDDEN, [], b"UUID not found\n"
|
|
|
|
# SYNC MODE
|
|
if not ail_2_ail.is_ail_instance_sync_enabled(self.ail_uuid, sync_mode=self.sync_mode):
|
|
sync_mode = self.sync_mode
|
|
redis_logger.warning(f'SYNC mode disabled: {self.remote_address} {ail_uuid} {sync_mode}')
|
|
print(f'SYNC mode disabled: {self.remote_address} {ail_uuid} {sync_mode}')
|
|
return http.HTTPStatus.FORBIDDEN, [], b"SYNC mode disabled\n"
|
|
|
|
# # TODO: CHECK API
|
|
elif self.sync_mode == 'api':
|
|
pass
|
|
|
|
else:
|
|
print(f'Invalid path: {self.remote_address}')
|
|
redis_logger.info(f'Invalid path: {self.remote_address}')
|
|
return http.HTTPStatus.BAD_REQUEST, [], b"Invalid path\n"
|
|
|
|
###########################################
|
|
|
|
# # TODO: clean shutdown / kill all connections
|
|
# # TODO: Filter object
|
|
# # TODO: IP/uuid to block
|
|
|
|
if __name__ == '__main__':
|
|
|
|
host = '0.0.0.0'
|
|
port = 4443
|
|
|
|
print('Launching Server...')
|
|
redis_logger.info('Launching Server...')
|
|
|
|
ail_2_ail.clear_server_connected_clients()
|
|
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
cert_dir = os.environ['AIL_FLASK']
|
|
ssl_context.load_cert_chain(certfile=os.path.join(cert_dir, 'server.crt'), keyfile=os.path.join(cert_dir, 'server.key'))
|
|
|
|
start_server = websockets.serve(ail_to_ail_serv, host, port, ssl=ssl_context, create_protocol=AIL_2_AIL_Protocol)
|
|
|
|
print(f'Server Launched: wss://{host}:{port}')
|
|
redis_logger.info(f'Server Launched: wss://{host}:{port}')
|
|
|
|
loop = asyncio.get_event_loop()
|
|
# server command
|
|
loop.create_task(server_controller())
|
|
# websockets server
|
|
loop.run_until_complete(start_server)
|
|
loop.run_forever()
|