diff --git a/gcp/api/server.py b/gcp/api/server.py index a54e9a1ff70..07368749fa4 100644 --- a/gcp/api/server.py +++ b/gcp/api/server.py @@ -22,6 +22,7 @@ import functools import logging import os +import threading import time from typing import Callable, List @@ -102,30 +103,47 @@ def wrapper(*args, **kwargs): return wrapper -def trace_log_fields(context: grpc.ServicerContext) -> dict: - """Makes the json_field needed to associate a log with the request trace.""" - fields = {} - trace_context = dict( - context.invocation_metadata()).get('x-cloud-trace-context') - if trace_context is None: - return fields - - # Trace context header example: - # "X-Cloud-Trace-Context: TRACE_ID/SPAN_ID;o=TRACE_TRUE" - parts = trace_context.split('/') - trace_id = parts[0] - # We don't set the GOOGLE_CLOUD_PROJECT env var explicitly, and I can't find - # any confirmation on whether Cloud Run will set automatically. - # Grab the project name from the (undocumented?) field on ndb.Client(). - # The most correct way to do this would be to use the instance metadata server - # https://cloud.google.com/run/docs/container-contract#metadata-server - project = getattr(_ndb_client, 'project', 'oss-vdb') # fall back to oss-vdb - fields['trace'] = f'projects/{project}/traces/{trace_id}' - if len(parts) > 1: - span_id = parts[1].split(';')[0] - fields['span_id'] = span_id - - return fields +class LogTraceFilter: + """Class for adding the trace information from the grpc requests into logs.""" + + def __init__(self): + self.thread_local = threading.local() + + def log_trace(self, func): + """Wrapper for grpc method to capture trace from header metadata""" + + @functools.wraps(func) + def wrapper(s, r, context: grpc.ServicerContext): + self.thread_local.trace = dict( + context.invocation_metadata()).get('x-cloud-trace-context') + return func(s, r, context) + + return wrapper + + def filter(self, record: logging.LogRecord) -> bool: + """logging.Filter method to add trace into log data.""" + trace = getattr(self.thread_local, 'trace', None) + if not trace: + return True + + # Trace context header example: + # "X-Cloud-Trace-Context: TRACE_ID/SPAN_ID;o=TRACE_TRUE" + parts = trace.split('/') + trace_id = parts[0] + # We don't set the GOOGLE_CLOUD_PROJECT env var explicitly, and I can't find + # any confirmation on whether Cloud Run will set automatically. + # Grab the project name from the (undocumented?) field on ndb.Client(). + # Most correct way to do this would be to use the instance metadata server + # https://cloud.google.com/run/docs/container-contract#metadata-server + project = getattr(_ndb_client, 'project', 'oss-vdb') # fall back to oss-vdb + record.trace = f'projects/{project}/traces/{trace_id}' + if len(parts) > 1: + record.span_id = parts[1].split(';')[0] + + return True + + +trace_filter = LogTraceFilter() class OSVServicer(osv_service_v1_pb2_grpc.OSVServicer, @@ -133,6 +151,7 @@ class OSVServicer(osv_service_v1_pb2_grpc.OSVServicer, """V1 OSV servicer.""" @ndb_context + @trace_filter.log_trace def GetVulnById(self, request, context: grpc.ServicerContext): """Return a `Vulnerability` object for a given OSV ID.""" bug: osv.Bug = osv.Bug.get_by_id(request.id) @@ -147,14 +166,13 @@ def GetVulnById(self, request, context: grpc.ServicerContext): return bug_to_response(bug, include_alias=True) @ndb_context + @trace_filter.log_trace def QueryAffected(self, request, context: grpc.ServicerContext): """Query vulnerabilities for a particular project at a given commit or version. """ - # Log some information about the query with structured logging - logging_trace = trace_log_fields(context) qtype, ecosystem, versioned = query_info(request.query) if ecosystem is not None: logging.info( @@ -167,18 +185,17 @@ def QueryAffected(self, request, context: grpc.ServicerContext): 'ecosystem': ecosystem, 'versioned': versioned == 'versioned' } - }, - **logging_trace + } }) else: - logging.info('QueryAffected for %s', qtype, extra=logging_trace) + logging.info('QueryAffected for %s', qtype) page_token = None if request.query.page_token: try: page_token = ndb.Cursor(urlsafe=request.query.page_token) except ValueError as e: - logging.warning(e, extra=logging_trace) + logging.warning(e) context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'Invalid page token.') query_context = QueryContext( @@ -205,12 +222,12 @@ def QueryAffected(self, request, context: grpc.ServicerContext): return None @ndb_context + @trace_filter.log_trace def QueryAffectedBatch(self, request, context: grpc.ServicerContext): """Query vulnerabilities (batch).""" batch_results = [] futures = [] - logging_trace = trace_log_fields(context) # Log some information about the query with structured logging e.g. # "message": "QueryAffectedBatch with 15 queries", # "details": { @@ -251,12 +268,9 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext): logging.info( 'QueryAffectedBatch with %d queries', len(request.query.queries), - extra={ - 'json_fields': { - 'details': query_details - }, - **logging_trace - }) + extra={'json_fields': { + 'details': query_details + }}) if len(request.query.queries) > _MAX_BATCH_QUERY: context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'Too many queries.') @@ -270,7 +284,7 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext): try: page_token = ndb.Cursor(urlsafe=query.page_token) except ValueError as e: - logging.warning(e, extra=logging_trace) + logging.warning(e) context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'Invalid page token at index: {i}.') query_context = QueryContext( @@ -300,6 +314,7 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext): return osv_service_v1_pb2.BatchVulnerabilityList(results=batch_results) @ndb_context + @trace_filter.log_trace def DetermineVersion(self, request, context: grpc.ServicerContext): """Determine the version of the provided hashes.""" res = determine_version(request.query, context).result() @@ -710,10 +725,7 @@ def to_response(b): if next_page_token: next_page_token = next_page_token.urlsafe() - logging.warning( - 'Page size limit hit, response size: %s', - len(bugs), - extra=trace_log_fields(context.service_context)) + logging.warning('Page size limit hit, response size: %s', len(bugs)) return bugs, next_page_token @@ -1057,9 +1069,7 @@ def query_by_version(context: QueryContext, context, query, package_name, ecosystem, purl, version) else: - logging.warning( - "Package query without ecosystem specified", - extra=trace_log_fields(context.service_context)) + logging.warning("Package query without ecosystem specified") # Unspecified ecosystem. Try semver first. # TODO: Remove after testing how many consumers are @@ -1172,6 +1182,7 @@ def main(): """Entrypoint.""" if is_cloud_run(): setup_gcp_logging('api-backend') + logging.getLogger().addFilter(trace_filter) logging.getLogger().setLevel(logging.INFO)