Skip to content

Commit

Permalink
Add traces to logs through logging.Filter (#2067)
Browse files Browse the repository at this point in the history
Added a new class/decorator to automatically insert traces into logs in
API.
Re:
#2055 (review)
  • Loading branch information
michaelkedar authored Mar 20, 2024
1 parent f658b6d commit 70da8e8
Showing 1 changed file with 56 additions and 45 deletions.
101 changes: 56 additions & 45 deletions gcp/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import functools
import logging
import os
import threading
import time
from typing import Callable, List

Expand Down Expand Up @@ -102,37 +103,55 @@ def wrapper(*args, **kwargs):
return wrapper


def trace_log_fields(context: grpc.ServicerContext) -> dict:
"""Makes the json_field needed to associate a log with the request trace."""
fields = {}
trace_context = dict(
context.invocation_metadata()).get('x-cloud-trace-context')
if trace_context is None:
return fields

# Trace context header example:
# "X-Cloud-Trace-Context: TRACE_ID/SPAN_ID;o=TRACE_TRUE"
parts = trace_context.split('/')
trace_id = parts[0]
# We don't set the GOOGLE_CLOUD_PROJECT env var explicitly, and I can't find
# any confirmation on whether Cloud Run will set automatically.
# Grab the project name from the (undocumented?) field on ndb.Client().
# The most correct way to do this would be to use the instance metadata server
# https://cloud.google.com/run/docs/container-contract#metadata-server
project = getattr(_ndb_client, 'project', 'oss-vdb') # fall back to oss-vdb
fields['trace'] = f'projects/{project}/traces/{trace_id}'
if len(parts) > 1:
span_id = parts[1].split(';')[0]
fields['span_id'] = span_id

return fields
class LogTraceFilter:
"""Class for adding the trace information from the grpc requests into logs."""

def __init__(self):
self.thread_local = threading.local()

def log_trace(self, func):
"""Wrapper for grpc method to capture trace from header metadata"""

@functools.wraps(func)
def wrapper(s, r, context: grpc.ServicerContext):
self.thread_local.trace = dict(
context.invocation_metadata()).get('x-cloud-trace-context')
return func(s, r, context)

return wrapper

def filter(self, record: logging.LogRecord) -> bool:
"""logging.Filter method to add trace into log data."""
trace = getattr(self.thread_local, 'trace', None)
if not trace:
return True

# Trace context header example:
# "X-Cloud-Trace-Context: TRACE_ID/SPAN_ID;o=TRACE_TRUE"
parts = trace.split('/')
trace_id = parts[0]
# We don't set the GOOGLE_CLOUD_PROJECT env var explicitly, and I can't find
# any confirmation on whether Cloud Run will set automatically.
# Grab the project name from the (undocumented?) field on ndb.Client().
# Most correct way to do this would be to use the instance metadata server
# https://cloud.google.com/run/docs/container-contract#metadata-server
project = getattr(_ndb_client, 'project', 'oss-vdb') # fall back to oss-vdb
record.trace = f'projects/{project}/traces/{trace_id}'
if len(parts) > 1:
record.span_id = parts[1].split(';')[0]

return True


trace_filter = LogTraceFilter()


class OSVServicer(osv_service_v1_pb2_grpc.OSVServicer,
health_pb2_grpc.HealthServicer):
"""V1 OSV servicer."""

@ndb_context
@trace_filter.log_trace
def GetVulnById(self, request, context: grpc.ServicerContext):
"""Return a `Vulnerability` object for a given OSV ID."""
bug: osv.Bug = osv.Bug.get_by_id(request.id)
Expand All @@ -147,14 +166,13 @@ def GetVulnById(self, request, context: grpc.ServicerContext):
return bug_to_response(bug, include_alias=True)

@ndb_context
@trace_filter.log_trace
def QueryAffected(self, request, context: grpc.ServicerContext):
"""Query vulnerabilities for a particular project at a given commit or
version.
"""

# Log some information about the query with structured logging
logging_trace = trace_log_fields(context)
qtype, ecosystem, versioned = query_info(request.query)
if ecosystem is not None:
logging.info(
Expand All @@ -167,18 +185,17 @@ def QueryAffected(self, request, context: grpc.ServicerContext):
'ecosystem': ecosystem,
'versioned': versioned == 'versioned'
}
},
**logging_trace
}
})
else:
logging.info('QueryAffected for %s', qtype, extra=logging_trace)
logging.info('QueryAffected for %s', qtype)

page_token = None
if request.query.page_token:
try:
page_token = ndb.Cursor(urlsafe=request.query.page_token)
except ValueError as e:
logging.warning(e, extra=logging_trace)
logging.warning(e)
context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'Invalid page token.')

query_context = QueryContext(
Expand All @@ -205,12 +222,12 @@ def QueryAffected(self, request, context: grpc.ServicerContext):
return None

@ndb_context
@trace_filter.log_trace
def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
"""Query vulnerabilities (batch)."""
batch_results = []
futures = []

logging_trace = trace_log_fields(context)
# Log some information about the query with structured logging e.g.
# "message": "QueryAffectedBatch with 15 queries",
# "details": {
Expand Down Expand Up @@ -251,12 +268,9 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
logging.info(
'QueryAffectedBatch with %d queries',
len(request.query.queries),
extra={
'json_fields': {
'details': query_details
},
**logging_trace
})
extra={'json_fields': {
'details': query_details
}})

if len(request.query.queries) > _MAX_BATCH_QUERY:
context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'Too many queries.')
Expand All @@ -270,7 +284,7 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
try:
page_token = ndb.Cursor(urlsafe=query.page_token)
except ValueError as e:
logging.warning(e, extra=logging_trace)
logging.warning(e)
context.abort(grpc.StatusCode.INVALID_ARGUMENT,
f'Invalid page token at index: {i}.')
query_context = QueryContext(
Expand Down Expand Up @@ -300,6 +314,7 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
return osv_service_v1_pb2.BatchVulnerabilityList(results=batch_results)

@ndb_context
@trace_filter.log_trace
def DetermineVersion(self, request, context: grpc.ServicerContext):
"""Determine the version of the provided hashes."""
res = determine_version(request.query, context).result()
Expand Down Expand Up @@ -710,10 +725,7 @@ def to_response(b):

if next_page_token:
next_page_token = next_page_token.urlsafe()
logging.warning(
'Page size limit hit, response size: %s',
len(bugs),
extra=trace_log_fields(context.service_context))
logging.warning('Page size limit hit, response size: %s', len(bugs))

return bugs, next_page_token

Expand Down Expand Up @@ -1057,9 +1069,7 @@ def query_by_version(context: QueryContext,
context, query, package_name, ecosystem, purl, version)

else:
logging.warning(
"Package query without ecosystem specified",
extra=trace_log_fields(context.service_context))
logging.warning("Package query without ecosystem specified")
# Unspecified ecosystem. Try semver first.

# TODO: Remove after testing how many consumers are
Expand Down Expand Up @@ -1172,6 +1182,7 @@ def main():
"""Entrypoint."""
if is_cloud_run():
setup_gcp_logging('api-backend')
logging.getLogger().addFilter(trace_filter)

logging.getLogger().setLevel(logging.INFO)

Expand Down

0 comments on commit 70da8e8

Please sign in to comment.