Source code for hipcheck_sdk.query
# SPDX-License-Identifier: Apache-2.0
import functools
import typing
from typing import Dict, Optional, Callable
from dataclasses import dataclass
import pydantic
@dataclass
class Endpoint:
"""
Class to encapsulate information about a `@query`-decorated function, thus
declared to be an endpoint for this plugin.
:meta private:
"""
# The name of the endpoint. If the default endpoint, name is ""
name: str
# The actual function that implements this endpoint
func: Callable
# The JSON Schema for the expected key
key_schema: dict
# The object type to convert json to
key_type: type
# The JSON Schema for the produced output object
output_schema: dict
def is_default(self):
return self.name == ""
# A global registry of all detected `@query`-decorated functions in this plugin.
# Used by the default `Plugin` class implementation to implement `queries()` and
# `schemas()` functions.
query_registry: Dict[str, Endpoint] = {}
def get_json_schema_for_type(ty: type) -> dict:
"""
Gets the JSON Schema for a Python object type. If the type is a child of
pydantic.BaseModel, use that. Otherwise, try to derive the schema.
:param type ty: The type for which to derive a JSON schema
:return: A jsonable dict representing the schema for `ty`
:meta private:
"""
if issubclass(ty, pydantic.BaseModel):
return ty.model_json_schema()
else:
adapter = pydantic.TypeAdapter(ty)
return adapter.json_schema()
def register_query(func, default, key_schema, output_schema):
"""
Add the function to `query_registry`. If `key_schema` or `output_schema`
are None, try to derive the schema.
:meta private:
"""
global query_registry
# Validate that func has 2 positional args
if func.__code__.co_argcount != 2:
raise TypeError("query function must have exactly 2 positional arguments")
var_names = func.__code__.co_varnames
hints = typing.get_type_hints(func)
if key_schema is None:
# Try to derive from function
key_type = hints[var_names[1]]
key_schema = get_json_schema_for_type(key_type)
else:
# We do not generate a class definition for key_schema due to
# potential code execution security concerns
key_type = None
if output_schema is None:
if "return" not in hints:
raise TypeError(
"cannot deduce query output type without return type hint on signature"
)
out_hint = hints["return"]
output_schema = get_json_schema_for_type(out_hint)
key = func.__name__
# Create an additional entry that maps this func to the empty string so queries
# received without an endpoint name will call it
if default:
if "" in query_registry:
raise TypeError("default query already defined")
query_registry[""] = Endpoint("", func, key_schema, key_type, output_schema)
query_registry[key] = Endpoint(key, func, key_schema, key_type, output_schema)
[docs]
def query(
f_py=None,
default: bool = False,
key_schema: Optional[dict] = None,
output_schema=None,
):
"""
Decorator function for query endpoints. Endpoint functions must have the
following signature:
.. code-block:: python
async fn <QUERY_NAME>(engine: hipcheck_sdk.engine.PluginEngine, key: <TYPE>) -> <TYPE>
:param bool default: Whether this endpoint is the default for the plugin
:param dict key_schema: A jsonable dict representing the schema for the key
to this enpdpoint. If `None`, derive from type hint on key parameter of
function. If a schema is supplied explicitly instead of having the SDK
use the type hint, the object passed to the query func will not be
automatically converted to a class instance.
:param dict output_schema: A jsonable dict representing the schema for the
return value of this endpoint. If `None`, derive from type hint on
return value of function.
:raises TypeError: The function lacked type hints or a JSON schema could
not be derived from them
"""
global query_registry
assert callable(f_py) or f_py is None
def _decorator(func):
register_query(func, default, key_schema, output_schema)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return _decorator(f_py) if callable(f_py) else _decorator