Skip to content

Machina Worker API#

Analysis worker base class. Workers inheriting from this class receive data to analyze based on the chosen data type bindings

Source code in machina/core/worker.py
class Worker():
    """Analysis worker base class.  Workers inheriting from this class receive data to analyze based on the chosen data type bindings"""

    next_queues = [] # passes data to another queue in the sequence, this should allow for chaining
    types = [] # indicates what type data to bind to, this should be completed in the subclass
    types_blacklist = [] # indicates what type data NOT to bind to, this should be completed in the subclass. implies binding to all types not in this list.  cannot be combined with types

    def __init__(self):
        self.cls_name = self.__class__.__name__
        self.config = self._load_configs()
        self.schema = self._load_schema()

        # Logging
        level = logging.getLevelName(self.config['worker']['log_level'])
        logging.basicConfig(level=level, format='[*] %(message)s')
        self.logger = logging.getLogger(__name__)

        if self.types and self.types_blacklist:
            self.logger.error("both types and types_blacklist cannot be set at the same time")
            raise Exception

        # validate whitelist types
        # if *, then set types to all available types
        if self.types:
            self.logger.info(f"Validating types: {pformat(self.types)}")
            if '*' in self.types:
                self.types = self.config['types']['available_types']
            else:
                types_valid, t = self._types_valid()
                if not types_valid:
                    self.logger.error(f"{t} is not configured as a type in types.json")
                    raise Exception

        # validate black list types
        if self.types_blacklist:
            self.logger.info(f"Validating types blacklist: {pformat(self.types_blacklist)}")
            types_blacklist_valid, t = self._types_blacklist_valid()
            if not types_blacklist_valid:
                self.logger.error(f"{t} is not configured as a type in types.json, so cannot be blacklisted")
                raise Exception

            # if valid, set types to all except the ones in valid blacklist, and '*'
            self.types = [t for t in self.config['types']['available_types'] if t not in self.types_blacklist]
            # self.types.remove('*')

        # neo4j set connection
        _cfg = self.config['neo4j']
        config.DATABASE_URL = f"bolt://{_cfg['user']}:{_cfg['pass']}@{_cfg['host']}:{_cfg['port']}/{_cfg['db_name']}"

        # Initializer does no queue consumption, so
        # dont create a connection or queue for it
        if self.cls_name != 'Initializer':

            # RabbitMQ Connection info
            # note this is not thread-safe
            self.rmq_conn = self.get_rmq_conn()
            self.rmq_recv_channel = self.rmq_conn.channel()

            # reduce Pika logging level
            logging.getLogger('pika').setLevel(logging.ERROR)

            # The queue to bind to is the name of the class
            bind_queue = self.cls_name

            # Initialize an exchange
            self.rmq_recv_channel.exchange_declare('machina')

            # Initialize direct queue w/ subclass name
            self.logger.info(f"Binding to direct queue: {bind_queue}")
            self.rmq_recv_channel.queue_declare(self.cls_name, durable=True)

            # Ensure that the worker's queue is bound to the exchange
            self.rmq_recv_channel.queue_bind(exchange='machina',
                queue=self.cls_name)

            # multiple-bindings approach:
            # https://www.rabbitmq.com/tutorials/tutorial-four-python.html
            # Bind using type strings as routing_keys
            # Publish using only a routing_key should go to all queues
            # that are bound to that routing_key
            for t in self.types:
                self.logger.info(f'binding to type: {t}')
                self.rmq_recv_channel.queue_bind(exchange='machina',
                    queue=self.cls_name,
                    routing_key=t)

            self.rmq_recv_channel.basic_qos(prefetch_count=1)
            self.rmq_recv_channel.basic_consume(self.cls_name,
                on_message_callback=self._callback)

    #############################################################
    # Privates

    def _load_configs(self) -> dict:
        """load configuration files from expected path, return as dictionary

        :return: the configuration dictionary
        :rtype: dict
        """

        fdir = '/configs'

        paths_cfg_fp = Path(fdir, 'paths.json')
        with open(paths_cfg_fp, 'r') as f:
            paths_cfg = json.load(f)

        rabbitmq_cfg_fp = Path(fdir, 'rabbitmq.json')
        with open(rabbitmq_cfg_fp, 'r') as f:
            rabbitmq_cfg = json.load(f)

        neo4j_cfg_fp = Path(fdir, 'neo4j.json')
        with open(neo4j_cfg_fp, 'r') as f:
            neo4j_cfg = json.load(f)

        types_fp = Path(fdir, 'types.json')
        with open(types_fp, 'r') as f:
            types_cfg = json.load(f)

        # Base-worker configurations, will be overridden by worker-specifc
        # configurations if there is overlap
        base_worker_cfg_fp = Path(fdir, 'workers', 'Worker.json')
        with open(base_worker_cfg_fp, 'r') as f:
            worker_cfg = json.load(f)

        # Worker-specific configuration
        worker_cfg_fp = Path(fdir, 'workers', self.cls_name+'.json')
        with open(worker_cfg_fp, 'r') as f:
            worker_cfg.update(json.load(f))

        return dict(paths=paths_cfg,
                    rabbitmq=rabbitmq_cfg,
                    neo4j=neo4j_cfg,
                    types=types_cfg,
                    worker=worker_cfg)

    def _load_schema(self) -> dict:
        """automatically resolve schema name based on class name

        :return: the schema dictionary
        :rtype: dict
        """
        class_schema = Path(self.config['paths']['schemas'], self.cls_name+'.json')
        with open(class_schema, 'r') as f:
            schema_data = json.load(f)
        return schema_data

    def _callback(
        self, 
        ch: pika.channel.Channel, 
        method: pika.spec.Basic.Deliver, 
        properties: pika.spec.BasicProperties, 
        body: bytes):
        """do last-second validation before handling the callback

        :param ch: pika channel
        :type ch: pika.channel.Channel
        :param method: pike method
        :type method: pika.spec.Basic.Deliver
        :param properties: pika properties
        :type properties: pika.spec.BasicProperties
        :param body: message body
        :type body: bytes
        """
        self._validate_body(body)

        self.logger.info("entering callback")
        thread = threading.Thread(target=self.callback, args=(body, properties))
        thread.start()
        while thread.is_alive():
            self.rmq_recv_channel._connection.sleep(1.0)
        self.logger.info("exiting callback")

        self.rmq_recv_channel.basic_ack(delivery_tag=method.delivery_tag)

    def _validate_body(self, body: bytes):
        """apply subclass worker schema and validate

        :param body: message body
        :type body: bytes
        """
        self.logger.info("validating schema")
        data = json.loads(body)
        # fixed resolver to ensure base schema uri is resolved
        # e.g. https://stackoverflow.com/questions/53968770/how-to-set-up-local-file-references-in-python-jsonschema-document
        # resolver = jsonschema.RefResolver('file://{}'.format(os.path.join(self.config['paths']['schemas'], 'binary.json')), self.schema)
        resolver = jsonschema.RefResolver(f"file:{Path(self.config['paths']['schemas'], self.cls_name+'.json')}", self.schema)

        jsonschema.validate(
            instance=data, 
            schema=self.schema, 
            resolver=resolver)

    def _types_valid(self) -> tuple:
        """ensure that the type to bind to is configured in types.json

        :return: tuple where first element is True if all requested type bindings are valid, or False if not.  If invalid, set the second element to the first discovered invalid type
        :rtype: tuple
        """
        for t in self.types:
            if t not in self.config['types']['available_types']:
                return False, t
        return True, None

    def _types_blacklist_valid(self) -> tuple:
        """ensure that the type to bind to is configured in types.json 

        :return: tuple where first element is True if all requested type bindings are valid, or False if not.  If invalid, set the second element to the first discovered invalid type
        :rtype: tuple
        """
        for t in self.types_blacklist:
            if t not in self.config['types']['available_types']:
                return False, t
        return True, None
    #############################################################

    #############################################################
    # RabbitMQ Helpers
    def get_rmq_conn(
        self, 
        max_attempts:int=10, 
        delay_seconds:int=1) -> pika.BlockingConnection:
        """get RabbitMQ connection instance

        :param max_attempts: max number of attempts to try to get the connection, defaults to 10
        :type max_attempts: int, optional
        :param delay_seconds: the delay between attempts to get the connection, defaults to 1
        :type delay_seconds: int, optional
        :return: the connection instance
        :rtype: pika.BlockingConnection
        """

        rabbitmq_user = self.config['rabbitmq']['rabbitmq_user']
        rabbitmq_password = self.config['rabbitmq']['rabbitmq_password']
        rabbitmq_host = self.config['rabbitmq']['rabbitmq_host']
        rabbitmq_port = self.config['rabbitmq']['rabbitmq_port']
        rabbitmq_heartbeat = self.config['rabbitmq']['rabbitmq_heartbeat']

        connection = None
        credentials = pika.PlainCredentials(rabbitmq_user, rabbitmq_password)
        parameters = pika.ConnectionParameters(rabbitmq_host,
            int(rabbitmq_port),
            '/',
            credentials,
            heartbeat=int(rabbitmq_heartbeat),
            socket_timeout=2)

        attempt = 0
        while attempt < max_attempts:
            try:
                connection = pika.BlockingConnection(parameters)
                break
            except pika.exceptions.AMQPConnectionError as e:
                self.logger.info(f"Attempt {attempt}/{max_attempts} to connect to RabbitMQ at {rabbitmq_host}:{rabbitmq_port}")
                self.logger.warn("Error connecting to RabbitMQ")

            attempt += 1
            time.sleep(delay_seconds)

        if not connection:
            self.logger.error('max attempts exceeded')
            sys.exit()

        return connection

    def start_consuming(self):
        """start consuming"""
        self.logger.info(f'{self.cls_name} worker started')
        try:
            self.rmq_recv_channel.start_consuming()
        except Exception as e:
            self.logger.error(e, exc_info=True)
            self.rmq_recv_channel.stop_consuming()
        self.connection.close()

    def callback(self, data: bytes, properties: pika.spec.BasicProperties):
        """callback for worker, implement in subclass

        :param data: incoming string payload
        :type data: bytes
        :param properties: message properties
        :type properties: pika.spec.BasicProperties
        :raises NotImplementedError:
        """
        raise NotImplementedError

    def publish_next(self, data: bytes):
        """publish to configured next_queues

        :param data: the data to publish
        :type data: bytes
        """
        if not self.next_queues:
            self.logger.warn('attempting to publish to next queue, but no next_queues defined in worker class')
        self.publish(data, queues=self.next_queues)

    def publish(self, data: bytes, queues: list):
        """publish directly to a list of arbitrary queues

        :param data: the data to publish
        :type data: bytes
        :param queues: the list of queue names (as strings) to publish data to
        :type queues: list
        """
        rmq_conn = self.get_rmq_conn()
        for q in queues:
            self.logger.info(f"publishing directly to {q}")
            rmq_channel = rmq_conn.channel()
            rmq_channel.basic_publish(
                exchange='machina',
                routing_key=q,
                body=data)
        rmq_conn.close()

    #############################################################

    #############################################################
    # Misc
    def get_binary_path(self, ts:str, md5:str) -> str:
        """get path to a binary on disk given a timestamp and its md5

        :param ts: the timestamp of the binary
        :type ts: str
        :param md5: the md5 of the binary
        :type md5: str
        :return: the path to the binary
        :rtype: str
        """
        binary_path = Path(self.config['paths']['binaries'], ts, md5)
        return str(binary_path)

callback(data, properties) #

callback for worker, implement in subclass

Parameters:

Name Type Description Default
data bytes

incoming string payload

required
properties pika.spec.BasicProperties

message properties

required

Raises:

Type Description
NotImplementedError
Source code in machina/core/worker.py
def callback(self, data: bytes, properties: pika.spec.BasicProperties):
    """callback for worker, implement in subclass

    :param data: incoming string payload
    :type data: bytes
    :param properties: message properties
    :type properties: pika.spec.BasicProperties
    :raises NotImplementedError:
    """
    raise NotImplementedError

get_binary_path(ts, md5) #

get path to a binary on disk given a timestamp and its md5

Parameters:

Name Type Description Default
ts str

the timestamp of the binary

required
md5 str

the md5 of the binary

required

Returns:

Type Description
str

the path to the binary

Source code in machina/core/worker.py
def get_binary_path(self, ts:str, md5:str) -> str:
    """get path to a binary on disk given a timestamp and its md5

    :param ts: the timestamp of the binary
    :type ts: str
    :param md5: the md5 of the binary
    :type md5: str
    :return: the path to the binary
    :rtype: str
    """
    binary_path = Path(self.config['paths']['binaries'], ts, md5)
    return str(binary_path)

get_rmq_conn(max_attempts=10, delay_seconds=1) #

get RabbitMQ connection instance

Parameters:

Name Type Description Default
max_attempts int

max number of attempts to try to get the connection, defaults to 10

10
delay_seconds int

the delay between attempts to get the connection, defaults to 1

1

Returns:

Type Description
pika.BlockingConnection

the connection instance

Source code in machina/core/worker.py
def get_rmq_conn(
    self, 
    max_attempts:int=10, 
    delay_seconds:int=1) -> pika.BlockingConnection:
    """get RabbitMQ connection instance

    :param max_attempts: max number of attempts to try to get the connection, defaults to 10
    :type max_attempts: int, optional
    :param delay_seconds: the delay between attempts to get the connection, defaults to 1
    :type delay_seconds: int, optional
    :return: the connection instance
    :rtype: pika.BlockingConnection
    """

    rabbitmq_user = self.config['rabbitmq']['rabbitmq_user']
    rabbitmq_password = self.config['rabbitmq']['rabbitmq_password']
    rabbitmq_host = self.config['rabbitmq']['rabbitmq_host']
    rabbitmq_port = self.config['rabbitmq']['rabbitmq_port']
    rabbitmq_heartbeat = self.config['rabbitmq']['rabbitmq_heartbeat']

    connection = None
    credentials = pika.PlainCredentials(rabbitmq_user, rabbitmq_password)
    parameters = pika.ConnectionParameters(rabbitmq_host,
        int(rabbitmq_port),
        '/',
        credentials,
        heartbeat=int(rabbitmq_heartbeat),
        socket_timeout=2)

    attempt = 0
    while attempt < max_attempts:
        try:
            connection = pika.BlockingConnection(parameters)
            break
        except pika.exceptions.AMQPConnectionError as e:
            self.logger.info(f"Attempt {attempt}/{max_attempts} to connect to RabbitMQ at {rabbitmq_host}:{rabbitmq_port}")
            self.logger.warn("Error connecting to RabbitMQ")

        attempt += 1
        time.sleep(delay_seconds)

    if not connection:
        self.logger.error('max attempts exceeded')
        sys.exit()

    return connection

publish(data, queues) #

publish directly to a list of arbitrary queues

Parameters:

Name Type Description Default
data bytes

the data to publish

required
queues list

the list of queue names (as strings) to publish data to

required
Source code in machina/core/worker.py
def publish(self, data: bytes, queues: list):
    """publish directly to a list of arbitrary queues

    :param data: the data to publish
    :type data: bytes
    :param queues: the list of queue names (as strings) to publish data to
    :type queues: list
    """
    rmq_conn = self.get_rmq_conn()
    for q in queues:
        self.logger.info(f"publishing directly to {q}")
        rmq_channel = rmq_conn.channel()
        rmq_channel.basic_publish(
            exchange='machina',
            routing_key=q,
            body=data)
    rmq_conn.close()

publish_next(data) #

publish to configured next_queues

Parameters:

Name Type Description Default
data bytes

the data to publish

required
Source code in machina/core/worker.py
def publish_next(self, data: bytes):
    """publish to configured next_queues

    :param data: the data to publish
    :type data: bytes
    """
    if not self.next_queues:
        self.logger.warn('attempting to publish to next queue, but no next_queues defined in worker class')
    self.publish(data, queues=self.next_queues)

start_consuming() #

start consuming

Source code in machina/core/worker.py
def start_consuming(self):
    """start consuming"""
    self.logger.info(f'{self.cls_name} worker started')
    try:
        self.rmq_recv_channel.start_consuming()
    except Exception as e:
        self.logger.error(e, exc_info=True)
        self.rmq_recv_channel.stop_consuming()
    self.connection.close()