# SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple, Dict, Optional
import asyncio
import logging
import pydantic
from hipcheck_sdk import *
import hipcheck_sdk.gen as gen
from hipcheck_sdk.chunk import *
from hipcheck_sdk.error import *
logger = logging.getLogger(__name__)
def split_once(s: str, delim: str) -> Tuple[str, Optional[str]]:
"""
Split `s` at first instance of `delim` and returns the resulting substrings
:param str s: The input string.
:param str delim: The delimiter at which to The concern to be recorded.
:return: A tuple containing the substrings split at the first instance
the delimiter. If the delimiter was not found, the second element of
the tuple is `None`.
:meta private:
"""
res = s.split(delim, 1)
if len(res) != 2:
res.append(None)
return tuple(res)
def parse_target_str(target: str) -> Tuple[str, str, str]:
"""
Return a tuple of (publisher, plugin, endpoint_name) from parsed from
a query target string (e.g. "mitre/example/query"). If the string contains
only one slash return "" for the endpoint name, indicating the default
query endpoint.
:param str target: A string of the form `<PUBLISHER>/<PLUGIN>/<ENDPOINT>`,
where `/<ENDPOINT>` may be omitted if targeting a default endpoint.
:return: A tuple of strings representing the publisher, plugin, and
endpoint name
:meta private:
"""
publisher, rest = split_once(target, "/")
if rest is None:
raise InvalidQueryTargetFormat()
plugin, name = split_once(rest, "/")
if name is None:
name = ""
return (publisher, plugin, name)
def deserialize_key(json_str: str, target_type: type) -> object:
"""
Try to turn a JSON string into a given type
:param str json_str: The JSON string to be read
:param type target_type: The type to produce
:return: An instance of `target_type`
:meta private:
"""
if issubclass(target_type, pydantic.BaseModel):
return target_type.model_validate_json(json_str)
elif target_type in [int, bool, float, list, dict]:
return json.loads(json_str)
else:
return json.loads(json_str, cls=target_type)
[docs]
class QueryBuilder:
"""
An alternative to calling `PluginEngine.batch_query()`, an instance of
this class is returned by `PluginQuery.batch()`, and allows plugin
authors to build up a list of keys for a batch query over time, then call
`QueryBuilder.send()` to send them in a single message to Hipcheck.
"""
def __init__(self, engine, target: str):
"""
:meta private:
"""
self.engine = engine
# The keys that will be added to a batch query
self.keys = []
# The target endpoint that all keys will be used to query
self.target = target
[docs]
def query(self, key: object) -> int:
"""
Adds a key to the batch query being built by this object
:param object key: The key to add to the query batch.
:return: The index of the list at which the key was added. Can be used
to index the output of `send()` to get the corresponding output for
that key.
"""
l = len(self.keys)
self.keys.append(key)
return l
[docs]
async def send(self) -> List[object]:
"""
Sends all keys aggregated with `query()` in a single batch query to
Hipcheck core.
:return: The list of output objects for the keys aggregated in this object
:raises SdkError:
"""
return await self.engine.batch_query(self.target, self.keys)
MockResponses = List[Tuple[Tuple[str, object], object]]
def find_response(m: MockResponses, key: Tuple[str, object]) -> Optional[object]:
"""
Find a tuple `y` in `m` such that y[0] == key, and return y[1]. This acts
like a pseudo-dictionary for storing mock responses to deal with the fact
that lists/dicts/objects are not hashable by default.
:param MockResponses m: The list of tuples to search
:param tuple key: The key to check elements of `m` against
:return: The second element of a tuple whose first element matches `key`. If none found, returns `None`.
:meta private:
"""
i = map(lambda x: x[1], filter(lambda y: y[0] == key, m))
try:
return next(i)
except StopIteration:
return None
[docs]
class PluginEngine:
"""
Manages a particular query session.
An instance of this class invokes a query endpoint, passing a handle to
itself. This allows the query endpoint to request information from other
Hipcheck plugins as part of its logic.
"""
def __init__(
self,
session_id: int = 0,
tx: asyncio.Queue = None,
rx: asyncio.Queue = None,
drop_tx: asyncio.Queue = None,
mock_responses: Optional[MockResponses] = None,
):
"""
:meta private:
"""
nones = [v is None for v in [tx, rx, drop_tx]]
if any(nones) and not all(nones):
raise UnspecifiedConfigError(
msg="tx, rx, and drop_tx must all be None or all be asyncio.Queue objects"
)
self.id: int = session_id
self.tx: asyncio.Queue = tx
self.rx: asyncio.Queue = rx
# So that we can remove ourselves when we get dropped
self.drop_tx: asyncio.Queue = drop_tx
self.concerns: List[str] = []
# When unit testing, this enables the user to mock plugin responses to various inputs
self.mock_responses = mock_responses
[docs]
def mock(mock_responses: List[Tuple[Tuple[str, object], object]] = []):
"""
For unit testing purposes, construct a PluginEngine with a set of
mock responses
:param tuple mock_responses: A list of key, value pairs that maps
queries to mock responses. Does not use a dict since many
relevant types are not hashable and thus cannot be used as
keys. The query is a tuple of a target string and a key
object, the response is the output object for that query.
:return: An instance of `PluginEngine`
"""
# In `PluginEngine.query()` if mock_responses is None we try to query Hipcheck core
# which will obviously fail in a unit-testing context. Try to defend here against
# the user making a mistake; if `mock()` called we expect to always mock responses.
if mock_responses is None:
mock_responses = []
return PluginEngine(mock_responses=mock_responses)
# Convenience function to expose a `QueryBuilder` to make it easy to dynamically build
# up multi-key queries against a single target and send them over gRPC in as few
# gRPC calls as possible.
[docs]
def batch(self, target: str) -> QueryBuilder:
"""
Create a `QueryBuilder` instance to dynamically aggregate keys for a
batch query against `target` as opposed to using `PluginEngine.batch_query()`
which requires having all keys available immediately
:param str target: A string of the form `<PUBLISHER>/<PLUGIN>/<ENDPOINT>`,
where `/<ENDPOINT>` may be omitted if targeting a default endpoint.
Indicates the remote plugin endpoint to query.
:return: An instance of `QueryBuilder`
"""
return QueryBuilder(self, target)
[docs]
async def query_inner(self, target: str, keys: List[object]) -> List[object]:
"""
:raises: SdkError
:meta: private
"""
# If using a mock engine, look to the `mock_responses` field for the query answer
if self.mock_responses is not None:
results = []
for i in keys:
opt_val = find_response(self.mock_responses, (target, i))
if opt_val is None:
raise UnknownPluginQuery()
else:
results.append(opt_val)
return results
else:
# Normal execution, send messages to Hipcheck core to query other plugins
publisher, plugin, name = parse_target_str(target)
query = Query(
id=self.id,
direction=QueryDirection.REQUEST,
publisher=publisher,
plugin=plugin,
query=name,
key=keys,
output=[],
concerns=[],
)
await self.send(query)
resp: Query = await self.recv()
return list(resp.output)
[docs]
async def query(self, target: str, key: object) -> object:
"""
Query another Hipcheck plugin endpoint `target` with key `input`
:param str target: A string of the form `<PUBLISHER>/<PLUGIN>/<ENDPOINT>`,
where `/<ENDPOINT>` may be omitted if targeting a default endpoint.
Indicates the remote plugin endpoint to query.
:param object key: The key for the query
:return: The deserialized result
:raises: SdkError
"""
outputs = await self.query_inner(target, [key])
return outputs[0]
[docs]
async def batch_query(self, target: str, keys: List[object]) -> List[object]:
"""
Query another Hipcheck plugin endpoint `target` with a list of keys
:param str target: Indicates the remote plugin endpoint to query. Has
the same format requirements as `query()`
:param list keys: The list of keys to send as a single batch query to
Hipcheck core.
:return: A list of output values corresponding to each element of key
being applied to the target endpoint.
:raises: SdkError
"""
return await self.query_inner(target, keys)
async def recv_raw(self) -> Optional[List[gen.Query]]:
"""
:meta private:
"""
out = []
try:
first = await self.rx.get()
except Exception as e:
# Underlying gRPC channel closed
# @Todo - tighten this exception
print(f"Recv exception: {e}")
return None
out.append(first)
# If more messages in the queue, opportunistically read more
while True:
try:
msg = self.rx.get_nowait()
except asyncio.QueueEmpty:
break
except Exception as e:
# @Todo - tighten this exception
print(f"Recv exception: {e}")
break
out.append(msg)
return out
async def send_session_error(self, plugin):
"""
:meta private:
"""
query = gen.Query(
id=self.id,
state=gen.QUERY_STATE_UNSPECIFIED,
publisher_name=plugin.publisher,
plugin_name=plugin.name,
query_name="",
concern=self.take_concerns(),
split=False,
)
await self.tx.put(query)
async def recv(self) -> Optional[Query]:
"""
:raises: SdkError
:meta private:
"""
synth = QuerySynthesizer()
res: Optional[Query] = None
while res is None:
opt_msg_chunks = await self.recv_raw()
if opt_msg_chunks is None:
return None
msg_chunks = opt_msg_chunks
res = synth.add(msg_chunks)
return res
[docs]
def record_concern(self, concern: str):
"""
Records a concern that will be emitted in the final Hipcheck report.
Intended for use within a `@query`-decorated endpoint function.
:param str concern: The concern to be recorded
"""
self.concerns.append(concern)
def take_concerns(self):
"""
:meta private:
"""
out = self.concerns
self.concerns = []
return out
async def send(self, query: Query):
"""
Send a gRPC query from plugin to the hipcheck server
:raises SdkError:
:meta private:
"""
query.id = self.id # incoming id value is just a placeholder
for pq in prepare(query):
await self.tx.put(pq)
async def handle_session_fallible(self, plugin):
"""
:raises SdkError:
:meta private:
"""
query: Query = await self.recv()
if query.direction == QueryDirection.RESPONSE:
raise ReceivedReplyWhenExpectingSubmitChunk()
name = query.query
# Per RFD 0009, there should only be one query key per query
if len(query.key) != 1:
raise UnspecifiedQueryState()
key = query.key[0]
query = next((x for x in plugin.queries() if x.name == name), None)
if query is None:
raise UnknownPluginQuery()
# None key type means they used an explicit key_schema so we leave as dict
if query.key_type is not None:
# Convert query as dict to target object schema
try:
key = deserialize_key(json.dumps(key), query.key_type)
except Exception as e:
logger.error(f"{e}")
raise InvalidJsonInQueryKey()
value = await query.func(self, key)
out = Query(
id=self.id,
direction=QueryDirection.RESPONSE,
publisher=plugin.publisher,
plugin=plugin.name,
query=name,
key=[],
output=[value],
concerns=self.take_concerns(),
)
await self.send(out)
# Notify HcSessionSocket that session is closed
await self.drop_tx.put(self.id)
async def handle_session(self, plugin):
"""
:meta private:
"""
try:
await self.handle_session_fallible(plugin)
# Errors that we raise intentionally
except SdkError as e:
logger.error(f"{e}")
await self.send_session_error(plugin)
# Other errors, such as syntactical ones
except Exception as e:
logger.error(f"{e}")
await self.send_session_error(plugin)
# except asyncio.QueueShutDown:
# return