diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index f42af328e..bc4a5324a 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,4 +1,22 @@ -ARG IMAGE=ghcr.io/newrelic-experimental/pyenv-devcontainer:latest - # To target other architectures, change the --platform directive in the Dockerfile. -FROM --platform=linux/amd64 ${IMAGE} +ARG IMAGE_TAG=latest +FROM ghcr.io/newrelic/newrelic-python-agent-ci:${IMAGE_TAG} + +# Setup non-root user +USER root +ARG UID=1000 +ARG GID=$UID +ENV HOME /home/vscode +RUN mkdir -p ${HOME} && \ + groupadd --gid ${GID} vscode && \ + useradd --uid ${UID} --gid ${GID} --home ${HOME} vscode && \ + chown -R ${UID}:${GID} /home/vscode + +# Move pyenv installation +ENV PYENV_ROOT="${HOME}/.pyenv" +ENV PATH="$PYENV_ROOT/bin:$PYENV_ROOT/shims:${PATH}" +RUN mv /root/.pyenv /home/vscode/.pyenv && \ + chown -R vscode:vscode /home/vscode/.pyenv + +# Set user +USER ${UID}:${GID} diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 92a8cdee4..fbefff476 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -5,7 +5,7 @@ // To target other architectures, change the --platform directive in the Dockerfile. "dockerfile": "Dockerfile", "args": { - "IMAGE": "ghcr.io/newrelic-experimental/pyenv-devcontainer:latest" + "IMAGE_TAG": "latest" } }, "remoteUser": "vscode", diff --git a/.github/actions/setup-python-matrix/action.yml b/.github/actions/setup-python-matrix/action.yml deleted file mode 100644 index 299dd2e7b..000000000 --- a/.github/actions/setup-python-matrix/action.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: "setup-python-matrix" -description: "Sets up all versions of python required for matrix testing in this repo." -runs: - using: "composite" - steps: - - uses: actions/setup-python@v4 - with: - python-version: "pypy-3.7" - architecture: x64 - - - uses: actions/setup-python@v4 - with: - python-version: "pypy-2.7" - architecture: x64 - - - uses: actions/setup-python@v4 - with: - python-version: "3.7" - architecture: x64 - - - uses: actions/setup-python@v4 - with: - python-version: "3.8" - architecture: x64 - - - uses: actions/setup-python@v4 - with: - python-version: "3.9" - architecture: x64 - - - uses: actions/setup-python@v4 - with: - python-version: "3.10" - architecture: x64 - - - uses: actions/setup-python@v4 - with: - python-version: "3.11" - architecture: x64 - - - uses: actions/setup-python@v4 - with: - python-version: "2.7" - architecture: x64 - - - name: Install Dependencies - shell: bash - run: | - python3.10 -m pip install -U pip - python3.10 -m pip install -U wheel setuptools tox 'virtualenv<20.22.0' diff --git a/.github/actions/update-rpm-config/action.yml b/.github/actions/update-rpm-config/action.yml new file mode 100644 index 000000000..9d19ebba0 --- /dev/null +++ b/.github/actions/update-rpm-config/action.yml @@ -0,0 +1,109 @@ +name: "update-rpm-config" +description: "Set current version of agent in rpm config using API." +inputs: + agent-language: + description: "Language agent to configure (eg. python)" + required: true + default: "python" + target-system: + description: "Target System: prod|staging|all" + required: true + default: "all" + agent-version: + description: "3-4 digit agent version number (eg. 1.2.3) with optional leading v (ignored)" + required: true + dry-run: + description: "Dry Run" + required: true + default: "false" + production-api-key: + description: "API key for New Relic Production" + required: false + staging-api-key: + description: "API key for New Relic Staging" + required: false + +runs: + using: "composite" + steps: + - name: Trim potential leading v from agent version + shell: bash + run: | + AGENT_VERSION=${{ inputs.agent-version }} + echo "AGENT_VERSION=${AGENT_VERSION#"v"}" >> $GITHUB_ENV + + - name: Generate Payload + shell: bash + run: | + echo "PAYLOAD='{ \"system_configuration\": { \"key\": \"${{ inputs.agent-language }}_agent_version\", \"value\": \"${{ env.AGENT_VERSION }}\" } }'" >> $GITHUB_ENV + + - name: Generate Content-Type + shell: bash + run: | + echo "CONTENT_TYPE='Content-Type: application/json'" >> $GITHUB_ENV + + - name: Update Staging system configuration page + shell: bash + if: ${{ inputs.dry-run == 'false' && (inputs.target-system == 'staging' || inputs.target-system == 'all') }} + run: | + curl -X POST 'https://staging-api.newrelic.com/v2/system_configuration.json' \ + -H "X-Api-Key:${{ inputs.staging-api-key }}" -i \ + -H ${{ env.CONTENT_TYPE }} \ + -d ${{ env.PAYLOAD }} + + - name: Update Production system configuration page + shell: bash + if: ${{ inputs.dry-run == 'false' && (inputs.target-system == 'prod' || inputs.target-system == 'all') }} + run: | + curl -X POST 'https://api.newrelic.com/v2/system_configuration.json' \ + -H "X-Api-Key:${{ inputs.production-api-key }}" -i \ + -H ${{ env.CONTENT_TYPE }} \ + -d ${{ env.PAYLOAD }} + + - name: Verify Staging system configuration update + shell: bash + if: ${{ inputs.dry-run == 'false' && (inputs.target-system == 'staging' || inputs.target-system == 'all') }} + run: | + STAGING_VERSION=$(curl -X GET 'https://staging-api.newrelic.com/v2/system_configuration.json' \ + -H "X-Api-Key:${{ inputs.staging-api-key }}" \ + -H "${{ env.CONTENT_TYPE }}" | jq ".system_configurations | from_entries | .${{inputs.agent-language}}_agent_version") + + if [ "${{ env.AGENT_VERSION }}" != "$STAGING_VERSION" ]; then + echo "Staging version mismatch: $STAGING_VERSION" + exit 1 + fi + + - name: Verify Production system configuration update + shell: bash + if: ${{ inputs.dry-run == 'false' && (inputs.target-system == 'prod' || inputs.target-system == 'all') }} + run: | + PROD_VERSION=$(curl -X GET 'https://api.newrelic.com/v2/system_configuration.json' \ + -H "X-Api-Key:${{ inputs.production-api-key }}" \ + -H "${{ env.CONTENT_TYPE }}" | jq ".system_configurations | from_entries | .${{inputs.agent-language}}_agent_version") + + if [ "${{ env.AGENT_VERSION }}" != "$PROD_VERSION" ]; then + echo "Production version mismatch: $PROD_VERSION" + exit 1 + fi + + - name: (dry-run) Update Staging system configuration page + shell: bash + if: ${{ inputs.dry-run != 'false' && (inputs.target-system == 'staging' || inputs.target-system == 'all') }} + run: | + cat << EOF + curl -X POST 'https://staging-api.newrelic.com/v2/system_configuration.json' \ + -H "X-Api-Key:**REDACTED**" -i \ + -H ${{ env.CONTENT_TYPE }} \ + -d ${{ env.PAYLOAD }} + EOF + + - name: (dry-run) Update Production system configuration page + shell: bash + if: ${{ inputs.dry-run != 'false' && (inputs.target-system == 'prod' || inputs.target-system == 'all') }} + run: | + cat << EOF + curl -X POST 'https://api.newrelic.com/v2/system_configuration.json' \ + -H "X-Api-Key:**REDACTED**" -i \ + -H ${{ env.CONTENT_TYPE }} \ + -d ${{ env.PAYLOAD }} + EOF diff --git a/.github/containers/Dockerfile b/.github/containers/Dockerfile new file mode 100644 index 000000000..57d8c234c --- /dev/null +++ b/.github/containers/Dockerfile @@ -0,0 +1,107 @@ + +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM ubuntu:20.04 + +# Install OS packages +RUN export DEBIAN_FRONTEND=noninteractive && \ + apt-get update && \ + apt-get install -y \ + bash \ + build-essential \ + curl \ + expat \ + fish \ + fontconfig \ + freetds-common \ + freetds-dev \ + gcc \ + git \ + libbz2-dev \ + libcurl4-openssl-dev \ + libffi-dev \ + libgmp-dev \ + libkrb5-dev \ + liblzma-dev \ + libmpfr-dev \ + libncurses-dev \ + libpq-dev \ + libreadline-dev \ + libsqlite3-dev \ + libssl-dev \ + locales \ + make \ + odbc-postgresql \ + openssl \ + python2-dev \ + python3-dev \ + python3-pip \ + sudo \ + tzdata \ + unixodbc-dev \ + unzip \ + vim \ + wget \ + zip \ + zlib1g \ + zlib1g-dev \ + zsh && \ + rm -rf /var/lib/apt/lists/* + +# Build librdkafka from source +ARG LIBRDKAFKA_VERSION=2.1.1 +RUN cd /tmp && \ + wget https://github.com/confluentinc/librdkafka/archive/refs/tags/v${LIBRDKAFKA_VERSION}.zip -O ./librdkafka.zip && \ + unzip ./librdkafka.zip && \ + rm ./librdkafka.zip && \ + cd ./librdkafka-${LIBRDKAFKA_VERSION} && \ + ./configure && \ + make all install && \ + cd /tmp && \ + rm -rf ./librdkafka-${LIBRDKAFKA_VERSION} + +# Setup ODBC config +RUN sed -i 's|Driver=psqlodbca.so|Driver=/usr/lib/x86_64-linux-gnu/odbc/psqlodbca.so|g' /etc/odbcinst.ini && \ + sed -i 's|Driver=psqlodbcw.so|Driver=/usr/lib/x86_64-linux-gnu/odbc/psqlodbcw.so|g' /etc/odbcinst.ini && \ + sed -i 's|Setup=libodbcpsqlS.so|Setup=/usr/lib/x86_64-linux-gnu/odbc/libodbcpsqlS.so|g' /etc/odbcinst.ini + +# Set the locale +RUN locale-gen --no-purge en_US.UTF-8 +ENV LANG=en_US.UTF-8 \ LANGUAGE=en_US:en \ LC_ALL=en_US.UTF-8 +ENV TZ="Etc/UTC" +RUN ln -fs "/usr/share/zoneinfo/${TZ}" /etc/localtime && \ + dpkg-reconfigure -f noninteractive tzdata + +# Use root user +ENV HOME /root +WORKDIR "${HOME}" + +# Install pyenv +ENV PYENV_ROOT="${HOME}/.pyenv" +RUN curl https://pyenv.run/ | /bin/bash +ENV PATH="$PYENV_ROOT/bin:$PYENV_ROOT/shims:${PATH}" +RUN echo 'eval "$(pyenv init -)"' >>$HOME/.bashrc && \ + pyenv update + +# Install Python +ARG PYTHON_VERSIONS="3.11 3.10 3.9 3.8 3.7 3.12 2.7 pypy2.7-7.3.12 pypy3.8-7.3.11" +COPY --chown=1000:1000 --chmod=+x ./install-python.sh /tmp/install-python.sh +RUN /tmp/install-python.sh && \ + rm /tmp/install-python.sh + +# Install dependencies for main python installation +COPY ./requirements.txt /tmp/requirements.txt +RUN pyenv exec pip install --upgrade -r /tmp/requirements.txt && \ + rm /tmp/requirements.txt \ No newline at end of file diff --git a/.github/containers/Makefile b/.github/containers/Makefile new file mode 100644 index 000000000..97b4e7256 --- /dev/null +++ b/.github/containers/Makefile @@ -0,0 +1,71 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Override constants +PLATFORM_OVERRIDE:= +PYTHON_VERSIONS_OVERRIDE:= + +# Computed variables +IMAGE_NAME:=ghcr.io/newrelic/newrelic-python-agent-ci +MAKEFILE_DIR:=$(dir $(realpath $(firstword ${MAKEFILE_LIST}))) +REPO_ROOT:=$(realpath ${MAKEFILE_DIR}../../) +UNAME_P:=$(shell uname -p) +PLATFORM_AUTOMATIC:=$(if $(findstring arm,${UNAME_P}),linux/arm64,linux/amd64) +PLATFORM:=$(if ${PLATFORM_OVERRIDE},${PLATFORM_OVERRIDE},${PLATFORM_AUTOMATIC}) +PYTHON_VERSIONS_AUTOMATIC:=3.10 2.7 +PYTHON_VERSIONS:=$(if ${PYTHON_VERSIONS_OVERRIDE},${PYTHON_VERSIONS_OVERRIDE},${PYTHON_VERSIONS_AUTOMATIC}) + +.PHONY: default +default: test + +.PHONY: build +build: + @docker build ${MAKEFILE_DIR} \ + --platform=${PLATFORM} \ + -t ${IMAGE_NAME}:local \ + --build-arg='PYTHON_VERSIONS=${PYTHON_VERSIONS}' + +# Run the local tag as a container. +.PHONY: run +run: run.local + +# Run a specific tag as a container. +# Usage: make run. +# Defaults to run.local, but can instead be run.latest or any other tag. +.PHONY: run.% +run.%: +# Build image if local was specified, else pull latest + @if [[ "$*" = "local" ]]; then cd ${MAKEFILE_DIR} && $(MAKE) build; else docker pull ${IMAGE_NAME}:$*; fi + @docker run --rm -it \ + --platform=${PLATFORM} \ + --mount type=bind,source="${REPO_ROOT}",target=/home/github/python-agent \ + --workdir=/home/github/python-agent \ + --add-host=host.docker.internal:host-gateway \ + -e NEW_RELIC_HOST="${NEW_RELIC_HOST}" \ + -e NEW_RELIC_LICENSE_KEY="${NEW_RELIC_LICENSE_KEY}" \ + -e NEW_RELIC_DEVELOPER_MODE="${NEW_RELIC_DEVELOPER_MODE}" \ + -e GITHUB_ACTIONS="true" \ + ${IMAGE_NAME}:$* /bin/bash + +# Ensure python versions are usable. Cannot be automatically used with PYTHON_VERSIONS_OVERRIDE. +.PHONY: test +test: build + @docker run --rm \ + --platform=${PLATFORM} \ + ghcr.io/newrelic/python-agent-ci:local \ + /bin/bash -c '\ + python3.10 --version && \ + python2.7 --version && \ + touch tox.ini && tox --version && \ + echo "Success! Python versions installed."' diff --git a/.github/containers/install-python.sh b/.github/containers/install-python.sh new file mode 100755 index 000000000..f9da0a003 --- /dev/null +++ b/.github/containers/install-python.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +main() { + # Coerce space separated string to array + if [[ ${#PYTHON_VERSIONS[@]} -eq 1 ]]; then + PYTHON_VERSIONS=($PYTHON_VERSIONS) + fi + + if [[ -z "${PYTHON_VERSIONS[@]}" ]]; then + echo "No python versions specified. Make sure PYTHON_VERSIONS is set." 1>&2 + exit 1 + fi + + # Find all latest pyenv supported versions for requested python versions + PYENV_VERSIONS=() + for v in "${PYTHON_VERSIONS[@]}"; do + LATEST=$(pyenv latest -k "$v" || pyenv latest -k "$v-dev") + if [[ -z "$LATEST" ]]; then + echo "Latest version could not be found for ${v}." 1>&2 + exit 1 + fi + PYENV_VERSIONS+=($LATEST) + done + + # Install each specific version + for v in "${PYENV_VERSIONS[@]}"; do + pyenv install "$v" & + done + wait + + # Set all installed versions as globally accessible + pyenv global ${PYENV_VERSIONS[@]} +} + +main diff --git a/.github/containers/requirements.txt b/.github/containers/requirements.txt new file mode 100644 index 000000000..68bdfe4fe --- /dev/null +++ b/.github/containers/requirements.txt @@ -0,0 +1,9 @@ +bandit +black +flake8 +isort +pip +setuptools +tox +virtualenv<20.22.0 +wheel \ No newline at end of file diff --git a/.github/scripts/retry.sh b/.github/scripts/retry.sh index 1cb17836e..079798a72 100755 --- a/.github/scripts/retry.sh +++ b/.github/scripts/retry.sh @@ -1,4 +1,18 @@ #!/bin/bash +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Time in seconds to backoff after the initial attempt. INITIAL_BACKOFF=10 @@ -25,4 +39,4 @@ for i in $(seq 1 $retries); do done # Exit with status code of wrapped command -exit $? +exit $result diff --git a/.github/stale.yml b/.github/stale.yml index 9d84541db..39e994219 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -13,7 +13,7 @@ # limitations under the License. # # Number of days of inactivity before an issue becomes stale -daysUntilStale: 60 +daysUntilStale: 365 # Number of days of inactivity before a stale issue is closed # Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. daysUntilClose: false diff --git a/.github/workflows/build-ci-image.yml b/.github/workflows/build-ci-image.yml new file mode 100644 index 000000000..9d60cea8e --- /dev/null +++ b/.github/workflows/build-ci-image.yml @@ -0,0 +1,68 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Build CI Image + +on: + workflow_dispatch: # Allow manual trigger + +concurrency: + group: ${{ github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + persist-credentials: false + fetch-depth: 0 + + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v2 + + - name: Generate Docker Metadata (Tags and Labels) + id: meta + uses: docker/metadata-action@v4 + with: + images: ghcr.io/${{ github.repository }}-ci + flavor: | + prefix= + suffix= + latest=false + tags: | + type=raw,value=latest,enable={{is_default_branch}} + type=schedule,pattern={{date 'YYYY-MM-DD'}} + type=sha,format=short,prefix=sha- + type=sha,format=long,prefix=sha- + + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and Publish Image + uses: docker/build-push-action@v3 + with: + push: ${{ github.event_name != 'pull_request' }} + context: .github/containers + platforms: ${{ (format('refs/heads/{0}', github.event.repository.default_branch) == github.ref) && 'linux/amd64,linux/arm64' || 'linux/amd64' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/deploy-python.yml b/.github/workflows/deploy-python.yml index fe16ee485..ca908b825 100644 --- a/.github/workflows/deploy-python.yml +++ b/.github/workflows/deploy-python.yml @@ -80,3 +80,13 @@ jobs: env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + + - name: Update RPM Config + uses: ./.github/actions/update-rpm-config + with: + agent-language: "python" + target-system: "all" + agent-version: "${{ github.ref_name }}" + dry-run: "false" + production-api-key: ${{ secrets.NEW_RELIC_API_KEY_PRODUCTION }}" + staging-api-key: ${{ secrets.NEW_RELIC_API_KEY_STAGING }}" diff --git a/.github/workflows/get-envs.py b/.github/workflows/get-envs.py index 576cbeb5c..4fcba6aa7 100755 --- a/.github/workflows/get-envs.py +++ b/.github/workflows/get-envs.py @@ -1,4 +1,18 @@ #!/usr/bin/env python3.8 +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import fileinput import os diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index dc73168eb..402d0c629 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,14 +38,15 @@ jobs: - elasticsearchserver08 - gearman - grpc - #- kafka - - libcurl + - kafka - memcached - mongodb + - mssql - mysql - postgres - rabbitmq - redis + - rediscluster - solr steps: @@ -117,16 +118,24 @@ jobs: ] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -154,16 +163,24 @@ jobs: group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -181,32 +198,49 @@ jobs: path: ./**/.coverage.* retention-days: 1 - libcurl: + postgres: env: - TOTAL_GROUPS: 1 + TOTAL_GROUPS: 2 strategy: fail-fast: false matrix: - group-number: [1] + group-number: [1, 2] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 + services: + postgres: + image: postgres:9 + env: + POSTGRES_PASSWORD: postgres + ports: + - 8080:5432 + - 8081:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix - # Special case packages - - name: Install libcurl-dev + - name: Fetch git tags run: | - sudo apt-get update - sudo apt-get install libcurl4-openssl-dev + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -224,49 +258,52 @@ jobs: path: ./**/.coverage.* retention-days: 1 - postgres: + mssql: env: - TOTAL_GROUPS: 2 + TOTAL_GROUPS: 1 strategy: fail-fast: false matrix: - group-number: [1, 2] + group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: - postgres: - image: postgres:9 + mssql: + image: mcr.microsoft.com/azure-sql-edge:latest env: - POSTGRES_PASSWORD: postgres + MSSQL_USER: python_agent + MSSQL_PASSWORD: python_agent + MSSQL_SA_PASSWORD: "python_agent#1234" + ACCEPT_EULA: "Y" ports: - - 8080:5432 - - 8081:5432 - # Set health checks to wait until postgres has started + - 8080:1433 + - 8081:1433 + # Set health checks to wait until mysql has started options: >- - --health-cmd pg_isready + --health-cmd "/opt/mssql-tools/bin/sqlcmd -U SA -P $MSSQL_SA_PASSWORD -Q 'SELECT 1'" --health-interval 10s --health-timeout 5s --health-retries 5 steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix - - name: Install odbc driver for postgresql + - name: Fetch git tags run: | - sudo apt-get update - sudo sudo apt-get install odbc-postgresql - sudo sed -i 's/Driver=psqlodbca.so/Driver=\/usr\/lib\/x86_64-linux-gnu\/odbc\/psqlodbca.so/g' /etc/odbcinst.ini - sudo sed -i 's/Driver=psqlodbcw.so/Driver=\/usr\/lib\/x86_64-linux-gnu\/odbc\/psqlodbcw.so/g' /etc/odbcinst.ini - sudo sed -i 's/Setup=libodbcpsqlS.so/Setup=\/usr\/lib\/x86_64-linux-gnu\/odbc\/libodbcpsqlS.so/g' /etc/odbcinst.ini + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -294,6 +331,10 @@ jobs: group-number: [1, 2] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: @@ -316,12 +357,115 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin + + - name: Get Environments + id: get-envs + run: | + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT + env: + GROUP_NUMBER: ${{ matrix.group-number }} + + - name: Test + run: | + tox -vv -e ${{ steps.get-envs.outputs.envs }} -p auto + env: + TOX_PARALLEL_NO_SPINNER: 1 + PY_COLORS: 0 + + - name: Upload Coverage Artifacts + uses: actions/upload-artifact@v3 + with: + name: coverage-${{ github.job }}-${{ strategy.job-index }} + path: ./**/.coverage.* + retention-days: 1 + + rediscluster: + env: + TOTAL_GROUPS: 1 + + strategy: + fail-fast: false + matrix: + group-number: [1] + + runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway + timeout-minutes: 30 + + services: + redis1: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6379:6379 + - 16379:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis2: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6380:6379 + - 16380:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis3: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6381:6379 + - 16381:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis4: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6382:6379 + - 16382:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis5: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6383:6379 + - 16383:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis6: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6384:6379 + - 16384:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + cluster-setup: + image: hmstepanek/redis-cluster:1.0.0 + options: >- + --add-host=host.docker.internal:host-gateway + + steps: + - uses: actions/checkout@v3 + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -349,6 +493,10 @@ jobs: group-number: [1, 2] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: @@ -366,12 +514,16 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -399,6 +551,10 @@ jobs: group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: @@ -418,12 +574,16 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -451,6 +611,10 @@ jobs: group-number: [1, 2] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: @@ -468,12 +632,16 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -501,6 +669,10 @@ jobs: group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: @@ -519,12 +691,16 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -542,77 +718,75 @@ jobs: path: ./**/.coverage.* retention-days: 1 - #kafka: - # env: - # TOTAL_GROUPS: 4 - - # strategy: - # fail-fast: false - # matrix: - # group-number: [1, 2, 3, 4] - - # runs-on: ubuntu-20.04 - # timeout-minutes: 30 - - # services: - # zookeeper: - # image: bitnami/zookeeper:3.7 - # env: - # ALLOW_ANONYMOUS_LOGIN: yes - - # ports: - # - 2181:2181 - - # kafka: - # image: bitnami/kafka:3.2 - # ports: - # - 8080:8080 - # - 8081:8081 - # env: - # ALLOW_PLAINTEXT_LISTENER: yes - # KAFKA_ZOOKEEPER_CONNECT: zookeeper:2181 - # KAFKA_CFG_AUTO_CREATE_TOPICS_ENABLE: true - # KAFKA_CFG_LISTENERS: L1://:8080,L2://:8081 - # KAFKA_CFG_ADVERTISED_LISTENERS: L1://127.0.0.1:8080,L2://kafka:8081, - # KAFKA_CFG_LISTENER_SECURITY_PROTOCOL_MAP: L1:PLAINTEXT,L2:PLAINTEXT - # KAFKA_CFG_INTER_BROKER_LISTENER_NAME: L2 - - # steps: - # - uses: actions/checkout@v3 - # - uses: ./.github/actions/setup-python-matrix - - # # Special case packages - # - name: Install librdkafka-dev - # run: | - # # Use lsb-release to find the codename of Ubuntu to use to install the correct library name - # sudo apt-get update - # sudo ln -fs /usr/share/zoneinfo/America/Los_Angeles /etc/localtime - # sudo apt-get install -y wget gnupg2 software-properties-common - # sudo wget -qO - https://packages.confluent.io/deb/7.2/archive.key | sudo apt-key add - - # sudo add-apt-repository "deb https://packages.confluent.io/clients/deb $(lsb_release -cs) main" - # sudo apt-get update - # sudo apt-get install -y librdkafka-dev/$(lsb_release -c | cut -f 2) - - # - name: Get Environments - # id: get-envs - # run: | - # echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" - # env: - # GROUP_NUMBER: ${{ matrix.group-number }} - - # - name: Test - # run: | - # tox -vv -e ${{ steps.get-envs.outputs.envs }} - # env: - # TOX_PARALLEL_NO_SPINNER: 1 - # PY_COLORS: 0 - - # - name: Upload Coverage Artifacts - # uses: actions/upload-artifact@v3 - # with: - # name: coverage-${{ github.job }}-${{ strategy.job-index }} - # path: ./**/.coverage.* - # retention-days: 1 + kafka: + env: + TOTAL_GROUPS: 4 + + strategy: + fail-fast: false + matrix: + group-number: [1, 2, 3, 4] + + runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway + timeout-minutes: 30 + + services: + zookeeper: + image: bitnami/zookeeper:3.7 + env: + ALLOW_ANONYMOUS_LOGIN: yes + + ports: + - 2181:2181 + + kafka: + image: bitnami/kafka:3.2 + ports: + - 8080:8080 + - 8082:8082 + - 8083:8083 + env: + KAFKA_ENABLE_KRAFT: no + ALLOW_PLAINTEXT_LISTENER: yes + KAFKA_ZOOKEEPER_CONNECT: zookeeper:2181 + KAFKA_CFG_AUTO_CREATE_TOPICS_ENABLE: true + KAFKA_CFG_LISTENERS: L1://:8082,L2://:8083,L3://:8080 + KAFKA_CFG_ADVERTISED_LISTENERS: L1://host.docker.internal:8082,L2://host.docker.internal:8083,L3://kafka:8080 + KAFKA_CFG_LISTENER_SECURITY_PROTOCOL_MAP: L1:PLAINTEXT,L2:PLAINTEXT,L3:PLAINTEXT + KAFKA_CFG_INTER_BROKER_LISTENER_NAME: L3 + + steps: + - uses: actions/checkout@v3 + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin + + - name: Get Environments + id: get-envs + run: | + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT + env: + GROUP_NUMBER: ${{ matrix.group-number }} + + - name: Test + run: | + tox -vv -e ${{ steps.get-envs.outputs.envs }} -p auto + env: + TOX_PARALLEL_NO_SPINNER: 1 + PY_COLORS: 0 + + - name: Upload Coverage Artifacts + uses: actions/upload-artifact@v3 + with: + name: coverage-${{ github.job }}-${{ strategy.job-index }} + path: ./**/.coverage.* + retention-days: 1 mongodb: env: @@ -624,6 +798,10 @@ jobs: group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: @@ -641,12 +819,16 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -674,10 +856,14 @@ jobs: group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: - es07: + elasticsearch: image: elasticsearch:7.17.8 env: "discovery.type": "single-node" @@ -693,12 +879,16 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -726,10 +916,14 @@ jobs: group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: - es08: + elasticsearch: image: elasticsearch:8.6.0 env: "xpack.security.enabled": "false" @@ -746,12 +940,16 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} @@ -779,13 +977,17 @@ jobs: group-number: [1] runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway timeout-minutes: 30 services: gearman: image: artefactual/gearmand ports: - - 4730:4730 + - 8080:4730 # Set health checks to wait until gearman has started options: >- --health-cmd "(echo status ; sleep 0.1) | nc 127.0.0.1 4730 -w 1" @@ -795,12 +997,81 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: ./.github/actions/setup-python-matrix + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin + + - name: Get Environments + id: get-envs + run: | + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT + env: + GROUP_NUMBER: ${{ matrix.group-number }} + + - name: Test + run: | + tox -vv -e ${{ steps.get-envs.outputs.envs }} -p auto + env: + TOX_PARALLEL_NO_SPINNER: 1 + PY_COLORS: 0 + + - name: Upload Coverage Artifacts + uses: actions/upload-artifact@v3 + with: + name: coverage-${{ github.job }}-${{ strategy.job-index }} + path: ./**/.coverage.* + retention-days: 1 + + firestore: + env: + TOTAL_GROUPS: 1 + + strategy: + fail-fast: false + matrix: + group-number: [1] + + runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway + timeout-minutes: 30 + + services: + firestore: + # Image set here MUST be repeated down below in options. See comment below. + image: gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators + ports: + - 8080:8080 + # Set health checks to wait 5 seconds in lieu of an actual healthcheck + options: >- + --health-cmd "echo success" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + --health-start-period 5s + gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators /bin/bash -c "gcloud emulators firestore start --host-port=0.0.0.0:8080" || + # This is a very hacky solution. GitHub Actions doesn't provide APIs for setting commands on services, but allows adding arbitrary options. + # --entrypoint won't work as it only accepts an executable and not the [] syntax. + # Instead, we specify the image again the command afterwards like a call to docker create. The result is a few environment variables + # and the original command being appended to our hijacked docker create command. We can avoid any issues by adding || to prevent that + # from every being executed as bash commands. + + steps: + - uses: actions/checkout@v3 + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin - name: Get Environments id: get-envs run: | - echo "::set-output name=envs::$(tox -l | grep "^${{ github.job }}\-" | ./.github/workflows/get-envs.py)" + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT env: GROUP_NUMBER: ${{ matrix.group-number }} diff --git a/MANIFEST.in b/MANIFEST.in index 0a75ce752..0759bce87 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,3 +8,5 @@ include newrelic/common/cacert.pem include newrelic/packages/wrapt/LICENSE include newrelic/packages/wrapt/README include newrelic/packages/urllib3/LICENSE.txt +include newrelic/packages/isort/LICENSE +include newrelic/packages/opentelemetry_proto/LICENSE.txt diff --git a/THIRD_PARTY_NOTICES.md b/THIRD_PARTY_NOTICES.md index 3662484f6..7c4242cc2 100644 --- a/THIRD_PARTY_NOTICES.md +++ b/THIRD_PARTY_NOTICES.md @@ -14,7 +14,25 @@ Copyright (c) Django Software Foundation and individual contributors. Distributed under the following license(s): - * [The BSD 3-Clause License](https://opensource.org/licenses/BSD-3-Clause) +* [The BSD 3-Clause License](https://opensource.org/licenses/BSD-3-Clause) + + +## [isort](https://pypi.org/project/isort) + +Copyright (c) 2013 Timothy Edmund Crosley + +Distributed under the following license(s): + +* [The MIT License](http://opensource.org/licenses/MIT) + + +## [opentelemetry-proto](https://pypi.org/project/opentelemetry-proto) + +Copyright (c) The OpenTelemetry Authors + +Distributed under the following license(s): + +* [The Apache License, Version 2.0 License](https://opensource.org/license/apache-2-0/) ## [six](https://pypi.org/project/six) @@ -23,7 +41,7 @@ Copyright (c) 2010-2013 Benjamin Peterson Distributed under the following license(s): - * [The MIT License](http://opensource.org/licenses/MIT) +* [The MIT License](http://opensource.org/licenses/MIT) ## [time.monotonic](newrelic/common/_monotonic.c) @@ -32,7 +50,7 @@ Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, Distributed under the following license(s): - * [Python Software Foundation](https://docs.python.org/3/license.html) +* [Python Software Foundation](https://docs.python.org/3/license.html) ## [urllib3](https://pypi.org/project/urllib3) @@ -41,7 +59,7 @@ Copyright (c) 2008-2019 Andrey Petrov and contributors (see CONTRIBUTORS.txt) Distributed under the following license(s): - * [The MIT License](http://opensource.org/licenses/MIT) +* [The MIT License](http://opensource.org/licenses/MIT) ## [wrapt](https://pypi.org/project/wrapt) @@ -51,5 +69,5 @@ All rights reserved. Distributed under the following license(s): - * [The BSD 2-Clause License](http://opensource.org/licenses/BSD-2-Clause) +* [The BSD 2-Clause License](http://opensource.org/licenses/BSD-2-Clause) diff --git a/codecov.yml b/codecov.yml index ec600226a..6ca30f640 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,21 +1,25 @@ ignore: - - "newrelic/packages/**/*" - - "newrelic/packages/*" - - "newrelic/hooks/adapter_meinheld.py" + - "newreilc/hooks/component_sentry.py" + - "newrelic/admin/*" + - "newrelic/console.py" - "newrelic/hooks/adapter_flup.py" + - "newrelic/hooks/adapter_meinheld.py" + - "newrelic/hooks/adapter_paste.py" - "newrelic/hooks/component_piston.py" + - "newrelic/hooks/database_oursql.py" + - "newrelic/hooks/database_psycopg2ct.py" + - "newrelic/hooks/datastore_aioredis.py" + - "newrelic/hooks/datastore_aredis.py" + - "newrelic/hooks/datastore_motor.py" - "newrelic/hooks/datastore_pyelasticsearch.py" - - "newrelic/hooks/external_pywapi.py" + - "newrelic/hooks/datastore_umemcache.py" - "newrelic/hooks/external_dropbox.py" - "newrelic/hooks/external_facepy.py" + - "newrelic/hooks/external_pywapi.py" - "newrelic/hooks/external_xmlrpclib.py" - "newrelic/hooks/framework_pylons.py" - "newrelic/hooks/framework_web2py.py" - - "newrelic/hooks/middleware_weberror.py" - "newrelic/hooks/framework_webpy.py" - - "newrelic/hooks/database_oursql.py" - - "newrelic/hooks/database_psycopg2ct.py" - - "newrelic/hooks/datastore_umemcache.py" - # Temporarily disable kafka - - "newrelic/hooks/messagebroker_kafkapython.py" - - "newrelic/hooks/messagebroker_confluentkafka.py" + - "newrelic/hooks/middleware_weberror.py" + - "newrelic/packages/*" + - "newrelic/packages/**/*" diff --git a/newrelic/agent.py b/newrelic/agent.py index 95a540780..2c7f0fb85 100644 --- a/newrelic/agent.py +++ b/newrelic/agent.py @@ -59,6 +59,7 @@ from newrelic.api.transaction import record_custom_metric as __record_custom_metric from newrelic.api.transaction import record_custom_metrics as __record_custom_metrics from newrelic.api.transaction import record_log_event as __record_log_event +from newrelic.api.transaction import record_ml_event as __record_ml_event from newrelic.api.transaction import set_background_task as __set_background_task from newrelic.api.transaction import set_transaction_name as __set_transaction_name from newrelic.api.transaction import suppress_apdex_metric as __suppress_apdex_metric @@ -152,6 +153,7 @@ def __asgi_application(*args, **kwargs): from newrelic.api.message_transaction import ( wrap_message_transaction as __wrap_message_transaction, ) +from newrelic.api.ml_model import wrap_mlmodel as __wrap_mlmodel from newrelic.api.profile_trace import ProfileTraceWrapper as __ProfileTraceWrapper from newrelic.api.profile_trace import profile_trace as __profile_trace from newrelic.api.profile_trace import wrap_profile_trace as __wrap_profile_trace @@ -206,11 +208,6 @@ def __asgi_application(*args, **kwargs): # EXPERIMENTAL - Generator traces are currently experimental and may not # exist in this form in future versions of the agent. - -# EXPERIMENTAL - Profile traces are currently experimental and may not -# exist in this form in future versions of the agent. - - initialize = __initialize extra_settings = __wrap_api_call(__extra_settings, "extra_settings") global_settings = __wrap_api_call(__global_settings, "global_settings") @@ -248,6 +245,7 @@ def __asgi_application(*args, **kwargs): record_custom_metrics = __wrap_api_call(__record_custom_metrics, "record_custom_metrics") record_custom_event = __wrap_api_call(__record_custom_event, "record_custom_event") record_log_event = __wrap_api_call(__record_log_event, "record_log_event") +record_ml_event = __wrap_api_call(__record_ml_event, "record_ml_event") accept_distributed_trace_payload = __wrap_api_call( __accept_distributed_trace_payload, "accept_distributed_trace_payload" ) @@ -341,3 +339,4 @@ def __asgi_application(*args, **kwargs): wrap_out_function = __wrap_api_call(__wrap_out_function, "wrap_out_function") insert_html_snippet = __wrap_api_call(__insert_html_snippet, "insert_html_snippet") verify_body_exists = __wrap_api_call(__verify_body_exists, "verify_body_exists") +wrap_mlmodel = __wrap_api_call(__wrap_mlmodel, "wrap_mlmodel") diff --git a/newrelic/api/application.py b/newrelic/api/application.py index ea57829f2..e2e7be139 100644 --- a/newrelic/api/application.py +++ b/newrelic/api/application.py @@ -142,10 +142,22 @@ def record_custom_metrics(self, metrics): if self.active and metrics: self._agent.record_custom_metrics(self._name, metrics) + def record_dimensional_metric(self, name, value, tags=None): + if self.active: + self._agent.record_dimensional_metric(self._name, name, value, tags) + + def record_dimensional_metrics(self, metrics): + if self.active and metrics: + self._agent.record_dimensional_metrics(self._name, metrics) + def record_custom_event(self, event_type, params): if self.active: self._agent.record_custom_event(self._name, event_type, params) + def record_ml_event(self, event_type, params): + if self.active: + self._agent.record_ml_event(self._name, event_type, params) + def record_transaction(self, data): if self.active: self._agent.record_transaction(self._name, data) diff --git a/newrelic/api/background_task.py b/newrelic/api/background_task.py index a4a9e8e6a..4cdcd8a0d 100644 --- a/newrelic/api/background_task.py +++ b/newrelic/api/background_task.py @@ -13,19 +13,16 @@ # limitations under the License. import functools -import sys from newrelic.api.application import Application, application_instance from newrelic.api.transaction import Transaction, current_transaction -from newrelic.common.async_proxy import async_proxy, TransactionContext +from newrelic.common.async_proxy import TransactionContext, async_proxy from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import FunctionWrapper, wrap_object class BackgroundTask(Transaction): - def __init__(self, application, name, group=None, source=None): - # Initialise the common transaction base class. super(BackgroundTask, self).__init__(application, source=source) @@ -53,7 +50,6 @@ def __init__(self, application, name, group=None, source=None): def BackgroundTaskWrapper(wrapped, application=None, name=None, group=None): - def wrapper(wrapped, instance, args, kwargs): if callable(name): if instance is not None: @@ -107,39 +103,19 @@ def create_transaction(transaction): manager = create_transaction(current_transaction(active_only=False)) + # This means that a transaction already exists, so we want to return if not manager: return wrapped(*args, **kwargs) - success = True - - try: - manager.__enter__() - try: - return wrapped(*args, **kwargs) - except: - success = False - if not manager.__exit__(*sys.exc_info()): - raise - finally: - if success and manager._ref_count == 0: - manager._is_finalized = True - manager.__exit__(None, None, None) - else: - manager._request_handler_finalize = True - manager._server_adapter_finalize = True - old_transaction = current_transaction() - if old_transaction is not None: - old_transaction.drop_transaction() + with manager: + return wrapped(*args, **kwargs) return FunctionWrapper(wrapped, wrapper) def background_task(application=None, name=None, group=None): - return functools.partial(BackgroundTaskWrapper, - application=application, name=name, group=group) + return functools.partial(BackgroundTaskWrapper, application=application, name=name, group=group) -def wrap_background_task(module, object_path, application=None, - name=None, group=None): - wrap_object(module, object_path, BackgroundTaskWrapper, - (application, name, group)) +def wrap_background_task(module, object_path, application=None, name=None, group=None): + wrap_object(module, object_path, BackgroundTaskWrapper, (application, name, group)) diff --git a/newrelic/api/database_trace.py b/newrelic/api/database_trace.py index 2bc497688..8990a1ef4 100644 --- a/newrelic/api/database_trace.py +++ b/newrelic/api/database_trace.py @@ -16,7 +16,7 @@ import logging from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.database_node import DatabaseNode from newrelic.core.stack_trace import current_stack @@ -44,11 +44,6 @@ def register_database_client( dbapi2_module._nr_explain_query = explain_query dbapi2_module._nr_explain_stmts = explain_stmts dbapi2_module._nr_instance_info = instance_info - dbapi2_module._nr_datastore_instance_feature_flag = False - - -def enable_datastore_instance_feature(dbapi2_module): - dbapi2_module._nr_datastore_instance_feature_flag = True class DatabaseTrace(TimeTrace): @@ -153,12 +148,7 @@ def finalize_data(self, transaction, exc=None, value=None, tb=None): if instance_enabled or db_name_enabled: - if ( - self.dbapi2_module - and self.connect_params - and self.dbapi2_module._nr_datastore_instance_feature_flag - and self.dbapi2_module._nr_instance_info is not None - ): + if self.dbapi2_module and self.connect_params and self.dbapi2_module._nr_instance_info is not None: instance_info = self.dbapi2_module._nr_instance_info(*self.connect_params) @@ -244,9 +234,9 @@ def create_node(self): ) -def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None): +def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None, async_wrapper=None): def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -273,9 +263,9 @@ def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_database_trace_wrapper_) -def database_trace(sql, dbapi2_module=None): - return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module) +def database_trace(sql, dbapi2_module=None, async_wrapper=None): + return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module, async_wrapper=async_wrapper) -def wrap_database_trace(module, object_path, sql, dbapi2_module=None): - wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module)) +def wrap_database_trace(module, object_path, sql, dbapi2_module=None, async_wrapper=None): + wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module, async_wrapper)) diff --git a/newrelic/api/datastore_trace.py b/newrelic/api/datastore_trace.py index fb40abcab..0401c79ea 100644 --- a/newrelic/api/datastore_trace.py +++ b/newrelic/api/datastore_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.datastore_node import DatastoreNode @@ -82,6 +82,9 @@ def __enter__(self): self.product = transaction._intern_string(self.product) self.target = transaction._intern_string(self.target) self.operation = transaction._intern_string(self.operation) + self.host = transaction._intern_string(self.host) + self.port_path_or_id = transaction._intern_string(self.port_path_or_id) + self.database_name = transaction._intern_string(self.database_name) datastore_tracer_settings = transaction.settings.datastore_tracer self.instance_reporting_enabled = datastore_tracer_settings.instance_reporting.enabled @@ -92,7 +95,14 @@ def __repr__(self): return "<%s object at 0x%x %s>" % ( self.__class__.__name__, id(self), - dict(product=self.product, target=self.target, operation=self.operation), + dict( + product=self.product, + target=self.target, + operation=self.operation, + host=self.host, + port_path_or_id=self.port_path_or_id, + database_name=self.database_name, + ), ) def finalize_data(self, transaction, exc=None, value=None, tb=None): @@ -125,7 +135,7 @@ def create_node(self): ) -def DatastoreTraceWrapper(wrapped, product, target, operation): +def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Wraps a method to time datastore queries. :param wrapped: The function to apply the trace to. @@ -140,6 +150,16 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): or the name of any API function/method in the client library. :type operation: str or callable + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None :rtype: :class:`newrelic.common.object_wrapper.FunctionWrapper` This is typically used to wrap datastore queries such as calls to Redis or @@ -155,7 +175,7 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): """ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -187,7 +207,33 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): else: _operation = operation - trace = DatastoreTrace(_product, _target, _operation, parent=parent, source=wrapped) + if callable(host): + if instance is not None: + _host = host(instance, *args, **kwargs) + else: + _host = host(*args, **kwargs) + else: + _host = host + + if callable(port_path_or_id): + if instance is not None: + _port_path_or_id = port_path_or_id(instance, *args, **kwargs) + else: + _port_path_or_id = port_path_or_id(*args, **kwargs) + else: + _port_path_or_id = port_path_or_id + + if callable(database_name): + if instance is not None: + _database_name = database_name(instance, *args, **kwargs) + else: + _database_name = database_name(*args, **kwargs) + else: + _database_name = database_name + + trace = DatastoreTrace( + _product, _target, _operation, _host, _port_path_or_id, _database_name, parent=parent, source=wrapped + ) if wrapper: # pylint: disable=W0125,W0126 return wrapper(wrapped, trace)(*args, **kwargs) @@ -198,7 +244,7 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_datastore_trace_wrapper_) -def datastore_trace(product, target, operation): +def datastore_trace(product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Decorator allows datastore query to be timed. :param product: The name of the vendor. @@ -211,6 +257,16 @@ def datastore_trace(product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to decorate datastore queries such as calls to Redis or ElasticSearch. @@ -224,10 +280,21 @@ def datastore_trace(product, target, operation): ... time.sleep(*args, **kwargs) """ - return functools.partial(DatastoreTraceWrapper, product=product, target=target, operation=operation) - - -def wrap_datastore_trace(module, object_path, product, target, operation): + return functools.partial( + DatastoreTraceWrapper, + product=product, + target=target, + operation=operation, + host=host, + port_path_or_id=port_path_or_id, + database_name=database_name, + async_wrapper=async_wrapper, + ) + + +def wrap_datastore_trace( + module, object_path, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None +): """Method applies custom timing to datastore query. :param module: Module containing the method to be instrumented. @@ -244,6 +311,16 @@ def wrap_datastore_trace(module, object_path, product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to time database query method calls such as Redis GET. @@ -256,4 +333,6 @@ def wrap_datastore_trace(module, object_path, product, target, operation): ... 'sleep') """ - wrap_object(module, object_path, DatastoreTraceWrapper, (product, target, operation)) + wrap_object( + module, object_path, DatastoreTraceWrapper, (product, target, operation, host, port_path_or_id, database_name, async_wrapper) + ) diff --git a/newrelic/api/external_trace.py b/newrelic/api/external_trace.py index c43c560c6..2e147df45 100644 --- a/newrelic/api/external_trace.py +++ b/newrelic/api/external_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.external_node import ExternalNode @@ -66,9 +66,9 @@ def create_node(self): ) -def ExternalTraceWrapper(wrapped, library, url, method=None): +def ExternalTraceWrapper(wrapped, library, url, method=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -103,7 +103,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -125,9 +125,9 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def external_trace(library, url, method=None): - return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method) +def external_trace(library, url, method=None, async_wrapper=None): + return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method, async_wrapper=async_wrapper) -def wrap_external_trace(module, object_path, library, url, method=None): - wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method)) +def wrap_external_trace(module, object_path, library, url, method=None, async_wrapper=None): + wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method, async_wrapper)) diff --git a/newrelic/api/function_trace.py b/newrelic/api/function_trace.py index 474c1b226..85d7617b6 100644 --- a/newrelic/api/function_trace.py +++ b/newrelic/api/function_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.function_node import FunctionNode @@ -89,9 +89,9 @@ def create_node(self): ) -def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -147,7 +147,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -171,13 +171,13 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): return functools.partial( - FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup + FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup, async_wrapper=async_wrapper ) def wrap_function_trace( - module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None + module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None ): - return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup)) + return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup, async_wrapper)) diff --git a/newrelic/api/graphql_trace.py b/newrelic/api/graphql_trace.py index 7a2c9ec02..e8803fa68 100644 --- a/newrelic/api/graphql_trace.py +++ b/newrelic/api/graphql_trace.py @@ -16,7 +16,7 @@ from newrelic.api.time_trace import TimeTrace, current_trace from newrelic.api.transaction import current_transaction -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.graphql_node import GraphQLOperationNode, GraphQLResolverNode @@ -109,9 +109,9 @@ def set_transaction_name(self, priority=None): transaction.set_transaction_name(name, "GraphQL", priority=priority) -def GraphQLOperationTraceWrapper(wrapped): +def GraphQLOperationTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -130,16 +130,16 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_operation_trace(): - return functools.partial(GraphQLOperationTraceWrapper) +def graphql_operation_trace(async_wrapper=None): + return functools.partial(GraphQLOperationTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_operation_trace(module, object_path): - wrap_object(module, object_path, GraphQLOperationTraceWrapper) +def wrap_graphql_operation_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLOperationTraceWrapper, (async_wrapper,)) class GraphQLResolverTrace(TimeTrace): - def __init__(self, field_name=None, **kwargs): + def __init__(self, field_name=None, field_parent_type=None, field_return_type=None, field_path=None, **kwargs): parent = kwargs.pop("parent", None) source = kwargs.pop("source", None) if kwargs: @@ -148,6 +148,9 @@ def __init__(self, field_name=None, **kwargs): super(GraphQLResolverTrace, self).__init__(parent=parent, source=source) self.field_name = field_name + self.field_parent_type = field_parent_type + self.field_return_type = field_return_type + self.field_path = field_path self._product = None def __repr__(self): @@ -175,6 +178,9 @@ def product(self): def finalize_data(self, *args, **kwargs): self._add_agent_attribute("graphql.field.name", self.field_name) + self._add_agent_attribute("graphql.field.parentType", self.field_parent_type) + self._add_agent_attribute("graphql.field.returnType", self.field_return_type) + self._add_agent_attribute("graphql.field.path", self.field_path) return super(GraphQLResolverTrace, self).finalize_data(*args, **kwargs) @@ -193,9 +199,9 @@ def create_node(self): ) -def GraphQLResolverTraceWrapper(wrapped): +def GraphQLResolverTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -214,9 +220,9 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_resolver_trace(): - return functools.partial(GraphQLResolverTraceWrapper) +def graphql_resolver_trace(async_wrapper=None): + return functools.partial(GraphQLResolverTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_resolver_trace(module, object_path): - wrap_object(module, object_path, GraphQLResolverTraceWrapper) +def wrap_graphql_resolver_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLResolverTraceWrapper, (async_wrapper,)) diff --git a/newrelic/api/memcache_trace.py b/newrelic/api/memcache_trace.py index 6657a9ce2..87f12f9fc 100644 --- a/newrelic/api/memcache_trace.py +++ b/newrelic/api/memcache_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.memcache_node import MemcacheNode @@ -51,9 +51,9 @@ def create_node(self): ) -def MemcacheTraceWrapper(wrapped, command): +def MemcacheTraceWrapper(wrapped, command, async_wrapper=None): def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -80,9 +80,9 @@ def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_wrapper_memcache_trace_) -def memcache_trace(command): - return functools.partial(MemcacheTraceWrapper, command=command) +def memcache_trace(command, async_wrapper=None): + return functools.partial(MemcacheTraceWrapper, command=command, async_wrapper=async_wrapper) -def wrap_memcache_trace(module, object_path, command): - wrap_object(module, object_path, MemcacheTraceWrapper, (command,)) +def wrap_memcache_trace(module, object_path, command, async_wrapper=None): + wrap_object(module, object_path, MemcacheTraceWrapper, (command, async_wrapper)) diff --git a/newrelic/api/message_trace.py b/newrelic/api/message_trace.py index be819d704..f564c41cb 100644 --- a/newrelic/api/message_trace.py +++ b/newrelic/api/message_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.message_node import MessageNode @@ -91,9 +91,9 @@ def create_node(self): ) -def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True): +def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -144,7 +144,7 @@ def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_message_trace_wrapper_) -def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True): +def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): return functools.partial( MessageTraceWrapper, library=library, @@ -153,10 +153,11 @@ def message_trace(library, operation, destination_type, destination_name, params destination_name=destination_name, params=params, terminal=terminal, + async_wrapper=async_wrapper, ) -def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True): +def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): wrap_object( - module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal) + module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal, async_wrapper) ) diff --git a/newrelic/api/message_transaction.py b/newrelic/api/message_transaction.py index 291a3897e..54a71f6ef 100644 --- a/newrelic/api/message_transaction.py +++ b/newrelic/api/message_transaction.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import sys from newrelic.api.application import Application, application_instance from newrelic.api.background_task import BackgroundTask @@ -39,7 +38,6 @@ def __init__( transport_type="AMQP", source=None, ): - name, group = self.get_transaction_name(library, destination_type, destination_name) super(MessageTransaction, self).__init__(application, name, group=group, source=source) @@ -218,30 +216,12 @@ def create_transaction(transaction): manager = create_transaction(current_transaction(active_only=False)) + # This means that transaction already exists and we want to return if not manager: return wrapped(*args, **kwargs) - success = True - - try: - manager.__enter__() - try: - return wrapped(*args, **kwargs) - except: # Catch all - success = False - if not manager.__exit__(*sys.exc_info()): - raise - finally: - if success and manager._ref_count == 0: - manager._is_finalized = True - manager.__exit__(None, None, None) - else: - manager._request_handler_finalize = True - manager._server_adapter_finalize = True - - old_transaction = current_transaction() - if old_transaction is not None: - old_transaction.drop_transaction() + with manager: + return wrapped(*args, **kwargs) return FunctionWrapper(wrapped, wrapper) diff --git a/newrelic/api/ml_model.py b/newrelic/api/ml_model.py new file mode 100644 index 000000000..edbcaf340 --- /dev/null +++ b/newrelic/api/ml_model.py @@ -0,0 +1,35 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from newrelic.common.object_names import callable_name +from newrelic.hooks.mlmodel_sklearn import _nr_instrument_model + + +def wrap_mlmodel(model, name=None, version=None, feature_names=None, label_names=None, metadata=None): + model_callable_name = callable_name(model) + _class = model.__class__.__name__ + module = sys.modules[model_callable_name.split(":")[0]] + _nr_instrument_model(module, _class) + if name: + model._nr_wrapped_name = name + if version: + model._nr_wrapped_version = version + if feature_names: + model._nr_wrapped_feature_names = feature_names + if label_names: + model._nr_wrapped_label_names = label_names + if metadata: + model._nr_wrapped_metadata = metadata diff --git a/newrelic/api/profile_trace.py b/newrelic/api/profile_trace.py index 28113b1d8..93aa191a4 100644 --- a/newrelic/api/profile_trace.py +++ b/newrelic/api/profile_trace.py @@ -13,31 +13,27 @@ # limitations under the License. import functools -import sys import os +import sys -from newrelic.packages import six - -from newrelic.api.time_trace import current_trace +from newrelic import __file__ as AGENT_PACKAGE_FILE from newrelic.api.function_trace import FunctionTrace -from newrelic.common.object_wrapper import FunctionWrapper, wrap_object +from newrelic.api.time_trace import current_trace from newrelic.common.object_names import callable_name +from newrelic.common.object_wrapper import FunctionWrapper, wrap_object +from newrelic.packages import six -from newrelic import __file__ as AGENT_PACKAGE_FILE -AGENT_PACKAGE_DIRECTORY = os.path.dirname(AGENT_PACKAGE_FILE) + '/' +AGENT_PACKAGE_DIRECTORY = os.path.dirname(AGENT_PACKAGE_FILE) + "/" class ProfileTrace(object): - def __init__(self, depth): self.function_traces = [] self.maximum_depth = depth self.current_depth = 0 - def __call__(self, frame, event, arg): - - if event not in ['call', 'c_call', 'return', 'c_return', - 'exception', 'c_exception']: + def __call__(self, frame, event, arg): # pragma: no cover + if event not in ["call", "c_call", "return", "c_return", "exception", "c_exception"]: return parent = current_trace() @@ -49,8 +45,7 @@ def __call__(self, frame, event, arg): # coroutine systems based on greenlets so don't run # if we detect may be using greenlets. - if (hasattr(sys, '_current_frames') and - parent.thread_id not in sys._current_frames()): + if hasattr(sys, "_current_frames") and parent.thread_id not in sys._current_frames(): return co = frame.f_code @@ -84,7 +79,7 @@ def _callable(): except Exception: pass - if event in ['call', 'c_call']: + if event in ["call", "c_call"]: # Skip the outermost as we catch that with the root # function traces for the profile trace. @@ -100,19 +95,17 @@ def _callable(): self.function_traces.append(None) return - if event == 'call': + if event == "call": func = _callable() if func: name = callable_name(func) else: - name = '%s:%s#%s' % (func_filename, func_name, - func_line_no) + name = "%s:%s#%s" % (func_filename, func_name, func_line_no) else: func = arg name = callable_name(arg) if not name: - name = '%s:@%s#%s' % (func_filename, func_name, - func_line_no) + name = "%s:@%s#%s" % (func_filename, func_name, func_line_no) function_trace = FunctionTrace(name=name, parent=parent) function_trace.__enter__() @@ -127,7 +120,7 @@ def _callable(): self.function_traces.append(function_trace) self.current_depth += 1 - elif event in ['return', 'c_return', 'c_exception']: + elif event in ["return", "c_return", "c_exception"]: if not self.function_traces: return @@ -143,9 +136,7 @@ def _callable(): self.current_depth -= 1 -def ProfileTraceWrapper(wrapped, name=None, group=None, label=None, - params=None, depth=3): - +def ProfileTraceWrapper(wrapped, name=None, group=None, label=None, params=None, depth=3): def wrapper(wrapped, instance, args, kwargs): parent = current_trace() @@ -192,7 +183,7 @@ def wrapper(wrapped, instance, args, kwargs): _params = params with FunctionTrace(_name, _group, _label, _params, parent=parent, source=wrapped): - if not hasattr(sys, 'getprofile'): + if not hasattr(sys, "getprofile"): return wrapped(*args, **kwargs) profiler = sys.getprofile() @@ -212,11 +203,8 @@ def wrapper(wrapped, instance, args, kwargs): def profile_trace(name=None, group=None, label=None, params=None, depth=3): - return functools.partial(ProfileTraceWrapper, name=name, - group=group, label=label, params=params, depth=depth) + return functools.partial(ProfileTraceWrapper, name=name, group=group, label=label, params=params, depth=depth) -def wrap_profile_trace(module, object_path, name=None, - group=None, label=None, params=None, depth=3): - return wrap_object(module, object_path, ProfileTraceWrapper, - (name, group, label, params, depth)) +def wrap_profile_trace(module, object_path, name=None, group=None, label=None, params=None, depth=3): + return wrap_object(module, object_path, ProfileTraceWrapper, (name, group, label, params, depth)) diff --git a/newrelic/api/transaction.py b/newrelic/api/transaction.py index f04bcba84..988b56be6 100644 --- a/newrelic/api/transaction.py +++ b/newrelic/api/transaction.py @@ -60,11 +60,15 @@ DST_NONE, DST_TRANSACTION_TRACER, ) -from newrelic.core.config import CUSTOM_EVENT_RESERVOIR_SIZE, LOG_EVENT_RESERVOIR_SIZE +from newrelic.core.config import ( + CUSTOM_EVENT_RESERVOIR_SIZE, + LOG_EVENT_RESERVOIR_SIZE, + ML_EVENT_RESERVOIR_SIZE, +) from newrelic.core.custom_event import create_custom_event from newrelic.core.log_event_node import LogEventNode from newrelic.core.stack_trace import exception_stack -from newrelic.core.stats_engine import CustomMetrics, SampledDataSet +from newrelic.core.stats_engine import CustomMetrics, DimensionalMetrics, SampledDataSet from newrelic.core.thread_utilization import utilization_tracker from newrelic.core.trace_cache import ( TraceCacheActiveTraceError, @@ -159,13 +163,11 @@ def path(self): class Transaction(object): - STATE_PENDING = 0 STATE_RUNNING = 1 STATE_STOPPED = 2 def __init__(self, application, enabled=None, source=None): - self._application = application self._source = source @@ -307,6 +309,7 @@ def __init__(self, application, enabled=None, source=None): self.synthetics_header = None self._custom_metrics = CustomMetrics() + self._dimensional_metrics = DimensionalMetrics() global_settings = application.global_settings @@ -330,12 +333,14 @@ def __init__(self, application, enabled=None, source=None): self._custom_events = SampledDataSet( capacity=self._settings.event_harvest_config.harvest_limits.custom_event_data ) + self._ml_events = SampledDataSet(capacity=self._settings.event_harvest_config.harvest_limits.ml_event_data) self._log_events = SampledDataSet( capacity=self._settings.event_harvest_config.harvest_limits.log_event_data ) else: self._custom_events = SampledDataSet(capacity=CUSTOM_EVENT_RESERVOIR_SIZE) self._log_events = SampledDataSet(capacity=LOG_EVENT_RESERVOIR_SIZE) + self._ml_events = SampledDataSet(capacity=ML_EVENT_RESERVOIR_SIZE) def __del__(self): self._dead = True @@ -343,7 +348,6 @@ def __del__(self): self.__exit__(None, None, None) def __enter__(self): - assert self._state == self.STATE_PENDING # Bail out if the transaction is not enabled. @@ -403,7 +407,6 @@ def __enter__(self): return self def __exit__(self, exc, value, tb): - # Bail out if the transaction is not enabled. if not self.enabled: @@ -584,10 +587,12 @@ def __exit__(self, exc, value, tb): errors=tuple(self._errors), slow_sql=tuple(self._slow_sql), custom_events=self._custom_events, + ml_events=self._ml_events, log_events=self._log_events, apdex_t=self.apdex, suppress_apdex=self.suppress_apdex, custom_metrics=self._custom_metrics, + dimensional_metrics=self._dimensional_metrics, guid=self.guid, cpu_time=self._cpu_user_time_value, suppress_transaction_trace=self.suppress_transaction_trace, @@ -636,7 +641,6 @@ def __exit__(self, exc, value, tb): # new samples can cause an error. if not self.ignore_transaction: - self._application.record_transaction(node) @property @@ -929,9 +933,7 @@ def filter_request_parameters(self, params): @property def request_parameters(self): if (self.capture_params is None) or self.capture_params: - if self._request_params: - r_attrs = {} for k, v in self._request_params.items(): @@ -1037,7 +1039,9 @@ def _create_distributed_trace_data(self): settings = self._settings account_id = settings.account_id - trusted_account_key = settings.trusted_account_key + trusted_account_key = settings.trusted_account_key or ( + self._settings.serverless_mode.enabled and self._settings.account_id + ) application_id = settings.primary_application_id if not (account_id and application_id and trusted_account_key and settings.distributed_tracing.enabled): @@ -1095,7 +1099,6 @@ def _generate_distributed_trace_headers(self, data=None): try: data = data or self._create_distributed_trace_data() if data: - traceparent = W3CTraceParent(data).text() yield ("traceparent", traceparent) @@ -1129,7 +1132,10 @@ def _can_accept_distributed_trace_headers(self): return False settings = self._settings - if not (settings.distributed_tracing.enabled and settings.trusted_account_key): + trusted_account_key = settings.trusted_account_key or ( + self._settings.serverless_mode.enabled and self._settings.account_id + ) + if not (settings.distributed_tracing.enabled and trusted_account_key): return False if self._distributed_trace_state: @@ -1175,10 +1181,13 @@ def _accept_distributed_trace_payload(self, payload, transport_type="HTTP"): settings = self._settings account_id = data.get("ac") + trusted_account_key = settings.trusted_account_key or ( + self._settings.serverless_mode.enabled and self._settings.account_id + ) # If trust key doesn't exist in the payload, use account_id received_trust_key = data.get("tk", account_id) - if settings.trusted_account_key != received_trust_key: + if trusted_account_key != received_trust_key: self._record_supportability("Supportability/DistributedTrace/AcceptPayload/Ignored/UntrustedAccount") if settings.debug.log_untrusted_distributed_trace_keys: _logger.debug( @@ -1192,11 +1201,10 @@ def _accept_distributed_trace_payload(self, payload, transport_type="HTTP"): except: return False - if "pr" in data: - try: - data["pr"] = float(data["pr"]) - except: - data["pr"] = None + try: + data["pr"] = float(data["pr"]) + except Exception: + data["pr"] = None self._accept_distributed_trace_data(data, transport_type) self._record_supportability("Supportability/DistributedTrace/AcceptPayload/Success") @@ -1288,8 +1296,10 @@ def accept_distributed_trace_headers(self, headers, transport_type="HTTP"): tracestate = ensure_str(tracestate) try: vendors = W3CTraceState.decode(tracestate) - tk = self._settings.trusted_account_key - payload = vendors.pop(tk + "@nr", "") + trusted_account_key = self._settings.trusted_account_key or ( + self._settings.serverless_mode.enabled and self._settings.account_id + ) + payload = vendors.pop(trusted_account_key + "@nr", "") self.tracing_vendors = ",".join(vendors.keys()) self.tracestate = vendors.text(limit=31) except: @@ -1298,7 +1308,7 @@ def accept_distributed_trace_headers(self, headers, transport_type="HTTP"): # Remove trusted new relic header if available and parse if payload: try: - tracestate_data = NrTraceState.decode(payload, tk) + tracestate_data = NrTraceState.decode(payload, trusted_account_key) except: tracestate_data = None if tracestate_data: @@ -1382,7 +1392,6 @@ def _generate_response_headers(self, read_length=None): # process web external calls. if self.client_cross_process_id is not None: - # Need to work out queueing time and duration up to this # point for inclusion in metrics and response header. If the # recording of the transaction had been prematurely stopped @@ -1426,11 +1435,17 @@ def _generate_response_headers(self, read_length=None): return nr_headers - def get_response_metadata(self): + # This function is CAT related and has been deprecated. + # Eventually, this will be removed. Until then, coverage + # does not need to factor this function into its analysis. + def get_response_metadata(self): # pragma: no cover nr_headers = dict(self._generate_response_headers()) return convert_to_cat_metadata_value(nr_headers) - def process_request_metadata(self, cat_linking_value): + # This function is CAT related and has been deprecated. + # Eventually, this will be removed. Until then, coverage + # does not need to factor this function into its analysis. + def process_request_metadata(self, cat_linking_value): # pragma: no cover try: payload = base64_decode(cat_linking_value) except: @@ -1447,7 +1462,6 @@ def process_request_metadata(self, cat_linking_value): return self._process_incoming_cat_headers(encoded_cross_process_id, encoded_txn_header) def set_transaction_name(self, name, group=None, priority=None): - # Always perform this operation even if the transaction # is not active at the time as will be called from # constructor. If path has been frozen do not allow @@ -1517,7 +1531,9 @@ def record_log_event(self, message, level=None, timestamp=None, priority=None): self._log_events.add(event, priority=priority) - def record_exception(self, exc=None, value=None, tb=None, params=None, ignore_errors=None): + # This function has been deprecated (and will be removed eventually) + # and therefore does not need to be included in coverage analysis + def record_exception(self, exc=None, value=None, tb=None, params=None, ignore_errors=None): # pragma: no cover # Deprecation Warning warnings.warn( ("The record_exception function is deprecated. Please use the new api named notice_error instead."), @@ -1600,6 +1616,16 @@ def record_custom_metrics(self, metrics): for name, value in metrics: self._custom_metrics.record_custom_metric(name, value) + def record_dimensional_metric(self, name, value, tags=None): + self._dimensional_metrics.record_dimensional_metric(name, value, tags) + + def record_dimensional_metrics(self, metrics): + for metric in metrics: + name, value = metric[:2] + tags = metric[2] if len(metric) >= 3 else None + + self._dimensional_metrics.record_dimensional_metric(name, value, tags) + def record_custom_event(self, event_type, params): settings = self._settings @@ -1613,6 +1639,19 @@ def record_custom_event(self, event_type, params): if event: self._custom_events.add(event, priority=self.priority) + def record_ml_event(self, event_type, params): + settings = self._settings + + if not settings: + return + + if not settings.ml_insights_events.enabled: + return + + event = create_custom_event(event_type, params) + if event: + self._ml_events.add(event, priority=self.priority) + def _intern_string(self, value): return self._string_cache.setdefault(value, value) @@ -1684,7 +1723,9 @@ def add_custom_attributes(self, items): return result - def add_custom_parameter(self, name, value): + # This function has been deprecated (and will be removed eventually) + # and therefore does not need to be included in coverage analysis + def add_custom_parameter(self, name, value): # pragma: no cover # Deprecation warning warnings.warn( ("The add_custom_parameter API has been deprecated. " "Please use the add_custom_attribute API."), @@ -1692,7 +1733,9 @@ def add_custom_parameter(self, name, value): ) return self.add_custom_attribute(name, value) - def add_custom_parameters(self, items): + # This function has been deprecated (and will be removed eventually) + # and therefore does not need to be included in coverage analysis + def add_custom_parameters(self, items): # pragma: no cover # Deprecation warning warnings.warn( ("The add_custom_parameters API has been deprecated. " "Please use the add_custom_attributes API."), @@ -1796,19 +1839,23 @@ def add_custom_attributes(items): return False -def add_custom_parameter(key, value): +# This function has been deprecated (and will be removed eventually) +# and therefore does not need to be included in coverage analysis +def add_custom_parameter(key, value): # pragma: no cover # Deprecation warning warnings.warn( - ("The add_custom_parameter API has been deprecated. " "Please use the add_custom_attribute API."), + ("The add_custom_parameter API has been deprecated. Please use the add_custom_attribute API."), DeprecationWarning, ) return add_custom_attribute(key, value) -def add_custom_parameters(items): +# This function has been deprecated (and will be removed eventually) +# and therefore does not need to be included in coverage analysis +def add_custom_parameters(items): # pragma: no cover # Deprecation warning warnings.warn( - ("The add_custom_parameters API has been deprecated. " "Please use the add_custom_attributes API."), + ("The add_custom_parameters API has been deprecated. Please use the add_custom_attributes API."), DeprecationWarning, ) return add_custom_attributes(items) @@ -1898,6 +1945,44 @@ def record_custom_metrics(metrics, application=None): application.record_custom_metrics(metrics) +def record_dimensional_metric(name, value, tags=None, application=None): + if application is None: + transaction = current_transaction() + if transaction: + transaction.record_dimensional_metric(name, value, tags) + else: + _logger.debug( + "record_dimensional_metric has been called but no " + "transaction was running. As a result, the following metric " + "has not been recorded. Name: %r Value: %r Tags: %r. To correct this " + "problem, supply an application object as a parameter to this " + "record_dimensional_metrics call.", + name, + value, + tags, + ) + elif application.enabled: + application.record_dimensional_metric(name, value, tags) + + +def record_dimensional_metrics(metrics, application=None): + if application is None: + transaction = current_transaction() + if transaction: + transaction.record_dimensional_metrics(metrics) + else: + _logger.debug( + "record_dimensional_metrics has been called but no " + "transaction was running. As a result, the following metrics " + "have not been recorded: %r. To correct this problem, " + "supply an application object as a parameter to this " + "record_dimensional_metric call.", + list(metrics), + ) + elif application.enabled: + application.record_dimensional_metrics(metrics) + + def record_custom_event(event_type, params, application=None): """Record a custom event. @@ -1926,6 +2011,34 @@ def record_custom_event(event_type, params, application=None): application.record_custom_event(event_type, params) +def record_ml_event(event_type, params, application=None): + """Record a machine learning custom event. + + Args: + event_type (str): The type (name) of the ml event. + params (dict): Attributes to add to the event. + application (newrelic.api.Application): Application instance. + + """ + + if application is None: + transaction = current_transaction() + if transaction: + transaction.record_ml_event(event_type, params) + else: + _logger.debug( + "record_ml_event has been called but no " + "transaction was running. As a result, the following event " + "has not been recorded. event_type: %r params: %r. To correct " + "this problem, supply an application object as a parameter to " + "this record_ml_event call.", + event_type, + params, + ) + elif application.enabled: + application.record_ml_event(event_type, params) + + def record_log_event(message, level=None, timestamp=None, application=None, priority=None): """Record a log event. diff --git a/newrelic/common/agent_http.py b/newrelic/common/agent_http.py index e9d9a00aa..89876a60c 100644 --- a/newrelic/common/agent_http.py +++ b/newrelic/common/agent_http.py @@ -92,6 +92,7 @@ def __init__( compression_method="gzip", max_payload_size_in_bytes=1000000, audit_log_fp=None, + default_content_encoding_header="Identity", ): self._audit_log_fp = audit_log_fp @@ -112,9 +113,7 @@ def _supportability_request(params, payload, body, compression_time): pass @classmethod - def log_request( - cls, fp, method, url, params, payload, headers, body=None, compression_time=None - ): + def log_request(cls, fp, method, url, params, payload, headers, body=None, compression_time=None): cls._supportability_request(params, payload, body, compression_time) if not fp: @@ -126,7 +125,8 @@ def log_request( cls.AUDIT_LOG_ID += 1 print( - "TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file=fp, + "TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + file=fp, ) print(file=fp) print("ID: %r" % cls.AUDIT_LOG_ID, file=fp) @@ -178,9 +178,7 @@ def log_response(cls, fp, log_id, status, headers, data, connection="direct"): except Exception: result = data - print( - "TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file=fp - ) + print("TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file=fp) print(file=fp) print("ID: %r" % log_id, file=fp) print(file=fp) @@ -219,9 +217,7 @@ def send_request( class HttpClient(BaseClient): CONNECTION_CLS = urllib3.HTTPSConnectionPool PREFIX_SCHEME = "https://" - BASE_HEADERS = urllib3.make_headers( - keep_alive=True, accept_encoding=True, user_agent=USER_AGENT - ) + BASE_HEADERS = urllib3.make_headers(keep_alive=True, accept_encoding=True, user_agent=USER_AGENT) def __init__( self, @@ -240,6 +236,7 @@ def __init__( compression_method="gzip", max_payload_size_in_bytes=1000000, audit_log_fp=None, + default_content_encoding_header="Identity", ): self._host = host port = self._port = port @@ -248,6 +245,7 @@ def __init__( self._compression_method = compression_method self._max_payload_size_in_bytes = max_payload_size_in_bytes self._audit_log_fp = audit_log_fp + self._default_content_encoding_header = default_content_encoding_header self._prefix = "" @@ -263,11 +261,9 @@ def __init__( # If there is no resolved cafile, assume the bundled certs are # required and report this condition as a supportability metric. - if not verify_path.cafile: + if not verify_path.cafile and not verify_path.capath: ca_bundle_path = certs.where() - internal_metric( - "Supportability/Python/Certificate/BundleRequired", 1 - ) + internal_metric("Supportability/Python/Certificate/BundleRequired", 1) if ca_bundle_path: if os.path.isdir(ca_bundle_path): @@ -279,11 +275,13 @@ def __init__( connection_kwargs["cert_reqs"] = "NONE" proxy = self._parse_proxy( - proxy_scheme, proxy_host, proxy_port, proxy_user, proxy_pass, - ) - proxy_headers = ( - proxy and proxy.auth and urllib3.make_headers(proxy_basic_auth=proxy.auth) + proxy_scheme, + proxy_host, + proxy_port, + proxy_user, + proxy_pass, ) + proxy_headers = proxy and proxy.auth and urllib3.make_headers(proxy_basic_auth=proxy.auth) if proxy: if self.CONNECTION_CLS.scheme == "https" and proxy.scheme != "https": @@ -343,15 +341,9 @@ def _connection(self): if self._connection_attr: return self._connection_attr - retries = urllib3.Retry( - total=False, connect=None, read=None, redirect=0, status=None - ) + retries = urllib3.Retry(total=False, connect=None, read=None, redirect=0, status=None) self._connection_attr = self.CONNECTION_CLS( - self._host, - self._port, - strict=True, - retries=retries, - **self._connection_kwargs + self._host, self._port, strict=True, retries=retries, **self._connection_kwargs ) return self._connection_attr @@ -374,9 +366,7 @@ def log_request( if not self._prefix: url = self.CONNECTION_CLS.scheme + "://" + self._host + url - return super(HttpClient, self).log_request( - fp, method, url, params, payload, headers, body, compression_time - ) + return super(HttpClient, self).log_request(fp, method, url, params, payload, headers, body, compression_time) @staticmethod def _compress(data, method="gzip", level=None): @@ -419,11 +409,9 @@ def send_request( method=self._compression_method, level=self._compression_level, ) - content_encoding = self._compression_method - else: - content_encoding = "Identity" - - merged_headers["Content-Encoding"] = content_encoding + merged_headers["Content-Encoding"] = self._compression_method + elif self._default_content_encoding_header: + merged_headers["Content-Encoding"] = self._default_content_encoding_header request_id = self.log_request( self._audit_log_fp, @@ -441,16 +429,16 @@ def send_request( try: response = self._connection.request_encode_url( - method, - path, - fields=params, - body=body, - headers=merged_headers, - **self._urlopen_kwargs + method, path, fields=params, body=body, headers=merged_headers, **self._urlopen_kwargs ) except urllib3.exceptions.HTTPError as e: self.log_response( - self._audit_log_fp, request_id, 0, None, None, connection, + self._audit_log_fp, + request_id, + 0, + None, + None, + connection, ) # All urllib3 HTTP errors should be treated as a network # interface exception. @@ -489,6 +477,7 @@ def __init__( compression_method="gzip", max_payload_size_in_bytes=1000000, audit_log_fp=None, + default_content_encoding_header="Identity", ): proxy = self._parse_proxy(proxy_scheme, proxy_host, None, None, None) if proxy and proxy.scheme == "https": @@ -515,6 +504,7 @@ def __init__( compression_method, max_payload_size_in_bytes, audit_log_fp, + default_content_encoding_header, ) @@ -536,9 +526,7 @@ def _supportability_request(params, payload, body, compression_time): "Supportability/Python/Collector/%s/ZLIB/Bytes" % agent_method, len(body), ) - internal_metric( - "Supportability/Python/Collector/ZLIB/Bytes", len(body) - ) + internal_metric("Supportability/Python/Collector/ZLIB/Bytes", len(body)) internal_metric( "Supportability/Python/Collector/%s/ZLIB/Compress" % agent_method, compression_time, @@ -548,28 +536,21 @@ def _supportability_request(params, payload, body, compression_time): len(payload), ) # Top level metric to aggregate overall bytes being sent - internal_metric( - "Supportability/Python/Collector/Output/Bytes", len(payload) - ) + internal_metric("Supportability/Python/Collector/Output/Bytes", len(payload)) @staticmethod def _supportability_response(status, exc, connection="direct"): if exc or not 200 <= status < 300: internal_count_metric("Supportability/Python/Collector/Failures", 1) - internal_count_metric( - "Supportability/Python/Collector/Failures/%s" % connection, 1 - ) + internal_count_metric("Supportability/Python/Collector/Failures/%s" % connection, 1) if exc: internal_count_metric( - "Supportability/Python/Collector/Exception/" - "%s" % callable_name(exc), + "Supportability/Python/Collector/Exception/" "%s" % callable_name(exc), 1, ) else: - internal_count_metric( - "Supportability/Python/Collector/HTTPError/%d" % status, 1 - ) + internal_count_metric("Supportability/Python/Collector/HTTPError/%d" % status, 1) class ApplicationModeClient(SupportabilityMixin, HttpClient): @@ -578,33 +559,31 @@ class ApplicationModeClient(SupportabilityMixin, HttpClient): class DeveloperModeClient(SupportabilityMixin, BaseClient): RESPONSES = { - "preconnect": {u"redirect_host": u"fake-collector.newrelic.com"}, + "preconnect": {"redirect_host": "fake-collector.newrelic.com"}, "agent_settings": [], "connect": { - u"js_agent_loader": u"", - u"js_agent_file": u"fake-js-agent.newrelic.com/nr-0.min.js", - u"browser_key": u"1234567890", - u"browser_monitoring.loader_version": u"0", - u"beacon": u"fake-beacon.newrelic.com", - u"error_beacon": u"fake-jserror.newrelic.com", - u"apdex_t": 0.5, - u"encoding_key": u"1111111111111111111111111111111111111111", - u"entity_guid": u"DEVELOPERMODEENTITYGUID", - u"agent_run_id": u"1234567", - u"product_level": 50, - u"trusted_account_ids": [12345], - u"trusted_account_key": u"12345", - u"url_rules": [], - u"collect_errors": True, - u"account_id": u"12345", - u"cross_process_id": u"12345#67890", - u"messages": [ - {u"message": u"Reporting to fake collector", u"level": u"INFO"} - ], - u"sampling_rate": 0, - u"collect_traces": True, - u"collect_span_events": True, - u"data_report_period": 60, + "js_agent_loader": "", + "js_agent_file": "fake-js-agent.newrelic.com/nr-0.min.js", + "browser_key": "1234567890", + "browser_monitoring.loader_version": "0", + "beacon": "fake-beacon.newrelic.com", + "error_beacon": "fake-jserror.newrelic.com", + "apdex_t": 0.5, + "encoding_key": "1111111111111111111111111111111111111111", + "entity_guid": "DEVELOPERMODEENTITYGUID", + "agent_run_id": "1234567", + "product_level": 50, + "trusted_account_ids": [12345], + "trusted_account_key": "12345", + "url_rules": [], + "collect_errors": True, + "account_id": "12345", + "cross_process_id": "12345#67890", + "messages": [{"message": "Reporting to fake collector", "level": "INFO"}], + "sampling_rate": 0, + "collect_traces": True, + "collect_span_events": True, + "data_report_period": 60, }, "metric_data": None, "get_agent_commands": [], @@ -648,7 +627,11 @@ def send_request( payload = {"return_value": result} response_data = json_encode(payload).encode("utf-8") self.log_response( - self._audit_log_fp, request_id, 200, {}, response_data, + self._audit_log_fp, + request_id, + 200, + {}, + response_data, ) return 200, response_data diff --git a/newrelic/common/async_wrapper.py b/newrelic/common/async_wrapper.py index c5f95308d..2d3db2b4b 100644 --- a/newrelic/common/async_wrapper.py +++ b/newrelic/common/async_wrapper.py @@ -18,7 +18,9 @@ is_coroutine_callable, is_asyncio_coroutine, is_generator_function, + is_async_generator_function, ) +from newrelic.packages import six def evaluate_wrapper(wrapper_string, wrapped, trace): @@ -29,7 +31,6 @@ def evaluate_wrapper(wrapper_string, wrapped, trace): def coroutine_wrapper(wrapped, trace): - WRAPPER = textwrap.dedent(""" @functools.wraps(wrapped) async def wrapper(*args, **kwargs): @@ -61,29 +62,76 @@ def wrapper(*args, **kwargs): return wrapped -def generator_wrapper(wrapped, trace): - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - g = wrapped(*args, **kwargs) - value = None - with trace: - while True: +if six.PY3: + def generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + with trace: + result = yield from wrapped(*args, **kwargs) + return result + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped +else: + def generator_wrapper(wrapped, trace): + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: try: - yielded = g.send(value) + yielded = g.send(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + g.close() + raise + except BaseException as e: + yielded = g.throw(e) + else: + yielded = g.send(sent) except StopIteration: - break + return + return wrapper - try: - value = yield yielded - except BaseException as e: - value = yield g.throw(type(e), e) - return wrapper +def async_generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: + try: + yielded = await g.asend(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + await g.aclose() + raise + except BaseException as e: + yielded = await g.athrow(e) + else: + yielded = await g.asend(sent) + except StopAsyncIteration: + return + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped def async_wrapper(wrapped): if is_coroutine_callable(wrapped): return coroutine_wrapper + elif is_async_generator_function(wrapped): + return async_generator_wrapper elif is_generator_function(wrapped): if is_asyncio_coroutine(wrapped): return awaitable_generator_wrapper diff --git a/newrelic/common/coroutine.py b/newrelic/common/coroutine.py index cf4c91f85..33a4922f5 100644 --- a/newrelic/common/coroutine.py +++ b/newrelic/common/coroutine.py @@ -43,3 +43,11 @@ def _iscoroutinefunction_tornado(fn): def is_coroutine_callable(wrapped): return is_coroutine_function(wrapped) or is_coroutine_function(getattr(wrapped, "__call__", None)) + + +if hasattr(inspect, 'isasyncgenfunction'): + def is_async_generator_function(wrapped): + return inspect.isasyncgenfunction(wrapped) +else: + def is_async_generator_function(wrapped): + return False diff --git a/newrelic/common/metric_utils.py b/newrelic/common/metric_utils.py new file mode 100644 index 000000000..ebffe8332 --- /dev/null +++ b/newrelic/common/metric_utils.py @@ -0,0 +1,35 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module implements functions for creating a unique identity from a name and set of tags for use in dimensional metrics. +""" + +from newrelic.core.attribute import process_user_attribute + + +def create_metric_identity(name, tags=None): + if tags: + # Convert dicts to an iterable of tuples, other iterables should already be in this form + if isinstance(tags, dict): + tags = tags.items() + + # Apply attribute system sanitization. + # process_user_attribute returns (None, None) for results that fail sanitization. + # The filter removes these results from the iterable before creating the frozenset. + tags = frozenset(filter(lambda args: args[0] is not None, map(lambda args: process_user_attribute(*args), tags))) + + tags = tags or None # Set empty iterables after filtering to None + + return (name, tags) diff --git a/newrelic/common/package_version_utils.py b/newrelic/common/package_version_utils.py index 13b816878..aa736a94e 100644 --- a/newrelic/common/package_version_utils.py +++ b/newrelic/common/package_version_utils.py @@ -13,6 +13,45 @@ # limitations under the License. import sys +import warnings + +try: + from functools import cache as _cache_package_versions +except ImportError: + from functools import wraps + from threading import Lock + + _package_version_cache = {} + _package_version_cache_lock = Lock() + + def _cache_package_versions(wrapped): + """ + Threadsafe implementation of caching for _get_package_version. + + Python 2.7 does not have the @functools.cache decorator, and + must be reimplemented with support for clearing the cache. + """ + + @wraps(wrapped) + def _wrapper(name): + if name in _package_version_cache: + return _package_version_cache[name] + + with _package_version_cache_lock: + if name in _package_version_cache: + return _package_version_cache[name] + + version = _package_version_cache[name] = wrapped(name) + return version + + def cache_clear(): + """Cache clear function to mimic @functools.cache""" + with _package_version_cache_lock: + _package_version_cache.clear() + + _wrapper.cache_clear = cache_clear + return _wrapper + # Need to account for 4 possible variations of version declaration specified in (rejected) PEP 396 VERSION_ATTRS = ("__version__", "version", "__version_tuple__", "version_tuple") # nosec @@ -67,23 +106,39 @@ def int_or_str(value): return version +@_cache_package_versions def _get_package_version(name): module = sys.modules.get(name, None) version = None - for attr in VERSION_ATTRS: - try: - version = getattr(module, attr, None) - # Cast any version specified as a list into a tuple. - version = tuple(version) if isinstance(version, list) else version - if version not in NULL_VERSIONS: - return version - except Exception: - pass + + with warnings.catch_warnings(record=True): + for attr in VERSION_ATTRS: + try: + version = getattr(module, attr, None) + + # In certain cases like importlib_metadata.version, version is a callable + # function. + if callable(version): + continue + + # Cast any version specified as a list into a tuple. + version = tuple(version) if isinstance(version, list) else version + if version not in NULL_VERSIONS: + return version + except Exception: + pass # importlib was introduced into the standard library starting in Python3.8. if "importlib" in sys.modules and hasattr(sys.modules["importlib"], "metadata"): try: - version = sys.modules["importlib"].metadata.version(name) # pylint: disable=E1101 + # In Python3.10+ packages_distribution can be checked for as well + if hasattr(sys.modules["importlib"].metadata, "packages_distributions"): # pylint: disable=E1101 + distributions = sys.modules["importlib"].metadata.packages_distributions() # pylint: disable=E1101 + distribution_name = distributions.get(name, name) + else: + distribution_name = name + + version = sys.modules["importlib"].metadata.version(distribution_name) # pylint: disable=E1101 if version not in NULL_VERSIONS: return version except Exception: @@ -95,4 +150,4 @@ def _get_package_version(name): if version not in NULL_VERSIONS: return version except Exception: - pass \ No newline at end of file + pass diff --git a/newrelic/config.py b/newrelic/config.py index dc0093c42..608d59fc3 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -102,6 +102,14 @@ _cache_object = [] + +def _reset_config_parser(): + global _config_object + global _cache_object + _config_object = ConfigParser.RawConfigParser() + _cache_object = [] + + # Mechanism for extracting settings from the configuration for use in # instrumentation modules and extensions. @@ -320,6 +328,8 @@ def _process_configuration(section): _process_setting(section, "api_key", "get", None) _process_setting(section, "host", "get", None) _process_setting(section, "port", "getint", None) + _process_setting(section, "otlp_host", "get", None) + _process_setting(section, "otlp_port", "getint", None) _process_setting(section, "ssl", "getboolean", None) _process_setting(section, "proxy_scheme", "get", None) _process_setting(section, "proxy_host", "get", None) @@ -440,6 +450,7 @@ def _process_configuration(section): ) _process_setting(section, "custom_insights_events.enabled", "getboolean", None) _process_setting(section, "custom_insights_events.max_samples_stored", "getint", None) + _process_setting(section, "ml_insights_events.enabled", "getboolean", None) _process_setting(section, "distributed_tracing.enabled", "getboolean", None) _process_setting(section, "distributed_tracing.exclude_newrelic_header", "getboolean", None) _process_setting(section, "span_events.enabled", "getboolean", None) @@ -499,6 +510,7 @@ def _process_configuration(section): _process_setting(section, "debug.disable_certificate_validation", "getboolean", None) _process_setting(section, "debug.disable_harvest_until_shutdown", "getboolean", None) _process_setting(section, "debug.connect_span_stream_in_developer_mode", "getboolean", None) + _process_setting(section, "debug.otlp_content_encoding", "get", None) _process_setting(section, "cross_application_tracer.enabled", "getboolean", None) _process_setting(section, "message_tracer.segment_parameters_enabled", "getboolean", None) _process_setting(section, "process_host.display_name", "get", None) @@ -533,6 +545,7 @@ def _process_configuration(section): None, ) _process_setting(section, "event_harvest_config.harvest_limits.custom_event_data", "getint", None) + _process_setting(section, "event_harvest_config.harvest_limits.ml_event_data", "getint", None) _process_setting(section, "event_harvest_config.harvest_limits.span_event_data", "getint", None) _process_setting(section, "event_harvest_config.harvest_limits.error_event_data", "getint", None) _process_setting(section, "event_harvest_config.harvest_limits.log_event_data", "getint", None) @@ -549,6 +562,9 @@ def _process_configuration(section): _process_setting(section, "application_logging.metrics.enabled", "getboolean", None) _process_setting(section, "application_logging.local_decorating.enabled", "getboolean", None) + _process_setting(section, "machine_learning.enabled", "getboolean", None) + _process_setting(section, "machine_learning.inference_events_value.enabled", "getboolean", None) + # Loading of configuration from specified file and for specified # deployment environment. Can also indicate whether configuration @@ -557,6 +573,11 @@ def _process_configuration(section): _configuration_done = False +def _reset_configuration_done(): + global _configuration_done + _configuration_done = False + + def _process_app_name_setting(): # Do special processing to handle the case where the application # name was actually a semicolon separated list of names. In this @@ -875,6 +896,10 @@ def apply_local_high_security_mode_setting(settings): settings.custom_insights_events.enabled = False _logger.info(log_template, "custom_insights_events.enabled", True, False) + if settings.ml_insights_events.enabled: + settings.ml_insights_events.enabled = False + _logger.info(log_template, "ml_insights_events.enabled", True, False) + if settings.message_tracer.segment_parameters_enabled: settings.message_tracer.segment_parameters_enabled = False _logger.info(log_template, "message_tracer.segment_parameters_enabled", True, False) @@ -883,6 +908,10 @@ def apply_local_high_security_mode_setting(settings): settings.application_logging.forwarding.enabled = False _logger.info(log_template, "application_logging.forwarding.enabled", True, False) + if settings.machine_learning.inference_events_value.enabled: + settings.machine_learning.inference_events_value.enabled = False + _logger.info(log_template, "machine_learning.inference_events_value.enabled", True, False) + return settings @@ -1245,7 +1274,6 @@ def _process_wsgi_application_configuration(): for section in _config_object.sections(): if not section.startswith("wsgi-application:"): continue - enabled = False try: @@ -2264,6 +2292,87 @@ def _process_module_builtin_defaults(): "instrument_graphql_validate", ) + _process_module_definition( + "google.cloud.firestore_v1.base_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_base_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.bulk_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_bulk_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_transaction", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_transaction", + ) + _process_module_definition( "ariadne.asgi", "newrelic.hooks.framework_ariadne", @@ -2346,11 +2455,6 @@ def _process_module_builtin_defaults(): "newrelic.hooks.messagebroker_kafkapython", "instrument_kafka_heartbeat", ) - _process_module_definition( - "kafka.consumer.group", - "newrelic.hooks.messagebroker_kafkapython", - "instrument_kafka_consumer_group", - ) _process_module_definition( "logging", @@ -2368,6 +2472,11 @@ def _process_module_builtin_defaults(): "newrelic.hooks.logger_loguru", "instrument_loguru_logger", ) + _process_module_definition( + "structlog._base", + "newrelic.hooks.logger_structlog", + "instrument_structlog__base", + ) _process_module_definition( "paste.httpserver", @@ -2669,18 +2778,6 @@ def _process_module_builtin_defaults(): "aioredis.connection", "newrelic.hooks.datastore_aioredis", "instrument_aioredis_connection" ) - _process_module_definition( - "redis.asyncio.client", "newrelic.hooks.datastore_aioredis", "instrument_aioredis_client" - ) - - _process_module_definition( - "redis.asyncio.commands", "newrelic.hooks.datastore_aioredis", "instrument_aioredis_client" - ) - - _process_module_definition( - "redis.asyncio.connection", "newrelic.hooks.datastore_aioredis", "instrument_aioredis_connection" - ) - # v7 and below _process_module_definition( "elasticsearch.client", @@ -2837,6 +2934,21 @@ def _process_module_builtin_defaults(): "instrument_pymongo_collection", ) + # Redis v4.2+ + _process_module_definition( + "redis.asyncio.client", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_client" + ) + + # Redis v4.2+ + _process_module_definition( + "redis.asyncio.commands", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_client" + ) + + # Redis v4.2+ + _process_module_definition( + "redis.asyncio.connection", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_connection" + ) + _process_module_definition( "redis.connection", "newrelic.hooks.datastore_redis", @@ -2844,6 +2956,10 @@ def _process_module_builtin_defaults(): ) _process_module_definition("redis.client", "newrelic.hooks.datastore_redis", "instrument_redis_client") + _process_module_definition( + "redis.commands.cluster", "newrelic.hooks.datastore_redis", "instrument_redis_commands_cluster" + ) + _process_module_definition( "redis.commands.core", "newrelic.hooks.datastore_redis", "instrument_redis_commands_core" ) @@ -2890,6 +3006,756 @@ def _process_module_builtin_defaults(): ) _process_module_definition("tastypie.api", "newrelic.hooks.component_tastypie", "instrument_tastypie_api") + _process_module_definition( + "sklearn.metrics", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_metrics", + ) + + _process_module_definition( + "sklearn.tree._classes", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_tree_models", + ) + # In scikit-learn < 0.21 the model classes are in tree.py instead of _classes.py. + _process_module_definition( + "sklearn.tree.tree", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_tree_models", + ) + + _process_module_definition( + "sklearn.compose._column_transformer", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_compose_models", + ) + + _process_module_definition( + "sklearn.compose._target", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_compose_models", + ) + + _process_module_definition( + "sklearn.covariance._empirical_covariance", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_models", + ) + + _process_module_definition( + "sklearn.covariance.empirical_covariance_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_models", + ) + + _process_module_definition( + "sklearn.covariance.shrunk_covariance_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_shrunk_models", + ) + + _process_module_definition( + "sklearn.covariance._shrunk_covariance", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_shrunk_models", + ) + + _process_module_definition( + "sklearn.covariance.robust_covariance_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_models", + ) + + _process_module_definition( + "sklearn.covariance._robust_covariance", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_models", + ) + + _process_module_definition( + "sklearn.covariance.graph_lasso_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_graph_models", + ) + + _process_module_definition( + "sklearn.covariance._graph_lasso", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_graph_models", + ) + + _process_module_definition( + "sklearn.covariance.elliptic_envelope", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_models", + ) + + _process_module_definition( + "sklearn.covariance._elliptic_envelope", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_covariance_models", + ) + + _process_module_definition( + "sklearn.ensemble._bagging", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_bagging_models", + ) + + _process_module_definition( + "sklearn.ensemble.bagging", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_bagging_models", + ) + + _process_module_definition( + "sklearn.ensemble._forest", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_forest_models", + ) + + _process_module_definition( + "sklearn.ensemble.forest", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_forest_models", + ) + + _process_module_definition( + "sklearn.ensemble._iforest", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_iforest_models", + ) + + _process_module_definition( + "sklearn.ensemble.iforest", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_iforest_models", + ) + + _process_module_definition( + "sklearn.ensemble._weight_boosting", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_weight_boosting_models", + ) + + _process_module_definition( + "sklearn.ensemble.weight_boosting", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_weight_boosting_models", + ) + + _process_module_definition( + "sklearn.ensemble._gb", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_gradient_boosting_models", + ) + + _process_module_definition( + "sklearn.ensemble.gradient_boosting", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_gradient_boosting_models", + ) + + _process_module_definition( + "sklearn.ensemble._voting", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_voting_models", + ) + + _process_module_definition( + "sklearn.ensemble.voting_classifier", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_voting_models", + ) + + _process_module_definition( + "sklearn.ensemble._stacking", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_stacking_models", + ) + + _process_module_definition( + "sklearn.ensemble._hist_gradient_boosting.gradient_boosting", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_ensemble_hist_models", + ) + + _process_module_definition( + "sklearn.linear_model._base", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model.base", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model._bayes", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_bayes_models", + ) + + _process_module_definition( + "sklearn.linear_model.bayes", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_bayes_models", + ) + + _process_module_definition( + "sklearn.linear_model._least_angle", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_least_angle_models", + ) + + _process_module_definition( + "sklearn.linear_model.least_angle", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_least_angle_models", + ) + + _process_module_definition( + "sklearn.linear_model.coordinate_descent", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_coordinate_descent_models", + ) + + _process_module_definition( + "sklearn.linear_model._coordinate_descent", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_coordinate_descent_models", + ) + + _process_module_definition( + "sklearn.linear_model._glm", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_GLM_models", + ) + + _process_module_definition( + "sklearn.linear_model._huber", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model.huber", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model._stochastic_gradient", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_stochastic_gradient_models", + ) + + _process_module_definition( + "sklearn.linear_model.stochastic_gradient", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_stochastic_gradient_models", + ) + + _process_module_definition( + "sklearn.linear_model._ridge", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_ridge_models", + ) + + _process_module_definition( + "sklearn.linear_model.ridge", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_ridge_models", + ) + + _process_module_definition( + "sklearn.linear_model._logistic", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_logistic_models", + ) + + _process_module_definition( + "sklearn.linear_model.logistic", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_logistic_models", + ) + + _process_module_definition( + "sklearn.linear_model._omp", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_OMP_models", + ) + + _process_module_definition( + "sklearn.linear_model.omp", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_OMP_models", + ) + + _process_module_definition( + "sklearn.linear_model._passive_aggressive", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_passive_aggressive_models", + ) + + _process_module_definition( + "sklearn.linear_model.passive_aggressive", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_passive_aggressive_models", + ) + + _process_module_definition( + "sklearn.linear_model._perceptron", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model.perceptron", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model._quantile", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model._ransac", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model.ransac", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model._theil_sen", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.linear_model.theil_sen", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_linear_models", + ) + + _process_module_definition( + "sklearn.cross_decomposition._pls", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cross_decomposition_models", + ) + + _process_module_definition( + "sklearn.cross_decomposition.pls_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cross_decomposition_models", + ) + + _process_module_definition( + "sklearn.discriminant_analysis", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_discriminant_analysis_models", + ) + + _process_module_definition( + "sklearn.gaussian_process._gpc", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_gaussian_process_models", + ) + + _process_module_definition( + "sklearn.gaussian_process.gpc", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_gaussian_process_models", + ) + + _process_module_definition( + "sklearn.gaussian_process._gpr", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_gaussian_process_models", + ) + + _process_module_definition( + "sklearn.gaussian_process.gpr", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_gaussian_process_models", + ) + + _process_module_definition( + "sklearn.dummy", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_dummy_models", + ) + + _process_module_definition( + "sklearn.feature_selection._rfe", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_feature_selection_rfe_models", + ) + + _process_module_definition( + "sklearn.feature_selection.rfe", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_feature_selection_rfe_models", + ) + + _process_module_definition( + "sklearn.feature_selection._variance_threshold", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_feature_selection_models", + ) + + _process_module_definition( + "sklearn.feature_selection.variance_threshold", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_feature_selection_models", + ) + + _process_module_definition( + "sklearn.feature_selection._from_model", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_feature_selection_models", + ) + + _process_module_definition( + "sklearn.feature_selection.from_model", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_feature_selection_models", + ) + + _process_module_definition( + "sklearn.feature_selection._sequential", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_feature_selection_models", + ) + + _process_module_definition( + "sklearn.kernel_ridge", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_kernel_ridge_models", + ) + + _process_module_definition( + "sklearn.neural_network._multilayer_perceptron", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neural_network_models", + ) + + _process_module_definition( + "sklearn.neural_network.multilayer_perceptron", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neural_network_models", + ) + + _process_module_definition( + "sklearn.neural_network._rbm", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neural_network_models", + ) + + _process_module_definition( + "sklearn.neural_network.rbm", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neural_network_models", + ) + + _process_module_definition( + "sklearn.calibration", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_calibration_models", + ) + + _process_module_definition( + "sklearn.cluster._affinity_propagation", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster.affinity_propagation_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster._agglomerative", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_agglomerative_models", + ) + + _process_module_definition( + "sklearn.cluster.hierarchical", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_agglomerative_models", + ) + + _process_module_definition( + "sklearn.cluster._birch", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster.birch", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster._bisect_k_means", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_kmeans_models", + ) + + _process_module_definition( + "sklearn.cluster._dbscan", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster.dbscan_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster._feature_agglomeration", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster._kmeans", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_kmeans_models", + ) + + _process_module_definition( + "sklearn.cluster.k_means_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_kmeans_models", + ) + + _process_module_definition( + "sklearn.cluster._mean_shift", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster.mean_shift_", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster._optics", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_models", + ) + + _process_module_definition( + "sklearn.cluster._spectral", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_clustering_models", + ) + + _process_module_definition( + "sklearn.cluster.spectral", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_clustering_models", + ) + + _process_module_definition( + "sklearn.cluster._bicluster", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_clustering_models", + ) + + _process_module_definition( + "sklearn.cluster.bicluster", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_cluster_clustering_models", + ) + + _process_module_definition( + "sklearn.multiclass", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_multiclass_models", + ) + + _process_module_definition( + "sklearn.multioutput", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_multioutput_models", + ) + + _process_module_definition( + "sklearn.naive_bayes", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_naive_bayes_models", + ) + + _process_module_definition( + "sklearn.model_selection._search", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_model_selection_models", + ) + + _process_module_definition( + "sklearn.mixture._bayesian_mixture", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_mixture_models", + ) + + _process_module_definition( + "sklearn.mixture.bayesian_mixture", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_mixture_models", + ) + + _process_module_definition( + "sklearn.mixture._gaussian_mixture", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_mixture_models", + ) + + _process_module_definition( + "sklearn.mixture.gaussian_mixture", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_mixture_models", + ) + + _process_module_definition( + "sklearn.pipeline", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_pipeline_models", + ) + + _process_module_definition( + "sklearn.semi_supervised._label_propagation", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_semi_supervised_models", + ) + + _process_module_definition( + "sklearn.semi_supervised._self_training", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_semi_supervised_models", + ) + + _process_module_definition( + "sklearn.semi_supervised.label_propagation", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_semi_supervised_models", + ) + + _process_module_definition( + "sklearn.svm._classes", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_svm_models", + ) + + _process_module_definition( + "sklearn.svm.classes", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_svm_models", + ) + + _process_module_definition( + "sklearn.neighbors._classification", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_KRadius_models", + ) + + _process_module_definition( + "sklearn.neighbors.classification", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_KRadius_models", + ) + + _process_module_definition( + "sklearn.neighbors._graph", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_KRadius_models", + ) + + _process_module_definition( + "sklearn.neighbors._kde", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors.kde", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors._lof", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors.lof", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors._nca", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors._nearest_centroid", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors.nearest_centroid", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors._regression", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_KRadius_models", + ) + + _process_module_definition( + "sklearn.neighbors.regression", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_KRadius_models", + ) + + _process_module_definition( + "sklearn.neighbors._unsupervised", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + + _process_module_definition( + "sklearn.neighbors.unsupervised", + "newrelic.hooks.mlmodel_sklearn", + "instrument_sklearn_neighbors_models", + ) + _process_module_definition( "rest_framework.views", "newrelic.hooks.component_djangorestframework", @@ -2922,9 +3788,7 @@ def _process_module_builtin_defaults(): "newrelic.hooks.application_celery", "instrument_celery_worker", ) - # _process_module_definition('celery.loaders.base', - # 'newrelic.hooks.application_celery', - # 'instrument_celery_loaders_base') + _process_module_definition( "celery.execute.trace", "newrelic.hooks.application_celery", @@ -3113,6 +3977,11 @@ def _process_module_entry_points(): _instrumentation_done = False +def _reset_instrumentation_done(): + global _instrumentation_done + _instrumentation_done = False + + def _setup_instrumentation(): global _instrumentation_done diff --git a/newrelic/core/agent.py b/newrelic/core/agent.py index 6ab9571a4..9d9aadab1 100644 --- a/newrelic/core/agent.py +++ b/newrelic/core/agent.py @@ -524,6 +524,33 @@ def record_custom_metrics(self, app_name, metrics): application.record_custom_metrics(metrics) + def record_dimensional_metric(self, app_name, name, value, tags=None): + """Records a basic metric for the named application. If there has + been no prior request to activate the application, the metric is + discarded. + + """ + + application = self._applications.get(app_name, None) + if application is None or not application.active: + return + + application.record_dimensional_metric(name, value, tags) + + def record_dimensional_metrics(self, app_name, metrics): + """Records the metrics for the named application. If there has + been no prior request to activate the application, the metric is + discarded. The metrics should be an iterable yielding tuples + consisting of the name and value. + + """ + + application = self._applications.get(app_name, None) + if application is None or not application.active: + return + + application.record_dimensional_metrics(metrics) + def record_custom_event(self, app_name, event_type, params): application = self._applications.get(app_name, None) if application is None or not application.active: @@ -531,6 +558,13 @@ def record_custom_event(self, app_name, event_type, params): application.record_custom_event(event_type, params) + def record_ml_event(self, app_name, event_type, params): + application = self._applications.get(app_name, None) + if application is None or not application.active: + return + + application.record_ml_event(event_type, params) + def record_log_event(self, app_name, message, level=None, timestamp=None, priority=None): application = self._applications.get(app_name, None) if application is None or not application.active: diff --git a/newrelic/core/agent_protocol.py b/newrelic/core/agent_protocol.py index ba277d4de..dd4dc264f 100644 --- a/newrelic/core/agent_protocol.py +++ b/newrelic/core/agent_protocol.py @@ -38,6 +38,7 @@ global_settings_dump, ) from newrelic.core.internal_metrics import internal_count_metric +from newrelic.core.otlp_utils import OTLP_CONTENT_TYPE, otlp_encode from newrelic.network.exceptions import ( DiscardDataForRequest, ForceAgentDisconnect, @@ -143,7 +144,9 @@ class AgentProtocol(object): "transaction_tracer.record_sql", "strip_exception_messages.enabled", "custom_insights_events.enabled", + "ml_insights_events.enabled", "application_logging.forwarding.enabled", + "machine_learning.inference_events_value.enabled", ) LOGGER_FUNC_MAPPING = { @@ -215,11 +218,16 @@ def __exit__(self, exc, value, tb): def close_connection(self): self.client.close_connection() - def send(self, method, payload=()): + def send( + self, + method, + payload=(), + path="/agent_listener/invoke_raw_method", + ): params, headers, payload = self._to_http(method, payload) try: - response = self.client.send_request(params=params, headers=headers, payload=payload) + response = self.client.send_request(path=path, params=params, headers=headers, payload=payload) except NetworkInterfaceException: # All HTTP errors are currently retried raise RetryDataForRequest @@ -251,7 +259,10 @@ def send(self, method, payload=()): exception = self.STATUS_CODE_RESPONSE.get(status, DiscardDataForRequest) raise exception if status == 200: - return json_decode(data.decode("utf-8"))["return_value"] + return self.decode_response(data) + + def decode_response(self, response): + return json_decode(response.decode("utf-8"))["return_value"] def _to_http(self, method, payload=()): params = dict(self._params) @@ -514,3 +525,77 @@ def connect( # can be modified later settings.aws_lambda_metadata = aws_lambda_metadata return cls(settings, client_cls=client_cls) + + +class OtlpProtocol(AgentProtocol): + def __init__(self, settings, host=None, client_cls=ApplicationModeClient): + if settings.audit_log_file: + audit_log_fp = open(settings.audit_log_file, "a") + else: + audit_log_fp = None + + self.client = client_cls( + host=host or settings.otlp_host, + port=settings.otlp_port or 4318, + proxy_scheme=settings.proxy_scheme, + proxy_host=settings.proxy_host, + proxy_port=settings.proxy_port, + proxy_user=settings.proxy_user, + proxy_pass=settings.proxy_pass, + timeout=settings.agent_limits.data_collector_timeout, + ca_bundle_path=settings.ca_bundle_path, + disable_certificate_validation=settings.debug.disable_certificate_validation, + compression_threshold=settings.agent_limits.data_compression_threshold, + compression_level=settings.agent_limits.data_compression_level, + compression_method=settings.compressed_content_encoding, + max_payload_size_in_bytes=1000000, + audit_log_fp=audit_log_fp, + default_content_encoding_header=None, + ) + + self._params = {} + self._headers = { + "api-key": settings.license_key, + } + + # In Python 2, the JSON is loaded with unicode keys and values; + # however, the header name must be a non-unicode value when given to + # the HTTP library. This code converts the header name from unicode to + # non-unicode. + if settings.request_headers_map: + for k, v in settings.request_headers_map.items(): + if not isinstance(k, str): + k = k.encode("utf-8") + self._headers[k] = v + + # Content-Type should be protobuf, but falls back to JSON if protobuf is not installed. + self._headers["Content-Type"] = OTLP_CONTENT_TYPE + self._run_token = settings.agent_run_id + + # Logging + self._proxy_host = settings.proxy_host + self._proxy_port = settings.proxy_port + self._proxy_user = settings.proxy_user + + # Do not access configuration anywhere inside the class + self.configuration = settings + + @classmethod + def connect( + cls, + app_name, + linked_applications, + environment, + settings, + client_cls=ApplicationModeClient, + ): + with cls(settings, client_cls=client_cls) as protocol: + pass + + return protocol + + def _to_http(self, method, payload=()): + return {}, self._headers, otlp_encode(payload) + + def decode_response(self, response): + return response.decode("utf-8") diff --git a/newrelic/core/application.py b/newrelic/core/application.py index 7be217428..82cdf8a9a 100644 --- a/newrelic/core/application.py +++ b/newrelic/core/application.py @@ -510,6 +510,9 @@ def connect_to_data_collector(self, activate_agent): with self._stats_custom_lock: self._stats_custom_engine.reset_stats(configuration) + with self._stats_lock: + self._stats_engine.reset_stats(configuration) + # Record an initial start time for the reporting period and # clear record of last transaction processed. @@ -860,6 +863,50 @@ def record_custom_metrics(self, metrics): self._global_events_account += 1 self._stats_custom_engine.record_custom_metric(name, value) + def record_dimensional_metric(self, name, value, tags=None): + """Record a dimensional metric against the application independent + of a specific transaction. + + NOTE that this will require locking of the stats engine for + dimensional metrics and so under heavy use will have performance + issues. It is better to record the dimensional metric against an + active transaction as they will then be aggregated at the end of + the transaction when all other metrics are aggregated and so no + additional locking will be required. + + """ + + if not self._active_session: + return + + with self._stats_lock: + self._global_events_account += 1 + self._stats_engine.record_dimensional_metric(name, value, tags) + + def record_dimensional_metrics(self, metrics): + """Record a set of dimensional metrics against the application + independent of a specific transaction. + + NOTE that this will require locking of the stats engine for + dimensional metrics and so under heavy use will have performance + issues. It is better to record the dimensional metric against an + active transaction as they will then be aggregated at the end of + the transaction when all other metrics are aggregated and so no + additional locking will be required. + + """ + + if not self._active_session: + return + + with self._stats_lock: + for metric in metrics: + name, value = metric[:2] + tags = metric[2] if len(metric) >= 3 else None + + self._global_events_account += 1 + self._stats_engine.record_dimensional_metric(name, value, tags) + def record_custom_event(self, event_type, params): if not self._active_session: return @@ -876,6 +923,22 @@ def record_custom_event(self, event_type, params): self._global_events_account += 1 self._stats_engine.record_custom_event(event) + def record_ml_event(self, event_type, params): + if not self._active_session: + return + + settings = self._stats_engine.settings + + if settings is None or not settings.ml_insights_events.enabled: + return + + event = create_custom_event(event_type, params) + + if event: + with self._stats_custom_lock: + self._global_events_account += 1 + self._stats_engine.record_ml_event(event) + def record_log_event(self, message, level=None, timestamp=None, priority=None): if not self._active_session: return @@ -1335,6 +1398,26 @@ def harvest(self, shutdown=False, flexible=False): stats.reset_custom_events() + # Send machine learning events + + if configuration.ml_insights_events.enabled: + ml_events = stats.ml_events + + if ml_events: + if ml_events.num_samples > 0: + ml_event_samples = list(ml_events) + + _logger.debug("Sending machine learning event data for harvest of %r.", self._app_name) + + self._active_session.send_ml_events(ml_events.sampling_info, ml_event_samples) + ml_event_samples = None + + # As per spec + internal_count_metric("Supportability/Events/Customer/Seen", ml_events.num_seen) + internal_count_metric("Supportability/Events/Customer/Sent", ml_events.num_samples) + + stats.reset_ml_events() + # Send log events if ( @@ -1416,11 +1499,14 @@ def harvest(self, shutdown=False, flexible=False): _logger.debug("Normalizing metrics for harvest of %r.", self._app_name) metric_data = stats.metric_data(metric_normalizer) + dimensional_metric_data = stats.dimensional_metric_data(metric_normalizer) _logger.debug("Sending metric data for harvest of %r.", self._app_name) # Send metrics self._active_session.send_metric_data(self._period_start, period_end, metric_data) + if dimensional_metric_data: + self._active_session.send_dimensional_metric_data(self._period_start, period_end, dimensional_metric_data) _logger.debug("Done sending data for harvest of %r.", self._app_name) diff --git a/newrelic/core/attribute.py b/newrelic/core/attribute.py index 372711369..10ae8e459 100644 --- a/newrelic/core/attribute.py +++ b/newrelic/core/attribute.py @@ -180,7 +180,6 @@ def create_user_attributes(attr_dict, attribute_filter): def truncate(text, maxsize=MAX_ATTRIBUTE_LENGTH, encoding="utf-8", ending=None): - # Truncate text so that its byte representation # is no longer than maxsize bytes. @@ -225,7 +224,6 @@ def check_max_int(value, max_int=MAX_64_BIT_INT): def process_user_attribute(name, value, max_length=MAX_ATTRIBUTE_LENGTH, ending=None): - # Perform all necessary checks on a potential attribute. # # Returns: @@ -245,23 +243,22 @@ def process_user_attribute(name, value, max_length=MAX_ATTRIBUTE_LENGTH, ending= value = sanitize(value) except NameIsNotStringException: - _logger.debug("Attribute name must be a string. Dropping " "attribute: %r=%r", name, value) + _logger.debug("Attribute name must be a string. Dropping attribute: %r=%r", name, value) return FAILED_RESULT except NameTooLongException: - _logger.debug("Attribute name exceeds maximum length. Dropping " "attribute: %r=%r", name, value) + _logger.debug("Attribute name exceeds maximum length. Dropping attribute: %r=%r", name, value) return FAILED_RESULT except IntTooLargeException: - _logger.debug("Attribute value exceeds maximum integer value. " "Dropping attribute: %r=%r", name, value) + _logger.debug("Attribute value exceeds maximum integer value. Dropping attribute: %r=%r", name, value) return FAILED_RESULT except CastingFailureException: - _logger.debug("Attribute value cannot be cast to a string. " "Dropping attribute: %r=%r", name, value) + _logger.debug("Attribute value cannot be cast to a string. Dropping attribute: %r=%r", name, value) return FAILED_RESULT else: - # Check length after casting valid_types_text = (six.text_type, six.binary_type) @@ -270,7 +267,7 @@ def process_user_attribute(name, value, max_length=MAX_ATTRIBUTE_LENGTH, ending= trunc_value = truncate(value, maxsize=max_length, ending=ending) if value != trunc_value: _logger.debug( - "Attribute value exceeds maximum length " "(%r bytes). Truncating value: %r=%r.", + "Attribute value exceeds maximum length (%r bytes). Truncating value: %r=%r.", max_length, name, trunc_value, @@ -282,15 +279,31 @@ def process_user_attribute(name, value, max_length=MAX_ATTRIBUTE_LENGTH, ending= def sanitize(value): + """ + Return value unchanged, if it's a valid type that is supported by + Insights. Otherwise, convert value to a string. - # Return value unchanged, if it's a valid type that is supported by - # Insights. Otherwise, convert value to a string. - # - # Raise CastingFailureException, if str(value) somehow fails. + Raise CastingFailureException, if str(value) somehow fails. + """ valid_value_types = (six.text_type, six.binary_type, bool, float, six.integer_types) - if not isinstance(value, valid_value_types): + # When working with numpy, note that numpy has its own `int`s, `str`s, + # et cetera. `numpy.str_` and `numpy.float_` inherit from Python's native + # `str` and `float`, respectively. However, some types, such as `numpy.int_` + # and `numpy.bool_`, do not inherit from `int` and `bool` (respectively). + # In those cases, the valid_value_types check fails and it will try to + # convert these to string, which is not the desired behavior. Checking for + # `type` in lieu of `isinstance` has the potential to impact performance. + + # numpy values have an attribute "item" that returns the closest + # equivalent Python native type. Ex: numpy.int64 -> int + # This is important to utilize in cases like int and bool where + # numpy does not inherit from those classes. This logic is + # determining whether or not the value is a valid_value_type (or + # inherited from one of those types) AND whether it is a numpy + # type (by determining if it has the attribute "item"). + if not isinstance(value, valid_value_types) and not hasattr(value, "item"): original = value try: @@ -298,8 +311,6 @@ def sanitize(value): except Exception: raise CastingFailureException() else: - _logger.debug( - "Attribute value is of type: %r. Casting %r to " "string: %s", type(original), original, value - ) + _logger.debug("Attribute value is of type: %r. Casting %r to string: %s", type(original), original, value) return value diff --git a/newrelic/core/config.py b/newrelic/core/config.py index 03758fc74..8b366f7d7 100644 --- a/newrelic/core/config.py +++ b/newrelic/core/config.py @@ -51,11 +51,14 @@ # By default, Transaction Events and Custom Events have the same size # reservoir. Error Events have a different default size. +# Slow harvest (Every 60 seconds) DEFAULT_RESERVOIR_SIZE = 1200 -CUSTOM_EVENT_RESERVOIR_SIZE = 3600 ERROR_EVENT_RESERVOIR_SIZE = 100 SPAN_EVENT_RESERVOIR_SIZE = 2000 +# Fast harvest (Every 5 seconds, so divide by 12 to get average per minute value) +CUSTOM_EVENT_RESERVOIR_SIZE = 3600 LOG_EVENT_RESERVOIR_SIZE = 10000 +ML_EVENT_RESERVOIR_SIZE = 100000 # settings that should be completely ignored if set server side IGNORED_SERVER_SIDE_SETTINGS = [ @@ -101,6 +104,7 @@ def create_settings(nested): class TopLevelSettings(Settings): _host = None + _otlp_host = None @property def host(self): @@ -112,6 +116,16 @@ def host(self): def host(self, value): self._host = value + @property + def otlp_host(self): + if self._otlp_host: + return self._otlp_host + return default_otlp_host(self.host) + + @otlp_host.setter + def otlp_host(self, value): + self._otlp_host = value + class AttributesSettings(Settings): pass @@ -121,6 +135,14 @@ class GCRuntimeMetricsSettings(Settings): enabled = False +class MachineLearningSettings(Settings): + pass + + +class MachineLearningInferenceEventsValueSettings(Settings): + pass + + class CodeLevelMetricsSettings(Settings): pass @@ -199,6 +221,10 @@ class CustomInsightsEventsSettings(Settings): pass +class MlInsightsEventsSettings(Settings): + pass + + class ProcessHostSettings(Settings): pass @@ -394,6 +420,8 @@ class EventHarvestConfigHarvestLimitSettings(Settings): _settings.application_logging.forwarding = ApplicationLoggingForwardingSettings() _settings.application_logging.local_decorating = ApplicationLoggingLocalDecoratingSettings() _settings.application_logging.metrics = ApplicationLoggingMetricsSettings() +_settings.machine_learning = MachineLearningSettings() +_settings.machine_learning.inference_events_value = MachineLearningInferenceEventsValueSettings() _settings.attributes = AttributesSettings() _settings.browser_monitoring = BrowserMonitorSettings() _settings.browser_monitoring.attributes = BrowserMonitorAttributesSettings() @@ -401,6 +429,7 @@ class EventHarvestConfigHarvestLimitSettings(Settings): _settings.console = ConsoleSettings() _settings.cross_application_tracer = CrossApplicationTracerSettings() _settings.custom_insights_events = CustomInsightsEventsSettings() +_settings.ml_insights_events = MlInsightsEventsSettings() _settings.datastore_tracer = DatastoreTracerSettings() _settings.datastore_tracer.database_name_reporting = DatastoreTracerDatabaseNameReportingSettings() _settings.datastore_tracer.instance_reporting = DatastoreTracerInstanceReportingSettings() @@ -571,6 +600,24 @@ def default_host(license_key): return host +def default_otlp_host(host): + HOST_MAP = { + "collector.newrelic.com": "otlp.nr-data.net", + "collector.eu.newrelic.com": "otlp.eu01.nr-data.net", + "gov-collector.newrelic.com": "gov-otlp.nr-data.net", + "staging-collector.newrelic.com": "staging-otlp.nr-data.net", + "staging-collector.eu.newrelic.com": "staging-otlp.eu01.nr-data.net", + "staging-gov-collector.newrelic.com": "staging-gov-otlp.nr-data.net", + "fake-collector.newrelic.com": "fake-otlp.nr-data.net", + } + otlp_host = HOST_MAP.get(host, None) + if not otlp_host: + default = HOST_MAP["collector.newrelic.com"] + _logger.warn("Unable to find corresponding OTLP host using default %s" % default) + otlp_host = default + return otlp_host + + _LOG_LEVEL = { "CRITICAL": logging.CRITICAL, "ERROR": logging.ERROR, @@ -596,7 +643,9 @@ def default_host(license_key): _settings.ssl = _environ_as_bool("NEW_RELIC_SSL", True) _settings.host = os.environ.get("NEW_RELIC_HOST") +_settings.otlp_host = os.environ.get("NEW_RELIC_OTLP_HOST") _settings.port = int(os.environ.get("NEW_RELIC_PORT", "0")) +_settings.otlp_port = int(os.environ.get("NEW_RELIC_OTLP_PORT", "0")) _settings.agent_run_id = None _settings.entity_guid = None @@ -697,6 +746,7 @@ def default_host(license_key): _settings.transaction_events.attributes.include = [] _settings.custom_insights_events.enabled = True +_settings.ml_insights_events.enabled = False _settings.distributed_tracing.enabled = _environ_as_bool("NEW_RELIC_DISTRIBUTED_TRACING_ENABLED", default=True) _settings.distributed_tracing.exclude_newrelic_header = False @@ -789,6 +839,10 @@ def default_host(license_key): "NEW_RELIC_CUSTOM_INSIGHTS_EVENTS_MAX_SAMPLES_STORED", CUSTOM_EVENT_RESERVOIR_SIZE ) +_settings.event_harvest_config.harvest_limits.ml_event_data = _environ_as_int( + "NEW_RELIC_ML_INSIGHTS_EVENTS_MAX_SAMPLES_STORED", ML_EVENT_RESERVOIR_SIZE +) + _settings.event_harvest_config.harvest_limits.span_event_data = _environ_as_int( "NEW_RELIC_SPAN_EVENTS_MAX_SAMPLES_STORED", SPAN_EVENT_RESERVOIR_SIZE ) @@ -826,6 +880,7 @@ def default_host(license_key): _settings.debug.log_untrusted_distributed_trace_keys = False _settings.debug.disable_harvest_until_shutdown = False _settings.debug.connect_span_stream_in_developer_mode = False +_settings.debug.otlp_content_encoding = None _settings.message_tracer.segment_parameters_enabled = True @@ -868,6 +923,10 @@ def default_host(license_key): _settings.application_logging.local_decorating.enabled = _environ_as_bool( "NEW_RELIC_APPLICATION_LOGGING_LOCAL_DECORATING_ENABLED", default=False ) +_settings.machine_learning.enabled = _environ_as_bool("NEW_RELIC_MACHINE_LEARNING_ENABLED", default=False) +_settings.machine_learning.inference_events_value.enabled = _environ_as_bool( + "NEW_RELIC_MACHINE_LEARNING_INFERENCE_EVENT_VALUE_ENABLED", default=False +) _settings.security.agent.enabled = _environ_as_bool("NEW_RELIC_SECURITY_AGENT_ENABLED", False) _settings.security.enabled = _environ_as_bool("NEW_RELIC_SECURITY_ENABLED", False) @@ -1122,8 +1181,8 @@ def apply_server_side_settings(server_side_config=None, settings=_settings): apply_config_setting(settings_snapshot, name, value) # Overlay with global server side configuration settings. - # global server side configuration always takes precedence over the global - # server side configuration settings. + # global server side configuration always takes precedence over the local + # agent configuration settings. for name, value in server_side_config.items(): apply_config_setting(settings_snapshot, name, value) @@ -1140,6 +1199,16 @@ def apply_server_side_settings(server_side_config=None, settings=_settings): settings_snapshot, "event_harvest_config.harvest_limits.span_event_data", span_event_harvest_limit ) + # Since the server does not override this setting as it's an OTLP setting, + # we must override it here manually by converting it into a per harvest cycle + # value. + apply_config_setting( + settings_snapshot, + "event_harvest_config.harvest_limits.ml_event_data", + # override ml_events / (60s/5s) harvest + settings_snapshot.event_harvest_config.harvest_limits.ml_event_data / 12, + ) + # This will be removed at some future point # Special case for account_id which will be sent instead of # cross_process_id in the future diff --git a/newrelic/core/data_collector.py b/newrelic/core/data_collector.py index 985e37240..269139664 100644 --- a/newrelic/core/data_collector.py +++ b/newrelic/core/data_collector.py @@ -25,21 +25,30 @@ DeveloperModeClient, ServerlessModeClient, ) -from newrelic.core.agent_protocol import AgentProtocol, ServerlessModeProtocol +from newrelic.core.agent_protocol import ( + AgentProtocol, + OtlpProtocol, + ServerlessModeProtocol, +) from newrelic.core.agent_streaming import StreamingRpc from newrelic.core.config import global_settings +from newrelic.core.otlp_utils import encode_metric_data, encode_ml_event_data _logger = logging.getLogger(__name__) class Session(object): PROTOCOL = AgentProtocol + OTLP_PROTOCOL = OtlpProtocol CLIENT = ApplicationModeClient def __init__(self, app_name, linked_applications, environment, settings): self._protocol = self.PROTOCOL.connect( app_name, linked_applications, environment, settings, client_cls=self.CLIENT ) + self._otlp_protocol = self.OTLP_PROTOCOL.connect( + app_name, linked_applications, environment, settings, client_cls=self.CLIENT + ) self._rpc = None @property @@ -112,6 +121,11 @@ def send_custom_events(self, sampling_info, custom_event_data): payload = (self.agent_run_id, sampling_info, custom_event_data) return self._protocol.send("custom_event_data", payload) + def send_ml_events(self, sampling_info, custom_event_data): + """Called to submit sample set for machine learning events.""" + payload = encode_ml_event_data(custom_event_data, str(self.agent_run_id)) + return self._otlp_protocol.send("ml_event_data", payload, path="/v1/logs") + def send_span_events(self, sampling_info, span_event_data): """Called to submit sample set for span events.""" @@ -128,6 +142,20 @@ def send_metric_data(self, start_time, end_time, metric_data): payload = (self.agent_run_id, start_time, end_time, metric_data) return self._protocol.send("metric_data", payload) + def send_dimensional_metric_data(self, start_time, end_time, metric_data): + """Called to submit dimensional metric data for specified period of time. + Time values are seconds since UNIX epoch as returned by the + time.time() function. The metric data should be iterable of + specific metrics. + + NOTE: This data is sent not sent to the normal agent endpoints but is sent + to the OTLP API endpoints to keep the entity separate. This is for use + with the machine learning integration only. + """ + + payload = encode_metric_data(metric_data, start_time, end_time) + return self._otlp_protocol.send("dimensional_metric_data", payload, path="/v1/metrics") + def send_log_events(self, sampling_info, log_event_data): """Called to submit sample set for log events.""" diff --git a/newrelic/core/environment.py b/newrelic/core/environment.py index 1306816ef..9bca085a3 100644 --- a/newrelic/core/environment.py +++ b/newrelic/core/environment.py @@ -17,10 +17,10 @@ """ +import logging import os import platform import sys -import sysconfig import newrelic from newrelic.common.package_version_utils import get_package_version @@ -29,12 +29,15 @@ physical_processor_count, total_physical_memory, ) +from newrelic.packages.isort import stdlibs as isort_stdlibs try: import newrelic.core._thread_utilization except ImportError: pass +_logger = logging.getLogger(__name__) + def environment_settings(): """Returns an array of arrays of environment settings""" @@ -195,8 +198,7 @@ def environment_settings(): env.extend(dispatcher) # Module information. - purelib = sysconfig.get_path("purelib") - platlib = sysconfig.get_path("platlib") + stdlib_builtin_module_names = _get_stdlib_builtin_module_names() plugins = [] @@ -208,29 +210,58 @@ def environment_settings(): # list for name, module in sys.modules.copy().items(): # Exclude lib.sub_paths as independent modules except for newrelic.hooks. - if "." in name and not name.startswith("newrelic.hooks."): + nr_hook = name.startswith("newrelic.hooks.") + if "." in name and not nr_hook or name.startswith("_"): continue + # If the module isn't actually loaded (such as failed relative imports # in Python 2.7), the module will be None and should not be reported. - if not module: + try: + if not module: + continue + except Exception: + # if the application uses generalimport to manage optional depedencies, + # it's possible that generalimport.MissingOptionalDependency is raised. + # In this case, we should not report the module as it is not actually loaded and + # is not a runtime dependency of the application. + # continue + # Exclude standard library/built-in modules. - # Third-party modules can be installed in either purelib or platlib directories. - # See https://docs.python.org/3/library/sysconfig.html#installation-paths. - if ( - not hasattr(module, "__file__") - or not module.__file__ - or not module.__file__.startswith(purelib) - or not module.__file__.startswith(platlib) - ): + if name in stdlib_builtin_module_names: continue try: version = get_package_version(name) - plugins.append("%s (%s)" % (name, version)) except Exception: - plugins.append(name) + version = None + + # If it has no version it's likely not a real package so don't report it unless + # it's a new relic hook. + if version or nr_hook: + plugins.append("%s (%s)" % (name, version)) env.append(("Plugin List", plugins)) return env + + +def _get_stdlib_builtin_module_names(): + builtins = set(sys.builtin_module_names) + # Since sys.stdlib_module_names is not available in versions of python below 3.10, + # use isort's hardcoded stdlibs instead. + python_version = sys.version_info[0:2] + if python_version < (3,): + stdlibs = isort_stdlibs.py27.stdlib + elif (3, 7) <= python_version < (3, 8): + stdlibs = isort_stdlibs.py37.stdlib + elif python_version < (3, 9): + stdlibs = isort_stdlibs.py38.stdlib + elif python_version < (3, 10): + stdlibs = isort_stdlibs.py39.stdlib + elif python_version >= (3, 10): + stdlibs = sys.stdlib_module_names + else: + _logger.warn("Unsupported Python version. Unable to determine stdlibs.") + return builtins + return builtins | stdlibs diff --git a/newrelic/core/otlp_utils.py b/newrelic/core/otlp_utils.py new file mode 100644 index 000000000..e78a63603 --- /dev/null +++ b/newrelic/core/otlp_utils.py @@ -0,0 +1,243 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides common utilities for interacting with OTLP protocol buffers. + +The serialization implemented here attempts to use protobuf as an encoding, but falls +back to JSON when encoutering exceptions unless the content type is explicitly set in debug settings. +""" + +import logging + +from newrelic.common.encoding_utils import json_encode +from newrelic.core.config import global_settings +from newrelic.core.stats_engine import CountStats, TimeStats + +_logger = logging.getLogger(__name__) + +_settings = global_settings() +otlp_content_setting = _settings.debug.otlp_content_encoding +if not otlp_content_setting or otlp_content_setting == "protobuf": + try: + from newrelic.packages.opentelemetry_proto.common_pb2 import AnyValue, KeyValue + from newrelic.packages.opentelemetry_proto.logs_pb2 import ( + LogsData, + ResourceLogs, + ScopeLogs, + ) + from newrelic.packages.opentelemetry_proto.metrics_pb2 import ( + AggregationTemporality, + Metric, + MetricsData, + NumberDataPoint, + ResourceMetrics, + ScopeMetrics, + Sum, + Summary, + SummaryDataPoint, + ) + from newrelic.packages.opentelemetry_proto.resource_pb2 import Resource + + ValueAtQuantile = SummaryDataPoint.ValueAtQuantile + AGGREGATION_TEMPORALITY_DELTA = AggregationTemporality.AGGREGATION_TEMPORALITY_DELTA + OTLP_CONTENT_TYPE = "application/x-protobuf" + + otlp_content_setting = "protobuf" # Explicitly set to overwrite None values + except Exception: + if otlp_content_setting == "protobuf": + raise # Reraise exception if content type explicitly set + # Fallback to JSON + otlp_content_setting = "json" + + +if otlp_content_setting == "json": + AnyValue = dict + KeyValue = dict + Metric = dict + MetricsData = dict + NumberDataPoint = dict + Resource = dict + ResourceMetrics = dict + ScopeMetrics = dict + Sum = dict + Summary = dict + SummaryDataPoint = dict + ValueAtQuantile = dict + ResourceLogs = dict + ScopeLogs = dict + LogsData = dict + + AGGREGATION_TEMPORALITY_DELTA = 1 + OTLP_CONTENT_TYPE = "application/json" + + +def otlp_encode(payload): + if type(payload) is dict: # pylint: disable=C0123 + _logger.warning( + "Using OTLP integration while protobuf is not installed. This may result in larger payload sizes and data loss." + ) + return json_encode(payload).encode("utf-8") + return payload.SerializeToString() + + +def create_key_value(key, value): + if isinstance(value, bool): + return KeyValue(key=key, value=AnyValue(bool_value=value)) + elif isinstance(value, int): + return KeyValue(key=key, value=AnyValue(int_value=value)) + elif isinstance(value, float): + return KeyValue(key=key, value=AnyValue(double_value=value)) + elif isinstance(value, str): + return KeyValue(key=key, value=AnyValue(string_value=value)) + # Technically AnyValue accepts array, kvlist, and bytes however, since + # those are not valid custom attribute types according to our api spec, + # we will not bother to support them here either. + else: + _logger.warning("Unsupported attribute value type %s: %s." % (key, value)) + + +def create_key_values_from_iterable(iterable): + if not iterable: + return None + elif isinstance(iterable, dict): + iterable = iterable.items() + + # The create_key_value list may return None if the value is an unsupported type + # so filter None values out before returning. + return list( + filter( + lambda i: i is not None, + (create_key_value(key, value) for key, value in iterable), + ) + ) + + +def create_resource(attributes=None): + attributes = attributes or {"instrumentation.provider": "newrelic-opentelemetry-python-ml"} + return Resource(attributes=create_key_values_from_iterable(attributes)) + + +def TimeStats_to_otlp_data_point(self, start_time, end_time, attributes=None): + data = SummaryDataPoint( + time_unix_nano=int(end_time * 1e9), # Time of current harvest + start_time_unix_nano=int(start_time * 1e9), # Time of last harvest + attributes=attributes, + count=int(self[0]), + sum=float(self[1]), + quantile_values=[ + ValueAtQuantile(quantile=0.0, value=float(self[3])), # Min Value + ValueAtQuantile(quantile=1.0, value=float(self[4])), # Max Value + ], + ) + return data + + +def CountStats_to_otlp_data_point(self, start_time, end_time, attributes=None): + data = NumberDataPoint( + time_unix_nano=int(end_time * 1e9), # Time of current harvest + start_time_unix_nano=int(start_time * 1e9), # Time of last harvest + attributes=attributes, + as_int=int(self[0]), + ) + return data + + +def stats_to_otlp_metrics(metric_data, start_time, end_time): + """ + Generator producing protos for Summary and Sum metrics, for CountStats and TimeStats respectively. + + Individual Metric protos must be entirely one type of metric data point. For mixed metric types we have to + separate the types and report multiple metrics, one for each type. + """ + for name, metric_container in metric_data: + # Types are checked here using type() instead of isinstance, as CountStats is a subclass of TimeStats. + # Imporperly checking with isinstance will lead to count metrics being encoded and reported twice. + if any(type(metric) is CountStats for metric in metric_container.values()): # pylint: disable=C0123 + # Metric contains Sum metric data points. + yield Metric( + name=name, + sum=Sum( + aggregation_temporality=AGGREGATION_TEMPORALITY_DELTA, + is_monotonic=True, + data_points=[ + CountStats_to_otlp_data_point( + value, + start_time=start_time, + end_time=end_time, + attributes=create_key_values_from_iterable(tags), + ) + for tags, value in metric_container.items() + if type(value) is CountStats # pylint: disable=C0123 + ], + ), + ) + if any(type(metric) is TimeStats for metric in metric_container.values()): # pylint: disable=C0123 + # Metric contains Summary metric data points. + yield Metric( + name=name, + summary=Summary( + data_points=[ + TimeStats_to_otlp_data_point( + value, + start_time=start_time, + end_time=end_time, + attributes=create_key_values_from_iterable(tags), + ) + for tags, value in metric_container.items() + if type(value) is TimeStats # pylint: disable=C0123 + ] + ), + ) + + +def encode_metric_data(metric_data, start_time, end_time, resource=None, scope=None): + resource = resource or create_resource() + return MetricsData( + resource_metrics=[ + ResourceMetrics( + resource=resource, + scope_metrics=[ + ScopeMetrics( + scope=scope, + metrics=list(stats_to_otlp_metrics(metric_data, start_time, end_time)), + ) + ], + ) + ] + ) + + +def encode_ml_event_data(custom_event_data, agent_run_id): + resource = create_resource() + ml_events = [] + for event in custom_event_data: + event_info, event_attrs = event + event_attrs.update( + { + "real_agent_id": agent_run_id, + "event.domain": "newrelic.ml_events", + "event.name": event_info["type"], + } + ) + ml_attrs = create_key_values_from_iterable(event_attrs) + unix_nano_timestamp = event_info["timestamp"] * 1e6 + ml_events.append( + { + "time_unix_nano": int(unix_nano_timestamp), + "attributes": ml_attrs, + } + ) + + return LogsData(resource_logs=[ResourceLogs(resource=resource, scope_logs=[ScopeLogs(log_records=ml_events)])]) diff --git a/newrelic/core/rules_engine.py b/newrelic/core/rules_engine.py index fccc5e5e1..62ecce3fe 100644 --- a/newrelic/core/rules_engine.py +++ b/newrelic/core/rules_engine.py @@ -22,6 +22,27 @@ class NormalizationRule(_NormalizationRule): + def __new__( + cls, + match_expression="", + replacement="", + ignore=False, + eval_order=0, + terminate_chain=False, + each_segment=False, + replace_all=False, + ): + return _NormalizationRule.__new__( + cls, + match_expression=match_expression, + replacement=replacement, + ignore=ignore, + eval_order=eval_order, + terminate_chain=terminate_chain, + each_segment=each_segment, + replace_all=replace_all, + ) + def __init__(self, *args, **kwargs): self.match_expression_re = re.compile(self.match_expression, re.IGNORECASE) diff --git a/newrelic/core/stats_engine.py b/newrelic/core/stats_engine.py index 203e3e796..ebebe7dbe 100644 --- a/newrelic/core/stats_engine.py +++ b/newrelic/core/stats_engine.py @@ -35,6 +35,7 @@ from newrelic.api.settings import STRIP_EXCEPTION_MESSAGE from newrelic.api.time_trace import get_linking_metadata from newrelic.common.encoding_utils import json_encode +from newrelic.common.metric_utils import create_metric_identity from newrelic.common.object_names import parse_exc_info from newrelic.common.streaming_utils import StreamBuffer from newrelic.core.attribute import ( @@ -61,7 +62,7 @@ "reset_synthetics_events", ), "span_event_data": ("reset_span_events",), - "custom_event_data": ("reset_custom_events",), + "custom_event_data": ("reset_custom_events", "reset_ml_events"), "error_event_data": ("reset_error_events",), "log_event_data": ("reset_log_events",), } @@ -180,6 +181,11 @@ def merge_custom_metric(self, value): self.merge_raw_time_metric(value) + def merge_dimensional_metric(self, value): + """Merge data value.""" + + self.merge_raw_time_metric(value) + class CountStats(TimeStats): def merge_stats(self, other): @@ -235,6 +241,99 @@ def reset_metric_stats(self): self.__stats_table = {} +class DimensionalMetrics(object): + + """Nested dictionary table for collecting a set of metrics broken down by tags.""" + + def __init__(self): + self.__stats_table = {} + + def __contains__(self, key): + if isinstance(key, tuple): + if not isinstance(key[1], frozenset): + # Convert tags dict to a frozen set for proper comparisons + name, tags = create_metric_identity(*key) + else: + name, tags = key + + # Check that both metric name and tags are already present. + stats_container = self.__stats_table.get(name) + return stats_container and tags in stats_container + else: + # Only look for metric name + return key in self.__stats_table + + def record_dimensional_metric(self, name, value, tags=None): + """Record a single value metric, merging the data with any data + from prior value metrics with the same name and tags. + """ + name, tags = create_metric_identity(name, tags) + + if isinstance(value, dict): + if len(value) == 1 and "count" in value: + new_stats = CountStats(call_count=value["count"]) + else: + new_stats = TimeStats(*c2t(**value)) + else: + new_stats = TimeStats(1, value, value, value, value, value**2) + + stats_container = self.__stats_table.get(name) + if stats_container is None: + # No existing metrics with this name. Set up new stats container. + self.__stats_table[name] = {tags: new_stats} + else: + # Existing metric container found. + stats = stats_container.get(tags) + if stats is None: + # No data points for this set of tags. Add new data. + stats_container[tags] = new_stats + else: + # Existing data points found, merge stats. + stats.merge_stats(new_stats) + + return (name, tags) + + def metrics(self): + """Returns an iterator over the set of value metrics. + The items returned are a dictionary of tags for each metric value. + Metric values are each a tuple consisting of the metric name and accumulated + stats for the metric. + """ + + return six.iteritems(self.__stats_table) + + def metrics_count(self): + """Returns a count of the number of unique metrics currently + recorded for apdex, time and value metrics. + """ + + return sum(len(metric) for metric in self.__stats_table.values()) + + def reset_metric_stats(self): + """Resets the accumulated statistics back to initial state for + metric data. + """ + self.__stats_table = {} + + def get(self, key, default=None): + return self.__stats_table.get(key, default) + + def __setitem__(self, key, value): + self.__stats_table[key] = value + + def __getitem__(self, key): + return self.__stats_table[key] + + def __str__(self): + return str(self.__stats_table) + + def __repr__(self): + return "%s(%s)" % (__class__.__name__, repr(self.__stats_table)) + + def items(self): + return self.metrics() + + class SlowSqlStats(list): def __init__(self): super(SlowSqlStats, self).__init__([0, 0, 0, 0, None]) @@ -433,9 +532,11 @@ class StatsEngine(object): def __init__(self): self.__settings = None self.__stats_table = {} + self.__dimensional_stats_table = DimensionalMetrics() self._transaction_events = SampledDataSet() self._error_events = SampledDataSet() self._custom_events = SampledDataSet() + self._ml_events = SampledDataSet() self._span_events = SampledDataSet() self._log_events = SampledDataSet() self._span_stream = None @@ -456,6 +557,10 @@ def settings(self): def stats_table(self): return self.__stats_table + @property + def dimensional_stats_table(self): + return self.__dimensional_stats_table + @property def transaction_events(self): return self._transaction_events @@ -464,6 +569,10 @@ def transaction_events(self): def custom_events(self): return self._custom_events + @property + def ml_events(self): + return self._ml_events + @property def span_events(self): return self._span_events @@ -494,7 +603,7 @@ def metrics_count(self): """ - return len(self.__stats_table) + return len(self.__stats_table) + self.__dimensional_stats_table.metrics_count() def record_apdex_metric(self, metric): """Record a single apdex metric, merging the data with any data @@ -716,7 +825,6 @@ def notice_error(self, error=None, attributes=None, expected=None, ignore=None, user_attributes = create_user_attributes(custom_attributes, settings.attribute_filter) - # Extract additional details about the exception as agent attributes agent_attributes = {} @@ -728,28 +836,37 @@ def notice_error(self, error=None, attributes=None, expected=None, ignore=None, error_group_name = None try: # Call callback to obtain error group name - error_group_name_raw = settings.error_collector.error_group_callback(value, { - "traceback": tb, - "error.class": exc, - "error.message": message_raw, - "error.expected": is_expected, - "custom_params": attributes, - # Transaction specific items should be set to None - "transactionName": None, - "response.status": None, - "request.method": None, - "request.uri": None, - }) + error_group_name_raw = settings.error_collector.error_group_callback( + value, + { + "traceback": tb, + "error.class": exc, + "error.message": message_raw, + "error.expected": is_expected, + "custom_params": attributes, + # Transaction specific items should be set to None + "transactionName": None, + "response.status": None, + "request.method": None, + "request.uri": None, + }, + ) if error_group_name_raw: _, error_group_name = process_user_attribute("error.group.name", error_group_name_raw) if error_group_name is None or not isinstance(error_group_name, six.string_types): - raise ValueError("Invalid attribute value for error.group.name. Expected string, got: %s" % repr(error_group_name_raw)) + raise ValueError( + "Invalid attribute value for error.group.name. Expected string, got: %s" + % repr(error_group_name_raw) + ) else: agent_attributes["error.group.name"] = error_group_name except Exception: - _logger.error("Encountered error when calling error group callback:\n%s", "".join(traceback.format_exception(*sys.exc_info()))) - + _logger.error( + "Encountered error when calling error group callback:\n%s", + "".join(traceback.format_exception(*sys.exc_info())), + ) + agent_attributes = create_agent_attributes(agent_attributes, settings.attribute_filter) # Record the exception details. @@ -774,7 +891,7 @@ def notice_error(self, error=None, attributes=None, expected=None, ignore=None, for attr in agent_attributes: if attr.destinations & DST_ERROR_COLLECTOR: attributes["agentAttributes"][attr.name] = attr.value - + error_details = TracedError( start_time=time.time(), path="Exception", message=message, type=fullname, parameters=attributes ) @@ -829,6 +946,15 @@ def record_custom_event(self, event): if settings.collect_custom_events and settings.custom_insights_events.enabled: self._custom_events.add(event) + def record_ml_event(self, event): + settings = self.__settings + + if not settings: + return + + if settings.ml_insights_events.enabled: + self._ml_events.add(event) + def record_custom_metric(self, name, value): """Record a single value metric, merging the data with any data from prior value metrics with the same name. @@ -865,6 +991,28 @@ def record_custom_metrics(self, metrics): for name, value in metrics: self.record_custom_metric(name, value) + def record_dimensional_metric(self, name, value, tags=None): + """Record a single value metric, merging the data with any data + from prior value metrics with the same name and tags. + """ + return self.__dimensional_stats_table.record_dimensional_metric(name, value, tags) + + def record_dimensional_metrics(self, metrics): + """Record the value metrics supplied by the iterable, merging + the data with any data from prior value metrics with the same + name. + + """ + + if not self.__settings: + return + + for metric in metrics: + name, value = metric[:2] + tags = metric[2] if len(metric) >= 3 else None + + self.record_dimensional_metric(name, value, tags) + def record_slow_sql_node(self, node): """Record a single sql metric, merging the data with any data from prior sql metrics for the same sql key. @@ -975,6 +1123,8 @@ def record_transaction(self, transaction): self.merge_custom_metrics(transaction.custom_metrics.metrics()) + self.merge_dimensional_metrics(transaction.dimensional_metrics.metrics()) + self.record_time_metrics(transaction.time_metrics(self)) # Capture any errors if error collection is enabled. @@ -1042,6 +1192,11 @@ def record_transaction(self, transaction): if settings.collect_custom_events and settings.custom_insights_events.enabled: self.custom_events.merge(transaction.custom_events) + # Merge in machine learning events + + if settings.ml_insights_events.enabled: + self.ml_events.merge(transaction.ml_events) + # Merge in span events if settings.distributed_tracing.enabled and settings.span_events.enabled and settings.collect_span_events: @@ -1129,7 +1284,11 @@ def metric_data(self, normalizer=None): if normalizer is not None: for key, value in six.iteritems(self.__stats_table): - key = (normalizer(key[0])[0], key[1]) + normalized_name, ignored = normalizer(key[0]) + if ignored: + continue + + key = (normalized_name, key[1]) stats = normalized_stats.get(key) if stats is None: normalized_stats[key] = copy.copy(value) @@ -1159,6 +1318,66 @@ def metric_data_count(self): return len(self.__stats_table) + def dimensional_metric_data(self, normalizer=None): + """Returns a list containing the low level metric data for + sending to the core application pertaining to the reporting + period. This consists of tuple pairs where first is dictionary + with name and scope keys with corresponding values, or integer + identifier if metric had an entry in dictionary mapping metric + (name, tags) as supplied from core application. The second is + the list of accumulated metric data, the list always being of + length 6. + + """ + + if not self.__settings: + return [] + + result = [] + normalized_stats = {} + + # Metric Renaming and Re-Aggregation. After applying the metric + # renaming rules, the metrics are re-aggregated to collapse the + # metrics with same names after the renaming. + + if self.__settings.debug.log_raw_metric_data: + _logger.info( + "Raw dimensional metric data for harvest of %r is %r.", + self.__settings.app_name, + list(self.__dimensional_stats_table.metrics()), + ) + + if normalizer is not None: + for key, value in self.__dimensional_stats_table.metrics(): + key = normalizer(key)[0] + stats = normalized_stats.get(key) + if stats is None: + normalized_stats[key] = copy.copy(value) + else: + stats.merge_stats(value) + else: + normalized_stats = self.__dimensional_stats_table + + if self.__settings.debug.log_normalized_metric_data: + _logger.info( + "Normalized metric data for harvest of %r is %r.", + self.__settings.app_name, + list(normalized_stats.metrics()), + ) + + for key, value in normalized_stats.items(): + result.append((key, value)) + + return result + + def dimensional_metric_data_count(self): + """Returns a count of the number of unique metrics.""" + + if not self.__settings: + return 0 + + return self.__dimensional_stats_table.metrics_count() + def error_data(self): """Returns a to a list containing any errors collected during the reporting period. @@ -1436,7 +1655,6 @@ def reset_stats(self, settings, reset_stream=False): """ self.__settings = settings - self.__stats_table = {} self.__sql_stats_table = {} self.__slow_transaction = None self.__slow_transaction_map = {} @@ -1444,9 +1662,11 @@ def reset_stats(self, settings, reset_stream=False): self.__transaction_errors = [] self.__synthetics_transactions = [] + self.reset_metric_stats() self.reset_transaction_events() self.reset_error_events() self.reset_custom_events() + self.reset_ml_events() self.reset_span_events() self.reset_log_events() self.reset_synthetics_events() @@ -1463,6 +1683,7 @@ def reset_metric_stats(self): """ self.__stats_table = {} + self.__dimensional_stats_table.reset_metric_stats() def reset_transaction_events(self): """Resets the accumulated statistics back to initial state for @@ -1489,6 +1710,12 @@ def reset_custom_events(self): else: self._custom_events = SampledDataSet() + def reset_ml_events(self): + if self.__settings is not None: + self._ml_events = SampledDataSet(self.__settings.event_harvest_config.harvest_limits.ml_event_data) + else: + self._ml_events = SampledDataSet() + def reset_span_events(self): if self.__settings is not None: self._span_events = SampledDataSet(self.__settings.event_harvest_config.harvest_limits.span_event_data) @@ -1622,6 +1849,7 @@ def merge(self, snapshot): self._merge_error_events(snapshot) self._merge_error_traces(snapshot) self._merge_custom_events(snapshot) + self._merge_ml_events(snapshot) self._merge_span_events(snapshot) self._merge_log_events(snapshot) self._merge_sql(snapshot) @@ -1647,6 +1875,7 @@ def rollback(self, snapshot): self._merge_synthetics_events(snapshot, rollback=True) self._merge_error_events(snapshot) self._merge_custom_events(snapshot, rollback=True) + self._merge_ml_events(snapshot, rollback=True) self._merge_span_events(snapshot, rollback=True) self._merge_log_events(snapshot, rollback=True) @@ -1716,6 +1945,12 @@ def _merge_custom_events(self, snapshot, rollback=False): return self._custom_events.merge(events) + def _merge_ml_events(self, snapshot, rollback=False): + events = snapshot.ml_events + if not events: + return + self._ml_events.merge(events) + def _merge_span_events(self, snapshot, rollback=False): events = snapshot.span_events if not events: @@ -1785,6 +2020,29 @@ def merge_custom_metrics(self, metrics): else: stats.merge_stats(other) + def merge_dimensional_metrics(self, metrics): + """ + Merges in a set of dimensional metrics. The metrics should be + provide as an iterable where each item is a tuple of the metric + key and the accumulated stats for the metric. The metric key should + also be a tuple, containing a name and attribute filtered frozenset of tags. + """ + + if not self.__settings: + return + + for key, other in metrics: + stats_container = self.__dimensional_stats_table.get(key) + if not stats_container: + self.__dimensional_stats_table[key] = other + else: + for tags, other_value in other.items(): + stats = stats_container.get(tags) + if not stats: + stats_container[tags] = other_value + else: + stats.merge_stats(other_value) + def _snapshot(self): copy = object.__new__(StatsEngineSnapshot) copy.__dict__.update(self.__dict__) @@ -1798,6 +2056,9 @@ def reset_transaction_events(self): def reset_custom_events(self): self._custom_events = None + def reset_ml_events(self): + self._ml_events = None + def reset_span_events(self): self._span_events = None diff --git a/newrelic/core/transaction_node.py b/newrelic/core/transaction_node.py index 0faae3790..d63d7f9b6 100644 --- a/newrelic/core/transaction_node.py +++ b/newrelic/core/transaction_node.py @@ -60,10 +60,12 @@ "errors", "slow_sql", "custom_events", + "ml_events", "log_events", "apdex_t", "suppress_apdex", "custom_metrics", + "dimensional_metrics", "guid", "cpu_time", "suppress_transaction_trace", diff --git a/newrelic/hooks/component_graphqlserver.py b/newrelic/hooks/component_graphqlserver.py index 29004c11f..ebc62a34d 100644 --- a/newrelic/hooks/component_graphqlserver.py +++ b/newrelic/hooks/component_graphqlserver.py @@ -1,19 +1,18 @@ -from newrelic.api.asgi_application import wrap_asgi_application from newrelic.api.error_trace import ErrorTrace from newrelic.api.graphql_trace import GraphQLOperationTrace from newrelic.api.transaction import current_transaction -from newrelic.api.transaction_name import TransactionNameWrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, + GRAPHQL_VERSION, + ignore_graphql_duplicate_exception, ) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception -def framework_details(): - import graphql_server - return ("GraphQLServer", getattr(graphql_server, "__version__", None)) +GRAPHQL_SERVER_VERSION = get_package_version("graphql-server") +graphql_server_major_version = int(GRAPHQL_SERVER_VERSION.split(".")[0]) + def bind_query(schema, params, *args, **kwargs): return getattr(params, "query", None) @@ -30,9 +29,8 @@ def wrap_get_response(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="GraphQLServer", version=GRAPHQL_SERVER_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -45,5 +43,8 @@ def wrap_get_response(wrapped, instance, args, kwargs): with ErrorTrace(ignore=ignore_graphql_duplicate_exception): return wrapped(*args, **kwargs) + def instrument_graphqlserver(module): - wrap_function_wrapper(module, "get_response", wrap_get_response) + if graphql_server_major_version <= 2: + return + wrap_function_wrapper(module, "get_response", wrap_get_response) diff --git a/newrelic/hooks/database_asyncpg.py b/newrelic/hooks/database_asyncpg.py index 0d03e9139..d6ca62ef3 100644 --- a/newrelic/hooks/database_asyncpg.py +++ b/newrelic/hooks/database_asyncpg.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from newrelic.api.database_trace import ( - DatabaseTrace, - enable_datastore_instance_feature, - register_database_client, -) +from newrelic.api.database_trace import DatabaseTrace, register_database_client from newrelic.api.datastore_trace import DatastoreTrace from newrelic.common.object_wrapper import ObjectProxy, wrap_function_wrapper @@ -43,7 +39,6 @@ def instance_info(cls, args, kwargs): quoting_style="single+dollar", instance_info=PostgresApi.instance_info, ) -enable_datastore_instance_feature(PostgresApi) class ProtocolProxy(ObjectProxy): @@ -94,9 +89,7 @@ async def query(self, query, *args, **kwargs): async def prepare(self, stmt_name, query, *args, **kwargs): with DatabaseTrace( - "PREPARE {stmt_name} FROM '{query}'".format( - stmt_name=stmt_name, query=query - ), + "PREPARE {stmt_name} FROM '{query}'".format(stmt_name=stmt_name, query=query), dbapi2_module=PostgresApi, connect_params=getattr(self, "_nr_connect_params", None), source=self.__wrapped__.prepare, @@ -131,9 +124,7 @@ def proxy_protocol(wrapped, instance, args, kwargs): def wrap_connect(wrapped, instance, args, kwargs): host = port = database_name = None if "addr" in kwargs: - host, port, database_name = PostgresApi._instance_info( - kwargs["addr"], None, kwargs.get("params") - ) + host, port, database_name = PostgresApi._instance_info(kwargs["addr"], None, kwargs.get("params")) with DatastoreTrace( PostgresApi._nr_database_product, diff --git a/newrelic/hooks/database_mysqldb.py b/newrelic/hooks/database_mysqldb.py index 31dd6bc19..c36d91d40 100644 --- a/newrelic/hooks/database_mysqldb.py +++ b/newrelic/hooks/database_mysqldb.py @@ -14,54 +14,69 @@ import os -from newrelic.api.database_trace import (enable_datastore_instance_feature, - DatabaseTrace, register_database_client) +from newrelic.api.database_trace import DatabaseTrace, register_database_client from newrelic.api.function_trace import FunctionTrace from newrelic.api.transaction import current_transaction from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_object +from newrelic.hooks.database_dbapi2 import ConnectionFactory as DBAPI2ConnectionFactory +from newrelic.hooks.database_dbapi2 import ConnectionWrapper as DBAPI2ConnectionWrapper -from newrelic.hooks.database_dbapi2 import (ConnectionWrapper as - DBAPI2ConnectionWrapper, ConnectionFactory as DBAPI2ConnectionFactory) class ConnectionWrapper(DBAPI2ConnectionWrapper): - def __enter__(self): transaction = current_transaction() name = callable_name(self.__wrapped__.__enter__) with FunctionTrace(name, source=self.__wrapped__.__enter__): - cursor = self.__wrapped__.__enter__() + cursor = self.__wrapped__.__enter__() # The __enter__() method of original connection object returns # a new cursor instance for use with 'as' assignment. We need # to wrap that in a cursor wrapper otherwise we will not track # any queries done via it. - return self.__cursor_wrapper__(cursor, self._nr_dbapi2_module, - self._nr_connect_params, None) + return self.__cursor_wrapper__(cursor, self._nr_dbapi2_module, self._nr_connect_params, None) def __exit__(self, exc, value, tb): transaction = current_transaction() name = callable_name(self.__wrapped__.__exit__) with FunctionTrace(name, source=self.__wrapped__.__exit__): if exc is None: - with DatabaseTrace('COMMIT', self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "COMMIT", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) else: - with DatabaseTrace('ROLLBACK', self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "ROLLBACK", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) + class ConnectionFactory(DBAPI2ConnectionFactory): __connection_wrapper__ = ConnectionWrapper + def instance_info(args, kwargs): - def _bind_params(host=None, user=None, passwd=None, db=None, port=None, - unix_socket=None, conv=None, connect_timeout=None, compress=None, - named_pipe=None, init_command=None, read_default_file=None, - read_default_group=None, *args, **kwargs): - return (host, port, db, unix_socket, - read_default_file, read_default_group) + def _bind_params( + host=None, + user=None, + passwd=None, + db=None, + port=None, + unix_socket=None, + conv=None, + connect_timeout=None, + compress=None, + named_pipe=None, + init_command=None, + read_default_file=None, + read_default_group=None, + *args, + **kwargs + ): + return (host, port, db, unix_socket, read_default_file, read_default_group) params = _bind_params(*args, **kwargs) host, port, db, unix_socket, read_default_file, read_default_group = params @@ -69,38 +84,38 @@ def _bind_params(host=None, user=None, passwd=None, db=None, port=None, port_path_or_id = None if read_default_file or read_default_group: - host = host or 'default' - port_path_or_id = 'unknown' + host = host or "default" + port_path_or_id = "unknown" elif not host: - host = 'localhost' + host = "localhost" - if host == 'localhost': + if host == "localhost": # precedence: explicit -> cnf (if used) -> env -> 'default' - port_path_or_id = (unix_socket or - port_path_or_id or - os.getenv('MYSQL_UNIX_PORT', 'default')) + port_path_or_id = unix_socket or port_path_or_id or os.getenv("MYSQL_UNIX_PORT", "default") elif explicit_host: # only reach here if host is explicitly passed in port = port and str(port) # precedence: explicit -> cnf (if used) -> env -> '3306' - port_path_or_id = (port or - port_path_or_id or - os.getenv('MYSQL_TCP_PORT', '3306')) + port_path_or_id = port or port_path_or_id or os.getenv("MYSQL_TCP_PORT", "3306") # There is no default database if omitted from the connect params # In this case, we should report unknown - db = db or 'unknown' + db = db or "unknown" return (host, port_path_or_id, db) -def instrument_mysqldb(module): - register_database_client(module, database_product='MySQL', - quoting_style='single+double', explain_query='explain', - explain_stmts=('select',), instance_info=instance_info) - enable_datastore_instance_feature(module) +def instrument_mysqldb(module): + register_database_client( + module, + database_product="MySQL", + quoting_style="single+double", + explain_query="explain", + explain_stmts=("select",), + instance_info=instance_info, + ) - wrap_object(module, 'connect', ConnectionFactory, (module,)) + wrap_object(module, "connect", ConnectionFactory, (module,)) # The connect() function is actually aliased with Connect() and # Connection, the later actually being the Connection type object. @@ -108,5 +123,5 @@ def instrument_mysqldb(module): # interferes with direct type usage. If people are using the # Connection object directly, they should really be using connect(). - if hasattr(module, 'Connect'): - wrap_object(module, 'Connect', ConnectionFactory, (module,)) + if hasattr(module, "Connect"): + wrap_object(module, "Connect", ConnectionFactory, (module,)) diff --git a/newrelic/hooks/database_psycopg2.py b/newrelic/hooks/database_psycopg2.py index 970909a33..bbed13184 100644 --- a/newrelic/hooks/database_psycopg2.py +++ b/newrelic/hooks/database_psycopg2.py @@ -15,17 +15,19 @@ import inspect import os -from newrelic.api.database_trace import (enable_datastore_instance_feature, - register_database_client, DatabaseTrace) +from newrelic.api.database_trace import DatabaseTrace, register_database_client from newrelic.api.function_trace import FunctionTrace from newrelic.api.transaction import current_transaction from newrelic.common.object_names import callable_name -from newrelic.common.object_wrapper import (wrap_object, ObjectProxy, - wrap_function_wrapper) - -from newrelic.hooks.database_dbapi2 import (ConnectionWrapper as - DBAPI2ConnectionWrapper, ConnectionFactory as DBAPI2ConnectionFactory, - CursorWrapper as DBAPI2CursorWrapper, DEFAULT) +from newrelic.common.object_wrapper import ( + ObjectProxy, + wrap_function_wrapper, + wrap_object, +) +from newrelic.hooks.database_dbapi2 import DEFAULT +from newrelic.hooks.database_dbapi2 import ConnectionFactory as DBAPI2ConnectionFactory +from newrelic.hooks.database_dbapi2 import ConnectionWrapper as DBAPI2ConnectionWrapper +from newrelic.hooks.database_dbapi2 import CursorWrapper as DBAPI2CursorWrapper try: from urllib import unquote @@ -43,33 +45,27 @@ # used. If the default connection and cursor are used without any unknown # arguments, we can safely drop all cursor parameters to generate explain # plans. Explain plans do not work with named cursors. -def _bind_connect( - dsn=None, connection_factory=None, cursor_factory=None, - *args, **kwargs): +def _bind_connect(dsn=None, connection_factory=None, cursor_factory=None, *args, **kwargs): return bool(connection_factory or cursor_factory) -def _bind_cursor( - name=None, cursor_factory=None, scrollable=None, - withhold=False, *args, **kwargs): +def _bind_cursor(name=None, cursor_factory=None, scrollable=None, withhold=False, *args, **kwargs): return bool(cursor_factory or args or kwargs) class CursorWrapper(DBAPI2CursorWrapper): - def execute(self, sql, parameters=DEFAULT, *args, **kwargs): - if hasattr(sql, 'as_string'): + if hasattr(sql, "as_string"): sql = sql.as_string(self) - return super(CursorWrapper, self).execute(sql, parameters, *args, - **kwargs) + return super(CursorWrapper, self).execute(sql, parameters, *args, **kwargs) def __enter__(self): self.__wrapped__.__enter__() return self def executemany(self, sql, seq_of_parameters): - if hasattr(sql, 'as_string'): + if hasattr(sql, "as_string"): sql = sql.as_string(self) return super(CursorWrapper, self).executemany(sql, seq_of_parameters) @@ -83,7 +79,7 @@ def __enter__(self): transaction = current_transaction() name = callable_name(self.__wrapped__.__enter__) with FunctionTrace(name, source=self.__wrapped__.__enter__): - self.__wrapped__.__enter__() + self.__wrapped__.__enter__() # Must return a reference to self as otherwise will be # returning the inner connection object. If 'as' is used @@ -98,19 +94,20 @@ def __exit__(self, exc, value, tb): name = callable_name(self.__wrapped__.__exit__) with FunctionTrace(name, source=self.__wrapped__.__exit__): if exc is None: - with DatabaseTrace('COMMIT', - self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "COMMIT", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) else: - with DatabaseTrace('ROLLBACK', - self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "ROLLBACK", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) # This connection wrapper does not save cursor parameters for explain plans. It # is only used for the default connection class. class ConnectionWrapper(ConnectionSaveParamsWrapper): - def cursor(self, *args, **kwargs): # If any unknown cursor params are detected or a cursor factory is # used, store params for explain plans later. @@ -119,9 +116,9 @@ def cursor(self, *args, **kwargs): else: cursor_params = None - return self.__cursor_wrapper__(self.__wrapped__.cursor( - *args, **kwargs), self._nr_dbapi2_module, - self._nr_connect_params, cursor_params) + return self.__cursor_wrapper__( + self.__wrapped__.cursor(*args, **kwargs), self._nr_dbapi2_module, self._nr_connect_params, cursor_params + ) class ConnectionFactory(DBAPI2ConnectionFactory): @@ -144,15 +141,13 @@ def instance_info(args, kwargs): def _parse_connect_params(args, kwargs): - def _bind_params(dsn=None, *args, **kwargs): return dsn dsn = _bind_params(*args, **kwargs) try: - if dsn and (dsn.startswith('postgres://') or - dsn.startswith('postgresql://')): + if dsn and (dsn.startswith("postgres://") or dsn.startswith("postgresql://")): # Parse dsn as URI # @@ -166,53 +161,52 @@ def _bind_params(dsn=None, *args, **kwargs): # ipv6 brackets [] are contained in the URI hostname # and should be removed - host = host and host.strip('[]') + host = host and host.strip("[]") port = parsed_uri.port db_name = parsed_uri.path - db_name = db_name and db_name.lstrip('/') + db_name = db_name and db_name.lstrip("/") db_name = db_name or None - query = parsed_uri.query or '' + query = parsed_uri.query or "" qp = dict(parse_qsl(query)) # Query parameters override hierarchical values in URI. - host = qp.get('host') or host or None - hostaddr = qp.get('hostaddr') - port = qp.get('port') or port - db_name = qp.get('dbname') or db_name + host = qp.get("host") or host or None + hostaddr = qp.get("hostaddr") + port = qp.get("port") or port + db_name = qp.get("dbname") or db_name elif dsn: # Parse dsn as a key-value connection string - kv = dict([pair.split('=', 2) for pair in dsn.split()]) - host = kv.get('host') - hostaddr = kv.get('hostaddr') - port = kv.get('port') - db_name = kv.get('dbname') + kv = dict([pair.split("=", 2) for pair in dsn.split()]) + host = kv.get("host") + hostaddr = kv.get("hostaddr") + port = kv.get("port") + db_name = kv.get("dbname") else: # No dsn, so get the instance info from keyword arguments. - host = kwargs.get('host') - hostaddr = kwargs.get('hostaddr') - port = kwargs.get('port') - db_name = kwargs.get('database') + host = kwargs.get("host") + hostaddr = kwargs.get("hostaddr") + port = kwargs.get("port") + db_name = kwargs.get("database") # Ensure non-None values are strings. - (host, hostaddr, port, db_name) = [str(s) if s is not None else s - for s in (host, hostaddr, port, db_name)] + (host, hostaddr, port, db_name) = [str(s) if s is not None else s for s in (host, hostaddr, port, db_name)] except Exception: - host = 'unknown' - hostaddr = 'unknown' - port = 'unknown' - db_name = 'unknown' + host = "unknown" + hostaddr = "unknown" + port = "unknown" + db_name = "unknown" return (host, hostaddr, port, db_name) @@ -221,37 +215,39 @@ def _add_defaults(parsed_host, parsed_hostaddr, parsed_port, parsed_database): # ENV variables set the default values - parsed_host = parsed_host or os.environ.get('PGHOST') - parsed_hostaddr = parsed_hostaddr or os.environ.get('PGHOSTADDR') - parsed_port = parsed_port or os.environ.get('PGPORT') - database = parsed_database or os.environ.get('PGDATABASE') or 'default' + parsed_host = parsed_host or os.environ.get("PGHOST") + parsed_hostaddr = parsed_hostaddr or os.environ.get("PGHOSTADDR") + parsed_port = parsed_port or os.environ.get("PGPORT") + database = parsed_database or os.environ.get("PGDATABASE") or "default" # If hostaddr is present, we use that, since host is used for auth only. parsed_host = parsed_hostaddr or parsed_host if parsed_host is None: - host = 'localhost' - port = 'default' - elif parsed_host.startswith('/'): - host = 'localhost' - port = '%s/.s.PGSQL.%s' % (parsed_host, parsed_port or '5432') + host = "localhost" + port = "default" + elif parsed_host.startswith("/"): + host = "localhost" + port = "%s/.s.PGSQL.%s" % (parsed_host, parsed_port or "5432") else: host = parsed_host - port = parsed_port or '5432' + port = parsed_port or "5432" return (host, port, database) def instrument_psycopg2(module): - register_database_client(module, database_product='Postgres', - quoting_style='single+dollar', explain_query='explain', - explain_stmts=('select', 'insert', 'update', 'delete'), - instance_info=instance_info) - - enable_datastore_instance_feature(module) + register_database_client( + module, + database_product="Postgres", + quoting_style="single+dollar", + explain_query="explain", + explain_stmts=("select", "insert", "update", "delete"), + instance_info=instance_info, + ) - wrap_object(module, 'connect', ConnectionFactory, (module,)) + wrap_object(module, "connect", ConnectionFactory, (module,)) def wrapper_psycopg2_register_type(wrapped, instance, args, kwargs): @@ -277,7 +273,7 @@ def _bind_params(context, *args, **kwargs): # Unwrap the context for string conversion since psycopg2 uses duck typing # and a TypeError will be raised if a wrapper is used. - if hasattr(context, '__wrapped__'): + if hasattr(context, "__wrapped__"): context = context.__wrapped__ return wrapped(context, *_args, **_kwargs) @@ -289,36 +285,31 @@ def _bind_params(context, *args, **kwargs): # In doing that we need to make sure it has not already been monkey # patched by checking to see if it is already an ObjectProxy. def instrument_psycopg2__psycopg2(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2_extensions(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2__json(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2__range(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2_sql(module): - if (hasattr(module, 'Composable') and - hasattr(module.Composable, 'as_string')): + if hasattr(module, "Composable") and hasattr(module.Composable, "as_string"): for name, cls in inspect.getmembers(module): if not inspect.isclass(cls): continue @@ -326,5 +317,4 @@ def instrument_psycopg2_sql(module): if not issubclass(cls, module.Composable): continue - wrap_function_wrapper(module, name + '.as_string', - wrapper_psycopg2_as_string) + wrap_function_wrapper(module, name + ".as_string", wrapper_psycopg2_as_string) diff --git a/newrelic/hooks/datastore_aioredis.py b/newrelic/hooks/datastore_aioredis.py index 9bd5b17b0..e27f8d7a9 100644 --- a/newrelic/hooks/datastore_aioredis.py +++ b/newrelic/hooks/datastore_aioredis.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from newrelic.api.datastore_trace import DatastoreTrace from newrelic.api.time_trace import current_trace from newrelic.api.transaction import current_transaction -from newrelic.common.object_wrapper import wrap_function_wrapper, function_wrapper +from newrelic.common.object_wrapper import function_wrapper, wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version_tuple from newrelic.hooks.datastore_redis import ( _redis_client_methods, _redis_multipart_commands, _redis_operation_re, ) -from newrelic.common.package_version_utils import get_package_version_tuple + +AIOREDIS_VERSION = get_package_version_tuple("aioredis") def _conn_attrs_to_dict(connection): @@ -39,14 +40,13 @@ def _conn_attrs_to_dict(connection): def _instance_info(kwargs): host = kwargs.get("host") or "localhost" - port_path_or_id = str(kwargs.get("port") or kwargs.get("path", 6379)) + port_path_or_id = str(kwargs.get("path") or kwargs.get("port", 6379)) db = str(kwargs.get("db") or 0) return (host, port_path_or_id, db) def _wrap_AioRedis_method_wrapper(module, instance_class_name, operation): - @function_wrapper async def _nr_wrapper_AioRedis_async_method_(wrapped, instance, args, kwargs): transaction = current_transaction() @@ -55,32 +55,35 @@ async def _nr_wrapper_AioRedis_async_method_(wrapped, instance, args, kwargs): with DatastoreTrace(product="Redis", target=None, operation=operation): return await wrapped(*args, **kwargs) - + def _nr_wrapper_AioRedis_method_(wrapped, instance, args, kwargs): # Check for transaction and return early if found. # Method will return synchronously without executing, # it will be added to the command stack and run later. - aioredis_version = get_package_version_tuple("aioredis") - if aioredis_version and aioredis_version < (2,): + + # This conditional is for versions of aioredis that are outside + # New Relic's supportability window but will still work. New + # Relic does not provide testing/support for this. In order to + # keep functionality without affecting coverage metrics, this + # segment is excluded from coverage analysis. + if AIOREDIS_VERSION and AIOREDIS_VERSION < (2,): # pragma: no cover # AioRedis v1 uses a RedisBuffer instead of a real connection for queueing up pipeline commands from aioredis.commands.transaction import _RedisBuffer + if isinstance(instance._pool_or_conn, _RedisBuffer): # Method will return synchronously without executing, # it will be added to the command stack and run later. return wrapped(*args, **kwargs) else: # AioRedis v2 uses a Pipeline object for a client and internally queues up pipeline commands - if aioredis_version: + if AIOREDIS_VERSION: from aioredis.client import Pipeline - else: - from redis.asyncio.client import Pipeline if isinstance(instance, Pipeline): return wrapped(*args, **kwargs) # Method should be run when awaited, therefore we wrap in an async wrapper. return _nr_wrapper_AioRedis_async_method_(wrapped)(*args, **kwargs) - name = "%s.%s" % (instance_class_name, operation) wrap_function_wrapper(module, name, _nr_wrapper_AioRedis_method_) @@ -109,7 +112,9 @@ async def wrap_Connection_send_command(wrapped, instance, args, kwargs): # If it's not a multi part command, there's no need to trace it, so # we can return early. - if operation.split()[0] not in _redis_multipart_commands: # Set the datastore info on the DatastoreTrace containing this function call. + if ( + operation.split()[0] not in _redis_multipart_commands + ): # Set the datastore info on the DatastoreTrace containing this function call. trace = current_trace() # Find DatastoreTrace no matter how many other traces are inbetween @@ -136,7 +141,12 @@ async def wrap_Connection_send_command(wrapped, instance, args, kwargs): return await wrapped(*args, **kwargs) -def wrap_RedisConnection_execute(wrapped, instance, args, kwargs): +# This wrapper is for versions of aioredis that are outside +# New Relic's supportability window but will still work. New +# Relic does not provide testing/support for this. In order to +# keep functionality without affecting coverage metrics, this +# segment is excluded from coverage analysis. +def wrap_RedisConnection_execute(wrapped, instance, args, kwargs): # pragma: no cover # RedisConnection in aioredis v1 returns a future instead of using coroutines transaction = current_transaction() if not transaction: @@ -161,7 +171,9 @@ def wrap_RedisConnection_execute(wrapped, instance, args, kwargs): # If it's not a multi part command, there's no need to trace it, so # we can return early. - if operation.split()[0] not in _redis_multipart_commands: # Set the datastore info on the DatastoreTrace containing this function call. + if ( + operation.split()[0] not in _redis_multipart_commands + ): # Set the datastore info on the DatastoreTrace containing this function call. trace = current_trace() # Find DatastoreTrace no matter how many other traces are inbetween @@ -202,6 +214,11 @@ def instrument_aioredis_connection(module): if hasattr(module.Connection, "send_command"): wrap_function_wrapper(module, "Connection.send_command", wrap_Connection_send_command) - if hasattr(module, "RedisConnection"): + # This conditional is for versions of aioredis that are outside + # New Relic's supportability window but will still work. New + # Relic does not provide testing/support for this. In order to + # keep functionality without affecting coverage metrics, this + # segment is excluded from coverage analysis. + if hasattr(module, "RedisConnection"): # pragma: no cover if hasattr(module.RedisConnection, "execute"): wrap_function_wrapper(module, "RedisConnection.execute", wrap_RedisConnection_execute) diff --git a/newrelic/hooks/datastore_firestore.py b/newrelic/hooks/datastore_firestore.py new file mode 100644 index 000000000..6d3196a7c --- /dev/null +++ b/newrelic/hooks/datastore_firestore.py @@ -0,0 +1,473 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.datastore_trace import wrap_datastore_trace +from newrelic.api.function_trace import wrap_function_trace +from newrelic.common.async_wrapper import generator_wrapper, async_generator_wrapper + + +def _conn_str_to_host(getter): + """Safely transform a getter that can retrieve a connection string into the resulting host.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[0] + except Exception: + return None + + return closure + + +def _conn_str_to_port(getter): + """Safely transform a getter that can retrieve a connection string into the resulting port.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[1] + except Exception: + return None + + return closure + + +# Default Target ID and Instance Info +_get_object_id = lambda obj, *args, **kwargs: getattr(obj, "id", None) +_get_client_database_string = lambda obj, *args, **kwargs: getattr( + getattr(obj, "_client", None), "_database_string", None +) +_get_client_target = lambda obj, *args, **kwargs: obj._client._target +_get_client_target_host = _conn_str_to_host(_get_client_target) +_get_client_target_port = _conn_str_to_port(_get_client_target) + +# Client Instance Info +_get_database_string = lambda obj, *args, **kwargs: getattr(obj, "_database_string", None) +_get_target = lambda obj, *args, **kwargs: obj._target +_get_target_host = _conn_str_to_host(_get_target) +_get_target_port = _conn_str_to_port(_get_target) + +# Query Target ID +_get_parent_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_parent", None), "id", None) + +# AggregationQuery Target ID +_get_collection_ref_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_collection_ref", None), "id", None) + + +def instrument_google_cloud_firestore_v1_base_client(module): + rollup = ("Datastore/all", "Datastore/Firestore/all") + wrap_function_trace( + module, "BaseClient.__init__", name="%s:BaseClient.__init__" % module.__name__, terminal=True, rollup=rollup + ) + + +def instrument_google_cloud_firestore_v1_client(module): + if hasattr(module, "Client"): + class_ = module.Client + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Client.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_client(module): + if hasattr(module, "AsyncClient"): + class_ = module.AsyncClient + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncClient.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_collection(module): + if hasattr(module, "CollectionReference"): + class_ = module.CollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_collection(module): + if hasattr(module, "AsyncCollectionReference"): + class_ = module.AsyncCollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + operation=method, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_document(module): + if hasattr(module, "DocumentReference"): + class_ = module.DocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_document(module): + if hasattr(module, "AsyncDocumentReference"): + class_ = module.AsyncDocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_query(module): + if hasattr(module, "Query"): + class_ = module.Query + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + if hasattr(module, "CollectionGroup"): + class_ = module.CollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_query(module): + if hasattr(module, "AsyncQuery"): + class_ = module.AsyncQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + if hasattr(module, "AsyncCollectionGroup"): + class_ = module.AsyncCollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_aggregation(module): + if hasattr(module, "AggregationQuery"): + class_ = module.AggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_aggregation(module): + if hasattr(module, "AsyncAggregationQuery"): + class_ = module.AsyncAggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_batch(module): + if hasattr(module, "WriteBatch"): + class_ = module.WriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "WriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_batch(module): + if hasattr(module, "AsyncWriteBatch"): + class_ = module.AsyncWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_bulk_batch(module): + if hasattr(module, "BulkWriteBatch"): + class_ = module.BulkWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "BulkWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_transaction(module): + if hasattr(module, "Transaction"): + class_ = module.Transaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "Transaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_transaction(module): + if hasattr(module, "AsyncTransaction"): + class_ = module.AsyncTransaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "AsyncTransaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) diff --git a/newrelic/hooks/datastore_redis.py b/newrelic/hooks/datastore_redis.py index b32c848b3..bbc517586 100644 --- a/newrelic/hooks/datastore_redis.py +++ b/newrelic/hooks/datastore_redis.py @@ -14,14 +14,71 @@ import re -from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.datastore_trace import DatastoreTrace, DatastoreTraceWrapper, wrap_datastore_trace +from newrelic.api.time_trace import current_trace from newrelic.api.transaction import current_transaction from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.async_wrapper import coroutine_wrapper, async_generator_wrapper, generator_wrapper -_redis_client_methods = { +_redis_client_sync_methods = { + "acl_dryrun", + "auth", + "bgrewriteaof", + "bitfield", + "blmpop", + "bzmpop", + "client", + "command", + "command_docs", + "command_getkeysandflags", + "command_info", + "debug_segfault", + "expiretime", + "failover", + "hello", + "latency_doctor", + "latency_graph", + "latency_histogram", + "lcs", + "lpop", + "lpos", + "memory_doctor", + "memory_help", + "monitor", + "pexpiretime", + "psetex", + "psync", + "pubsub", + "renamenx", + "rpop", + "script_debug", + "sentinel_ckquorum", + "sentinel_failover", + "sentinel_flushconfig", + "sentinel_get_master_addr_by_name", + "sentinel_master", + "sentinel_masters", + "sentinel_monitor", + "sentinel_remove", + "sentinel_reset", + "sentinel_sentinels", + "sentinel_set", + "sentinel_slaves", + "shutdown", + "sort", + "sort_ro", + "spop", + "srandmember", + "unwatch", + "watch", + "zlexcount", + "zrevrangebyscore", +} + + +_redis_client_async_methods = { "acl_cat", "acl_deluser", - "acl_dryrun", "acl_genpass", "acl_getuser", "acl_help", @@ -50,11 +107,8 @@ "arrlen", "arrpop", "arrtrim", - "auth", - "bgrewriteaof", "bgsave", "bitcount", - "bitfield", "bitfield_ro", "bitop_and", "bitop_not", @@ -63,13 +117,11 @@ "bitop", "bitpos", "blmove", - "blmpop", "blpop", "brpop", "brpoplpush", "byrank", "byrevrank", - "bzmpop", "bzpopmax", "bzpopmin", "card", @@ -85,12 +137,12 @@ "client_no_evict", "client_pause", "client_reply", + "client_setinfo", "client_setname", "client_tracking", "client_trackinginfo", "client_unblock", "client_unpause", - "client", "cluster_add_slots", "cluster_addslots", "cluster_count_failure_report", @@ -117,10 +169,7 @@ "cluster_slots", "cluster", "command_count", - "command_docs", "command_getkeys", - "command_getkeysandflags", - "command_info", "command_list", "command", "commit", @@ -136,7 +185,6 @@ "createrule", "dbsize", "debug_object", - "debug_segfault", "debug_sleep", "debug", "decr", @@ -159,10 +207,8 @@ "exists", "expire", "expireat", - "expiretime", "explain_cli", "explain", - "failover", "fcall_ro", "fcall", "flushall", @@ -176,6 +222,7 @@ "function_load", "function_restore", "function_stats", + "gears_refresh_cluster", "geoadd", "geodist", "geohash", @@ -191,7 +238,6 @@ "getrange", "getset", "hdel", - "hello", "hexists", "hget", "hgetall", @@ -203,7 +249,7 @@ "hmset_dict", "hmset", "hrandfield", - "hscan_inter", + "hscan_iter", "hscan", "hset", "hsetnx", @@ -219,13 +265,9 @@ "insertnx", "keys", "lastsave", - "latency_doctor", - "latency_graph", - "latency_histogram", "latency_history", "latency_latest", "latency_reset", - "lcs", "lindex", "linsert", "list", @@ -234,8 +276,6 @@ "lmpop", "loadchunk", "lolwut", - "lpop", - "lpos", "lpush", "lpushx", "lrange", @@ -244,8 +284,6 @@ "ltrim", "madd", "max", - "memory_doctor", - "memory_help", "memory_malloc_stats", "memory_purge", "memory_stats", @@ -260,7 +298,6 @@ "module_load", "module_loadex", "module_unload", - "monitor", "move", "mrange", "mrevrange", @@ -276,21 +313,19 @@ "persist", "pexpire", "pexpireat", - "pexpiretime", "pfadd", "pfcount", "pfmerge", "ping", "profile", - "psetex", "psubscribe", - "psync", "pttl", "publish", "pubsub_channels", "pubsub_numpat", "pubsub_numsub", - "pubsub", + "pubsub_shardchannels", + "pubsub_shardnumsub", "punsubscribe", "quantile", "query", @@ -302,7 +337,6 @@ "readonly", "readwrite", "rename", - "renamenx", "replicaof", "reserve", "reset", @@ -311,7 +345,6 @@ "revrange", "revrank", "role", - "rpop", "rpoplpush", "rpush", "rpushx", @@ -321,7 +354,6 @@ "scan", "scandump", "scard", - "script_debug", "script_exists", "script_flush", "script_kill", @@ -330,24 +362,11 @@ "sdiffstore", "search", "select", - "sentinel_ckquorum", - "sentinel_failover", - "sentinel_flushconfig", - "sentinel_get_master_addr_by_name", - "sentinel_master", - "sentinel_masters", - "sentinel_monitor", - "sentinel_remove", - "sentinel_reset", - "sentinel_sentinels", - "sentinel_set", - "sentinel_slaves", "set", "setbit", "setex", "setnx", "setrange", - "shutdown", "sinter", "sintercard", "sinterstore", @@ -360,11 +379,8 @@ "smembers", "smismember", "smove", - "sort_ro", - "sort", "spellcheck", - "spop", - "srandmember", + "spublish", "srem", "sscan_iter", "sscan", @@ -384,6 +400,11 @@ "syndump", "synupdate", "tagvals", + "tfcall_async", + "tfcall", + "tfunction_delete", + "tfunction_list", + "tfunction_load", "time", "toggle", "touch", @@ -392,9 +413,8 @@ "type", "unlink", "unsubscribe", - "unwatch", "wait", - "watch", + "waitaof", "xack", "xadd", "xautoclaim", @@ -430,7 +450,6 @@ "zinter", "zintercard", "zinterstore", - "zlexcount", "zmpop", "zmscore", "zpopmax", @@ -447,7 +466,6 @@ "zremrangebyscore", "zrevrange", "zrevrangebylex", - "zrevrangebyscore", "zrevrank", "zscan_iter", "zscan", @@ -456,6 +474,15 @@ "zunionstore", } +_redis_client_gen_methods = { + "scan_iter", + "hscan_iter", + "sscan_iter", + "zscan_iter", +} + +_redis_client_methods = _redis_client_sync_methods.union(_redis_client_async_methods) + _redis_multipart_commands = set(["client", "cluster", "command", "config", "debug", "sentinel", "slowlog", "script"]) _redis_operation_re = re.compile(r"[-\s]+") @@ -479,28 +506,85 @@ def _instance_info(kwargs): def _wrap_Redis_method_wrapper_(module, instance_class_name, operation): - def _nr_wrapper_Redis_method_(wrapped, instance, args, kwargs): - transaction = current_transaction() + name = "%s.%s" % (instance_class_name, operation) + if operation in _redis_client_gen_methods: + async_wrapper = generator_wrapper + else: + async_wrapper = None + + wrap_datastore_trace(module, name, product="Redis", target=None, operation=operation, async_wrapper=async_wrapper) - if transaction is None: + +def _wrap_asyncio_Redis_method_wrapper(module, instance_class_name, operation): + def _nr_wrapper_asyncio_Redis_method_(wrapped, instance, args, kwargs): + from redis.asyncio.client import Pipeline + + if isinstance(instance, Pipeline): return wrapped(*args, **kwargs) - dt = DatastoreTrace(product="Redis", target=None, operation=operation, source=wrapped) + # Method should be run when awaited or iterated, therefore we wrap in an async wrapper. + return DatastoreTraceWrapper(wrapped, product="Redis", target=None, operation=operation, async_wrapper=async_wrapper)(*args, **kwargs) + + name = "%s.%s" % (instance_class_name, operation) + if operation in _redis_client_gen_methods: + async_wrapper = async_generator_wrapper + else: + async_wrapper = coroutine_wrapper + + wrap_function_wrapper(module, name, _nr_wrapper_asyncio_Redis_method_) - transaction._nr_datastore_instance_info = (None, None, None) - with dt: - result = wrapped(*args, **kwargs) +async def wrap_async_Connection_send_command(wrapped, instance, args, kwargs): + transaction = current_transaction() + if not transaction: + return await wrapped(*args, **kwargs) - host, port_path_or_id, db = transaction._nr_datastore_instance_info - dt.host = host - dt.port_path_or_id = port_path_or_id - dt.database_name = db + host, port_path_or_id, db = (None, None, None) - return result + try: + dt = transaction.settings.datastore_tracer + if dt.instance_reporting.enabled or dt.database_name_reporting.enabled: + conn_kwargs = _conn_attrs_to_dict(instance) + host, port_path_or_id, db = _instance_info(conn_kwargs) + except Exception: + pass - name = "%s.%s" % (instance_class_name, operation) - wrap_function_wrapper(module, name, _nr_wrapper_Redis_method_) + # Older Redis clients would when sending multi part commands pass + # them in as separate arguments to send_command(). Need to therefore + # detect those and grab the next argument from the set of arguments. + + operation = args[0].strip().lower() + + # If it's not a multi part command, there's no need to trace it, so + # we can return early. + + if ( + operation.split()[0] not in _redis_multipart_commands + ): # Set the datastore info on the DatastoreTrace containing this function call. + trace = current_trace() + + # Find DatastoreTrace no matter how many other traces are inbetween + while trace is not None and not isinstance(trace, DatastoreTrace): + trace = getattr(trace, "parent", None) + + if trace is not None: + trace.host = host + trace.port_path_or_id = port_path_or_id + trace.database_name = db + + return await wrapped(*args, **kwargs) + + # Convert multi args to single arg string + + if operation in _redis_multipart_commands and len(args) > 1: + operation = "%s %s" % (operation, args[1].strip().lower()) + + operation = _redis_operation_re.sub("_", operation) + + with DatastoreTrace( + product="Redis", target=None, operation=operation, host=host, port_path_or_id=port_path_or_id, database_name=db + ): + return await wrapped(*args, **kwargs) def _nr_Connection_send_command_wrapper_(wrapped, instance, args, kwargs): @@ -519,7 +603,15 @@ def _nr_Connection_send_command_wrapper_(wrapped, instance, args, kwargs): except: pass - transaction._nr_datastore_instance_info = (host, port_path_or_id, db) + # Find DatastoreTrace no matter how many other traces are inbetween + trace = current_trace() + while trace is not None and not isinstance(trace, DatastoreTrace): + trace = getattr(trace, "parent", None) + + if trace is not None: + trace.host = host + trace.port_path_or_id = port_path_or_id + trace.database_name = db # Older Redis clients would when sending multi part commands pass # them in as separate arguments to send_command(). Need to therefore @@ -564,6 +656,13 @@ def instrument_redis_client(module): _wrap_Redis_method_wrapper_(module, "Redis", name) +def instrument_asyncio_redis_client(module): + if hasattr(module, "Redis"): + class_ = getattr(module, "Redis") + for operation in _redis_client_async_methods: + if hasattr(class_, operation): + _wrap_asyncio_Redis_method_wrapper(module, "Redis", operation) + def instrument_redis_commands_core(module): _instrument_redis_commands_module(module, "CoreCommands") @@ -596,6 +695,10 @@ def instrument_redis_commands_bf_commands(module): _instrument_redis_commands_module(module, "TOPKCommands") +def instrument_redis_commands_cluster(module): + _instrument_redis_commands_module(module, "RedisClusterCommands") + + def _instrument_redis_commands_module(module, class_name): for name in _redis_client_methods: if hasattr(module, class_name): @@ -605,4 +708,12 @@ def _instrument_redis_commands_module(module, class_name): def instrument_redis_connection(module): - wrap_function_wrapper(module, "Connection.send_command", _nr_Connection_send_command_wrapper_) + if hasattr(module, "Connection"): + if hasattr(module.Connection, "send_command"): + wrap_function_wrapper(module, "Connection.send_command", _nr_Connection_send_command_wrapper_) + + +def instrument_asyncio_redis_connection(module): + if hasattr(module, "Connection"): + if hasattr(module.Connection, "send_command"): + wrap_function_wrapper(module, "Connection.send_command", wrap_async_Connection_send_command) diff --git a/newrelic/hooks/external_botocore.py b/newrelic/hooks/external_botocore.py index 7d49fbd03..2f2b8a113 100644 --- a/newrelic/hooks/external_botocore.py +++ b/newrelic/hooks/external_botocore.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from newrelic.api.message_trace import message_trace from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import ExternalTrace +from newrelic.api.message_trace import message_trace from newrelic.common.object_wrapper import wrap_function_wrapper def extract_sqs(*args, **kwargs): - queue_value = kwargs.get('QueueUrl', 'Unknown') - return queue_value.rsplit('/', 1)[-1] + queue_value = kwargs.get("QueueUrl", "Unknown") + return queue_value.rsplit("/", 1)[-1] def extract(argument_names, default=None): @@ -41,42 +41,27 @@ def extractor_string(*args, **kwargs): CUSTOM_TRACE_POINTS = { - ('sns', 'publish'): message_trace( - 'SNS', 'Produce', 'Topic', - extract(('TopicArn', 'TargetArn'), 'PhoneNumber')), - ('dynamodb', 'put_item'): datastore_trace( - 'DynamoDB', extract('TableName'), 'put_item'), - ('dynamodb', 'get_item'): datastore_trace( - 'DynamoDB', extract('TableName'), 'get_item'), - ('dynamodb', 'update_item'): datastore_trace( - 'DynamoDB', extract('TableName'), 'update_item'), - ('dynamodb', 'delete_item'): datastore_trace( - 'DynamoDB', extract('TableName'), 'delete_item'), - ('dynamodb', 'create_table'): datastore_trace( - 'DynamoDB', extract('TableName'), 'create_table'), - ('dynamodb', 'delete_table'): datastore_trace( - 'DynamoDB', extract('TableName'), 'delete_table'), - ('dynamodb', 'query'): datastore_trace( - 'DynamoDB', extract('TableName'), 'query'), - ('dynamodb', 'scan'): datastore_trace( - 'DynamoDB', extract('TableName'), 'scan'), - ('sqs', 'send_message'): message_trace( - 'SQS', 'Produce', 'Queue', extract_sqs), - ('sqs', 'send_message_batch'): message_trace( - 'SQS', 'Produce', 'Queue', extract_sqs), - ('sqs', 'receive_message'): message_trace( - 'SQS', 'Consume', 'Queue', extract_sqs), + ("sns", "publish"): message_trace("SNS", "Produce", "Topic", extract(("TopicArn", "TargetArn"), "PhoneNumber")), + ("dynamodb", "put_item"): datastore_trace("DynamoDB", extract("TableName"), "put_item"), + ("dynamodb", "get_item"): datastore_trace("DynamoDB", extract("TableName"), "get_item"), + ("dynamodb", "update_item"): datastore_trace("DynamoDB", extract("TableName"), "update_item"), + ("dynamodb", "delete_item"): datastore_trace("DynamoDB", extract("TableName"), "delete_item"), + ("dynamodb", "create_table"): datastore_trace("DynamoDB", extract("TableName"), "create_table"), + ("dynamodb", "delete_table"): datastore_trace("DynamoDB", extract("TableName"), "delete_table"), + ("dynamodb", "query"): datastore_trace("DynamoDB", extract("TableName"), "query"), + ("dynamodb", "scan"): datastore_trace("DynamoDB", extract("TableName"), "scan"), + ("sqs", "send_message"): message_trace("SQS", "Produce", "Queue", extract_sqs), + ("sqs", "send_message_batch"): message_trace("SQS", "Produce", "Queue", extract_sqs), + ("sqs", "receive_message"): message_trace("SQS", "Consume", "Queue", extract_sqs), } -def bind__create_api_method(py_operation_name, operation_name, service_model, - *args, **kwargs): +def bind__create_api_method(py_operation_name, operation_name, service_model, *args, **kwargs): return (py_operation_name, service_model) def _nr_clientcreator__create_api_method_(wrapped, instance, args, kwargs): - (py_operation_name, service_model) = \ - bind__create_api_method(*args, **kwargs) + (py_operation_name, service_model) = bind__create_api_method(*args, **kwargs) service_name = service_model.service_name.lower() tracer = CUSTOM_TRACE_POINTS.get((service_name, py_operation_name)) @@ -95,30 +80,27 @@ def _bind_make_request_params(operation_model, request_dict, *args, **kwargs): def _nr_endpoint_make_request_(wrapped, instance, args, kwargs): operation_model, request_dict = _bind_make_request_params(*args, **kwargs) - url = request_dict.get('url', '') - method = request_dict.get('method', None) - - with ExternalTrace(library='botocore', url=url, method=method, source=wrapped) as trace: + url = request_dict.get("url", "") + method = request_dict.get("method", None) + with ExternalTrace(library="botocore", url=url, method=method, source=wrapped) as trace: try: - trace._add_agent_attribute('aws.operation', operation_model.name) + trace._add_agent_attribute("aws.operation", operation_model.name) except: pass result = wrapped(*args, **kwargs) try: - request_id = result[1]['ResponseMetadata']['RequestId'] - trace._add_agent_attribute('aws.requestId', request_id) + request_id = result[1]["ResponseMetadata"]["RequestId"] + trace._add_agent_attribute("aws.requestId", request_id) except: pass return result def instrument_botocore_endpoint(module): - wrap_function_wrapper(module, 'Endpoint.make_request', - _nr_endpoint_make_request_) + wrap_function_wrapper(module, "Endpoint.make_request", _nr_endpoint_make_request_) def instrument_botocore_client(module): - wrap_function_wrapper(module, 'ClientCreator._create_api_method', - _nr_clientcreator__create_api_method_) + wrap_function_wrapper(module, "ClientCreator._create_api_method", _nr_clientcreator__create_api_method_) diff --git a/newrelic/hooks/framework_ariadne.py b/newrelic/hooks/framework_ariadne.py index 498c662c4..4927abe0b 100644 --- a/newrelic/hooks/framework_ariadne.py +++ b/newrelic/hooks/framework_ariadne.py @@ -21,17 +21,12 @@ from newrelic.api.wsgi_application import wrap_wsgi_application from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement -from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, -) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception +from newrelic.hooks.framework_graphql import GRAPHQL_VERSION, ignore_graphql_duplicate_exception - -def framework_details(): - import ariadne - - return ("Ariadne", getattr(ariadne, "__version__", None)) +ARIADNE_VERSION = get_package_version("ariadne") +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) def bind_graphql(schema, data, *args, **kwargs): @@ -49,9 +44,8 @@ def wrap_graphql_sync(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) # No version info available on ariadne - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Ariadne", version=ARIADNE_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) query = data["query"] if hasattr(query, "body"): @@ -83,9 +77,8 @@ async def wrap_graphql(wrapped, instance, args, kwargs): result = await result return result - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) # No version info available on ariadne - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Ariadne", version=ARIADNE_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) query = data["query"] if hasattr(query, "body"): @@ -104,6 +97,9 @@ async def wrap_graphql(wrapped, instance, args, kwargs): def instrument_ariadne_execute(module): + # v0.9.0 is the version where ariadne started using graphql-core v3 + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "graphql"): wrap_function_wrapper(module, "graphql", wrap_graphql) @@ -112,10 +108,14 @@ def instrument_ariadne_execute(module): def instrument_ariadne_asgi(module): + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "GraphQL"): - wrap_asgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_asgi_application(module, "GraphQL.__call__", framework=("Ariadne", ARIADNE_VERSION)) def instrument_ariadne_wsgi(module): + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "GraphQL"): - wrap_wsgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_wsgi_application(module, "GraphQL.__call__", framework=("Ariadne", ARIADNE_VERSION)) diff --git a/newrelic/hooks/framework_bottle.py b/newrelic/hooks/framework_bottle.py index 99caa844f..5635fa782 100644 --- a/newrelic/hooks/framework_bottle.py +++ b/newrelic/hooks/framework_bottle.py @@ -18,14 +18,21 @@ import functools -from newrelic.api.function_trace import (FunctionTrace, FunctionTraceWrapper, - wrap_function_trace) +from newrelic.api.function_trace import ( + FunctionTrace, + FunctionTraceWrapper, + wrap_function_trace, +) from newrelic.api.transaction import current_transaction from newrelic.api.wsgi_application import wrap_wsgi_application from newrelic.common.object_names import callable_name -from newrelic.common.object_wrapper import (wrap_out_function, - function_wrapper, ObjectProxy, wrap_object_attribute, - wrap_function_wrapper) +from newrelic.common.object_wrapper import ( + ObjectProxy, + function_wrapper, + wrap_function_wrapper, + wrap_object_attribute, + wrap_out_function, +) module_bottle = None @@ -34,17 +41,17 @@ def status_code(exc, value, tb): # The HTTPError class derives from HTTPResponse and so we do not # need to check for it seperately as isinstance() will pick it up. - if isinstance(value, module_bottle.HTTPResponse): - if hasattr(value, 'status_code'): + if isinstance(value, module_bottle.HTTPResponse): # pragma: no cover + if hasattr(value, "status_code"): return value.status_code - elif hasattr(value, 'status'): + elif hasattr(value, "status"): return value.status - elif hasattr(value, 'http_status_code'): + elif hasattr(value, "http_status_code"): return value.http_status_code def should_ignore(exc, value, tb): - if hasattr(module_bottle, 'RouteReset'): + if hasattr(module_bottle, "RouteReset"): if isinstance(value, module_bottle.RouteReset): return True @@ -113,8 +120,7 @@ def get(self, status, default=None): transaction.set_transaction_name(name, priority=1) handler = FunctionTraceWrapper(handler, name=name) else: - transaction.set_transaction_name(str(status), - group='StatusCode', priority=1) + transaction.set_transaction_name(str(status), group="StatusCode", priority=1) return handler or default @@ -140,43 +146,39 @@ def instrument_bottle(module): global module_bottle module_bottle = module - framework_details = ('Bottle', getattr(module, '__version__')) - - if hasattr(module.Bottle, 'wsgi'): # version >= 0.9 - wrap_wsgi_application(module, 'Bottle.wsgi', - framework=framework_details) - elif hasattr(module.Bottle, '__call__'): # version < 0.9 - wrap_wsgi_application(module, 'Bottle.__call__', - framework=framework_details) - - if (hasattr(module, 'Route') and - hasattr(module.Route, '_make_callback')): # version >= 0.10 - wrap_out_function(module, 'Route._make_callback', - output_wrapper_Route_make_callback) - elif hasattr(module.Bottle, '_match'): # version >= 0.9 - wrap_out_function(module, 'Bottle._match', - output_wrapper_Bottle_match) - elif hasattr(module.Bottle, 'match_url'): # version < 0.9 - wrap_out_function(module, 'Bottle.match_url', - output_wrapper_Bottle_match) - - wrap_object_attribute(module, 'Bottle.error_handler', - proxy_Bottle_error_handler) - - if hasattr(module, 'auth_basic'): - wrap_function_wrapper(module, 'auth_basic', wrapper_auth_basic) - - if hasattr(module, 'SimpleTemplate'): - wrap_function_trace(module, 'SimpleTemplate.render') - - if hasattr(module, 'MakoTemplate'): - wrap_function_trace(module, 'MakoTemplate.render') - - if hasattr(module, 'CheetahTemplate'): - wrap_function_trace(module, 'CheetahTemplate.render') - - if hasattr(module, 'Jinja2Template'): - wrap_function_trace(module, 'Jinja2Template.render') - - if hasattr(module, 'SimpleTALTemplate'): - wrap_function_trace(module, 'SimpleTALTemplate.render') + framework_details = ("Bottle", getattr(module, "__version__")) + # version >= 0.9 + if hasattr(module.Bottle, "wsgi"): # pragma: no cover + wrap_wsgi_application(module, "Bottle.wsgi", framework=framework_details) + # version < 0.9 + elif hasattr(module.Bottle, "__call__"): # pragma: no cover + wrap_wsgi_application(module, "Bottle.__call__", framework=framework_details) + # version >= 0.10 + if hasattr(module, "Route") and hasattr(module.Route, "_make_callback"): # pragma: no cover + wrap_out_function(module, "Route._make_callback", output_wrapper_Route_make_callback) + # version >= 0.9 + elif hasattr(module.Bottle, "_match"): # pragma: no cover + wrap_out_function(module, "Bottle._match", output_wrapper_Bottle_match) + # version < 0.9 + elif hasattr(module.Bottle, "match_url"): # pragma: no cover + wrap_out_function(module, "Bottle.match_url", output_wrapper_Bottle_match) + + wrap_object_attribute(module, "Bottle.error_handler", proxy_Bottle_error_handler) + + if hasattr(module, "auth_basic"): + wrap_function_wrapper(module, "auth_basic", wrapper_auth_basic) + + if hasattr(module, "SimpleTemplate"): + wrap_function_trace(module, "SimpleTemplate.render") + + if hasattr(module, "MakoTemplate"): + wrap_function_trace(module, "MakoTemplate.render") + + if hasattr(module, "CheetahTemplate"): + wrap_function_trace(module, "CheetahTemplate.render") + + if hasattr(module, "Jinja2Template"): + wrap_function_trace(module, "Jinja2Template.render") + + if hasattr(module, "SimpleTALTemplate"): # pragma: no cover + wrap_function_trace(module, "SimpleTALTemplate.render") diff --git a/newrelic/hooks/framework_django.py b/newrelic/hooks/framework_django.py index 005f28279..3d9f448cc 100644 --- a/newrelic/hooks/framework_django.py +++ b/newrelic/hooks/framework_django.py @@ -12,48 +12,60 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging import sys import threading -import logging -import functools - -from newrelic.packages import six from newrelic.api.application import register_application from newrelic.api.background_task import BackgroundTaskWrapper from newrelic.api.error_trace import wrap_error_trace -from newrelic.api.function_trace import (FunctionTrace, wrap_function_trace, - FunctionTraceWrapper) +from newrelic.api.function_trace import ( + FunctionTrace, + FunctionTraceWrapper, + wrap_function_trace, +) from newrelic.api.html_insertion import insert_html_snippet -from newrelic.api.transaction import current_transaction from newrelic.api.time_trace import notice_error +from newrelic.api.transaction import current_transaction from newrelic.api.transaction_name import wrap_transaction_name from newrelic.api.wsgi_application import WSGIApplicationWrapper - -from newrelic.common.object_wrapper import (FunctionWrapper, wrap_in_function, - wrap_post_function, wrap_function_wrapper, function_wrapper) +from newrelic.common.coroutine import is_asyncio_coroutine, is_coroutine_function from newrelic.common.object_names import callable_name +from newrelic.common.object_wrapper import ( + FunctionWrapper, + function_wrapper, + wrap_function_wrapper, + wrap_in_function, + wrap_post_function, +) from newrelic.config import extra_settings from newrelic.core.config import global_settings -from newrelic.common.coroutine import is_coroutine_function, is_asyncio_coroutine +from newrelic.packages import six if six.PY3: from newrelic.hooks.framework_django_py3 import ( - _nr_wrapper_BaseHandler_get_response_async_, _nr_wrap_converted_middleware_async_, + _nr_wrapper_BaseHandler_get_response_async_, ) _logger = logging.getLogger(__name__) _boolean_states = { - '1': True, 'yes': True, 'true': True, 'on': True, - '0': False, 'no': False, 'false': False, 'off': False + "1": True, + "yes": True, + "true": True, + "on": True, + "0": False, + "no": False, + "false": False, + "off": False, } def _setting_boolean(value): if value.lower() not in _boolean_states: - raise ValueError('Not a boolean: %s' % value) + raise ValueError("Not a boolean: %s" % value) return _boolean_states[value.lower()] @@ -62,21 +74,20 @@ def _setting_set(value): _settings_types = { - 'browser_monitoring.auto_instrument': _setting_boolean, - 'instrumentation.templates.inclusion_tag': _setting_set, - 'instrumentation.background_task.startup_timeout': float, - 'instrumentation.scripts.django_admin': _setting_set, + "browser_monitoring.auto_instrument": _setting_boolean, + "instrumentation.templates.inclusion_tag": _setting_set, + "instrumentation.background_task.startup_timeout": float, + "instrumentation.scripts.django_admin": _setting_set, } _settings_defaults = { - 'browser_monitoring.auto_instrument': True, - 'instrumentation.templates.inclusion_tag': set(), - 'instrumentation.background_task.startup_timeout': 10.0, - 'instrumentation.scripts.django_admin': set(), + "browser_monitoring.auto_instrument": True, + "instrumentation.templates.inclusion_tag": set(), + "instrumentation.background_task.startup_timeout": 10.0, + "instrumentation.scripts.django_admin": set(), } -django_settings = extra_settings('import-hook:django', - types=_settings_types, defaults=_settings_defaults) +django_settings = extra_settings("import-hook:django", types=_settings_types, defaults=_settings_defaults) def should_add_browser_timing(response, transaction): @@ -92,7 +103,7 @@ def should_add_browser_timing(response, transaction): # do RUM insertion, need to move to a WSGI middleware and # deal with how to update the content length. - if hasattr(response, 'streaming_content'): + if hasattr(response, "streaming_content"): return False # Need to be running within a valid web transaction. @@ -121,21 +132,21 @@ def should_add_browser_timing(response, transaction): # a user may want to also perform insertion for # 'application/xhtml+xml'. - ctype = response.get('Content-Type', '').lower().split(';')[0] + ctype = response.get("Content-Type", "").lower().split(";")[0] if ctype not in transaction.settings.browser_monitoring.content_type: return False # Don't risk it if content encoding already set. - if response.has_header('Content-Encoding'): + if response.has_header("Content-Encoding"): return False # Don't risk it if content is actually within an attachment. - cdisposition = response.get('Content-Disposition', '').lower() + cdisposition = response.get("Content-Disposition", "").lower() - if cdisposition.split(';')[0].strip().lower() == 'attachment': + if cdisposition.split(";")[0].strip().lower() == "attachment": return False return True @@ -144,6 +155,7 @@ def should_add_browser_timing(response, transaction): # Response middleware for automatically inserting RUM header and # footer into HTML response returned by application + def browser_timing_insertion(response, transaction): # No point continuing if header is empty. This can occur if @@ -175,14 +187,15 @@ def html_to_be_inserted(): if result is not None: if transaction.settings.debug.log_autorum_middleware: - _logger.debug('RUM insertion from Django middleware ' - 'triggered. Bytes added was %r.', - len(result) - len(response.content)) + _logger.debug( + "RUM insertion from Django middleware triggered. Bytes added was %r.", + len(result) - len(response.content), + ) response.content = result - if response.get('Content-Length', None): - response['Content-Length'] = str(len(response.content)) + if response.get("Content-Length", None): + response["Content-Length"] = str(len(response.content)) return response @@ -192,18 +205,19 @@ def html_to_be_inserted(): # 'newrelic' will be automatically inserted into set of tag # libraries when performing step to instrument the middleware. + def newrelic_browser_timing_header(): from django.utils.safestring import mark_safe transaction = current_transaction() - return transaction and mark_safe(transaction.browser_timing_header()) or '' + return transaction and mark_safe(transaction.browser_timing_header()) or "" # nosec def newrelic_browser_timing_footer(): from django.utils.safestring import mark_safe transaction = current_transaction() - return transaction and mark_safe(transaction.browser_timing_footer()) or '' + return transaction and mark_safe(transaction.browser_timing_footer()) or "" # nosec # Addition of instrumentation for middleware. Can only do this @@ -256,9 +270,14 @@ def wrapper(wrapped, instance, args, kwargs): yield wrapper(wrapped) -def wrap_view_middleware(middleware): +# Because this is not being used in any version of Django that is +# within New Relic's support window, no tests will be added +# for this. However, value exists to keeping backwards compatible +# functionality, so instead of removing this instrumentation, this +# will be excluded from the coverage analysis. +def wrap_view_middleware(middleware): # pragma: no cover - # XXX This is no longer being used. The changes to strip the + # This is no longer being used. The changes to strip the # wrapper from the view handler when passed into the function # urlresolvers.reverse() solves most of the problems. To back # that up, the object wrapper now proxies various special @@ -293,7 +312,7 @@ def wrapper(wrapped, instance, args, kwargs): def _wrapped(request, view_func, view_args, view_kwargs): # This strips the view handler wrapper before call. - if hasattr(view_func, '_nr_last_object'): + if hasattr(view_func, "_nr_last_object"): view_func = view_func._nr_last_object return wrapped(request, view_func, view_args, view_kwargs) @@ -370,37 +389,28 @@ def insert_and_wrap_middleware(handler, *args, **kwargs): # priority than that for view handler so view handler # name always takes precedence. - if hasattr(handler, '_request_middleware'): - handler._request_middleware = list( - wrap_leading_middleware( - handler._request_middleware)) + if hasattr(handler, "_request_middleware"): + handler._request_middleware = list(wrap_leading_middleware(handler._request_middleware)) - if hasattr(handler, '_view_middleware'): - handler._view_middleware = list( - wrap_leading_middleware( - handler._view_middleware)) + if hasattr(handler, "_view_middleware"): + handler._view_middleware = list(wrap_leading_middleware(handler._view_middleware)) - if hasattr(handler, '_template_response_middleware'): + if hasattr(handler, "_template_response_middleware"): handler._template_response_middleware = list( - wrap_trailing_middleware( - handler._template_response_middleware)) + wrap_trailing_middleware(handler._template_response_middleware) + ) - if hasattr(handler, '_response_middleware'): - handler._response_middleware = list( - wrap_trailing_middleware( - handler._response_middleware)) + if hasattr(handler, "_response_middleware"): + handler._response_middleware = list(wrap_trailing_middleware(handler._response_middleware)) - if hasattr(handler, '_exception_middleware'): - handler._exception_middleware = list( - wrap_trailing_middleware( - handler._exception_middleware)) + if hasattr(handler, "_exception_middleware"): + handler._exception_middleware = list(wrap_trailing_middleware(handler._exception_middleware)) finally: lock.release() -def _nr_wrapper_GZipMiddleware_process_response_(wrapped, instance, args, - kwargs): +def _nr_wrapper_GZipMiddleware_process_response_(wrapped, instance, args, kwargs): transaction = current_transaction() @@ -433,36 +443,33 @@ def _nr_wrapper_BaseHandler_get_response_(wrapped, instance, args, kwargs): request = _bind_get_response(*args, **kwargs) - if hasattr(request, '_nr_exc_info'): + if hasattr(request, "_nr_exc_info"): notice_error(error=request._nr_exc_info, status_code=response.status_code) - delattr(request, '_nr_exc_info') + delattr(request, "_nr_exc_info") return response # Post import hooks for modules. + def instrument_django_core_handlers_base(module): # Attach a post function to load_middleware() method of # BaseHandler to trigger insertion of browser timing # middleware and wrapping of middleware for timing etc. - wrap_post_function(module, 'BaseHandler.load_middleware', - insert_and_wrap_middleware) + wrap_post_function(module, "BaseHandler.load_middleware", insert_and_wrap_middleware) - if six.PY3 and hasattr(module.BaseHandler, 'get_response_async'): - wrap_function_wrapper(module, 'BaseHandler.get_response_async', - _nr_wrapper_BaseHandler_get_response_async_) + if six.PY3 and hasattr(module.BaseHandler, "get_response_async"): + wrap_function_wrapper(module, "BaseHandler.get_response_async", _nr_wrapper_BaseHandler_get_response_async_) - wrap_function_wrapper(module, 'BaseHandler.get_response', - _nr_wrapper_BaseHandler_get_response_) + wrap_function_wrapper(module, "BaseHandler.get_response", _nr_wrapper_BaseHandler_get_response_) def instrument_django_gzip_middleware(module): - wrap_function_wrapper(module, 'GZipMiddleware.process_response', - _nr_wrapper_GZipMiddleware_process_response_) + wrap_function_wrapper(module, "GZipMiddleware.process_response", _nr_wrapper_GZipMiddleware_process_response_) def wrap_handle_uncaught_exception(middleware): @@ -506,10 +513,9 @@ def instrument_django_core_handlers_wsgi(module): import django - framework = ('Django', django.get_version()) + framework = ("Django", django.get_version()) - module.WSGIHandler.__call__ = WSGIApplicationWrapper( - module.WSGIHandler.__call__, framework=framework) + module.WSGIHandler.__call__ = WSGIApplicationWrapper(module.WSGIHandler.__call__, framework=framework) # Wrap handle_uncaught_exception() of WSGIHandler so that # can capture exception details of any exception which @@ -519,10 +525,10 @@ def instrument_django_core_handlers_wsgi(module): # exception, so last chance to do this as exception will not # propagate up to the WSGI application. - if hasattr(module.WSGIHandler, 'handle_uncaught_exception'): - module.WSGIHandler.handle_uncaught_exception = ( - wrap_handle_uncaught_exception( - module.WSGIHandler.handle_uncaught_exception)) + if hasattr(module.WSGIHandler, "handle_uncaught_exception"): + module.WSGIHandler.handle_uncaught_exception = wrap_handle_uncaught_exception( + module.WSGIHandler.handle_uncaught_exception + ) def wrap_view_handler(wrapped, priority=3): @@ -532,7 +538,7 @@ def wrap_view_handler(wrapped, priority=3): # called recursively. We flag that view handler was wrapped # using the '_nr_django_view_handler' attribute. - if hasattr(wrapped, '_nr_django_view_handler'): + if hasattr(wrapped, "_nr_django_view_handler"): return wrapped if hasattr(wrapped, "view_class"): @@ -584,7 +590,7 @@ def wrapper(wrapped, instance, args, kwargs): if transaction is None: return wrapped(*args, **kwargs) - if hasattr(transaction, '_nr_django_url_resolver'): + if hasattr(transaction, "_nr_django_url_resolver"): return wrapped(*args, **kwargs) # Tag the transaction so we know when we are in the top @@ -602,8 +608,7 @@ def _wrapped(path): if type(result) is tuple: callback, callback_args, callback_kwargs = result - result = (wrap_view_handler(callback, priority=5), - callback_args, callback_kwargs) + result = (wrap_view_handler(callback, priority=5), callback_args, callback_kwargs) else: result.func = wrap_view_handler(result.func, priority=5) @@ -636,8 +641,7 @@ def wrapper(wrapped, instance, args, kwargs): return wrap_view_handler(result, priority=priority) else: callback, param_dict = result - return (wrap_view_handler(callback, priority=priority), - param_dict) + return (wrap_view_handler(callback, priority=priority), param_dict) return FunctionWrapper(wrapped, wrapper) @@ -653,9 +657,10 @@ def wrap_url_reverse(wrapped): def wrapper(wrapped, instance, args, kwargs): def execute(viewname, *args, **kwargs): - if hasattr(viewname, '_nr_last_object'): + if hasattr(viewname, "_nr_last_object"): viewname = viewname._nr_last_object return wrapped(viewname, *args, **kwargs) + return execute(*args, **kwargs) return FunctionWrapper(wrapped, wrapper) @@ -672,20 +677,19 @@ def instrument_django_core_urlresolvers(module): # lost. We thus intercept it here so can capture that # traceback which is otherwise lost. - wrap_error_trace(module, 'get_callable') + wrap_error_trace(module, "get_callable") # Wrap methods which resolves a request to a view handler. # This can be called against a resolver initialised against # a custom URL conf associated with a specific request, or a # resolver which uses the default URL conf. - if hasattr(module, 'RegexURLResolver'): + if hasattr(module, "RegexURLResolver"): urlresolver = module.RegexURLResolver else: urlresolver = module.URLResolver - urlresolver.resolve = wrap_url_resolver( - urlresolver.resolve) + urlresolver.resolve = wrap_url_resolver(urlresolver.resolve) # Wrap methods which resolve error handlers. For 403 and 404 # we give these higher naming priority over any prior @@ -695,26 +699,22 @@ def instrument_django_core_urlresolvers(module): # handler in place so error details identify the correct # transaction. - if hasattr(urlresolver, 'resolve403'): - urlresolver.resolve403 = wrap_url_resolver_nnn( - urlresolver.resolve403, priority=3) + if hasattr(urlresolver, "resolve403"): + urlresolver.resolve403 = wrap_url_resolver_nnn(urlresolver.resolve403, priority=3) - if hasattr(urlresolver, 'resolve404'): - urlresolver.resolve404 = wrap_url_resolver_nnn( - urlresolver.resolve404, priority=3) + if hasattr(urlresolver, "resolve404"): + urlresolver.resolve404 = wrap_url_resolver_nnn(urlresolver.resolve404, priority=3) - if hasattr(urlresolver, 'resolve500'): - urlresolver.resolve500 = wrap_url_resolver_nnn( - urlresolver.resolve500, priority=1) + if hasattr(urlresolver, "resolve500"): + urlresolver.resolve500 = wrap_url_resolver_nnn(urlresolver.resolve500, priority=1) - if hasattr(urlresolver, 'resolve_error_handler'): - urlresolver.resolve_error_handler = wrap_url_resolver_nnn( - urlresolver.resolve_error_handler, priority=1) + if hasattr(urlresolver, "resolve_error_handler"): + urlresolver.resolve_error_handler = wrap_url_resolver_nnn(urlresolver.resolve_error_handler, priority=1) # Wrap function for performing reverse URL lookup to strip any # instrumentation wrapper when view handler is passed in. - if hasattr(module, 'reverse'): + if hasattr(module, "reverse"): module.reverse = wrap_url_reverse(module.reverse) @@ -723,7 +723,7 @@ def instrument_django_urls_base(module): # Wrap function for performing reverse URL lookup to strip any # instrumentation wrapper when view handler is passed in. - if hasattr(module, 'reverse'): + if hasattr(module, "reverse"): module.reverse = wrap_url_reverse(module.reverse) @@ -742,17 +742,15 @@ def instrument_django_template(module): def template_name(template, *args): return template.name - if hasattr(module.Template, '_render'): - wrap_function_trace(module, 'Template._render', - name=template_name, group='Template/Render') + if hasattr(module.Template, "_render"): + wrap_function_trace(module, "Template._render", name=template_name, group="Template/Render") else: - wrap_function_trace(module, 'Template.render', - name=template_name, group='Template/Render') + wrap_function_trace(module, "Template.render", name=template_name, group="Template/Render") # Django 1.8 no longer has module.libraries. As automatic way is not # preferred we can just skip this now. - if not hasattr(module, 'libraries'): + if not hasattr(module, "libraries"): return # Register template tags used for manual insertion of RUM @@ -766,12 +764,12 @@ def template_name(template, *args): library.simple_tag(newrelic_browser_timing_header) library.simple_tag(newrelic_browser_timing_footer) - module.libraries['django.templatetags.newrelic'] = library + module.libraries["django.templatetags.newrelic"] = library def wrap_template_block(wrapped): def wrapper(wrapped, instance, args, kwargs): - return FunctionTraceWrapper(wrapped, name=instance.name, group='Template/Block')(*args, **kwargs) + return FunctionTraceWrapper(wrapped, name=instance.name, group="Template/Block")(*args, **kwargs) return FunctionWrapper(wrapped, wrapper) @@ -812,11 +810,15 @@ def instrument_django_core_servers_basehttp(module): # instrumentation of the wsgiref module or some other means. def wrap_wsgi_application_entry_point(server, application, **kwargs): - return ((server, WSGIApplicationWrapper(application, - framework='Django'),), kwargs) + return ( + ( + server, + WSGIApplicationWrapper(application, framework="Django"), + ), + kwargs, + ) - if (not hasattr(module, 'simple_server') and - hasattr(module.ServerHandler, 'run')): + if not hasattr(module, "simple_server") and hasattr(module.ServerHandler, "run"): # Patch the server to make it work properly. @@ -833,11 +835,10 @@ def run(self, application): def close(self): if self.result is not None: try: - self.request_handler.log_request( - self.status.split(' ', 1)[0], self.bytes_sent) + self.request_handler.log_request(self.status.split(" ", 1)[0], self.bytes_sent) finally: try: - if hasattr(self.result, 'close'): + if hasattr(self.result, "close"): self.result.close() finally: self.result = None @@ -855,17 +856,16 @@ def close(self): # Now wrap it with our instrumentation. - wrap_in_function(module, 'ServerHandler.run', - wrap_wsgi_application_entry_point) + wrap_in_function(module, "ServerHandler.run", wrap_wsgi_application_entry_point) def instrument_django_contrib_staticfiles_views(module): - if not hasattr(module.serve, '_nr_django_view_handler'): + if not hasattr(module.serve, "_nr_django_view_handler"): module.serve = wrap_view_handler(module.serve, priority=3) def instrument_django_contrib_staticfiles_handlers(module): - wrap_transaction_name(module, 'StaticFilesHandler.serve') + wrap_transaction_name(module, "StaticFilesHandler.serve") def instrument_django_views_debug(module): @@ -878,10 +878,8 @@ def instrument_django_views_debug(module): # from a middleware or view handler in place so error # details identify the correct transaction. - module.technical_404_response = wrap_view_handler( - module.technical_404_response, priority=3) - module.technical_500_response = wrap_view_handler( - module.technical_500_response, priority=1) + module.technical_404_response = wrap_view_handler(module.technical_404_response, priority=3) + module.technical_500_response = wrap_view_handler(module.technical_500_response, priority=1) def resolve_view_handler(view, request): @@ -890,8 +888,7 @@ def resolve_view_handler(view, request): # duplicate the lookup mechanism. if request.method.lower() in view.http_method_names: - handler = getattr(view, request.method.lower(), - view.http_method_not_allowed) + handler = getattr(view, request.method.lower(), view.http_method_not_allowed) else: handler = view.http_method_not_allowed @@ -936,7 +933,7 @@ def _args(request, *args, **kwargs): priority = 4 - if transaction.group == 'Function': + if transaction.group == "Function": if transaction.name == callable_name(view): priority = 5 @@ -953,22 +950,22 @@ def instrument_django_views_generic_base(module): def instrument_django_http_multipartparser(module): - wrap_function_trace(module, 'MultiPartParser.parse') + wrap_function_trace(module, "MultiPartParser.parse") def instrument_django_core_mail(module): - wrap_function_trace(module, 'mail_admins') - wrap_function_trace(module, 'mail_managers') - wrap_function_trace(module, 'send_mail') + wrap_function_trace(module, "mail_admins") + wrap_function_trace(module, "mail_managers") + wrap_function_trace(module, "send_mail") def instrument_django_core_mail_message(module): - wrap_function_trace(module, 'EmailMessage.send') + wrap_function_trace(module, "EmailMessage.send") def _nr_wrapper_BaseCommand___init___(wrapped, instance, args, kwargs): instance.handle = FunctionTraceWrapper(instance.handle) - if hasattr(instance, 'handle_noargs'): + if hasattr(instance, "handle_noargs"): instance.handle_noargs = FunctionTraceWrapper(instance.handle_noargs) return wrapped(*args, **kwargs) @@ -982,29 +979,25 @@ def _args(argv, *args, **kwargs): subcommand = _argv[1] commands = django_settings.instrumentation.scripts.django_admin - startup_timeout = \ - django_settings.instrumentation.background_task.startup_timeout + startup_timeout = django_settings.instrumentation.background_task.startup_timeout if subcommand not in commands: return wrapped(*args, **kwargs) application = register_application(timeout=startup_timeout) - return BackgroundTaskWrapper(wrapped, application, subcommand, 'Django')(*args, **kwargs) + return BackgroundTaskWrapper(wrapped, application, subcommand, "Django")(*args, **kwargs) def instrument_django_core_management_base(module): - wrap_function_wrapper(module, 'BaseCommand.__init__', - _nr_wrapper_BaseCommand___init___) - wrap_function_wrapper(module, 'BaseCommand.run_from_argv', - _nr_wrapper_BaseCommand_run_from_argv_) + wrap_function_wrapper(module, "BaseCommand.__init__", _nr_wrapper_BaseCommand___init___) + wrap_function_wrapper(module, "BaseCommand.run_from_argv", _nr_wrapper_BaseCommand_run_from_argv_) @function_wrapper -def _nr_wrapper_django_inclusion_tag_wrapper_(wrapped, instance, - args, kwargs): +def _nr_wrapper_django_inclusion_tag_wrapper_(wrapped, instance, args, kwargs): - name = hasattr(wrapped, '__name__') and wrapped.__name__ + name = hasattr(wrapped, "__name__") and wrapped.__name__ if name is None: return wrapped(*args, **kwargs) @@ -1013,16 +1006,14 @@ def _nr_wrapper_django_inclusion_tag_wrapper_(wrapped, instance, tags = django_settings.instrumentation.templates.inclusion_tag - if '*' not in tags and name not in tags and qualname not in tags: + if "*" not in tags and name not in tags and qualname not in tags: return wrapped(*args, **kwargs) - return FunctionTraceWrapper(wrapped, name=name, group='Template/Tag')(*args, **kwargs) + return FunctionTraceWrapper(wrapped, name=name, group="Template/Tag")(*args, **kwargs) @function_wrapper -def _nr_wrapper_django_inclusion_tag_decorator_(wrapped, instance, - args, kwargs): - +def _nr_wrapper_django_inclusion_tag_decorator_(wrapped, instance, args, kwargs): def _bind_params(func, *args, **kwargs): return func, args, kwargs @@ -1033,63 +1024,56 @@ def _bind_params(func, *args, **kwargs): return wrapped(func, *_args, **_kwargs) -def _nr_wrapper_django_template_base_Library_inclusion_tag_(wrapped, - instance, args, kwargs): +def _nr_wrapper_django_template_base_Library_inclusion_tag_(wrapped, instance, args, kwargs): - return _nr_wrapper_django_inclusion_tag_decorator_( - wrapped(*args, **kwargs)) + return _nr_wrapper_django_inclusion_tag_decorator_(wrapped(*args, **kwargs)) @function_wrapper -def _nr_wrapper_django_template_base_InclusionNode_render_(wrapped, - instance, args, kwargs): +def _nr_wrapper_django_template_base_InclusionNode_render_(wrapped, instance, args, kwargs): if wrapped.__self__ is None: return wrapped(*args, **kwargs) - file_name = getattr(wrapped.__self__, '_nr_file_name', None) + file_name = getattr(wrapped.__self__, "_nr_file_name", None) if file_name is None: return wrapped(*args, **kwargs) name = wrapped.__self__._nr_file_name - return FunctionTraceWrapper(wrapped, name=name, group='Template/Include')(*args, **kwargs) + return FunctionTraceWrapper(wrapped, name=name, group="Template/Include")(*args, **kwargs) -def _nr_wrapper_django_template_base_generic_tag_compiler_(wrapped, instance, - args, kwargs): +def _nr_wrapper_django_template_base_generic_tag_compiler_(wrapped, instance, args, kwargs): if wrapped.__code__.co_argcount > 6: # Django > 1.3. - def _bind_params(parser, token, params, varargs, varkw, defaults, - name, takes_context, node_class, *args, **kwargs): + def _bind_params( + parser, token, params, varargs, varkw, defaults, name, takes_context, node_class, *args, **kwargs + ): return node_class + else: # Django <= 1.3. - def _bind_params(params, defaults, name, node_class, parser, token, - *args, **kwargs): + def _bind_params(params, defaults, name, node_class, parser, token, *args, **kwargs): return node_class node_class = _bind_params(*args, **kwargs) - if node_class.__name__ == 'InclusionNode': + if node_class.__name__ == "InclusionNode": result = wrapped(*args, **kwargs) - result.render = ( - _nr_wrapper_django_template_base_InclusionNode_render_( - result.render)) + result.render = _nr_wrapper_django_template_base_InclusionNode_render_(result.render) return result return wrapped(*args, **kwargs) -def _nr_wrapper_django_template_base_Library_tag_(wrapped, instance, - args, kwargs): - +def _nr_wrapper_django_template_base_Library_tag_(wrapped, instance, args, kwargs): def _bind_params(name=None, compile_function=None, *args, **kwargs): return compile_function @@ -1105,14 +1089,16 @@ def _get_node_class(compile_function): # Django >= 1.4 uses functools.partial if isinstance(compile_function, functools.partial): - node_class = compile_function.keywords.get('node_class') + node_class = compile_function.keywords.get("node_class") # Django < 1.4 uses their home-grown "curry" function, # not functools.partial. - if (hasattr(compile_function, 'func_closure') and - hasattr(compile_function, '__name__') and - compile_function.__name__ == '_curried'): + if ( + hasattr(compile_function, "func_closure") + and hasattr(compile_function, "__name__") + and compile_function.__name__ == "_curried" + ): # compile_function here is generic_tag_compiler(), which has been # curried. To get node_class, we first get the function obj, args, @@ -1121,19 +1107,20 @@ def _get_node_class(compile_function): # is not consistent from platform to platform, so we need to map # them to the variables in compile_function.__code__.co_freevars. - cells = dict(zip(compile_function.__code__.co_freevars, - (c.cell_contents for c in compile_function.func_closure))) + cells = dict( + zip(compile_function.__code__.co_freevars, (c.cell_contents for c in compile_function.func_closure)) + ) # node_class is the 4th arg passed to generic_tag_compiler() - if 'args' in cells and len(cells['args']) > 3: - node_class = cells['args'][3] + if "args" in cells and len(cells["args"]) > 3: + node_class = cells["args"][3] return node_class node_class = _get_node_class(compile_function) - if node_class is None or node_class.__name__ != 'InclusionNode': + if node_class is None or node_class.__name__ != "InclusionNode": return wrapped(*args, **kwargs) # Climb stack to find the file_name of the include template. @@ -1146,9 +1133,8 @@ def _get_node_class(compile_function): for i in range(1, stack_levels + 1): frame = sys._getframe(i) - if ('generic_tag_compiler' in frame.f_code.co_names and - 'file_name' in frame.f_code.co_freevars): - file_name = frame.f_locals.get('file_name') + if "generic_tag_compiler" in frame.f_code.co_names and "file_name" in frame.f_code.co_freevars: + file_name = frame.f_locals.get("file_name") if file_name is None: return wrapped(*args, **kwargs) @@ -1167,22 +1153,22 @@ def instrument_django_template_base(module): settings = global_settings() - if 'django.instrumentation.inclusion-tags.r1' in settings.feature_flag: + if "django.instrumentation.inclusion-tags.r1" in settings.feature_flag: - if hasattr(module, 'generic_tag_compiler'): - wrap_function_wrapper(module, 'generic_tag_compiler', - _nr_wrapper_django_template_base_generic_tag_compiler_) + if hasattr(module, "generic_tag_compiler"): + wrap_function_wrapper( + module, "generic_tag_compiler", _nr_wrapper_django_template_base_generic_tag_compiler_ + ) - if hasattr(module, 'Library'): - wrap_function_wrapper(module, 'Library.tag', - _nr_wrapper_django_template_base_Library_tag_) + if hasattr(module, "Library"): + wrap_function_wrapper(module, "Library.tag", _nr_wrapper_django_template_base_Library_tag_) - wrap_function_wrapper(module, 'Library.inclusion_tag', - _nr_wrapper_django_template_base_Library_inclusion_tag_) + wrap_function_wrapper( + module, "Library.inclusion_tag", _nr_wrapper_django_template_base_Library_inclusion_tag_ + ) def _nr_wrap_converted_middleware_(middleware, name): - @function_wrapper def _wrapper(wrapped, instance, args, kwargs): transaction = current_transaction() @@ -1197,9 +1183,7 @@ def _wrapper(wrapped, instance, args, kwargs): return _wrapper(middleware) -def _nr_wrapper_convert_exception_to_response_(wrapped, instance, args, - kwargs): - +def _nr_wrapper_convert_exception_to_response_(wrapped, instance, args, kwargs): def _bind_params(original_middleware, *args, **kwargs): return original_middleware @@ -1214,21 +1198,19 @@ def _bind_params(original_middleware, *args, **kwargs): def instrument_django_core_handlers_exception(module): - if hasattr(module, 'convert_exception_to_response'): - wrap_function_wrapper(module, 'convert_exception_to_response', - _nr_wrapper_convert_exception_to_response_) + if hasattr(module, "convert_exception_to_response"): + wrap_function_wrapper(module, "convert_exception_to_response", _nr_wrapper_convert_exception_to_response_) - if hasattr(module, 'handle_uncaught_exception'): - module.handle_uncaught_exception = ( - wrap_handle_uncaught_exception( - module.handle_uncaught_exception)) + if hasattr(module, "handle_uncaught_exception"): + module.handle_uncaught_exception = wrap_handle_uncaught_exception(module.handle_uncaught_exception) def instrument_django_core_handlers_asgi(module): import django - framework = ('Django', django.get_version()) + framework = ("Django", django.get_version()) - if hasattr(module, 'ASGIHandler'): + if hasattr(module, "ASGIHandler"): from newrelic.api.asgi_application import wrap_asgi_application - wrap_asgi_application(module, 'ASGIHandler.__call__', framework=framework) + + wrap_asgi_application(module, "ASGIHandler.__call__", framework=framework) diff --git a/newrelic/hooks/framework_flask.py b/newrelic/hooks/framework_flask.py index c0540a60d..6ef45e6af 100644 --- a/newrelic/hooks/framework_flask.py +++ b/newrelic/hooks/framework_flask.py @@ -166,7 +166,7 @@ def _nr_wrapper_error_handler_(wrapped, instance, args, kwargs): return FunctionTraceWrapper(wrapped, name=name)(*args, **kwargs) -def _nr_wrapper_Flask__register_error_handler_(wrapped, instance, args, kwargs): +def _nr_wrapper_Flask__register_error_handler_(wrapped, instance, args, kwargs): # pragma: no cover def _bind_params(key, code_or_exception, f): return key, code_or_exception, f @@ -189,7 +189,6 @@ def _bind_params(code_or_exception, f): def _nr_wrapper_Flask_try_trigger_before_first_request_functions_(wrapped, instance, args, kwargs): - transaction = current_transaction() if transaction is None: @@ -355,7 +354,6 @@ def _nr_wrapper_Blueprint_endpoint_(wrapped, instance, args, kwargs): @function_wrapper def _nr_wrapper_Blueprint_before_request_wrapped_(wrapped, instance, args, kwargs): - transaction = current_transaction() if transaction is None: diff --git a/newrelic/hooks/framework_graphql.py b/newrelic/hooks/framework_graphql.py index d261b2e9f..df86e6984 100644 --- a/newrelic/hooks/framework_graphql.py +++ b/newrelic/hooks/framework_graphql.py @@ -13,7 +13,10 @@ # limitations under the License. import logging +import sys +import time from collections import deque +from inspect import isawaitable from newrelic.api.error_trace import ErrorTrace from newrelic.api.function_trace import FunctionTrace @@ -22,7 +25,14 @@ from newrelic.api.transaction import current_transaction, ignore_transaction from newrelic.common.object_names import callable_name, parse_exc_info from newrelic.common.object_wrapper import function_wrapper, wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement +from newrelic.hooks.framework_graphql_py3 import ( + nr_coro_execute_name_wrapper, + nr_coro_graphql_impl_wrapper, + nr_coro_resolver_error_wrapper, + nr_coro_resolver_wrapper, +) _logger = logging.getLogger(__name__) @@ -32,23 +42,8 @@ VERSION = None -def framework_version(): - """Framework version string.""" - global VERSION - if VERSION is None: - from graphql import __version__ as version - - VERSION = version - - return VERSION - - -def graphql_version(): - """Minor version tuple.""" - version = framework_version() - - # Take first two values in version to avoid ValueErrors with pre-releases (ex: 3.2.0a0) - return tuple(int(v) for v in version.split(".")[:2]) +GRAPHQL_VERSION = get_package_version("graphql-core") +major_version = int(GRAPHQL_VERSION.split(".")[0]) def ignore_graphql_duplicate_exception(exc, val, tb): @@ -98,10 +93,6 @@ def bind_operation_v3(operation, root_value): return operation -def bind_operation_v2(exe_context, operation, root_value): - return operation - - def wrap_execute_operation(wrapped, instance, args, kwargs): transaction = current_transaction() trace = current_trace() @@ -118,15 +109,9 @@ def wrap_execute_operation(wrapped, instance, args, kwargs): try: operation = bind_operation_v3(*args, **kwargs) except TypeError: - try: - operation = bind_operation_v2(*args, **kwargs) - except TypeError: - return wrapped(*args, **kwargs) + return wrapped(*args, **kwargs) - if graphql_version() < (3, 0): - execution_context = args[0] - else: - execution_context = instance + execution_context = instance trace.operation_name = get_node_value(operation, "name") or "" @@ -145,12 +130,17 @@ def wrap_execute_operation(wrapped, instance, args, kwargs): transaction.set_transaction_name(callable_name(wrapped), "GraphQL", priority=11) result = wrapped(*args, **kwargs) - if not execution_context.errors: - if hasattr(trace, "set_transaction_name"): + + def set_name(value=None): + if not execution_context.errors and hasattr(trace, "set_transaction_name"): # Operation trace sets transaction name trace.set_transaction_name(priority=14) + return value - return result + if isawaitable(result): + return nr_coro_execute_name_wrapper(wrapped, result, set_name) + else: + return set_name(result) def get_node_value(field, attr, subattr="value"): @@ -161,39 +151,25 @@ def get_node_value(field, attr, subattr="value"): def is_fragment_spread_node(field): - # Resolve version specific imports - try: - from graphql.language.ast import FragmentSpread - except ImportError: - from graphql import FragmentSpreadNode as FragmentSpread + from graphql.language.ast import FragmentSpreadNode - return isinstance(field, FragmentSpread) + return isinstance(field, FragmentSpreadNode) def is_fragment(field): - # Resolve version specific imports - try: - from graphql.language.ast import FragmentSpread, InlineFragment - except ImportError: - from graphql import FragmentSpreadNode as FragmentSpread - from graphql import InlineFragmentNode as InlineFragment - - _fragment_types = (InlineFragment, FragmentSpread) + from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode + _fragment_types = (InlineFragmentNode, FragmentSpreadNode) return isinstance(field, _fragment_types) def is_named_fragment(field): - # Resolve version specific imports - try: - from graphql.language.ast import NamedType - except ImportError: - from graphql import NamedTypeNode as NamedType + from graphql.language.ast import NamedTypeNode return ( is_fragment(field) and getattr(field, "type_condition", None) is not None - and isinstance(field.type_condition, NamedType) + and isinstance(field.type_condition, NamedTypeNode) ) @@ -321,12 +297,25 @@ def wrap_resolver(wrapped, instance, args, kwargs): if transaction is None: return wrapped(*args, **kwargs) - name = callable_name(wrapped) + base_resolver = getattr(wrapped, "_nr_base_resolver", wrapped) + + name = callable_name(base_resolver) transaction.set_transaction_name(name, "GraphQL", priority=13) + trace = FunctionTrace(name, source=base_resolver) - with FunctionTrace(name, source=wrapped): - with ErrorTrace(ignore=ignore_graphql_duplicate_exception): - return wrapped(*args, **kwargs) + with ErrorTrace(ignore=ignore_graphql_duplicate_exception): + sync_start_time = time.time() + result = wrapped(*args, **kwargs) + + if isawaitable(result): + # Grab any async resolvers and wrap with traces + return nr_coro_resolver_error_wrapper( + wrapped, name, trace, ignore_graphql_duplicate_exception, result, transaction + ) + else: + with trace: + trace.start_time = sync_start_time + return result def wrap_error_handler(wrapped, instance, args, kwargs): @@ -368,19 +357,12 @@ def bind_resolve_field_v3(parent_type, source, field_nodes, path): return parent_type, field_nodes, path -def bind_resolve_field_v2(exe_context, parent_type, source, field_asts, parent_info, field_path): - return parent_type, field_asts, field_path - - def wrap_resolve_field(wrapped, instance, args, kwargs): transaction = current_transaction() if transaction is None: return wrapped(*args, **kwargs) - if graphql_version() < (3, 0): - bind_resolve_field = bind_resolve_field_v2 - else: - bind_resolve_field = bind_resolve_field_v3 + bind_resolve_field = bind_resolve_field_v3 try: parent_type, field_asts, field_path = bind_resolve_field(*args, **kwargs) @@ -390,18 +372,34 @@ def wrap_resolve_field(wrapped, instance, args, kwargs): field_name = field_asts[0].name.value field_def = parent_type.fields.get(field_name) field_return_type = str(field_def.type) if field_def else "" + if isinstance(field_path, list): + field_path = field_path[0] + else: + field_path = field_path.key - with GraphQLResolverTrace(field_name) as trace: - with ErrorTrace(ignore=ignore_graphql_duplicate_exception): - trace._add_agent_attribute("graphql.field.parentType", parent_type.name) - trace._add_agent_attribute("graphql.field.returnType", field_return_type) + trace = GraphQLResolverTrace( + field_name, field_parent_type=parent_type.name, field_return_type=field_return_type, field_path=field_path + ) + start_time = time.time() - if isinstance(field_path, list): - trace._add_agent_attribute("graphql.field.path", field_path[0]) - else: - trace._add_agent_attribute("graphql.field.path", field_path.key) + try: + result = wrapped(*args, **kwargs) + except Exception: + # Synchonous resolver with exception raised + with trace: + trace.start_time = start_time + notice_error(ignore=ignore_graphql_duplicate_exception) + raise - return wrapped(*args, **kwargs) + if isawaitable(result): + # Asynchronous resolvers (returned coroutines from non-coroutine functions) + # Return a coroutine that handles wrapping in a resolver trace + return nr_coro_resolver_wrapper(wrapped, trace, ignore_graphql_duplicate_exception, result) + else: + # Synchonous resolver with no exception raised + with trace: + trace.start_time = start_time + return result def bind_graphql_impl_query(schema, source, *args, **kwargs): @@ -428,11 +426,8 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs): if not transaction: return wrapped(*args, **kwargs) - transaction.add_framework_info(name="GraphQL", version=framework_version()) - if graphql_version() < (3, 0): - bind_query = bind_execute_graphql_query - else: - bind_query = bind_graphql_impl_query + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) + bind_query = bind_graphql_impl_query try: schema, query = bind_query(*args, **kwargs) @@ -444,17 +439,34 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs): transaction.set_transaction_name(callable_name(wrapped), "GraphQL", priority=10) - with GraphQLOperationTrace() as trace: - trace.statement = graphql_statement(query) + trace = GraphQLOperationTrace() + + trace.statement = graphql_statement(query) - # Handle Schemas created from frameworks - if hasattr(schema, "_nr_framework"): - framework = schema._nr_framework - trace.product = framework[0] - transaction.add_framework_info(name=framework[0], version=framework[1]) + # Handle Schemas created from frameworks + if hasattr(schema, "_nr_framework"): + framework = schema._nr_framework + trace.product = framework[0] + transaction.add_framework_info(name=framework[0], version=framework[1]) + # Trace must be manually started and stopped to ensure it exists prior to and during the entire duration of the query. + # Otherwise subsequent instrumentation will not be able to find an operation trace and will have issues. + trace.__enter__() + try: with ErrorTrace(ignore=ignore_graphql_duplicate_exception): result = wrapped(*args, **kwargs) + except Exception as e: + # Execution finished synchronously, exit immediately. + trace.__exit__(*sys.exc_info()) + raise + else: + if isawaitable(result): + # Asynchronous implementations + # Return a coroutine that handles closing the operation trace + return nr_coro_graphql_impl_wrapper(wrapped, trace, ignore_graphql_duplicate_exception, result) + else: + # Execution finished synchronously, exit immediately. + trace.__exit__(None, None, None) return result @@ -480,11 +492,15 @@ def instrument_graphql_execute(module): def instrument_graphql_execution_utils(module): + if major_version == 2: + return if hasattr(module, "ExecutionContext"): wrap_function_wrapper(module, "ExecutionContext.__init__", wrap_executor_context_init) def instrument_graphql_execution_middleware(module): + if major_version == 2: + return if hasattr(module, "get_middleware_resolvers"): wrap_function_wrapper(module, "get_middleware_resolvers", wrap_get_middleware_resolvers) if hasattr(module, "MiddlewareManager"): @@ -492,20 +508,26 @@ def instrument_graphql_execution_middleware(module): def instrument_graphql_error_located_error(module): + if major_version == 2: + return if hasattr(module, "located_error"): wrap_function_wrapper(module, "located_error", wrap_error_handler) def instrument_graphql_validate(module): + if major_version == 2: + return wrap_function_wrapper(module, "validate", wrap_validate) def instrument_graphql(module): + if major_version == 2: + return if hasattr(module, "graphql_impl"): wrap_function_wrapper(module, "graphql_impl", wrap_graphql_impl) - if hasattr(module, "execute_graphql"): - wrap_function_wrapper(module, "execute_graphql", wrap_graphql_impl) def instrument_graphql_parser(module): + if major_version == 2: + return wrap_function_wrapper(module, "parse", wrap_parse) diff --git a/newrelic/hooks/framework_graphql_py3.py b/newrelic/hooks/framework_graphql_py3.py new file mode 100644 index 000000000..3931aa6ed --- /dev/null +++ b/newrelic/hooks/framework_graphql_py3.py @@ -0,0 +1,68 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import sys + +from newrelic.api.error_trace import ErrorTrace +from newrelic.api.function_trace import FunctionTrace + + +def nr_coro_execute_name_wrapper(wrapped, result, set_name): + @functools.wraps(wrapped) + async def _nr_coro_execute_name_wrapper(): + result_ = await result + set_name() + return result_ + + return _nr_coro_execute_name_wrapper() + + +def nr_coro_resolver_error_wrapper(wrapped, name, trace, ignore, result, transaction): + @functools.wraps(wrapped) + async def _nr_coro_resolver_error_wrapper(): + with trace: + with ErrorTrace(ignore=ignore): + try: + return await result + except Exception: + transaction.set_transaction_name(name, "GraphQL", priority=15) + raise + + return _nr_coro_resolver_error_wrapper() + + +def nr_coro_resolver_wrapper(wrapped, trace, ignore, result): + @functools.wraps(wrapped) + async def _nr_coro_resolver_wrapper(): + with trace: + with ErrorTrace(ignore=ignore): + return await result + + return _nr_coro_resolver_wrapper() + +def nr_coro_graphql_impl_wrapper(wrapped, trace, ignore, result): + @functools.wraps(wrapped) + async def _nr_coro_graphql_impl_wrapper(): + try: + with ErrorTrace(ignore=ignore): + result_ = await result + except: + trace.__exit__(*sys.exc_info()) + raise + else: + trace.__exit__(None, None, None) + return result_ + + + return _nr_coro_graphql_impl_wrapper() \ No newline at end of file diff --git a/newrelic/hooks/framework_pyramid.py b/newrelic/hooks/framework_pyramid.py index efe3c4468..996ebb372 100644 --- a/newrelic/hooks/framework_pyramid.py +++ b/newrelic/hooks/framework_pyramid.py @@ -48,8 +48,11 @@ from newrelic.api.transaction import current_transaction from newrelic.api.wsgi_application import wrap_wsgi_application from newrelic.common.object_names import callable_name -from newrelic.common.object_wrapper import (FunctionWrapper, wrap_out_function, - wrap_function_wrapper) +from newrelic.common.object_wrapper import ( + FunctionWrapper, + wrap_function_wrapper, + wrap_out_function, +) def instrument_pyramid_router(module): @@ -57,16 +60,17 @@ def instrument_pyramid_router(module): try: import pkg_resources - pyramid_version = pkg_resources.get_distribution('pyramid').version + + pyramid_version = pkg_resources.get_distribution("pyramid").version except Exception: pass - wrap_wsgi_application( - module, 'Router.__call__', framework=('Pyramid', pyramid_version)) + wrap_wsgi_application(module, "Router.__call__", framework=("Pyramid", pyramid_version)) def status_code(exc, value, tb): from pyramid.httpexceptions import HTTPException + # Ignore certain exceptions based on HTTP status codes. if isinstance(value, HTTPException): @@ -75,6 +79,7 @@ def status_code(exc, value, tb): def should_ignore(exc, value, tb): from pyramid.exceptions import PredicateMismatch + # Always ignore PredicateMismatch as it is raised by views to force # subsequent views to be consulted when multi views are being used. # It isn't therefore strictly an error as such as a subsequent view @@ -100,9 +105,7 @@ def view_handler_wrapper(wrapped, instance, args, kwargs): # set exception views to priority=1 so they won't take precedence over # the original view callable - transaction.set_transaction_name( - name, - priority=1 if args and isinstance(args[0], Exception) else 2) + transaction.set_transaction_name(name, priority=1 if args and isinstance(args[0], Exception) else 2) with FunctionTrace(name, source=view_callable) as trace: try: @@ -114,7 +117,7 @@ def view_handler_wrapper(wrapped, instance, args, kwargs): def wrap_view_handler(mapped_view): - if hasattr(mapped_view, '_nr_wrapped'): + if hasattr(mapped_view, "_nr_wrapped"): # pragma: no cover return mapped_view else: wrapped = FunctionWrapper(mapped_view, view_handler_wrapper) @@ -157,7 +160,7 @@ def _wrapper(context, request): return wrapper(context, request) finally: attr = instance.attr - inst = getattr(request, '__view__', None) + inst = getattr(request, "__view__", None) if inst is not None: if attr: handler = getattr(inst, attr) @@ -166,7 +169,7 @@ def _wrapper(context, request): tracer.name = name tracer.add_code_level_metrics(handler) else: - method = getattr(inst, '__call__') + method = getattr(inst, "__call__") if method: name = callable_name(method) transaction.set_transaction_name(name, priority=2) @@ -180,22 +183,21 @@ def instrument_pyramid_config_views(module): # Location of the ViewDeriver class changed from pyramid.config to # pyramid.config.views so check if present before trying to update. - if hasattr(module, 'ViewDeriver'): - wrap_out_function(module, 'ViewDeriver.__call__', wrap_view_handler) - elif hasattr(module, 'Configurator'): - wrap_out_function(module, 'Configurator._derive_view', - wrap_view_handler) + if hasattr(module, "ViewDeriver"): # pragma: no cover + wrap_out_function(module, "ViewDeriver.__call__", wrap_view_handler) + elif hasattr(module, "Configurator"): + wrap_out_function(module, "Configurator._derive_view", wrap_view_handler) - if hasattr(module, 'DefaultViewMapper'): + if hasattr(module, "DefaultViewMapper"): module.DefaultViewMapper.map_class_requestonly = FunctionWrapper( - module.DefaultViewMapper.map_class_requestonly, - default_view_mapper_wrapper) + module.DefaultViewMapper.map_class_requestonly, default_view_mapper_wrapper + ) module.DefaultViewMapper.map_class_native = FunctionWrapper( - module.DefaultViewMapper.map_class_native, - default_view_mapper_wrapper) + module.DefaultViewMapper.map_class_native, default_view_mapper_wrapper + ) def instrument_pyramid_config_tweens(module): - wrap_function_wrapper(module, 'Tweens.add_explicit', wrap_add_tween) + wrap_function_wrapper(module, "Tweens.add_explicit", wrap_add_tween) - wrap_function_wrapper(module, 'Tweens.add_implicit', wrap_add_tween) + wrap_function_wrapper(module, "Tweens.add_implicit", wrap_add_tween) diff --git a/newrelic/hooks/framework_strawberry.py b/newrelic/hooks/framework_strawberry.py index 92a0ea8b4..e6d06bb04 100644 --- a/newrelic/hooks/framework_strawberry.py +++ b/newrelic/hooks/framework_strawberry.py @@ -16,20 +16,14 @@ from newrelic.api.error_trace import ErrorTrace from newrelic.api.graphql_trace import GraphQLOperationTrace from newrelic.api.transaction import current_transaction -from newrelic.api.transaction_name import TransactionNameWrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement -from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, -) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception +from newrelic.hooks.framework_graphql import GRAPHQL_VERSION, ignore_graphql_duplicate_exception - -def framework_details(): - import strawberry - - return ("Strawberry", getattr(strawberry, "__version__", None)) +STRAWBERRY_GRAPHQL_VERSION = get_package_version("strawberry-graphql") +strawberry_version_tuple = tuple(map(int, STRAWBERRY_GRAPHQL_VERSION.split("."))) def bind_execute(query, *args, **kwargs): @@ -47,9 +41,8 @@ def wrap_execute_sync(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Strawberry", version=STRAWBERRY_GRAPHQL_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -74,9 +67,8 @@ async def wrap_execute(wrapped, instance, args, kwargs): except TypeError: return await wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Strawberry", version=STRAWBERRY_GRAPHQL_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -98,19 +90,20 @@ def wrap_from_resolver(wrapped, instance, args, kwargs): result = wrapped(*args, **kwargs) try: - field = bind_from_resolver(*args, **kwargs) + field = bind_from_resolver(*args, **kwargs) except TypeError: pass else: if hasattr(field, "base_resolver"): if hasattr(field.base_resolver, "wrapped_func"): - resolver_name = callable_name(field.base_resolver.wrapped_func) - result = TransactionNameWrapper(result, resolver_name, "GraphQL", priority=13) + result._nr_base_resolver = field.base_resolver.wrapped_func return result def instrument_strawberry_schema(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "Schema"): if hasattr(module.Schema, "execute"): wrap_function_wrapper(module, "Schema.execute", wrap_execute) @@ -119,11 +112,15 @@ def instrument_strawberry_schema(module): def instrument_strawberry_asgi(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "GraphQL"): - wrap_asgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_asgi_application(module, "GraphQL.__call__", framework=("Strawberry", STRAWBERRY_GRAPHQL_VERSION)) def instrument_strawberry_schema_converter(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "GraphQLCoreConverter"): if hasattr(module.GraphQLCoreConverter, "from_resolver"): wrap_function_wrapper(module, "GraphQLCoreConverter.from_resolver", wrap_from_resolver) diff --git a/newrelic/hooks/logger_loguru.py b/newrelic/hooks/logger_loguru.py index 9e7ed3eae..dc9843b20 100644 --- a/newrelic/hooks/logger_loguru.py +++ b/newrelic/hooks/logger_loguru.py @@ -134,7 +134,7 @@ def patch_loguru_logger(logger): if not hasattr(logger._core, "_nr_instrumented"): logger.add(_nr_log_forwarder, format="{message}") logger._core._nr_instrumented = True - elif not hasattr(logger, "_nr_instrumented"): + elif not hasattr(logger, "_nr_instrumented"): # pragma: no cover for _, handler in six.iteritems(logger._handlers): if handler._writer is _nr_log_forwarder: logger._nr_instrumented = True diff --git a/newrelic/hooks/logger_structlog.py b/newrelic/hooks/logger_structlog.py new file mode 100644 index 000000000..e652a795c --- /dev/null +++ b/newrelic/hooks/logger_structlog.py @@ -0,0 +1,86 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.api.transaction import current_transaction, record_log_event +from newrelic.core.config import global_settings +from newrelic.api.application import application_instance +from newrelic.hooks.logger_logging import add_nr_linking_metadata + + +def normalize_level_name(method_name): + # Look up level number for method name, using result to look up level name for that level number. + # Convert result to upper case, and default to UNKNOWN in case of errors or missing values. + try: + from structlog._log_levels import _LEVEL_TO_NAME, _NAME_TO_LEVEL + return _LEVEL_TO_NAME[_NAME_TO_LEVEL[method_name]].upper() + except Exception: + return "UNKNOWN" + + +def bind_process_event(method_name, event, event_kw): + return method_name, event, event_kw + + +def wrap__process_event(wrapped, instance, args, kwargs): + try: + method_name, event, event_kw = bind_process_event(*args, **kwargs) + except TypeError: + return wrapped(*args, **kwargs) + + original_message = event # Save original undecorated message + + transaction = current_transaction() + + if transaction: + settings = transaction.settings + else: + settings = global_settings() + + # Return early if application logging not enabled + if settings and settings.application_logging and settings.application_logging.enabled: + if settings.application_logging.local_decorating and settings.application_logging.local_decorating.enabled: + event = add_nr_linking_metadata(event) + + # Send log to processors for filtering, allowing any DropEvent exceptions that occur to prevent instrumentation from recording the log event. + result = wrapped(method_name, event, event_kw) + + level_name = normalize_level_name(method_name) + + if settings.application_logging.metrics and settings.application_logging.metrics.enabled: + if transaction: + transaction.record_custom_metric("Logging/lines", {"count": 1}) + transaction.record_custom_metric("Logging/lines/%s" % level_name, {"count": 1}) + else: + application = application_instance(activate=False) + if application and application.enabled: + application.record_custom_metric("Logging/lines", {"count": 1}) + application.record_custom_metric("Logging/lines/%s" % level_name, {"count": 1}) + + if settings.application_logging.forwarding and settings.application_logging.forwarding.enabled: + try: + record_log_event(original_message, level_name) + + except Exception: + pass + + # Return the result from wrapped after we've recorded the resulting log event. + return result + + return wrapped(*args, **kwargs) + + +def instrument_structlog__base(module): + if hasattr(module, "BoundLoggerBase") and hasattr(module.BoundLoggerBase, "_process_event"): + wrap_function_wrapper(module, "BoundLoggerBase._process_event", wrap__process_event) diff --git a/newrelic/hooks/messagebroker_pika.py b/newrelic/hooks/messagebroker_pika.py index cecc1b934..d6120c10d 100644 --- a/newrelic/hooks/messagebroker_pika.py +++ b/newrelic/hooks/messagebroker_pika.py @@ -36,7 +36,6 @@ def _add_consume_rabbitmq_trace(transaction, method, properties, nr_start_time, queue_name=None): - routing_key = None if hasattr(method, "routing_key"): routing_key = method.routing_key @@ -197,7 +196,7 @@ def _wrap_basic_get_Channel(wrapper, queue, callback, *args, **kwargs): return queue, args, kwargs -def _wrap_basic_get_Channel_old(wrapper, callback=None, queue="", *args, **kwargs): +def _wrap_basic_get_Channel_old(wrapper, callback=None, queue="", *args, **kwargs): # pragma: no cover if callback is not None: callback = wrapper(callback) args = (callback, queue) + args @@ -368,7 +367,6 @@ def callback_wrapper(wrapped, instance, args, kwargs): correlation_id=correlation_id, source=wrapped, ) as mt: - # Improve transaction naming _new_txn_name = "RabbitMQ/Exchange/%s/%s" % (exchange, name) mt.set_transaction_name(_new_txn_name, group="Message") @@ -404,7 +402,7 @@ def instrument_pika_adapters(module): version = tuple(int(num) for num in pika.__version__.split(".", 1)[0]) - if version[0] < 1: + if version[0] < 1: # pragma: no cover wrap_consume = _wrap_basic_consume_BlockingChannel_old else: wrap_consume = _wrap_basic_consume_Channel @@ -426,7 +424,7 @@ def instrument_pika_channel(module): version = tuple(int(num) for num in pika.__version__.split(".", 1)[0]) - if version[0] < 1: + if version[0] < 1: # pragma: no cover wrap_consume = _wrap_basic_consume_Channel_old wrap_get = _wrap_basic_get_Channel_old else: diff --git a/newrelic/hooks/mlmodel_sklearn.py b/newrelic/hooks/mlmodel_sklearn.py new file mode 100644 index 000000000..bdfeccfc8 --- /dev/null +++ b/newrelic/hooks/mlmodel_sklearn.py @@ -0,0 +1,781 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import uuid + +from newrelic.api.function_trace import FunctionTrace +from newrelic.api.time_trace import current_trace +from newrelic.api.transaction import current_transaction +from newrelic.common.object_wrapper import ObjectProxy, wrap_function_wrapper +from newrelic.core.config import global_settings + +METHODS_TO_WRAP = ("predict", "fit", "fit_predict", "predict_log_proba", "predict_proba", "transform", "score") +METRIC_SCORERS = ( + "accuracy_score", + "balanced_accuracy_score", + "f1_score", + "precision_score", + "recall_score", + "roc_auc_score", + "r2_score", +) +PY2 = sys.version_info[0] == 2 +_logger = logging.getLogger(__name__) + + +def isnumeric(column): + import numpy as np + + try: + column.astype(np.float64) + return [True] * len(column) + except: + pass + return [False] * len(column) + + +class PredictReturnTypeProxy(ObjectProxy): + def __init__(self, wrapped, model_name, training_step): + super(ObjectProxy, self).__init__(wrapped) + self._nr_model_name = model_name + self._nr_training_step = training_step + + +def _wrap_method_trace(module, class_, method, name=None, group=None): + def _nr_wrapper_method(wrapped, instance, args, kwargs): + transaction = current_transaction() + trace = current_trace() + + if transaction is None: + return wrapped(*args, **kwargs) + + settings = transaction.settings if transaction.settings is not None else global_settings() + + if settings and not settings.machine_learning.enabled: + return wrapped(*args, **kwargs) + + wrapped_attr_name = "_nr_wrapped_%s" % method + + # If the method has already been wrapped do not wrap it again. This happens + # when one class inherits from another and they both implement the method. + if getattr(trace, wrapped_attr_name, False): + return wrapped(*args, **kwargs) + + trace = FunctionTrace(name=name, group=group, source=wrapped) + + try: + # Set the _nr_wrapped attribute to denote that this method is being wrapped. + setattr(trace, wrapped_attr_name, True) + + with trace: + return_val = wrapped(*args, **kwargs) + finally: + # Set the _nr_wrapped attribute to denote that this method is no longer wrapped. + setattr(trace, wrapped_attr_name, False) + + # If this is the fit method, increment the training_step counter. + if method in ("fit", "fit_predict"): + training_step = getattr(instance, "_nr_wrapped_training_step", -1) + setattr(instance, "_nr_wrapped_training_step", training_step + 1) + + # If this is the predict method, wrap the return type in an nr type with + # _nr_wrapped attrs that will attach model info to the data. + if method in ("predict", "fit_predict"): + training_step = getattr(instance, "_nr_wrapped_training_step", "Unknown") + create_prediction_event(transaction, class_, instance, args, kwargs, return_val) + return PredictReturnTypeProxy(return_val, model_name=class_, training_step=training_step) + return return_val + + wrap_function_wrapper(module, "%s.%s" % (class_, method), _nr_wrapper_method) + + +def _calc_prediction_feature_stats(prediction_input, class_, feature_column_names, tags): + import numpy as np + + # Drop any feature columns that are not numeric since we can't compute stats + # on non-numeric columns. + x = np.array(prediction_input) + isnumeric_features = np.apply_along_axis(isnumeric, 0, x) + numeric_features = x[isnumeric_features] + + # Drop any feature column names that are not numeric since we can't compute stats + # on non-numeric columns. + feature_column_names = feature_column_names[isnumeric_features[0]] + + # Only compute stats for features if we have any feature columns left after dropping + # non-numeric columns. + num_cols = len(feature_column_names) + if num_cols > 0: + # Boolean selection of numpy array values reshapes the array to a single + # dimension so we have to reshape it back into a 2D array. + features = np.reshape(numeric_features, (len(numeric_features) // num_cols, num_cols)) + features = features.astype(dtype=np.float64) + + _record_stats(features, feature_column_names, class_, "Feature", tags) + + +def _record_stats(data, column_names, class_, column_type, tags): + import numpy as np + + mean = np.mean(data, axis=0) + percentile25 = np.percentile(data, q=0.25, axis=0) + percentile50 = np.percentile(data, q=0.50, axis=0) + percentile75 = np.percentile(data, q=0.75, axis=0) + standard_deviation = np.std(data, axis=0) + _min = np.min(data, axis=0) + _max = np.max(data, axis=0) + _count = data.shape[0] + + transaction = current_transaction() + + # Currently record_metric only supports a subset of these stats so we have + # to upload them one at a time instead of as a dictionary of stats per + # feature column. + for index, col_name in enumerate(column_names): + metric_name = "MLModel/Sklearn/Named/%s/Predict/%s/%s" % (class_, column_type, col_name) + + transaction.record_dimensional_metrics( + [ + ("%s/%s" % (metric_name, "Mean"), float(mean[index]), tags), + ("%s/%s" % (metric_name, "Percentile25"), float(percentile25[index]), tags), + ("%s/%s" % (metric_name, "Percentile50"), float(percentile50[index]), tags), + ("%s/%s" % (metric_name, "Percentile75"), float(percentile75[index]), tags), + ("%s/%s" % (metric_name, "StandardDeviation"), float(standard_deviation[index]), tags), + ("%s/%s" % (metric_name, "Min"), float(_min[index]), tags), + ("%s/%s" % (metric_name, "Max"), float(_max[index]), tags), + ("%s/%s" % (metric_name, "Count"), _count, tags), + ] + ) + + +def _calc_prediction_label_stats(labels, class_, label_column_names, tags): + import numpy as np + + labels = np.array(labels, dtype=np.float64) + _record_stats(labels, label_column_names, class_, "Label", tags) + + +def _get_label_names(user_defined_label_names, prediction_array): + import numpy as np + + if user_defined_label_names is None: + return np.array(range(prediction_array.shape[1])) + if user_defined_label_names and len(user_defined_label_names) != prediction_array.shape[1]: + _logger.warning( + "The number of label names passed to the ml_model wrapper function is not equal to the number of predictions in the data set. Please supply the correct number of label names." + ) + return np.array(range(prediction_array.shape[1])) + else: + return user_defined_label_names + + +def find_type_category(data_set, row_index, column_index): + # If pandas DataFrame, return type of column. + pd = sys.modules.get("pandas", None) + if pd and isinstance(data_set, pd.DataFrame): + value_type = data_set.iloc[:, column_index].dtype.name + if value_type == "category": + return "categorical" + categorized_value_type = categorize_data_type(value_type) + return categorized_value_type + # If it's not a pandas DataFrame then it is a list or numpy array. + python_type = str(type(data_set[column_index][row_index])) + return categorize_data_type(python_type) + + +def categorize_data_type(python_type): + if "int" in python_type or "float" in python_type or "complex" in python_type: + return "numeric" + if "bool" in python_type: + return "bool" + if "str" in python_type or "unicode" in python_type: + return "str" + else: + return python_type + + +def _get_feature_column_names(user_provided_feature_names, features): + import numpy as np + + num_feature_columns = np.array(features).shape[1] + + # If the user provided feature names are the correct size, return the user provided feature + # names. + if user_provided_feature_names and len(user_provided_feature_names) == num_feature_columns: + return np.array(user_provided_feature_names) + + # If the user provided feature names aren't the correct size, log a warning and do not use the user provided feature names. + if user_provided_feature_names: + _logger.warning( + "The number of feature names passed to the ml_model wrapper function is not equal to the number of columns in the data set. Please supply the correct number of feature names." + ) + + # If the user doesn't provide the feature names or they were provided but the size was incorrect and the features are a pandas data frame, return the column names from the pandas data frame. + pd = sys.modules.get("pandas", None) + if pd and isinstance(features, pd.DataFrame): + return features.columns + + # If the user doesn't provide the feature names or they were provided but the size was incorrect and the features are not a pandas data frame, return the column indexes as the feature names. + return np.array(range(num_feature_columns)) + + +def bind_predict(X, *args, **kwargs): + return X + + +def create_prediction_event(transaction, class_, instance, args, kwargs, return_val): + import numpy as np + + data_set = bind_predict(*args, **kwargs) + model_name = getattr(instance, "_nr_wrapped_name", class_) + model_version = getattr(instance, "_nr_wrapped_version", "0.0.0") + user_provided_feature_names = getattr(instance, "_nr_wrapped_feature_names", None) + label_names = getattr(instance, "_nr_wrapped_label_names", None) + metadata = getattr(instance, "_nr_wrapped_metadata", {}) + settings = transaction.settings if transaction.settings is not None else global_settings() + + prediction_id = uuid.uuid4() + + labels = [] + if return_val is not None: + if not hasattr(return_val, "__iter__"): + labels = np.array([return_val]) + else: + labels = np.array(return_val) + if len(labels.shape) == 1: + labels = np.reshape(labels, (len(labels) // 1, 1)) + + label_names_list = _get_label_names(label_names, labels) + _calc_prediction_label_stats( + labels, + class_, + label_names_list, + tags={ + "prediction_id": prediction_id, + "model_version": model_version, + # The following are used for entity synthesis. + "modelName": model_name, + }, + ) + + final_feature_names = _get_feature_column_names(user_provided_feature_names, data_set) + np_casted_data_set = np.array(data_set) + _calc_prediction_feature_stats( + data_set, + class_, + final_feature_names, + tags={ + "prediction_id": prediction_id, + "model_version": model_version, + # The following are used for entity synthesis. + "modelName": model_name, + }, + ) + features, predictions = np_casted_data_set.shape + for prediction_index, prediction in enumerate(np_casted_data_set): + inference_id = uuid.uuid4() + + event = { + "inference_id": inference_id, + "prediction_id": prediction_id, + "model_version": model_version, + "new_relic_data_schema_version": 2, + # The following are used for entity synthesis. + "modelName": model_name, + } + if metadata and isinstance(metadata, dict): + event.update(metadata) + # Don't include the raw value when inference_event_value is disabled. + if settings and settings.machine_learning and settings.machine_learning.inference_events_value.enabled: + event.update( + { + "feature.%s" % str(final_feature_names[feature_col_index]): value + for feature_col_index, value in enumerate(prediction) + } + ) + event.update( + { + "label.%s" % str(label_names_list[index]): str(value) + for index, value in enumerate(labels[prediction_index]) + } + ) + transaction.record_ml_event("InferenceData", event) + + +def _nr_instrument_model(module, model_class): + for method_name in METHODS_TO_WRAP: + if hasattr(getattr(module, model_class), method_name): + # Function/MLModel/Sklearn/Named/. + name = "MLModel/Sklearn/Named/%s.%s" % (model_class, method_name) + _wrap_method_trace(module, model_class, method_name, name=name) + + +def _instrument_sklearn_models(module, model_classes): + for model_cls in model_classes: + if hasattr(module, model_cls): + _nr_instrument_model(module, model_cls) + + +def _bind_scorer(y_true, y_pred, *args, **kwargs): + return y_true, y_pred, args, kwargs + + +def wrap_metric_scorer(wrapped, instance, args, kwargs): + transaction = current_transaction() + # If there is no transaction, do not wrap anything. + if not transaction: + return wrapped(*args, **kwargs) + + settings = transaction.settings if transaction.settings is not None else global_settings() + + if settings and not settings.machine_learning.enabled: + return wrapped(*args, **kwargs) + + score = wrapped(*args, **kwargs) + + y_true, y_pred, args, kwargs = _bind_scorer(*args, **kwargs) + model_name = "Unknown" + training_step = "Unknown" + if hasattr(y_pred, "_nr_model_name"): + model_name = y_pred._nr_model_name + if hasattr(y_pred, "_nr_training_step"): + training_step = y_pred._nr_training_step + # Attribute values must be int, float, str, or boolean. If it's not one of these + # types and an iterable add the values as separate attributes. + if not isinstance(score, (str, int, float, bool)): + if hasattr(score, "__iter__"): + for i, s in enumerate(score): + transaction._add_agent_attribute( + "%s/TrainingStep/%s/%s[%s]" % (model_name, training_step, wrapped.__name__, i), s + ) + else: + transaction._add_agent_attribute("%s/TrainingStep/%s/%s" % (model_name, training_step, wrapped.__name__), score) + return score + + +def instrument_sklearn_tree_models(module): + model_classes = ( + "DecisionTreeClassifier", + "DecisionTreeRegressor", + "ExtraTreeClassifier", + "ExtraTreeRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_ensemble_bagging_models(module): + model_classes = ( + "BaggingClassifier", + "BaggingRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_ensemble_forest_models(module): + model_classes = ( + "ExtraTreesClassifier", + "ExtraTreesRegressor", + "RandomForestClassifier", + "RandomForestRegressor", + "RandomTreesEmbedding", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_ensemble_iforest_models(module): + model_classes = ("IsolationForest",) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_ensemble_weight_boosting_models(module): + model_classes = ( + "AdaBoostClassifier", + "AdaBoostRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_ensemble_gradient_boosting_models(module): + model_classes = ( + "GradientBoostingClassifier", + "GradientBoostingRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_ensemble_voting_models(module): + model_classes = ( + "VotingClassifier", + "VotingRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_ensemble_stacking_models(module): + module_classes = ( + "StackingClassifier", + "StackingRegressor", + ) + _instrument_sklearn_models(module, module_classes) + + +def instrument_sklearn_ensemble_hist_models(module): + model_classes = ( + "HistGradientBoostingClassifier", + "HistGradientBoostingRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_coordinate_descent_models(module): + model_classes = ( + "Lasso", + "LassoCV", + "ElasticNet", + "ElasticNetCV", + "MultiTaskLasso", + "MultiTaskLassoCV", + "MultiTaskElasticNet", + "MultiTaskElasticNetCV", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_compose_models(module): + model_classes = ( + "ColumnTransformer", + "TransformedTargetRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_covariance_shrunk_models(module): + model_classes = ( + "ShrunkCovariance", + "LedoitWolf", + "OAS", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_cross_decomposition_models(module): + model_classes = ( + "PLSRegression", + "PLSSVD", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_covariance_graph_models(module): + model_classes = ( + "GraphicalLasso", + "GraphicalLassoCV", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_discriminant_analysis_models(module): + model_classes = ( + "LinearDiscriminantAnalysis", + "QuadraticDiscriminantAnalysis", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_covariance_models(module): + model_classes = ( + "EmpiricalCovariance", + "MinCovDet", + "EllipticEnvelope", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_gaussian_process_models(module): + model_classes = ( + "GaussianProcessClassifier", + "GaussianProcessRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_dummy_models(module): + model_classes = ( + "DummyClassifier", + "DummyRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_feature_selection_rfe_models(module): + model_classes = ( + "RFE", + "RFECV", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_kernel_ridge_models(module): + model_classes = ("KernelRidge",) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_calibration_models(module): + model_classes = ("CalibratedClassifierCV",) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_cluster_models(module): + model_classes = ( + "AffinityPropagation", + "Birch", + "DBSCAN", + "MeanShift", + "OPTICS", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_least_angle_models(module): + model_classes = ( + "Lars", + "LarsCV", + "LassoLars", + "LassoLarsCV", + "LassoLarsIC", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_feature_selection_models(module): + model_classes = ( + "VarianceThreshold", + "SelectFromModel", + "SequentialFeatureSelector", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_cluster_agglomerative_models(module): + model_classes = ( + "AgglomerativeClustering", + "FeatureAgglomeration", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_GLM_models(module): + model_classes = ( + "PoissonRegressor", + "GammaRegressor", + "TweedieRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_cluster_clustering_models(module): + model_classes = ( + "SpectralBiclustering", + "SpectralCoclustering", + "SpectralClustering", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_stochastic_gradient_models(module): + model_classes = ( + "SGDClassifier", + "SGDRegressor", + "SGDOneClassSVM", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_ridge_models(module): + model_classes = ( + "Ridge", + "RidgeCV", + "RidgeClassifier", + "RidgeClassifierCV", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_logistic_models(module): + model_classes = ( + "LogisticRegression", + "LogisticRegressionCV", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_OMP_models(module): + model_classes = ( + "OrthogonalMatchingPursuit", + "OrthogonalMatchingPursuitCV", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_passive_aggressive_models(module): + model_classes = ( + "PassiveAggressiveClassifier", + "PassiveAggressiveRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_bayes_models(module): + model_classes = ( + "ARDRegression", + "BayesianRidge", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_linear_models(module): + model_classes = ( + "HuberRegressor", + "LinearRegression", + "Perceptron", + "QuantileRegressor", + "TheilSenRegressor", + "RANSACRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_cluster_kmeans_models(module): + model_classes = ( + "BisectingKMeans", + "KMeans", + "MiniBatchKMeans", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_multiclass_models(module): + model_classes = ( + "OneVsRestClassifier", + "OneVsOneClassifier", + "OutputCodeClassifier", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_multioutput_models(module): + model_classes = ( + "MultiOutputEstimator", + "MultiOutputClassifier", + "ClassifierChain", + "RegressorChain", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_naive_bayes_models(module): + model_classes = ( + "GaussianNB", + "MultinomialNB", + "ComplementNB", + "BernoulliNB", + "CategoricalNB", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_model_selection_models(module): + model_classes = ( + "GridSearchCV", + "RandomizedSearchCV", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_mixture_models(module): + model_classes = ( + "GaussianMixture", + "BayesianGaussianMixture", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_neural_network_models(module): + model_classes = ( + "BernoulliRBM", + "MLPClassifier", + "MLPRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_neighbors_KRadius_models(module): + model_classes = ( + "KNeighborsClassifier", + "RadiusNeighborsClassifier", + "KNeighborsTransformer", + "RadiusNeighborsTransformer", + "KNeighborsRegressor", + "RadiusNeighborsRegressor", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_svm_models(module): + model_classes = ( + "LinearSVC", + "LinearSVR", + "SVC", + "NuSVC", + "SVR", + "NuSVR", + "OneClassSVM", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_semi_supervised_models(module): + model_classes = ( + "LabelPropagation", + "LabelSpreading", + "SelfTrainingClassifier", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_pipeline_models(module): + model_classes = ( + "Pipeline", + "FeatureUnion", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_neighbors_models(module): + model_classes = ( + "KernelDensity", + "LocalOutlierFactor", + "NeighborhoodComponentsAnalysis", + "NearestCentroid", + "NearestNeighbors", + ) + _instrument_sklearn_models(module, model_classes) + + +def instrument_sklearn_metrics(module): + for scorer in METRIC_SCORERS: + if hasattr(module, scorer): + wrap_function_wrapper(module, scorer, wrap_metric_scorer) diff --git a/newrelic/packages/isort/LICENSE b/newrelic/packages/isort/LICENSE new file mode 100644 index 000000000..b5083a50d --- /dev/null +++ b/newrelic/packages/isort/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013 Timothy Edmund Crosley + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/newrelic/packages/isort/__init__.py b/newrelic/packages/isort/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/newrelic/packages/isort/stdlibs/__init__.py b/newrelic/packages/isort/stdlibs/__init__.py new file mode 100644 index 000000000..3394a7eda --- /dev/null +++ b/newrelic/packages/isort/stdlibs/__init__.py @@ -0,0 +1,2 @@ +from . import all as _all +from . import py2, py3, py27, py36, py37, py38, py39, py310, py311 diff --git a/newrelic/packages/isort/stdlibs/all.py b/newrelic/packages/isort/stdlibs/all.py new file mode 100644 index 000000000..08a365e19 --- /dev/null +++ b/newrelic/packages/isort/stdlibs/all.py @@ -0,0 +1,3 @@ +from . import py2, py3 + +stdlib = py2.stdlib | py3.stdlib diff --git a/newrelic/packages/isort/stdlibs/py2.py b/newrelic/packages/isort/stdlibs/py2.py new file mode 100644 index 000000000..74af019e4 --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py2.py @@ -0,0 +1,3 @@ +from . import py27 + +stdlib = py27.stdlib diff --git a/newrelic/packages/isort/stdlibs/py27.py b/newrelic/packages/isort/stdlibs/py27.py new file mode 100644 index 000000000..a9bc99d0c --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py27.py @@ -0,0 +1,301 @@ +""" +File contains the standard library of Python 2.7. + +DO NOT EDIT. If the standard library changes, a new list should be created +using the mkstdlibs.py script. +""" + +stdlib = { + "AL", + "BaseHTTPServer", + "Bastion", + "CGIHTTPServer", + "Carbon", + "ColorPicker", + "ConfigParser", + "Cookie", + "DEVICE", + "DocXMLRPCServer", + "EasyDialogs", + "FL", + "FrameWork", + "GL", + "HTMLParser", + "MacOS", + "MimeWriter", + "MiniAEFrame", + "Nav", + "PixMapWrapper", + "Queue", + "SUNAUDIODEV", + "ScrolledText", + "SimpleHTTPServer", + "SimpleXMLRPCServer", + "SocketServer", + "StringIO", + "Tix", + "Tkinter", + "UserDict", + "UserList", + "UserString", + "W", + "__builtin__", + "_ast", + "_winreg", + "abc", + "aepack", + "aetools", + "aetypes", + "aifc", + "al", + "anydbm", + "applesingle", + "argparse", + "array", + "ast", + "asynchat", + "asyncore", + "atexit", + "audioop", + "autoGIL", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "bsddb", + "buildtools", + "bz2", + "cPickle", + "cProfile", + "cStringIO", + "calendar", + "cd", + "cfmfile", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "commands", + "compileall", + "compiler", + "contextlib", + "cookielib", + "copy", + "copy_reg", + "crypt", + "csv", + "ctypes", + "curses", + "datetime", + "dbhash", + "dbm", + "decimal", + "difflib", + "dircache", + "dis", + "distutils", + "dl", + "doctest", + "dumbdbm", + "dummy_thread", + "dummy_threading", + "email", + "encodings", + "ensurepip", + "errno", + "exceptions", + "fcntl", + "filecmp", + "fileinput", + "findertools", + "fl", + "flp", + "fm", + "fnmatch", + "formatter", + "fpectl", + "fpformat", + "fractions", + "ftplib", + "functools", + "future_builtins", + "gc", + "gdbm", + "gensuitemodule", + "getopt", + "getpass", + "gettext", + "gl", + "glob", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "hotshot", + "htmlentitydefs", + "htmllib", + "httplib", + "ic", + "icopen", + "imageop", + "imaplib", + "imgfile", + "imghdr", + "imp", + "importlib", + "imputil", + "inspect", + "io", + "itertools", + "jpeg", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "macerrors", + "macostools", + "macpath", + "macresource", + "mailbox", + "mailcap", + "marshal", + "math", + "md5", + "mhlib", + "mimetools", + "mimetypes", + "mimify", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multifile", + "multiprocessing", + "mutex", + "netrc", + "new", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "popen2", + "poplib", + "posix", + "posixfile", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "quopri", + "random", + "re", + "readline", + "resource", + "rexec", + "rfc822", + "rlcompleter", + "robotparser", + "runpy", + "sched", + "select", + "sets", + "sgmllib", + "sha", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statvfs", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "sunaudiodev", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "thread", + "threading", + "time", + "timeit", + "token", + "tokenize", + "trace", + "traceback", + "ttk", + "tty", + "turtle", + "types", + "unicodedata", + "unittest", + "urllib", + "urllib2", + "urlparse", + "user", + "uu", + "uuid", + "videoreader", + "warnings", + "wave", + "weakref", + "webbrowser", + "whichdb", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpclib", + "zipfile", + "zipimport", + "zlib", +} diff --git a/newrelic/packages/isort/stdlibs/py3.py b/newrelic/packages/isort/stdlibs/py3.py new file mode 100644 index 000000000..988254385 --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py3.py @@ -0,0 +1,3 @@ +from . import py36, py37, py38, py39, py310, py311 + +stdlib = py36.stdlib | py37.stdlib | py38.stdlib | py39.stdlib | py310.stdlib | py311.stdlib diff --git a/newrelic/packages/isort/stdlibs/py310.py b/newrelic/packages/isort/stdlibs/py310.py new file mode 100644 index 000000000..f45cf50a3 --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py310.py @@ -0,0 +1,222 @@ +""" +File contains the standard library of Python 3.10. + +DO NOT EDIT. If the standard library changes, a new list should be created +using the mkstdlibs.py script. +""" + +stdlib = { + "_ast", + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "graphlib", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "idlelib", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", + "zoneinfo", +} diff --git a/newrelic/packages/isort/stdlibs/py311.py b/newrelic/packages/isort/stdlibs/py311.py new file mode 100644 index 000000000..6fa42e995 --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py311.py @@ -0,0 +1,222 @@ +""" +File contains the standard library of Python 3.11. + +DO NOT EDIT. If the standard library changes, a new list should be created +using the mkstdlibs.py script. +""" + +stdlib = { + "_ast", + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "graphlib", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "idlelib", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "tomllib", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", + "zoneinfo", +} diff --git a/newrelic/packages/isort/stdlibs/py36.py b/newrelic/packages/isort/stdlibs/py36.py new file mode 100644 index 000000000..59ebd24cb --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py36.py @@ -0,0 +1,224 @@ +""" +File contains the standard library of Python 3.6. + +DO NOT EDIT. If the standard library changes, a new list should be created +using the mkstdlibs.py script. +""" + +stdlib = { + "_ast", + "_dummy_thread", + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "dummy_threading", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "formatter", + "fpectl", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "macpath", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", +} diff --git a/newrelic/packages/isort/stdlibs/py37.py b/newrelic/packages/isort/stdlibs/py37.py new file mode 100644 index 000000000..e0ad1228a --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py37.py @@ -0,0 +1,225 @@ +""" +File contains the standard library of Python 3.7. + +DO NOT EDIT. If the standard library changes, a new list should be created +using the mkstdlibs.py script. +""" + +stdlib = { + "_ast", + "_dummy_thread", + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "dummy_threading", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "formatter", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "macpath", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", +} diff --git a/newrelic/packages/isort/stdlibs/py38.py b/newrelic/packages/isort/stdlibs/py38.py new file mode 100644 index 000000000..3d89fd26b --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py38.py @@ -0,0 +1,224 @@ +""" +File contains the standard library of Python 3.8. + +DO NOT EDIT. If the standard library changes, a new list should be created +using the mkstdlibs.py script. +""" + +stdlib = { + "_ast", + "_dummy_thread", + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "dummy_threading", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "formatter", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", +} diff --git a/newrelic/packages/isort/stdlibs/py39.py b/newrelic/packages/isort/stdlibs/py39.py new file mode 100644 index 000000000..4b7dd5954 --- /dev/null +++ b/newrelic/packages/isort/stdlibs/py39.py @@ -0,0 +1,224 @@ +""" +File contains the standard library of Python 3.9. + +DO NOT EDIT. If the standard library changes, a new list should be created +using the mkstdlibs.py script. +""" + +stdlib = { + "_ast", + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "formatter", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "graphlib", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", + "zoneinfo", +} diff --git a/newrelic/packages/opentelemetry_proto/LICENSE.txt b/newrelic/packages/opentelemetry_proto/LICENSE.txt new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/newrelic/packages/opentelemetry_proto/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/newrelic/packages/opentelemetry_proto/__init__.py b/newrelic/packages/opentelemetry_proto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/newrelic/packages/opentelemetry_proto/common_pb2.py b/newrelic/packages/opentelemetry_proto/common_pb2.py new file mode 100644 index 000000000..a38431a58 --- /dev/null +++ b/newrelic/packages/opentelemetry_proto/common_pb2.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: opentelemetry/proto/common/v1/common.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*opentelemetry/proto/common/v1/common.proto\x12\x1dopentelemetry.proto.common.v1\"\x8c\x02\n\x08\x41nyValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x14\n\nbool_value\x18\x02 \x01(\x08H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x03H\x00\x12\x16\n\x0c\x64ouble_value\x18\x04 \x01(\x01H\x00\x12@\n\x0b\x61rray_value\x18\x05 \x01(\x0b\x32).opentelemetry.proto.common.v1.ArrayValueH\x00\x12\x43\n\x0ckvlist_value\x18\x06 \x01(\x0b\x32+.opentelemetry.proto.common.v1.KeyValueListH\x00\x12\x15\n\x0b\x62ytes_value\x18\x07 \x01(\x0cH\x00\x42\x07\n\x05value\"E\n\nArrayValue\x12\x37\n\x06values\x18\x01 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.AnyValue\"G\n\x0cKeyValueList\x12\x37\n\x06values\x18\x01 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\"O\n\x08KeyValue\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x36\n\x05value\x18\x02 \x01(\x0b\x32\'.opentelemetry.proto.common.v1.AnyValue\";\n\x16InstrumentationLibrary\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t:\x02\x18\x01\"5\n\x14InstrumentationScope\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\tB[\n io.opentelemetry.proto.common.v1B\x0b\x43ommonProtoP\x01Z(go.opentelemetry.io/proto/otlp/common/v1b\x06proto3') + + + +_ANYVALUE = DESCRIPTOR.message_types_by_name['AnyValue'] +_ARRAYVALUE = DESCRIPTOR.message_types_by_name['ArrayValue'] +_KEYVALUELIST = DESCRIPTOR.message_types_by_name['KeyValueList'] +_KEYVALUE = DESCRIPTOR.message_types_by_name['KeyValue'] +_INSTRUMENTATIONLIBRARY = DESCRIPTOR.message_types_by_name['InstrumentationLibrary'] +_INSTRUMENTATIONSCOPE = DESCRIPTOR.message_types_by_name['InstrumentationScope'] +AnyValue = _reflection.GeneratedProtocolMessageType('AnyValue', (_message.Message,), { + 'DESCRIPTOR' : _ANYVALUE, + '__module__' : 'opentelemetry.proto.common.v1.common_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.common.v1.AnyValue) + }) +_sym_db.RegisterMessage(AnyValue) + +ArrayValue = _reflection.GeneratedProtocolMessageType('ArrayValue', (_message.Message,), { + 'DESCRIPTOR' : _ARRAYVALUE, + '__module__' : 'opentelemetry.proto.common.v1.common_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.common.v1.ArrayValue) + }) +_sym_db.RegisterMessage(ArrayValue) + +KeyValueList = _reflection.GeneratedProtocolMessageType('KeyValueList', (_message.Message,), { + 'DESCRIPTOR' : _KEYVALUELIST, + '__module__' : 'opentelemetry.proto.common.v1.common_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.common.v1.KeyValueList) + }) +_sym_db.RegisterMessage(KeyValueList) + +KeyValue = _reflection.GeneratedProtocolMessageType('KeyValue', (_message.Message,), { + 'DESCRIPTOR' : _KEYVALUE, + '__module__' : 'opentelemetry.proto.common.v1.common_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.common.v1.KeyValue) + }) +_sym_db.RegisterMessage(KeyValue) + +InstrumentationLibrary = _reflection.GeneratedProtocolMessageType('InstrumentationLibrary', (_message.Message,), { + 'DESCRIPTOR' : _INSTRUMENTATIONLIBRARY, + '__module__' : 'opentelemetry.proto.common.v1.common_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.common.v1.InstrumentationLibrary) + }) +_sym_db.RegisterMessage(InstrumentationLibrary) + +InstrumentationScope = _reflection.GeneratedProtocolMessageType('InstrumentationScope', (_message.Message,), { + 'DESCRIPTOR' : _INSTRUMENTATIONSCOPE, + '__module__' : 'opentelemetry.proto.common.v1.common_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.common.v1.InstrumentationScope) + }) +_sym_db.RegisterMessage(InstrumentationScope) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n io.opentelemetry.proto.common.v1B\013CommonProtoP\001Z(go.opentelemetry.io/proto/otlp/common/v1' + _INSTRUMENTATIONLIBRARY._options = None + _INSTRUMENTATIONLIBRARY._serialized_options = b'\030\001' + _ANYVALUE._serialized_start=78 + _ANYVALUE._serialized_end=346 + _ARRAYVALUE._serialized_start=348 + _ARRAYVALUE._serialized_end=417 + _KEYVALUELIST._serialized_start=419 + _KEYVALUELIST._serialized_end=490 + _KEYVALUE._serialized_start=492 + _KEYVALUE._serialized_end=571 + _INSTRUMENTATIONLIBRARY._serialized_start=573 + _INSTRUMENTATIONLIBRARY._serialized_end=632 + _INSTRUMENTATIONSCOPE._serialized_start=634 + _INSTRUMENTATIONSCOPE._serialized_end=687 +# @@protoc_insertion_point(module_scope) diff --git a/newrelic/packages/opentelemetry_proto/logs_pb2.py b/newrelic/packages/opentelemetry_proto/logs_pb2.py new file mode 100644 index 000000000..bb6a55d66 --- /dev/null +++ b/newrelic/packages/opentelemetry_proto/logs_pb2.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: opentelemetry/proto/logs/v1/logs.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import common_pb2 as opentelemetry_dot_proto_dot_common_dot_v1_dot_common__pb2 +from . import resource_pb2 as opentelemetry_dot_proto_dot_resource_dot_v1_dot_resource__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&opentelemetry/proto/logs/v1/logs.proto\x12\x1bopentelemetry.proto.logs.v1\x1a*opentelemetry/proto/common/v1/common.proto\x1a.opentelemetry/proto/resource/v1/resource.proto\"L\n\x08LogsData\x12@\n\rresource_logs\x18\x01 \x03(\x0b\x32).opentelemetry.proto.logs.v1.ResourceLogs\"\xff\x01\n\x0cResourceLogs\x12;\n\x08resource\x18\x01 \x01(\x0b\x32).opentelemetry.proto.resource.v1.Resource\x12:\n\nscope_logs\x18\x02 \x03(\x0b\x32&.opentelemetry.proto.logs.v1.ScopeLogs\x12\x62\n\x1cinstrumentation_library_logs\x18\xe8\x07 \x03(\x0b\x32\x37.opentelemetry.proto.logs.v1.InstrumentationLibraryLogsB\x02\x18\x01\x12\x12\n\nschema_url\x18\x03 \x01(\t\"\xa0\x01\n\tScopeLogs\x12\x42\n\x05scope\x18\x01 \x01(\x0b\x32\x33.opentelemetry.proto.common.v1.InstrumentationScope\x12;\n\x0blog_records\x18\x02 \x03(\x0b\x32&.opentelemetry.proto.logs.v1.LogRecord\x12\x12\n\nschema_url\x18\x03 \x01(\t\"\xc9\x01\n\x1aInstrumentationLibraryLogs\x12V\n\x17instrumentation_library\x18\x01 \x01(\x0b\x32\x35.opentelemetry.proto.common.v1.InstrumentationLibrary\x12;\n\x0blog_records\x18\x02 \x03(\x0b\x32&.opentelemetry.proto.logs.v1.LogRecord\x12\x12\n\nschema_url\x18\x03 \x01(\t:\x02\x18\x01\"\xef\x02\n\tLogRecord\x12\x16\n\x0etime_unix_nano\x18\x01 \x01(\x06\x12\x1f\n\x17observed_time_unix_nano\x18\x0b \x01(\x06\x12\x44\n\x0fseverity_number\x18\x02 \x01(\x0e\x32+.opentelemetry.proto.logs.v1.SeverityNumber\x12\x15\n\rseverity_text\x18\x03 \x01(\t\x12\x35\n\x04\x62ody\x18\x05 \x01(\x0b\x32\'.opentelemetry.proto.common.v1.AnyValue\x12;\n\nattributes\x18\x06 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\x12 \n\x18\x64ropped_attributes_count\x18\x07 \x01(\r\x12\r\n\x05\x66lags\x18\x08 \x01(\x07\x12\x10\n\x08trace_id\x18\t \x01(\x0c\x12\x0f\n\x07span_id\x18\n \x01(\x0cJ\x04\x08\x04\x10\x05*\xc3\x05\n\x0eSeverityNumber\x12\x1f\n\x1bSEVERITY_NUMBER_UNSPECIFIED\x10\x00\x12\x19\n\x15SEVERITY_NUMBER_TRACE\x10\x01\x12\x1a\n\x16SEVERITY_NUMBER_TRACE2\x10\x02\x12\x1a\n\x16SEVERITY_NUMBER_TRACE3\x10\x03\x12\x1a\n\x16SEVERITY_NUMBER_TRACE4\x10\x04\x12\x19\n\x15SEVERITY_NUMBER_DEBUG\x10\x05\x12\x1a\n\x16SEVERITY_NUMBER_DEBUG2\x10\x06\x12\x1a\n\x16SEVERITY_NUMBER_DEBUG3\x10\x07\x12\x1a\n\x16SEVERITY_NUMBER_DEBUG4\x10\x08\x12\x18\n\x14SEVERITY_NUMBER_INFO\x10\t\x12\x19\n\x15SEVERITY_NUMBER_INFO2\x10\n\x12\x19\n\x15SEVERITY_NUMBER_INFO3\x10\x0b\x12\x19\n\x15SEVERITY_NUMBER_INFO4\x10\x0c\x12\x18\n\x14SEVERITY_NUMBER_WARN\x10\r\x12\x19\n\x15SEVERITY_NUMBER_WARN2\x10\x0e\x12\x19\n\x15SEVERITY_NUMBER_WARN3\x10\x0f\x12\x19\n\x15SEVERITY_NUMBER_WARN4\x10\x10\x12\x19\n\x15SEVERITY_NUMBER_ERROR\x10\x11\x12\x1a\n\x16SEVERITY_NUMBER_ERROR2\x10\x12\x12\x1a\n\x16SEVERITY_NUMBER_ERROR3\x10\x13\x12\x1a\n\x16SEVERITY_NUMBER_ERROR4\x10\x14\x12\x19\n\x15SEVERITY_NUMBER_FATAL\x10\x15\x12\x1a\n\x16SEVERITY_NUMBER_FATAL2\x10\x16\x12\x1a\n\x16SEVERITY_NUMBER_FATAL3\x10\x17\x12\x1a\n\x16SEVERITY_NUMBER_FATAL4\x10\x18*X\n\x0eLogRecordFlags\x12\x1f\n\x1bLOG_RECORD_FLAG_UNSPECIFIED\x10\x00\x12%\n LOG_RECORD_FLAG_TRACE_FLAGS_MASK\x10\xff\x01\x42U\n\x1eio.opentelemetry.proto.logs.v1B\tLogsProtoP\x01Z&go.opentelemetry.io/proto/otlp/logs/v1b\x06proto3') + +_SEVERITYNUMBER = DESCRIPTOR.enum_types_by_name['SeverityNumber'] +SeverityNumber = enum_type_wrapper.EnumTypeWrapper(_SEVERITYNUMBER) +_LOGRECORDFLAGS = DESCRIPTOR.enum_types_by_name['LogRecordFlags'] +LogRecordFlags = enum_type_wrapper.EnumTypeWrapper(_LOGRECORDFLAGS) +SEVERITY_NUMBER_UNSPECIFIED = 0 +SEVERITY_NUMBER_TRACE = 1 +SEVERITY_NUMBER_TRACE2 = 2 +SEVERITY_NUMBER_TRACE3 = 3 +SEVERITY_NUMBER_TRACE4 = 4 +SEVERITY_NUMBER_DEBUG = 5 +SEVERITY_NUMBER_DEBUG2 = 6 +SEVERITY_NUMBER_DEBUG3 = 7 +SEVERITY_NUMBER_DEBUG4 = 8 +SEVERITY_NUMBER_INFO = 9 +SEVERITY_NUMBER_INFO2 = 10 +SEVERITY_NUMBER_INFO3 = 11 +SEVERITY_NUMBER_INFO4 = 12 +SEVERITY_NUMBER_WARN = 13 +SEVERITY_NUMBER_WARN2 = 14 +SEVERITY_NUMBER_WARN3 = 15 +SEVERITY_NUMBER_WARN4 = 16 +SEVERITY_NUMBER_ERROR = 17 +SEVERITY_NUMBER_ERROR2 = 18 +SEVERITY_NUMBER_ERROR3 = 19 +SEVERITY_NUMBER_ERROR4 = 20 +SEVERITY_NUMBER_FATAL = 21 +SEVERITY_NUMBER_FATAL2 = 22 +SEVERITY_NUMBER_FATAL3 = 23 +SEVERITY_NUMBER_FATAL4 = 24 +LOG_RECORD_FLAG_UNSPECIFIED = 0 +LOG_RECORD_FLAG_TRACE_FLAGS_MASK = 255 + + +_LOGSDATA = DESCRIPTOR.message_types_by_name['LogsData'] +_RESOURCELOGS = DESCRIPTOR.message_types_by_name['ResourceLogs'] +_SCOPELOGS = DESCRIPTOR.message_types_by_name['ScopeLogs'] +_INSTRUMENTATIONLIBRARYLOGS = DESCRIPTOR.message_types_by_name['InstrumentationLibraryLogs'] +_LOGRECORD = DESCRIPTOR.message_types_by_name['LogRecord'] +LogsData = _reflection.GeneratedProtocolMessageType('LogsData', (_message.Message,), { + 'DESCRIPTOR' : _LOGSDATA, + '__module__' : 'opentelemetry.proto.logs.v1.logs_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.logs.v1.LogsData) + }) +_sym_db.RegisterMessage(LogsData) + +ResourceLogs = _reflection.GeneratedProtocolMessageType('ResourceLogs', (_message.Message,), { + 'DESCRIPTOR' : _RESOURCELOGS, + '__module__' : 'opentelemetry.proto.logs.v1.logs_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.logs.v1.ResourceLogs) + }) +_sym_db.RegisterMessage(ResourceLogs) + +ScopeLogs = _reflection.GeneratedProtocolMessageType('ScopeLogs', (_message.Message,), { + 'DESCRIPTOR' : _SCOPELOGS, + '__module__' : 'opentelemetry.proto.logs.v1.logs_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.logs.v1.ScopeLogs) + }) +_sym_db.RegisterMessage(ScopeLogs) + +InstrumentationLibraryLogs = _reflection.GeneratedProtocolMessageType('InstrumentationLibraryLogs', (_message.Message,), { + 'DESCRIPTOR' : _INSTRUMENTATIONLIBRARYLOGS, + '__module__' : 'opentelemetry.proto.logs.v1.logs_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.logs.v1.InstrumentationLibraryLogs) + }) +_sym_db.RegisterMessage(InstrumentationLibraryLogs) + +LogRecord = _reflection.GeneratedProtocolMessageType('LogRecord', (_message.Message,), { + 'DESCRIPTOR' : _LOGRECORD, + '__module__' : 'opentelemetry.proto.logs.v1.logs_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.logs.v1.LogRecord) + }) +_sym_db.RegisterMessage(LogRecord) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\036io.opentelemetry.proto.logs.v1B\tLogsProtoP\001Z&go.opentelemetry.io/proto/otlp/logs/v1' + _RESOURCELOGS.fields_by_name['instrumentation_library_logs']._options = None + _RESOURCELOGS.fields_by_name['instrumentation_library_logs']._serialized_options = b'\030\001' + _INSTRUMENTATIONLIBRARYLOGS._options = None + _INSTRUMENTATIONLIBRARYLOGS._serialized_options = b'\030\001' + _SEVERITYNUMBER._serialized_start=1237 + _SEVERITYNUMBER._serialized_end=1944 + _LOGRECORDFLAGS._serialized_start=1946 + _LOGRECORDFLAGS._serialized_end=2034 + _LOGSDATA._serialized_start=163 + _LOGSDATA._serialized_end=239 + _RESOURCELOGS._serialized_start=242 + _RESOURCELOGS._serialized_end=497 + _SCOPELOGS._serialized_start=500 + _SCOPELOGS._serialized_end=660 + _INSTRUMENTATIONLIBRARYLOGS._serialized_start=663 + _INSTRUMENTATIONLIBRARYLOGS._serialized_end=864 + _LOGRECORD._serialized_start=867 + _LOGRECORD._serialized_end=1234 +# @@protoc_insertion_point(module_scope) diff --git a/newrelic/packages/opentelemetry_proto/metrics_pb2.py b/newrelic/packages/opentelemetry_proto/metrics_pb2.py new file mode 100644 index 000000000..dea77c7de --- /dev/null +++ b/newrelic/packages/opentelemetry_proto/metrics_pb2.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: opentelemetry/proto/metrics/v1/metrics.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import common_pb2 as opentelemetry_dot_proto_dot_common_dot_v1_dot_common__pb2 +from . import resource_pb2 as opentelemetry_dot_proto_dot_resource_dot_v1_dot_resource__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n,opentelemetry/proto/metrics/v1/metrics.proto\x12\x1eopentelemetry.proto.metrics.v1\x1a*opentelemetry/proto/common/v1/common.proto\x1a.opentelemetry/proto/resource/v1/resource.proto\"X\n\x0bMetricsData\x12I\n\x10resource_metrics\x18\x01 \x03(\x0b\x32/.opentelemetry.proto.metrics.v1.ResourceMetrics\"\x94\x02\n\x0fResourceMetrics\x12;\n\x08resource\x18\x01 \x01(\x0b\x32).opentelemetry.proto.resource.v1.Resource\x12\x43\n\rscope_metrics\x18\x02 \x03(\x0b\x32,.opentelemetry.proto.metrics.v1.ScopeMetrics\x12k\n\x1finstrumentation_library_metrics\x18\xe8\x07 \x03(\x0b\x32=.opentelemetry.proto.metrics.v1.InstrumentationLibraryMetricsB\x02\x18\x01\x12\x12\n\nschema_url\x18\x03 \x01(\t\"\x9f\x01\n\x0cScopeMetrics\x12\x42\n\x05scope\x18\x01 \x01(\x0b\x32\x33.opentelemetry.proto.common.v1.InstrumentationScope\x12\x37\n\x07metrics\x18\x02 \x03(\x0b\x32&.opentelemetry.proto.metrics.v1.Metric\x12\x12\n\nschema_url\x18\x03 \x01(\t\"\xc8\x01\n\x1dInstrumentationLibraryMetrics\x12V\n\x17instrumentation_library\x18\x01 \x01(\x0b\x32\x35.opentelemetry.proto.common.v1.InstrumentationLibrary\x12\x37\n\x07metrics\x18\x02 \x03(\x0b\x32&.opentelemetry.proto.metrics.v1.Metric\x12\x12\n\nschema_url\x18\x03 \x01(\t:\x02\x18\x01\"\x92\x03\n\x06Metric\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0c\n\x04unit\x18\x03 \x01(\t\x12\x36\n\x05gauge\x18\x05 \x01(\x0b\x32%.opentelemetry.proto.metrics.v1.GaugeH\x00\x12\x32\n\x03sum\x18\x07 \x01(\x0b\x32#.opentelemetry.proto.metrics.v1.SumH\x00\x12>\n\thistogram\x18\t \x01(\x0b\x32).opentelemetry.proto.metrics.v1.HistogramH\x00\x12U\n\x15\x65xponential_histogram\x18\n \x01(\x0b\x32\x34.opentelemetry.proto.metrics.v1.ExponentialHistogramH\x00\x12:\n\x07summary\x18\x0b \x01(\x0b\x32\'.opentelemetry.proto.metrics.v1.SummaryH\x00\x42\x06\n\x04\x64\x61taJ\x04\x08\x04\x10\x05J\x04\x08\x06\x10\x07J\x04\x08\x08\x10\t\"M\n\x05Gauge\x12\x44\n\x0b\x64\x61ta_points\x18\x01 \x03(\x0b\x32/.opentelemetry.proto.metrics.v1.NumberDataPoint\"\xba\x01\n\x03Sum\x12\x44\n\x0b\x64\x61ta_points\x18\x01 \x03(\x0b\x32/.opentelemetry.proto.metrics.v1.NumberDataPoint\x12W\n\x17\x61ggregation_temporality\x18\x02 \x01(\x0e\x32\x36.opentelemetry.proto.metrics.v1.AggregationTemporality\x12\x14\n\x0cis_monotonic\x18\x03 \x01(\x08\"\xad\x01\n\tHistogram\x12G\n\x0b\x64\x61ta_points\x18\x01 \x03(\x0b\x32\x32.opentelemetry.proto.metrics.v1.HistogramDataPoint\x12W\n\x17\x61ggregation_temporality\x18\x02 \x01(\x0e\x32\x36.opentelemetry.proto.metrics.v1.AggregationTemporality\"\xc3\x01\n\x14\x45xponentialHistogram\x12R\n\x0b\x64\x61ta_points\x18\x01 \x03(\x0b\x32=.opentelemetry.proto.metrics.v1.ExponentialHistogramDataPoint\x12W\n\x17\x61ggregation_temporality\x18\x02 \x01(\x0e\x32\x36.opentelemetry.proto.metrics.v1.AggregationTemporality\"P\n\x07Summary\x12\x45\n\x0b\x64\x61ta_points\x18\x01 \x03(\x0b\x32\x30.opentelemetry.proto.metrics.v1.SummaryDataPoint\"\x86\x02\n\x0fNumberDataPoint\x12;\n\nattributes\x18\x07 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\x12\x1c\n\x14start_time_unix_nano\x18\x02 \x01(\x06\x12\x16\n\x0etime_unix_nano\x18\x03 \x01(\x06\x12\x13\n\tas_double\x18\x04 \x01(\x01H\x00\x12\x10\n\x06\x61s_int\x18\x06 \x01(\x10H\x00\x12;\n\texemplars\x18\x05 \x03(\x0b\x32(.opentelemetry.proto.metrics.v1.Exemplar\x12\r\n\x05\x66lags\x18\x08 \x01(\rB\x07\n\x05valueJ\x04\x08\x01\x10\x02\"\xe6\x02\n\x12HistogramDataPoint\x12;\n\nattributes\x18\t \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\x12\x1c\n\x14start_time_unix_nano\x18\x02 \x01(\x06\x12\x16\n\x0etime_unix_nano\x18\x03 \x01(\x06\x12\r\n\x05\x63ount\x18\x04 \x01(\x06\x12\x10\n\x03sum\x18\x05 \x01(\x01H\x00\x88\x01\x01\x12\x15\n\rbucket_counts\x18\x06 \x03(\x06\x12\x17\n\x0f\x65xplicit_bounds\x18\x07 \x03(\x01\x12;\n\texemplars\x18\x08 \x03(\x0b\x32(.opentelemetry.proto.metrics.v1.Exemplar\x12\r\n\x05\x66lags\x18\n \x01(\r\x12\x10\n\x03min\x18\x0b \x01(\x01H\x01\x88\x01\x01\x12\x10\n\x03max\x18\x0c \x01(\x01H\x02\x88\x01\x01\x42\x06\n\x04_sumB\x06\n\x04_minB\x06\n\x04_maxJ\x04\x08\x01\x10\x02\"\xb5\x04\n\x1d\x45xponentialHistogramDataPoint\x12;\n\nattributes\x18\x01 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\x12\x1c\n\x14start_time_unix_nano\x18\x02 \x01(\x06\x12\x16\n\x0etime_unix_nano\x18\x03 \x01(\x06\x12\r\n\x05\x63ount\x18\x04 \x01(\x06\x12\x0b\n\x03sum\x18\x05 \x01(\x01\x12\r\n\x05scale\x18\x06 \x01(\x11\x12\x12\n\nzero_count\x18\x07 \x01(\x06\x12W\n\x08positive\x18\x08 \x01(\x0b\x32\x45.opentelemetry.proto.metrics.v1.ExponentialHistogramDataPoint.Buckets\x12W\n\x08negative\x18\t \x01(\x0b\x32\x45.opentelemetry.proto.metrics.v1.ExponentialHistogramDataPoint.Buckets\x12\r\n\x05\x66lags\x18\n \x01(\r\x12;\n\texemplars\x18\x0b \x03(\x0b\x32(.opentelemetry.proto.metrics.v1.Exemplar\x12\x10\n\x03min\x18\x0c \x01(\x01H\x00\x88\x01\x01\x12\x10\n\x03max\x18\r \x01(\x01H\x01\x88\x01\x01\x1a\x30\n\x07\x42uckets\x12\x0e\n\x06offset\x18\x01 \x01(\x11\x12\x15\n\rbucket_counts\x18\x02 \x03(\x04\x42\x06\n\x04_minB\x06\n\x04_max\"\xc5\x02\n\x10SummaryDataPoint\x12;\n\nattributes\x18\x07 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\x12\x1c\n\x14start_time_unix_nano\x18\x02 \x01(\x06\x12\x16\n\x0etime_unix_nano\x18\x03 \x01(\x06\x12\r\n\x05\x63ount\x18\x04 \x01(\x06\x12\x0b\n\x03sum\x18\x05 \x01(\x01\x12Y\n\x0fquantile_values\x18\x06 \x03(\x0b\x32@.opentelemetry.proto.metrics.v1.SummaryDataPoint.ValueAtQuantile\x12\r\n\x05\x66lags\x18\x08 \x01(\r\x1a\x32\n\x0fValueAtQuantile\x12\x10\n\x08quantile\x18\x01 \x01(\x01\x12\r\n\x05value\x18\x02 \x01(\x01J\x04\x08\x01\x10\x02\"\xc1\x01\n\x08\x45xemplar\x12\x44\n\x13\x66iltered_attributes\x18\x07 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\x12\x16\n\x0etime_unix_nano\x18\x02 \x01(\x06\x12\x13\n\tas_double\x18\x03 \x01(\x01H\x00\x12\x10\n\x06\x61s_int\x18\x06 \x01(\x10H\x00\x12\x0f\n\x07span_id\x18\x04 \x01(\x0c\x12\x10\n\x08trace_id\x18\x05 \x01(\x0c\x42\x07\n\x05valueJ\x04\x08\x01\x10\x02*\x8c\x01\n\x16\x41ggregationTemporality\x12\'\n#AGGREGATION_TEMPORALITY_UNSPECIFIED\x10\x00\x12!\n\x1d\x41GGREGATION_TEMPORALITY_DELTA\x10\x01\x12&\n\"AGGREGATION_TEMPORALITY_CUMULATIVE\x10\x02*;\n\x0e\x44\x61taPointFlags\x12\r\n\tFLAG_NONE\x10\x00\x12\x1a\n\x16\x46LAG_NO_RECORDED_VALUE\x10\x01\x42^\n!io.opentelemetry.proto.metrics.v1B\x0cMetricsProtoP\x01Z)go.opentelemetry.io/proto/otlp/metrics/v1b\x06proto3') + +_AGGREGATIONTEMPORALITY = DESCRIPTOR.enum_types_by_name['AggregationTemporality'] +AggregationTemporality = enum_type_wrapper.EnumTypeWrapper(_AGGREGATIONTEMPORALITY) +_DATAPOINTFLAGS = DESCRIPTOR.enum_types_by_name['DataPointFlags'] +DataPointFlags = enum_type_wrapper.EnumTypeWrapper(_DATAPOINTFLAGS) +AGGREGATION_TEMPORALITY_UNSPECIFIED = 0 +AGGREGATION_TEMPORALITY_DELTA = 1 +AGGREGATION_TEMPORALITY_CUMULATIVE = 2 +FLAG_NONE = 0 +FLAG_NO_RECORDED_VALUE = 1 + + +_METRICSDATA = DESCRIPTOR.message_types_by_name['MetricsData'] +_RESOURCEMETRICS = DESCRIPTOR.message_types_by_name['ResourceMetrics'] +_SCOPEMETRICS = DESCRIPTOR.message_types_by_name['ScopeMetrics'] +_INSTRUMENTATIONLIBRARYMETRICS = DESCRIPTOR.message_types_by_name['InstrumentationLibraryMetrics'] +_METRIC = DESCRIPTOR.message_types_by_name['Metric'] +_GAUGE = DESCRIPTOR.message_types_by_name['Gauge'] +_SUM = DESCRIPTOR.message_types_by_name['Sum'] +_HISTOGRAM = DESCRIPTOR.message_types_by_name['Histogram'] +_EXPONENTIALHISTOGRAM = DESCRIPTOR.message_types_by_name['ExponentialHistogram'] +_SUMMARY = DESCRIPTOR.message_types_by_name['Summary'] +_NUMBERDATAPOINT = DESCRIPTOR.message_types_by_name['NumberDataPoint'] +_HISTOGRAMDATAPOINT = DESCRIPTOR.message_types_by_name['HistogramDataPoint'] +_EXPONENTIALHISTOGRAMDATAPOINT = DESCRIPTOR.message_types_by_name['ExponentialHistogramDataPoint'] +_EXPONENTIALHISTOGRAMDATAPOINT_BUCKETS = _EXPONENTIALHISTOGRAMDATAPOINT.nested_types_by_name['Buckets'] +_SUMMARYDATAPOINT = DESCRIPTOR.message_types_by_name['SummaryDataPoint'] +_SUMMARYDATAPOINT_VALUEATQUANTILE = _SUMMARYDATAPOINT.nested_types_by_name['ValueAtQuantile'] +_EXEMPLAR = DESCRIPTOR.message_types_by_name['Exemplar'] +MetricsData = _reflection.GeneratedProtocolMessageType('MetricsData', (_message.Message,), { + 'DESCRIPTOR' : _METRICSDATA, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.MetricsData) + }) +_sym_db.RegisterMessage(MetricsData) + +ResourceMetrics = _reflection.GeneratedProtocolMessageType('ResourceMetrics', (_message.Message,), { + 'DESCRIPTOR' : _RESOURCEMETRICS, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.ResourceMetrics) + }) +_sym_db.RegisterMessage(ResourceMetrics) + +ScopeMetrics = _reflection.GeneratedProtocolMessageType('ScopeMetrics', (_message.Message,), { + 'DESCRIPTOR' : _SCOPEMETRICS, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.ScopeMetrics) + }) +_sym_db.RegisterMessage(ScopeMetrics) + +InstrumentationLibraryMetrics = _reflection.GeneratedProtocolMessageType('InstrumentationLibraryMetrics', (_message.Message,), { + 'DESCRIPTOR' : _INSTRUMENTATIONLIBRARYMETRICS, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.InstrumentationLibraryMetrics) + }) +_sym_db.RegisterMessage(InstrumentationLibraryMetrics) + +Metric = _reflection.GeneratedProtocolMessageType('Metric', (_message.Message,), { + 'DESCRIPTOR' : _METRIC, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.Metric) + }) +_sym_db.RegisterMessage(Metric) + +Gauge = _reflection.GeneratedProtocolMessageType('Gauge', (_message.Message,), { + 'DESCRIPTOR' : _GAUGE, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.Gauge) + }) +_sym_db.RegisterMessage(Gauge) + +Sum = _reflection.GeneratedProtocolMessageType('Sum', (_message.Message,), { + 'DESCRIPTOR' : _SUM, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.Sum) + }) +_sym_db.RegisterMessage(Sum) + +Histogram = _reflection.GeneratedProtocolMessageType('Histogram', (_message.Message,), { + 'DESCRIPTOR' : _HISTOGRAM, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.Histogram) + }) +_sym_db.RegisterMessage(Histogram) + +ExponentialHistogram = _reflection.GeneratedProtocolMessageType('ExponentialHistogram', (_message.Message,), { + 'DESCRIPTOR' : _EXPONENTIALHISTOGRAM, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.ExponentialHistogram) + }) +_sym_db.RegisterMessage(ExponentialHistogram) + +Summary = _reflection.GeneratedProtocolMessageType('Summary', (_message.Message,), { + 'DESCRIPTOR' : _SUMMARY, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.Summary) + }) +_sym_db.RegisterMessage(Summary) + +NumberDataPoint = _reflection.GeneratedProtocolMessageType('NumberDataPoint', (_message.Message,), { + 'DESCRIPTOR' : _NUMBERDATAPOINT, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.NumberDataPoint) + }) +_sym_db.RegisterMessage(NumberDataPoint) + +HistogramDataPoint = _reflection.GeneratedProtocolMessageType('HistogramDataPoint', (_message.Message,), { + 'DESCRIPTOR' : _HISTOGRAMDATAPOINT, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.HistogramDataPoint) + }) +_sym_db.RegisterMessage(HistogramDataPoint) + +ExponentialHistogramDataPoint = _reflection.GeneratedProtocolMessageType('ExponentialHistogramDataPoint', (_message.Message,), { + + 'Buckets' : _reflection.GeneratedProtocolMessageType('Buckets', (_message.Message,), { + 'DESCRIPTOR' : _EXPONENTIALHISTOGRAMDATAPOINT_BUCKETS, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.ExponentialHistogramDataPoint.Buckets) + }) + , + 'DESCRIPTOR' : _EXPONENTIALHISTOGRAMDATAPOINT, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.ExponentialHistogramDataPoint) + }) +_sym_db.RegisterMessage(ExponentialHistogramDataPoint) +_sym_db.RegisterMessage(ExponentialHistogramDataPoint.Buckets) + +SummaryDataPoint = _reflection.GeneratedProtocolMessageType('SummaryDataPoint', (_message.Message,), { + + 'ValueAtQuantile' : _reflection.GeneratedProtocolMessageType('ValueAtQuantile', (_message.Message,), { + 'DESCRIPTOR' : _SUMMARYDATAPOINT_VALUEATQUANTILE, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.SummaryDataPoint.ValueAtQuantile) + }) + , + 'DESCRIPTOR' : _SUMMARYDATAPOINT, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.SummaryDataPoint) + }) +_sym_db.RegisterMessage(SummaryDataPoint) +_sym_db.RegisterMessage(SummaryDataPoint.ValueAtQuantile) + +Exemplar = _reflection.GeneratedProtocolMessageType('Exemplar', (_message.Message,), { + 'DESCRIPTOR' : _EXEMPLAR, + '__module__' : 'opentelemetry.proto.metrics.v1.metrics_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.metrics.v1.Exemplar) + }) +_sym_db.RegisterMessage(Exemplar) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n!io.opentelemetry.proto.metrics.v1B\014MetricsProtoP\001Z)go.opentelemetry.io/proto/otlp/metrics/v1' + _RESOURCEMETRICS.fields_by_name['instrumentation_library_metrics']._options = None + _RESOURCEMETRICS.fields_by_name['instrumentation_library_metrics']._serialized_options = b'\030\001' + _INSTRUMENTATIONLIBRARYMETRICS._options = None + _INSTRUMENTATIONLIBRARYMETRICS._serialized_options = b'\030\001' + _AGGREGATIONTEMPORALITY._serialized_start=3754 + _AGGREGATIONTEMPORALITY._serialized_end=3894 + _DATAPOINTFLAGS._serialized_start=3896 + _DATAPOINTFLAGS._serialized_end=3955 + _METRICSDATA._serialized_start=172 + _METRICSDATA._serialized_end=260 + _RESOURCEMETRICS._serialized_start=263 + _RESOURCEMETRICS._serialized_end=539 + _SCOPEMETRICS._serialized_start=542 + _SCOPEMETRICS._serialized_end=701 + _INSTRUMENTATIONLIBRARYMETRICS._serialized_start=704 + _INSTRUMENTATIONLIBRARYMETRICS._serialized_end=904 + _METRIC._serialized_start=907 + _METRIC._serialized_end=1309 + _GAUGE._serialized_start=1311 + _GAUGE._serialized_end=1388 + _SUM._serialized_start=1391 + _SUM._serialized_end=1577 + _HISTOGRAM._serialized_start=1580 + _HISTOGRAM._serialized_end=1753 + _EXPONENTIALHISTOGRAM._serialized_start=1756 + _EXPONENTIALHISTOGRAM._serialized_end=1951 + _SUMMARY._serialized_start=1953 + _SUMMARY._serialized_end=2033 + _NUMBERDATAPOINT._serialized_start=2036 + _NUMBERDATAPOINT._serialized_end=2298 + _HISTOGRAMDATAPOINT._serialized_start=2301 + _HISTOGRAMDATAPOINT._serialized_end=2659 + _EXPONENTIALHISTOGRAMDATAPOINT._serialized_start=2662 + _EXPONENTIALHISTOGRAMDATAPOINT._serialized_end=3227 + _EXPONENTIALHISTOGRAMDATAPOINT_BUCKETS._serialized_start=3163 + _EXPONENTIALHISTOGRAMDATAPOINT_BUCKETS._serialized_end=3211 + _SUMMARYDATAPOINT._serialized_start=3230 + _SUMMARYDATAPOINT._serialized_end=3555 + _SUMMARYDATAPOINT_VALUEATQUANTILE._serialized_start=3499 + _SUMMARYDATAPOINT_VALUEATQUANTILE._serialized_end=3549 + _EXEMPLAR._serialized_start=3558 + _EXEMPLAR._serialized_end=3751 +# @@protoc_insertion_point(module_scope) \ No newline at end of file diff --git a/newrelic/packages/opentelemetry_proto/resource_pb2.py b/newrelic/packages/opentelemetry_proto/resource_pb2.py new file mode 100644 index 000000000..8cc64e352 --- /dev/null +++ b/newrelic/packages/opentelemetry_proto/resource_pb2.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: opentelemetry/proto/resource/v1/resource.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import common_pb2 as opentelemetry_dot_proto_dot_common_dot_v1_dot_common__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n.opentelemetry/proto/resource/v1/resource.proto\x12\x1fopentelemetry.proto.resource.v1\x1a*opentelemetry/proto/common/v1/common.proto\"i\n\x08Resource\x12;\n\nattributes\x18\x01 \x03(\x0b\x32\'.opentelemetry.proto.common.v1.KeyValue\x12 \n\x18\x64ropped_attributes_count\x18\x02 \x01(\rBa\n\"io.opentelemetry.proto.resource.v1B\rResourceProtoP\x01Z*go.opentelemetry.io/proto/otlp/resource/v1b\x06proto3') + + + +_RESOURCE = DESCRIPTOR.message_types_by_name['Resource'] +Resource = _reflection.GeneratedProtocolMessageType('Resource', (_message.Message,), { + 'DESCRIPTOR' : _RESOURCE, + '__module__' : 'opentelemetry.proto.resource.v1.resource_pb2' + # @@protoc_insertion_point(class_scope:opentelemetry.proto.resource.v1.Resource) + }) +_sym_db.RegisterMessage(Resource) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\"io.opentelemetry.proto.resource.v1B\rResourceProtoP\001Z*go.opentelemetry.io/proto/otlp/resource/v1' + _RESOURCE._serialized_start=127 + _RESOURCE._serialized_end=232 +# @@protoc_insertion_point(module_scope) diff --git a/setup.cfg b/setup.cfg index 453a10eeb..006265c36 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,4 +5,4 @@ license_files = [flake8] max-line-length = 120 -extend-ignore = E122,E126,E127,E128,E203,E501,E722,F841,W504 +extend-ignore = E122,E126,E127,E128,E203,E501,E722,F841,W504,E731 diff --git a/setup.py b/setup.py index 587afeaf9..b351ae06d 100644 --- a/setup.py +++ b/setup.py @@ -102,6 +102,8 @@ def build_extension(self, ext): "newrelic.hooks", "newrelic.network", "newrelic/packages", + "newrelic/packages/isort", + "newrelic/packages/isort/stdlibs", "newrelic/packages/urllib3", "newrelic/packages/urllib3/util", "newrelic/packages/urllib3/contrib", @@ -109,6 +111,7 @@ def build_extension(self, ext): "newrelic/packages/urllib3/packages", "newrelic/packages/urllib3/packages/backports", "newrelic/packages/wrapt", + "newrelic/packages/opentelemetry_proto", "newrelic.samplers", ] diff --git a/tests/agent_features/_test_async_coroutine_trace.py b/tests/agent_features/_test_async_coroutine_trace.py index 51b81f5f6..1250b8c25 100644 --- a/tests/agent_features/_test_async_coroutine_trace.py +++ b/tests/agent_features/_test_async_coroutine_trace.py @@ -28,6 +28,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -41,6 +42,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_awaitable_timing(event_loop, trace, metric): @@ -79,6 +82,8 @@ def _test(): (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) @pytest.mark.parametrize("yield_from", [True, False]) diff --git a/tests/agent_features/_test_async_generator_trace.py b/tests/agent_features/_test_async_generator_trace.py new file mode 100644 index 000000000..30b970c37 --- /dev/null +++ b/tests/agent_features/_test_async_generator_trace.py @@ -0,0 +1,548 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import sys +import time + +import pytest +from testing_support.fixtures import capture_transaction_metrics, validate_tt_parenting +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +asyncio = pytest.importorskip("asyncio") + + +@pytest.mark.parametrize( + "trace,metric", + [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), + ], +) +def test_async_generator_timing(event_loop, trace, metric): + @trace() + async def simple_gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_async_generator_timing", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_async_generator_timing") + def _test_async_generator_timing(): + async def _test(): + async for _ in simple_gen(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_timing() + + # Check that coroutines time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +class MyException(Exception): + pass + + +@validate_transaction_metrics( + "test_async_generator_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=["_test_async_generator_trace:MyException"]) +def test_async_generator_error(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_error") + async def _test(): + gen = agen() + await gen.asend(None) + await gen.athrow(MyException) + + with pytest.raises(MyException): + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_caught_exception", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=[]) +def test_async_generator_caught_exception(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + time.sleep(0.1) + try: + yield + except ValueError: + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_caught_exception") + def _test_async_generator_caught_exception(): + async def _test(): + gen = agen() + # kickstart the generator (the try/except logic is inside the + # generator) + await gen.asend(None) + await gen.athrow(ValueError) + + # consume the generator + async for _ in gen: + pass + + # The ValueError should not be reraised + event_loop.run_until_complete(_test()) + _test_async_generator_caught_exception() + + assert full_metrics[("Function/agen", "")].total_call_time >= 0.2 + + +@validate_transaction_metrics( + "test_async_generator_handles_terminal_nodes", + background_task=True, + scoped_metrics=[("Function/parent", 1), ("Function/agen", None)], + rollup_metrics=[("Function/parent", 1), ("Function/agen", None)], +) +def test_async_generator_handles_terminal_nodes(event_loop): + # sometimes coroutines can be called underneath terminal nodes + # In this case, the trace shouldn't actually be created and we also + # shouldn't get any errors + + @function_trace(name="agen") + async def agen(): + yield + time.sleep(0.1) + + @function_trace(name="parent", terminal=True) + async def parent(): + # parent calls child + async for _ in agen(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_handles_terminal_nodes") + def _test_async_generator_handles_terminal_nodes(): + async def _test(): + await parent() + + event_loop.run_until_complete(_test()) + _test_async_generator_handles_terminal_nodes() + + metric_key = ("Function/parent", "") + assert full_metrics[metric_key].total_exclusive_call_time >= 0.1 + + +@validate_transaction_metrics( + "test_async_generator_close_ends_trace", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_close_ends_trace(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_close_ends_trace") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # trace should be ended/recorded by close + await gen.aclose() + + # We may call gen.close as many times as we want + await gen.aclose() + + event_loop.run_until_complete(_test()) + +@validate_tt_parenting( + ( + "TransactionNode", + [ + ( + "FunctionNode", + [ + ("FunctionNode", []), + ], + ), + ], + ) +) +@validate_transaction_metrics( + "test_async_generator_parents", + background_task=True, + scoped_metrics=[("Function/child", 1), ("Function/parent", 1)], + rollup_metrics=[("Function/child", 1), ("Function/parent", 1)], +) +def test_async_generator_parents(event_loop): + @function_trace(name="child") + async def child(): + yield + time.sleep(0.1) + yield + + @function_trace(name="parent") + async def parent(): + time.sleep(0.1) + yield + async for _ in child(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_parents") + def _test_async_generator_parents(): + async def _test(): + async for _ in parent(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_parents() + + # Check that the child time is subtracted from the parent time (parenting + # relationship is correctly established) + key = ("Function/parent", "") + assert full_metrics[key].total_exclusive_call_time < 0.2 + + +@validate_transaction_metrics( + "test_asend_receives_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_asend_receives_a_value(event_loop): + _received = [] + @function_trace(name="agen") + async def agen(): + value = yield + _received.append(value) + yield value + + @background_task(name="test_asend_receives_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.asend("foobar") == "foobar" + assert _received and _received[0] == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_yields_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_yields_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + yield "foobar" + + @background_task(name="test_athrow_yields_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.athrow(MyException) == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_multiple_throws_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + + @background_task(name="test_multiple_throws_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + assert await gen.asend(None) is None + assert await gen.athrow(MyException) == "foo" + assert await gen.athrow(MyException) == "foo" + assert await gen.asend(None) == "bar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_does_not_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_does_not_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + return + + @background_task(name="test_athrow_does_not_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # async generator will raise StopAsyncIteration + with pytest.raises(StopAsyncIteration): + await gen.athrow(MyException) + + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize( + "trace", + [ + function_trace(name="simple_gen"), + external_trace(library="lib", url="http://foo.com"), + database_trace("select * from foo"), + datastore_trace("lib", "foo", "bar"), + message_trace("lib", "op", "typ", "name"), + memcache_trace("cmd"), + ], +) +def test_async_generator_functions_outside_of_transaction(event_loop, trace): + @trace + async def agen(): + for _ in range(2): + yield "foo" + + async def _test(): + assert [_ async for _ in agen()] == ["foo", "foo"] + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_catching_generator_exit_causes_runtime_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_catching_generator_exit_causes_runtime_error(event_loop): + @function_trace(name="agen") + async def agen(): + try: + yield + except GeneratorExit: + yield + + @background_task(name="test_catching_generator_exit_causes_runtime_error") + async def _test(): + gen = agen() + + # kickstart the coroutine (we're inside the try now) + await gen.asend(None) + + # Generators cannot catch generator exit exceptions (which are injected by + # close). This will result in a runtime error. + with pytest.raises(RuntimeError): + await gen.aclose() + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_time_excludes_creation_time", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_time_excludes_creation_time(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_time_excludes_creation_time") + def _test_async_generator_time_excludes_creation_time(): + async def _test(): + gen = agen() + time.sleep(0.1) + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_time_excludes_creation_time() + + # check that the trace does not include the time between creation and + # consumption + assert full_metrics[("Function/agen", "")].total_call_time < 0.1 + + +@validate_transaction_metrics( + "test_complete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@background_task(name="test_complete_async_generator") +def test_complete_async_generator(event_loop): + @function_trace(name="agen") + async def agen(): + for i in range(5): + yield i + + async def _test(): + gen = agen() + assert [x async for x in gen] == [x for x in range(5)] + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize("nr_transaction", [True, False]) +def test_incomplete_async_generator(event_loop, nr_transaction): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + def _test_incomplete_async_generator(): + async def _test(): + c = agen() + + async for _ in c: + break + + if nr_transaction: + _test = background_task(name="test_incomplete_async_generator")(_test) + + event_loop.run_until_complete(_test()) + + if nr_transaction: + _test_incomplete_async_generator = validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + )(_test_incomplete_async_generator) + + _test_incomplete_async_generator() + + +def test_incomplete_async_generator_transaction_exited(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + @validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + ) + def _test_incomplete_async_generator(): + c = agen() + @background_task(name="test_incomplete_async_generator") + async def _test(): + async for _ in c: + break + + event_loop.run_until_complete(_test()) + + # Remove generator after transaction completes + del c + + _test_incomplete_async_generator() diff --git a/tests/agent_features/conftest.py b/tests/agent_features/conftest.py index 57263238b..bd6aa6c2a 100644 --- a/tests/agent_features/conftest.py +++ b/tests/agent_features/conftest.py @@ -30,6 +30,7 @@ "debug.record_transaction_failure": True, "debug.log_autorum_middleware": True, "agent_limits.errors_per_harvest": 100, + "ml_insights_events.enabled": True } collector_agent_registration = collector_agent_registration_fixture( diff --git a/tests/agent_features/test_apdex_metrics.py b/tests/agent_features/test_apdex_metrics.py index e32a96e31..c150fcf7e 100644 --- a/tests/agent_features/test_apdex_metrics.py +++ b/tests/agent_features/test_apdex_metrics.py @@ -13,24 +13,41 @@ # limitations under the License. import webtest - -from testing_support.validators.validate_apdex_metrics import ( - validate_apdex_metrics) from testing_support.sample_applications import simple_app +from testing_support.validators.validate_apdex_metrics import validate_apdex_metrics +from newrelic.api.transaction import current_transaction, suppress_apdex_metric +from newrelic.api.wsgi_application import wsgi_application normal_application = webtest.TestApp(simple_app) - # NOTE: This test validates that the server-side apdex_t is set to 0.5 # If the server-side configuration changes, this test will start to fail. @validate_apdex_metrics( - name='', - group='Uri', + name="", + group="Uri", apdex_t_min=0.5, apdex_t_max=0.5, ) def test_apdex(): - normal_application.get('/') + normal_application.get("/") + + +# This has to be a Web Transaction. +# The apdex measurement only applies to Web Transactions +def test_apdex_suppression(): + @wsgi_application() + def simple_apdex_supression_app(environ, start_response): + suppress_apdex_metric() + + start_response(status="200 OK", response_headers=[]) + transaction = current_transaction() + + assert transaction.suppress_apdex + assert transaction.apdex == 0 + return [] + + apdex_suppression_app = webtest.TestApp(simple_apdex_supression_app) + apdex_suppression_app.get("/") diff --git a/tests/agent_features/test_asgi_browser.py b/tests/agent_features/test_asgi_browser.py index 1e718e1e0..281d08b96 100644 --- a/tests/agent_features/test_asgi_browser.py +++ b/tests/agent_features/test_asgi_browser.py @@ -111,7 +111,7 @@ def test_footer_attributes(): obfuscation_key = settings.license_key[:13] - type_transaction_data = unicode if six.PY2 else str # noqa: F821 + type_transaction_data = unicode if six.PY2 else str # noqa: F821, pylint: disable=E0602 assert isinstance(data["transactionName"], type_transaction_data) txn_name = deobfuscate(data["transactionName"], obfuscation_key) diff --git a/tests/agent_features/test_async_generator_trace.py b/tests/agent_features/test_async_generator_trace.py new file mode 100644 index 000000000..208cf1588 --- /dev/null +++ b/tests/agent_features/test_async_generator_trace.py @@ -0,0 +1,19 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +# Async Generators were introduced in Python 3.6, but some APIs weren't completely stable until Python 3.7. +if sys.version_info >= (3, 7): + from _test_async_generator_trace import * # NOQA diff --git a/tests/agent_features/test_async_wrapper_detection.py b/tests/agent_features/test_async_wrapper_detection.py new file mode 100644 index 000000000..bb1fd3f1e --- /dev/null +++ b/tests/agent_features/test_async_wrapper_detection.py @@ -0,0 +1,102 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import functools +import time + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +from newrelic.common.async_wrapper import generator_wrapper + +from testing_support.fixtures import capture_transaction_metrics +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +trace_metric_cases = [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), +] + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_automatic_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace() + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + for _ in gen(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_manual_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace(async_wrapper=generator_wrapper) + def wrapper_func(): + """Function that returns a generator object, obscuring the automatic introspection of async_wrapper()""" + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + return gen() + + for _ in wrapper_func(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 diff --git a/tests/agent_features/test_attributes_in_action.py b/tests/agent_features/test_attributes_in_action.py index e56994d0a..08601fccf 100644 --- a/tests/agent_features/test_attributes_in_action.py +++ b/tests/agent_features/test_attributes_in_action.py @@ -20,14 +20,22 @@ override_application_settings, reset_core_stats_engine, validate_attributes, +) +from testing_support.validators.validate_browser_attributes import ( validate_browser_attributes, +) +from testing_support.validators.validate_error_event_attributes import ( validate_error_event_attributes, +) +from testing_support.validators.validate_error_event_attributes_outside_transaction import ( validate_error_event_attributes_outside_transaction, - validate_error_trace_attributes_outside_transaction, ) from testing_support.validators.validate_error_trace_attributes import ( validate_error_trace_attributes, ) +from testing_support.validators.validate_error_trace_attributes_outside_transaction import ( + validate_error_trace_attributes_outside_transaction, +) from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_error_trace_attributes import ( validate_transaction_error_trace_attributes, @@ -43,7 +51,7 @@ from newrelic.api.background_task import background_task from newrelic.api.message_transaction import message_transaction from newrelic.api.time_trace import notice_error -from newrelic.api.transaction import add_custom_attribute, current_transaction, set_user_id +from newrelic.api.transaction import add_custom_attribute, set_user_id from newrelic.api.wsgi_application import wsgi_application from newrelic.common.object_names import callable_name @@ -930,16 +938,21 @@ def test_none_type_routing_key_agent_attribute(): _forgone_agent_attributes = [] -@pytest.mark.parametrize('input_user_id, reported_user_id, high_security',( +@pytest.mark.parametrize( + "input_user_id, reported_user_id, high_security", + ( ("1234", "1234", True), - ("a" * 260, "a" * 255, False), -)) + ("a" * 260, "a" * 255, False), + ), +) def test_enduser_id_attribute_api_valid_types(input_user_id, reported_user_id, high_security): @reset_core_stats_engine() @validate_error_trace_attributes( callable_name(ValueError), exact_attrs={"user": {}, "intrinsic": {}, "agent": {"enduser.id": reported_user_id}} ) - @validate_error_event_attributes(exact_attrs={"user": {}, "intrinsic": {}, "agent": {"enduser.id": reported_user_id}}) + @validate_error_event_attributes( + exact_attrs={"user": {}, "intrinsic": {}, "agent": {"enduser.id": reported_user_id}} + ) @validate_attributes("agent", _required_agent_attributes, _forgone_agent_attributes) @background_task() @override_application_settings({"high_security": high_security}) @@ -950,10 +963,11 @@ def _test(): raise ValueError() except Exception: notice_error() + _test() -@pytest.mark.parametrize('input_user_id',(None, '', 123)) +@pytest.mark.parametrize("input_user_id", (None, "", 123)) def test_enduser_id_attribute_api_invalid_types(input_user_id): @reset_core_stats_engine() @validate_attributes("agent", [], ["enduser.id"]) @@ -965,4 +979,5 @@ def _test(): raise ValueError() except Exception: notice_error() + _test() diff --git a/tests/agent_features/test_collector_payloads.py b/tests/agent_features/test_collector_payloads.py index 0c1b2367c..42510e5c7 100644 --- a/tests/agent_features/test_collector_payloads.py +++ b/tests/agent_features/test_collector_payloads.py @@ -14,15 +14,15 @@ import pytest import webtest -from testing_support.fixtures import ( - override_application_settings, - validate_custom_event_collector_json, -) +from testing_support.fixtures import override_application_settings from testing_support.sample_applications import ( simple_app, simple_custom_event_app, simple_exceptional_app, ) +from testing_support.validators.validate_custom_event_collector_json import ( + validate_custom_event_collector_json, +) from testing_support.validators.validate_error_event_collector_json import ( validate_error_event_collector_json, ) diff --git a/tests/agent_features/test_configuration.py b/tests/agent_features/test_configuration.py index 5df69d71e..1a311e693 100644 --- a/tests/agent_features/test_configuration.py +++ b/tests/agent_features/test_configuration.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import tempfile import pytest @@ -21,8 +22,18 @@ except ImportError: import urllib.parse as urlparse +import logging + +from newrelic.api.exceptions import ConfigurationError from newrelic.common.object_names import callable_name -from newrelic.config import delete_setting, translate_deprecated_settings +from newrelic.config import ( + _reset_config_parser, + _reset_configuration_done, + _reset_instrumentation_done, + delete_setting, + initialize, + translate_deprecated_settings, +) from newrelic.core.config import ( Settings, apply_config_setting, @@ -34,6 +45,10 @@ ) +def function_to_trace(): + pass + + def parameterize_local_config(settings_list): settings_object_list = [] @@ -262,7 +277,6 @@ def parameterize_local_config(settings_list): @parameterize_local_config(_test_dictionary_local_config) def test_dict_parse(settings): - assert "NR-SESSION" in settings.request_headers_map config = settings.event_harvest_config @@ -577,9 +591,388 @@ def test_translate_deprecated_ignored_params_with_new_setting(): ("agent_run_id", None), ("entity_guid", None), ("distributed_tracing.exclude_newrelic_header", False), + ("otlp_host", "otlp.nr-data.net"), + ("otlp_port", 0), ), ) def test_default_values(name, expected_value): settings = global_settings() value = fetch_config_setting(settings, name) assert value == expected_value + + +def test_initialize(): + initialize() + + +newrelic_ini_contents = b""" +[newrelic] +app_name = Python Agent Test (agent_features) +""" + + +def test_initialize_raises_if_config_does_not_match_previous(): + error_message = "Configuration has already been done against " "differing configuration file or environment.*" + with pytest.raises(ConfigurationError, match=error_message): + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.seek(0) + + initialize(config_file=f.name) + + +def test_initialize_via_config_file(): + _reset_configuration_done() + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.seek(0) + + initialize(config_file=f.name) + + +def test_initialize_no_config_file(): + _reset_configuration_done() + initialize() + + +def test_initialize_config_file_does_not_exist(): + _reset_configuration_done() + error_message = "Unable to open configuration file does-not-exist." + with pytest.raises(ConfigurationError, match=error_message): + initialize(config_file="does-not-exist") + + +def test_initialize_environment(): + _reset_configuration_done() + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.seek(0) + + initialize(config_file=f.name, environment="developement") + + +def test_initialize_log_level(): + _reset_configuration_done() + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.seek(0) + + initialize(config_file=f.name, log_level="debug") + + +def test_initialize_log_file(): + _reset_configuration_done() + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.seek(0) + + initialize(config_file=f.name, log_file="stdout") + + +@pytest.mark.parametrize( + "feature_flag,expect_warning", + ( + (["django.instrumentation.inclusion-tags.r1"], False), + (["noexist"], True), + ), +) +def test_initialize_config_file_feature_flag(feature_flag, expect_warning, logger): + settings = global_settings() + apply_config_setting(settings, "feature_flag", feature_flag) + _reset_configuration_done() + + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.seek(0) + + initialize(config_file=f.name) + + message = ( + "Unknown agent feature flag 'noexist' provided. " + "Check agent documentation or release notes, or " + "contact New Relic support for clarification of " + "validity of the specific feature flag." + ) + if expect_warning: + assert message in logger.caplog.records + else: + assert message not in logger.caplog.records + + apply_config_setting(settings, "feature_flag", []) + + +@pytest.mark.parametrize( + "feature_flag,expect_warning", + ( + (["django.instrumentation.inclusion-tags.r1"], False), + (["noexist"], True), + ), +) +def test_initialize_no_config_file_feature_flag(feature_flag, expect_warning, logger): + settings = global_settings() + apply_config_setting(settings, "feature_flag", feature_flag) + _reset_configuration_done() + + initialize() + + message = ( + "Unknown agent feature flag 'noexist' provided. " + "Check agent documentation or release notes, or " + "contact New Relic support for clarification of " + "validity of the specific feature flag." + ) + + if expect_warning: + assert message in logger.caplog.records + else: + assert message not in logger.caplog.records + + apply_config_setting(settings, "feature_flag", []) + + +@pytest.mark.parametrize( + "setting_name,setting_value,expect_error", + ( + ("transaction_tracer.function_trace", [callable_name(function_to_trace)], False), + ("transaction_tracer.generator_trace", [callable_name(function_to_trace)], False), + ("transaction_tracer.function_trace", ["no_exist"], True), + ("transaction_tracer.generator_trace", ["no_exist"], True), + ), +) +def test_initialize_config_file_with_traces(setting_name, setting_value, expect_error, logger): + settings = global_settings() + apply_config_setting(settings, setting_name, setting_value) + _reset_configuration_done() + + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.seek(0) + + initialize(config_file=f.name) + + if expect_error: + assert "CONFIGURATION ERROR" in logger.caplog.records + else: + assert "CONFIGURATION ERROR" not in logger.caplog.records + + apply_config_setting(settings, setting_name, []) + + +func_newrelic_ini = b""" +[function-trace:] +enabled = True +function = test_configuration:function_to_trace +name = function_to_trace +group = group +label = label +terminal = False +rollup = foo/all +""" + +bad_func_newrelic_ini = b""" +[function-trace:] +enabled = True +function = function_to_trace +""" + +func_missing_enabled_newrelic_ini = b""" +[function-trace:] +function = function_to_trace +""" + +external_newrelic_ini = b""" +[external-trace:] +enabled = True +function = test_configuration:function_to_trace +library = "foo" +url = localhost:80/foo +method = GET +""" + +bad_external_newrelic_ini = b""" +[external-trace:] +enabled = True +function = function_to_trace +""" + +external_missing_enabled_newrelic_ini = b""" +[external-trace:] +function = function_to_trace +""" + +generator_newrelic_ini = b""" +[generator-trace:] +enabled = True +function = test_configuration:function_to_trace +name = function_to_trace +group = group +""" + +bad_generator_newrelic_ini = b""" +[generator-trace:] +enabled = True +function = function_to_trace +""" + +generator_missing_enabled_newrelic_ini = b""" +[generator-trace:] +function = function_to_trace +""" + +bg_task_newrelic_ini = b""" +[background-task:] +enabled = True +function = test_configuration:function_to_trace +lambda = test_configuration:function_to_trace +""" + +bad_bg_task_newrelic_ini = b""" +[background-task:] +enabled = True +function = function_to_trace +""" + +bg_task_missing_enabled_newrelic_ini = b""" +[background-task:] +function = function_to_trace +""" + +db_trace_newrelic_ini = b""" +[database-trace:] +enabled = True +function = test_configuration:function_to_trace +sql = test_configuration:function_to_trace +""" + +bad_db_trace_newrelic_ini = b""" +[database-trace:] +enabled = True +function = function_to_trace +""" + +db_trace_missing_enabled_newrelic_ini = b""" +[database-trace:] +function = function_to_trace +""" + +wsgi_newrelic_ini = b""" +[wsgi-application:] +enabled = True +function = test_configuration:function_to_trace +application = app +""" + +bad_wsgi_newrelic_ini = b""" +[wsgi-application:] +enabled = True +function = function_to_trace +application = app +""" + +wsgi_missing_enabled_newrelic_ini = b""" +[wsgi-application:] +function = function_to_trace +application = app +""" + +wsgi_unparseable_enabled_newrelic_ini = b""" +[wsgi-application:] +enabled = not-a-bool +function = function_to_trace +application = app +""" + + +@pytest.mark.parametrize( + "section,expect_error", + ( + (func_newrelic_ini, False), + (bad_func_newrelic_ini, True), + (func_missing_enabled_newrelic_ini, False), + (external_newrelic_ini, False), + (bad_external_newrelic_ini, True), + (external_missing_enabled_newrelic_ini, False), + (generator_newrelic_ini, False), + (bad_generator_newrelic_ini, True), + (generator_missing_enabled_newrelic_ini, False), + (bg_task_newrelic_ini, False), + (bad_bg_task_newrelic_ini, True), + (bg_task_missing_enabled_newrelic_ini, False), + (db_trace_newrelic_ini, False), + (bad_db_trace_newrelic_ini, True), + (db_trace_missing_enabled_newrelic_ini, False), + (wsgi_newrelic_ini, False), + (bad_wsgi_newrelic_ini, True), + (wsgi_missing_enabled_newrelic_ini, False), + (wsgi_unparseable_enabled_newrelic_ini, True), + ), + ids=( + "func_newrelic_ini", + "bad_func_newrelic_ini", + "func_missing_enabled_newrelic_ini", + "external_newrelic_ini", + "bad_external_newrelic_ini", + "external_missing_enabled_newrelic_ini", + "generator_newrelic_ini", + "bad_generator_newrelic_ini", + "generator_missing_enabled_newrelic_ini", + "bg_task_newrelic_ini", + "bad_bg_task_newrelic_ini", + "bg_task_missing_enabled_newrelic_ini", + "db_trace_newrelic_ini", + "bad_db_trace_newrelic_ini", + "db_trace_missing_enabled_newrelic_ini", + "wsgi_newrelic_ini", + "bad_wsgi_newrelic_ini", + "wsgi_missing_enabled_newrelic_ini", + "wsgi_unparseable_enabled_newrelic_ini", + ), +) +def test_initialize_developer_mode(section, expect_error, logger): + settings = global_settings() + apply_config_setting(settings, "monitor_mode", False) + apply_config_setting(settings, "developer_mode", True) + _reset_configuration_done() + _reset_instrumentation_done() + _reset_config_parser() + + with tempfile.NamedTemporaryFile() as f: + f.write(newrelic_ini_contents) + f.write(section) + f.seek(0) + + initialize(config_file=f.name) + + if expect_error: + assert "CONFIGURATION ERROR" in logger.caplog.records + else: + assert "CONFIGURATION ERROR" not in logger.caplog.records + + +@pytest.fixture +def caplog_handler(): + class CaplogHandler(logging.StreamHandler): + """ + To prevent possible issues with pytest's monkey patching + use a custom Caplog handler to capture all records + """ + + def __init__(self, *args, **kwargs): + self.records = [] + super(CaplogHandler, self).__init__(*args, **kwargs) + + def emit(self, record): + self.records.append(self.format(record)) + + return CaplogHandler() + + +@pytest.fixture +def logger(caplog_handler): + _logger = logging.getLogger("newrelic.config") + _logger.addHandler(caplog_handler) + _logger.caplog = caplog_handler + _logger.setLevel(logging.WARNING) + yield _logger + del caplog_handler.records[:] + _logger.removeHandler(caplog_handler) diff --git a/tests/agent_features/test_coroutine_trace.py b/tests/agent_features/test_coroutine_trace.py index 36e365bc4..2043f1326 100644 --- a/tests/agent_features/test_coroutine_trace.py +++ b/tests/agent_features/test_coroutine_trace.py @@ -31,6 +31,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -47,6 +48,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_coroutine_timing(trace, metric): @@ -337,6 +340,37 @@ def coro(): pass +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/coro", 1)], + rollup_metrics=[("Function/coro", 1)], +) +@background_task(name="test_multiple_throws_yield_a_value") +def test_multiple_throws_yield_a_value(): + @function_trace(name="coro") + def coro(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + c = coro() + + # kickstart the coroutine + assert next(c) is None + assert c.throw(MyException) == "foo" + assert c.throw(MyException) == "foo" + assert next(c) == "bar" + + # finish consumption of the coroutine if necessary + for _ in c: + pass + + @pytest.mark.parametrize( "trace", [ diff --git a/tests/agent_features/test_custom_metrics.py b/tests/agent_features/test_custom_metrics.py new file mode 100644 index 000000000..21a67149a --- /dev/null +++ b/tests/agent_features/test_custom_metrics.py @@ -0,0 +1,62 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_custom_metrics_outside_transaction import ( + validate_custom_metrics_outside_transaction, +) + +from newrelic.api.application import application_instance as application +from newrelic.api.background_task import background_task +from newrelic.api.transaction import ( + current_transaction, + record_custom_metric, + record_custom_metrics, +) + + +# Testing record_custom_metric +@reset_core_stats_engine() +@background_task() +def test_custom_metric_inside_transaction(): + transaction = current_transaction() + record_custom_metric("CustomMetric/InsideTransaction/Count", 1) + for metric in transaction._custom_metrics.metrics(): + assert metric == ("CustomMetric/InsideTransaction/Count", [1, 1, 1, 1, 1, 1]) + + +@reset_core_stats_engine() +@validate_custom_metrics_outside_transaction([("CustomMetric/OutsideTransaction/Count", 1)]) +@background_task() +def test_custom_metric_outside_transaction_with_app(): + app = application() + record_custom_metric("CustomMetric/OutsideTransaction/Count", 1, application=app) + + +# Testing record_custom_metricS +@reset_core_stats_engine() +@background_task() +def test_custom_metrics_inside_transaction(): + transaction = current_transaction() + record_custom_metrics([("CustomMetrics/InsideTransaction/Count", 1)]) + for metric in transaction._custom_metrics.metrics(): + assert metric == ("CustomMetrics/InsideTransaction/Count", [1, 1, 1, 1, 1, 1]) + + +@reset_core_stats_engine() +@validate_custom_metrics_outside_transaction([("CustomMetrics/OutsideTransaction/Count", 1)]) +@background_task() +def test_custom_metrics_outside_transaction_with_app(): + app = application() + record_custom_metrics([("CustomMetrics/OutsideTransaction/Count", 1)], application=app) diff --git a/tests/agent_features/test_datastore_trace.py b/tests/agent_features/test_datastore_trace.py new file mode 100644 index 000000000..08067e040 --- /dev/null +++ b/tests/agent_features/test_datastore_trace.py @@ -0,0 +1,89 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from testing_support.validators.validate_datastore_trace_inputs import ( + validate_datastore_trace_inputs, +) + +from newrelic.api.background_task import background_task +from newrelic.api.datastore_trace import DatastoreTrace, DatastoreTraceWrapper + + +@validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", +) +@background_task() +def test_dt_trace_all_args(): + with DatastoreTrace( + product="Agent Features", + target="test_target", + operation="test_operation", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ): + pass + + +@validate_datastore_trace_inputs(operation=None, target=None, host=None, port_path_or_id=None, database_name=None) +@background_task() +def test_dt_trace_empty(): + with DatastoreTrace(product=None, target=None, operation=None): + pass + + +@background_task() +def test_dt_trace_callable_args(): + def product_callable(): + return "Agent Features" + + def target_callable(): + return "test_target" + + def operation_callable(): + return "test_operation" + + def host_callable(): + return "test_host" + + def port_path_id_callable(): + return "test_port" + + def db_name_callable(): + return "test_db_name" + + @validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ) + def _test(): + pass + + wrapped_fn = DatastoreTraceWrapper( + _test, + product=product_callable, + target=target_callable, + operation=operation_callable, + host=host_callable, + port_path_or_id=port_path_id_callable, + database_name=db_name_callable, + ) + wrapped_fn() diff --git a/tests/agent_features/test_dimensional_metrics.py b/tests/agent_features/test_dimensional_metrics.py new file mode 100644 index 000000000..ef9e98418 --- /dev/null +++ b/tests/agent_features/test_dimensional_metrics.py @@ -0,0 +1,224 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_dimensional_metric_payload import ( + validate_dimensional_metric_payload, +) +from testing_support.validators.validate_dimensional_metrics_outside_transaction import ( + validate_dimensional_metrics_outside_transaction, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +import newrelic.core.otlp_utils +from newrelic.api.application import application_instance +from newrelic.api.background_task import background_task +from newrelic.api.transaction import ( + record_dimensional_metric, + record_dimensional_metrics, +) +from newrelic.common.metric_utils import create_metric_identity +from newrelic.core.config import global_settings +from newrelic.packages import six + +try: + # python 2.x + reload +except NameError: + # python 3.x + from importlib import reload + + +@pytest.fixture(scope="module", autouse=True, params=["protobuf", "json"]) +def otlp_content_encoding(request): + if six.PY2 and request.param == "protobuf": + pytest.skip("OTLP protos are not compatible with Python 2.") + + _settings = global_settings() + prev = _settings.debug.otlp_content_encoding + _settings.debug.otlp_content_encoding = request.param + reload(newrelic.core.otlp_utils) + assert newrelic.core.otlp_utils.otlp_content_setting == request.param, "Content encoding mismatch." + + yield + + _settings.debug.otlp_content_encoding = prev + + +_test_tags_examples = [ + (None, None), + ({}, None), + ([], None), + ({"str": "a"}, frozenset({("str", "a")})), + ({"int": 1}, frozenset({("int", 1)})), + ({"float": 1.0}, frozenset({("float", 1.0)})), + ({"bool": True}, frozenset({("bool", True)})), + ({"list": [1]}, frozenset({("list", "[1]")})), + ({"dict": {"subtag": 1}}, frozenset({("dict", "{'subtag': 1}")})), + ([("tags-as-list", 1)], frozenset({("tags-as-list", 1)})), +] + + +@pytest.mark.parametrize("tags,expected", _test_tags_examples) +def test_create_metric_identity(tags, expected): + name = "Metric" + output_name, output_tags = create_metric_identity(name, tags=tags) + assert output_name == name, "Name does not match." + assert output_tags == expected, "Output tags do not match." + + +@pytest.mark.parametrize("tags,expected", _test_tags_examples) +@reset_core_stats_engine() +def test_record_dimensional_metric_inside_transaction(tags, expected): + @validate_transaction_metrics( + "test_record_dimensional_metric_inside_transaction", + background_task=True, + dimensional_metrics=[ + ("Metric", expected, 1), + ], + ) + @background_task(name="test_record_dimensional_metric_inside_transaction") + def _test(): + record_dimensional_metric("Metric", 1, tags=tags) + + _test() + + +@pytest.mark.parametrize("tags,expected", _test_tags_examples) +@reset_core_stats_engine() +def test_record_dimensional_metric_outside_transaction(tags, expected): + @validate_dimensional_metrics_outside_transaction([("Metric", expected, 1)]) + def _test(): + app = application_instance() + record_dimensional_metric("Metric", 1, tags=tags, application=app) + + _test() + + +@pytest.mark.parametrize("tags,expected", _test_tags_examples) +@reset_core_stats_engine() +def test_record_dimensional_metrics_inside_transaction(tags, expected): + @validate_transaction_metrics( + "test_record_dimensional_metrics_inside_transaction", + background_task=True, + dimensional_metrics=[("Metric.1", expected, 1), ("Metric.2", expected, 1)], + ) + @background_task(name="test_record_dimensional_metrics_inside_transaction") + def _test(): + record_dimensional_metrics([("Metric.1", 1, tags), ("Metric.2", 1, tags)]) + + _test() + + +@pytest.mark.parametrize("tags,expected", _test_tags_examples) +@reset_core_stats_engine() +def test_record_dimensional_metrics_outside_transaction(tags, expected): + @validate_dimensional_metrics_outside_transaction([("Metric.1", expected, 1), ("Metric.2", expected, 1)]) + def _test(): + app = application_instance() + record_dimensional_metrics([("Metric.1", 1, tags), ("Metric.2", 1, tags)], application=app) + + _test() + + +@reset_core_stats_engine() +def test_dimensional_metrics_different_tags(): + @validate_transaction_metrics( + "test_dimensional_metrics_different_tags", + background_task=True, + dimensional_metrics=[ + ("Metric", frozenset({("tag", 1)}), 1), + ("Metric", frozenset({("tag", 2)}), 2), + ], + ) + @background_task(name="test_dimensional_metrics_different_tags") + def _test(): + record_dimensional_metrics( + [ + ("Metric", 1, {"tag": 1}), + ("Metric", 1, {"tag": 2}), + ] + ) + record_dimensional_metric("Metric", 1, {"tag": 2}) + + _test() + + +@reset_core_stats_engine() +@validate_dimensional_metric_payload( + summary_metrics=[ + ("Metric.Summary", {"tag": 1}, 1), + ("Metric.Summary", {"tag": 2}, 1), + ("Metric.Summary", None, 1), + ("Metric.Mixed", {"tag": 1}, 1), + ("Metric.NotPresent", None, None), + ], + count_metrics=[ + ("Metric.Count", {"tag": 1}, 1), + ("Metric.Count", {"tag": 2}, 2), + ("Metric.Count", None, 3), + ("Metric.Mixed", {"tag": 2}, 2), + ("Metric.NotPresent", None, None), + ], +) +def test_dimensional_metrics_payload(): + @background_task(name="test_dimensional_metric_payload") + def _test(): + record_dimensional_metrics( + [ + ("Metric.Summary", 1, {"tag": 1}), + ("Metric.Summary", 2, {"tag": 2}), + ("Metric.Summary", 3), # No tags + ("Metric.Count", {"count": 1}, {"tag": 1}), + ("Metric.Count", {"count": 2}, {"tag": 2}), + ("Metric.Count", {"count": 3}), # No tags + ("Metric.Mixed", 1, {"tag": 1}), + ("Metric.Mixed", {"count": 2}, {"tag": 2}), + ] + ) + + _test() + app = application_instance() + core_app = app._agent.application(app.name) + core_app.harvest() + + +@reset_core_stats_engine() +@validate_dimensional_metric_payload( + summary_metrics=[ + ("Metric.Summary", None, 1), + ("Metric.Count", None, None), # Should NOT be present + ], + count_metrics=[ + ("Metric.Count", None, 1), + ("Metric.Summary", None, None), # Should NOT be present + ], +) +def test_dimensional_metrics_no_duplicate_encodings(): + @background_task(name="test_dimensional_metric_payload") + def _test(): + record_dimensional_metrics( + [ + ("Metric.Summary", 1), + ("Metric.Count", {"count": 1}), + ] + ) + + _test() + app = application_instance() + core_app = app._agent.application(app.name) + core_app.harvest() diff --git a/tests/agent_features/test_distributed_tracing.py b/tests/agent_features/test_distributed_tracing.py index 7f795573a..263b1bdcf 100644 --- a/tests/agent_features/test_distributed_tracing.py +++ b/tests/agent_features/test_distributed_tracing.py @@ -12,71 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json + import pytest import webtest -import copy +from testing_support.fixtures import override_application_settings, validate_attributes +from testing_support.validators.validate_error_event_attributes import ( + validate_error_event_attributes, +) +from testing_support.validators.validate_transaction_event_attributes import ( + validate_transaction_event_attributes, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.application import application_instance -from newrelic.api.background_task import background_task, BackgroundTask -from newrelic.api.transaction import (current_transaction, current_trace_id, - current_span_id) +from newrelic.api.background_task import BackgroundTask, background_task +from newrelic.api.external_trace import ExternalTrace from newrelic.api.time_trace import current_trace +from newrelic.api.transaction import ( + accept_distributed_trace_headers, + accept_distributed_trace_payload, + create_distributed_trace_payload, + current_span_id, + current_trace_id, + current_transaction, +) from newrelic.api.web_transaction import WSGIWebTransaction from newrelic.api.wsgi_application import wsgi_application -from testing_support.fixtures import (override_application_settings, - validate_attributes, - validate_error_event_attributes) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_transaction_event_attributes import validate_transaction_event_attributes - -distributed_trace_intrinsics = ['guid', 'traceId', 'priority', 'sampled'] -inbound_payload_intrinsics = ['parent.type', 'parent.app', 'parent.account', - 'parent.transportType', 'parent.transportDuration'] +distributed_trace_intrinsics = ["guid", "traceId", "priority", "sampled"] +inbound_payload_intrinsics = [ + "parent.type", + "parent.app", + "parent.account", + "parent.transportType", + "parent.transportDuration", +] payload = { - 'v': [0, 1], - 'd': { - 'ac': '1', - 'ap': '2827902', - 'id': '7d3efb1b173fecfa', - 'pa': '5e5733a911cfbc73', - 'pr': 10.001, - 'sa': True, - 'ti': 1518469636035, - 'tr': 'd6b4ba0c3a712ca', - 'ty': 'App', - } + "v": [0, 1], + "d": { + "ac": "1", + "ap": "2827902", + "id": "7d3efb1b173fecfa", + "pa": "5e5733a911cfbc73", + "pr": 10.001, + "sa": True, + "ti": 1518469636035, + "tr": "d6b4ba0c3a712ca", + "ty": "App", + }, } -parent_order = ['parent_type', 'parent_account', - 'parent_app', 'parent_transport_type'] +parent_order = ["parent_type", "parent_account", "parent_app", "parent_transport_type"] parent_info = { - 'parent_type': payload['d']['ty'], - 'parent_account': payload['d']['ac'], - 'parent_app': payload['d']['ap'], - 'parent_transport_type': 'HTTP' + "parent_type": payload["d"]["ty"], + "parent_account": payload["d"]["ac"], + "parent_app": payload["d"]["ap"], + "parent_transport_type": "HTTP", } @wsgi_application() def target_wsgi_application(environ, start_response): - status = '200 OK' - output = b'hello world' - response_headers = [('Content-type', 'text/html; charset=utf-8'), - ('Content-Length', str(len(output)))] + status = "200 OK" + output = b"hello world" + response_headers = [("Content-type", "text/html; charset=utf-8"), ("Content-Length", str(len(output)))] txn = current_transaction() # Make assertions on the WSGIWebTransaction object assert txn._distributed_trace_state - assert txn.parent_type == 'App' - assert txn.parent_app == '2827902' - assert txn.parent_account == '1' - assert txn.parent_span == '7d3efb1b173fecfa' - assert txn.parent_transport_type == 'HTTP' + assert txn.parent_type == "App" + assert txn.parent_app == "2827902" + assert txn.parent_account == "1" + assert txn.parent_span == "7d3efb1b173fecfa" + assert txn.parent_transport_type == "HTTP" assert isinstance(txn.parent_transport_duration, float) - assert txn._trace_id == 'd6b4ba0c3a712ca' + assert txn._trace_id == "d6b4ba0c3a712ca" assert txn.priority == 10.001 assert txn.sampled @@ -87,90 +102,75 @@ def target_wsgi_application(environ, start_response): test_application = webtest.TestApp(target_wsgi_application) _override_settings = { - 'trusted_account_key': '1', - 'distributed_tracing.enabled': True, + "trusted_account_key": "1", + "distributed_tracing.enabled": True, } _metrics = [ - ('Supportability/DistributedTrace/AcceptPayload/Success', 1), - ('Supportability/TraceContext/Accept/Success', None) + ("Supportability/DistributedTrace/AcceptPayload/Success", 1), + ("Supportability/TraceContext/Accept/Success", None), ] @override_application_settings(_override_settings) -@validate_transaction_metrics( - '', - group='Uri', - rollup_metrics=_metrics) +@validate_transaction_metrics("", group="Uri", rollup_metrics=_metrics) def test_distributed_tracing_web_transaction(): - headers = {'newrelic': json.dumps(payload)} - response = test_application.get('/', headers=headers) - assert 'X-NewRelic-App-Data' not in response.headers + headers = {"newrelic": json.dumps(payload)} + response = test_application.get("/", headers=headers) + assert "X-NewRelic-App-Data" not in response.headers -@pytest.mark.parametrize('span_events', (True, False)) -@pytest.mark.parametrize('accept_payload', (True, False)) +@pytest.mark.parametrize("span_events", (True, False)) +@pytest.mark.parametrize("accept_payload", (True, False)) def test_distributed_trace_attributes(span_events, accept_payload): if accept_payload: - _required_intrinsics = ( - distributed_trace_intrinsics + inbound_payload_intrinsics) + _required_intrinsics = distributed_trace_intrinsics + inbound_payload_intrinsics _forgone_txn_intrinsics = [] _forgone_error_intrinsics = [] _exact_intrinsics = { - 'parent.type': 'Mobile', - 'parent.app': '2827902', - 'parent.account': '1', - 'parent.transportType': 'HTTP', - 'traceId': 'd6b4ba0c3a712ca', + "parent.type": "Mobile", + "parent.app": "2827902", + "parent.account": "1", + "parent.transportType": "HTTP", + "traceId": "d6b4ba0c3a712ca", } - _exact_txn_attributes = {'agent': {}, 'user': {}, - 'intrinsic': _exact_intrinsics.copy()} - _exact_error_attributes = {'agent': {}, 'user': {}, - 'intrinsic': _exact_intrinsics.copy()} - _exact_txn_attributes['intrinsic']['parentId'] = '7d3efb1b173fecfa' - _exact_txn_attributes['intrinsic']['parentSpanId'] = 'c86df80de2e6f51c' - - _forgone_error_intrinsics.append('parentId') - _forgone_error_intrinsics.append('parentSpanId') - _forgone_txn_intrinsics.append('grandparentId') - _forgone_error_intrinsics.append('grandparentId') - - _required_attributes = { - 'intrinsic': _required_intrinsics, 'agent': [], 'user': []} - _forgone_txn_attributes = {'intrinsic': _forgone_txn_intrinsics, - 'agent': [], 'user': []} - _forgone_error_attributes = {'intrinsic': _forgone_error_intrinsics, - 'agent': [], 'user': []} + _exact_txn_attributes = {"agent": {}, "user": {}, "intrinsic": _exact_intrinsics.copy()} + _exact_error_attributes = {"agent": {}, "user": {}, "intrinsic": _exact_intrinsics.copy()} + _exact_txn_attributes["intrinsic"]["parentId"] = "7d3efb1b173fecfa" + _exact_txn_attributes["intrinsic"]["parentSpanId"] = "c86df80de2e6f51c" + + _forgone_error_intrinsics.append("parentId") + _forgone_error_intrinsics.append("parentSpanId") + _forgone_txn_intrinsics.append("grandparentId") + _forgone_error_intrinsics.append("grandparentId") + + _required_attributes = {"intrinsic": _required_intrinsics, "agent": [], "user": []} + _forgone_txn_attributes = {"intrinsic": _forgone_txn_intrinsics, "agent": [], "user": []} + _forgone_error_attributes = {"intrinsic": _forgone_error_intrinsics, "agent": [], "user": []} else: _required_intrinsics = distributed_trace_intrinsics - _forgone_txn_intrinsics = _forgone_error_intrinsics = \ - inbound_payload_intrinsics + ['grandparentId', 'parentId', - 'parentSpanId'] - - _required_attributes = { - 'intrinsic': _required_intrinsics, 'agent': [], 'user': []} - _forgone_txn_attributes = {'intrinsic': _forgone_txn_intrinsics, - 'agent': [], 'user': []} - _forgone_error_attributes = {'intrinsic': _forgone_error_intrinsics, - 'agent': [], 'user': []} + _forgone_txn_intrinsics = _forgone_error_intrinsics = inbound_payload_intrinsics + [ + "grandparentId", + "parentId", + "parentSpanId", + ] + + _required_attributes = {"intrinsic": _required_intrinsics, "agent": [], "user": []} + _forgone_txn_attributes = {"intrinsic": _forgone_txn_intrinsics, "agent": [], "user": []} + _forgone_error_attributes = {"intrinsic": _forgone_error_intrinsics, "agent": [], "user": []} _exact_txn_attributes = _exact_error_attributes = None _forgone_trace_intrinsics = _forgone_error_intrinsics test_settings = _override_settings.copy() - test_settings['span_events.enabled'] = span_events + test_settings["span_events.enabled"] = span_events @override_application_settings(test_settings) - @validate_transaction_event_attributes( - _required_attributes, _forgone_txn_attributes, - _exact_txn_attributes) - @validate_error_event_attributes( - _required_attributes, _forgone_error_attributes, - _exact_error_attributes) - @validate_attributes('intrinsic', - _required_intrinsics, _forgone_trace_intrinsics) - @background_task(name='test_distributed_trace_attributes') + @validate_transaction_event_attributes(_required_attributes, _forgone_txn_attributes, _exact_txn_attributes) + @validate_error_event_attributes(_required_attributes, _forgone_error_attributes, _exact_error_attributes) + @validate_attributes("intrinsic", _required_intrinsics, _forgone_trace_intrinsics) + @background_task(name="test_distributed_trace_attributes") def _test(): txn = current_transaction() @@ -183,19 +183,19 @@ def _test(): "id": "c86df80de2e6f51c", "tr": "d6b4ba0c3a712ca", "ti": 1518469636035, - "tx": "7d3efb1b173fecfa" - } + "tx": "7d3efb1b173fecfa", + }, } - payload['d']['pa'] = "5e5733a911cfbc73" + payload["d"]["pa"] = "5e5733a911cfbc73" if accept_payload: - result = txn.accept_distributed_trace_payload(payload) + result = accept_distributed_trace_payload(payload) assert result else: - txn._create_distributed_trace_payload() + create_distributed_trace_payload() try: - raise ValueError('cookies') + raise ValueError("cookies") except ValueError: txn.notice_error() @@ -203,33 +203,30 @@ def _test(): _forgone_attributes = { - 'agent': [], - 'user': [], - 'intrinsic': (inbound_payload_intrinsics + ['grandparentId']), + "agent": [], + "user": [], + "intrinsic": (inbound_payload_intrinsics + ["grandparentId"]), } @override_application_settings(_override_settings) -@validate_transaction_event_attributes( - {}, _forgone_attributes) -@validate_error_event_attributes( - {}, _forgone_attributes) -@validate_attributes('intrinsic', - {}, _forgone_attributes['intrinsic']) -@background_task(name='test_distributed_trace_attrs_omitted') +@validate_transaction_event_attributes({}, _forgone_attributes) +@validate_error_event_attributes({}, _forgone_attributes) +@validate_attributes("intrinsic", {}, _forgone_attributes["intrinsic"]) +@background_task(name="test_distributed_trace_attrs_omitted") def test_distributed_trace_attrs_omitted(): txn = current_transaction() try: - raise ValueError('cookies') + raise ValueError("cookies") except ValueError: txn.notice_error() # test our distributed_trace metrics by creating a transaction and then forcing # it to process a distributed trace payload -@pytest.mark.parametrize('web_transaction', (True, False)) -@pytest.mark.parametrize('gen_error', (True, False)) -@pytest.mark.parametrize('has_parent', (True, False)) +@pytest.mark.parametrize("web_transaction", (True, False)) +@pytest.mark.parametrize("gen_error", (True, False)) +@pytest.mark.parametrize("has_parent", (True, False)) def test_distributed_tracing_metrics(web_transaction, gen_error, has_parent): def _make_dt_tag(pi): return "%s/%s/%s/%s/all" % tuple(pi[x] for x in parent_order) @@ -237,11 +234,11 @@ def _make_dt_tag(pi): # figure out which metrics we'll see based on the test params # note: we'll always see DurationByCaller if the distributed # tracing flag is turned on - metrics = ['DurationByCaller'] + metrics = ["DurationByCaller"] if gen_error: - metrics.append('ErrorsByCaller') + metrics.append("ErrorsByCaller") if has_parent: - metrics.append('TransportDuration') + metrics.append("TransportDuration") tag = None dt_payload = copy.deepcopy(payload) @@ -251,15 +248,14 @@ def _make_dt_tag(pi): if has_parent: tag = _make_dt_tag(parent_info) else: - tag = _make_dt_tag(dict((x, 'Unknown') for x in parent_info.keys())) - del dt_payload['d']['tr'] + # tag = _make_dt_tag(dict((x, "Unknown") for x in parent_order)) + tag = _make_dt_tag(dict((x, "Unknown") for x in parent_info.keys())) + del dt_payload["d"]["tr"] # now run the test - transaction_name = "test_dt_metrics_%s" % '_'.join(metrics) + transaction_name = "test_dt_metrics_%s" % "_".join(metrics) _rollup_metrics = [ - ("%s/%s%s" % (x, tag, bt), 1) - for x in metrics - for bt in ['', 'Web' if web_transaction else 'Other'] + ("%s/%s%s" % (x, tag, bt), 1) for x in metrics for bt in ["", "Web" if web_transaction else "Other"] ] def _make_test_transaction(): @@ -268,16 +264,15 @@ def _make_test_transaction(): if not web_transaction: return BackgroundTask(application, transaction_name) - environ = {'REQUEST_URI': '/trace_ends_after_txn'} + environ = {"REQUEST_URI": "/trace_ends_after_txn"} tn = WSGIWebTransaction(application, environ) tn.set_transaction_name(transaction_name) return tn @override_application_settings(_override_settings) @validate_transaction_metrics( - transaction_name, - background_task=not(web_transaction), - rollup_metrics=_rollup_metrics) + transaction_name, background_task=not (web_transaction), rollup_metrics=_rollup_metrics + ) def _test(): with _make_test_transaction() as transaction: transaction.accept_distributed_trace_payload(dt_payload) @@ -291,62 +286,62 @@ def _test(): _test() -NEW_RELIC_ACCEPTED = \ - [('Supportability/DistributedTrace/AcceptPayload/Success', 1), - ('Supportability/TraceContext/Accept/Success', None), - ('Supportability/TraceContext/TraceParent/Accept/Success', None), - ('Supportability/TraceContext/Accept/Success', None)] -TRACE_CONTEXT_ACCEPTED = \ - [('Supportability/TraceContext/Accept/Success', 1), - ('Supportability/TraceContext/TraceParent/Accept/Success', 1), - ('Supportability/TraceContext/Accept/Success', 1), - ('Supportability/DistributedTrace/AcceptPayload/Success', None)] -NO_HEADERS_ACCEPTED = \ - [('Supportability/DistributedTrace/AcceptPayload/Success', None), - ('Supportability/TraceContext/Accept/Success', None), - ('Supportability/TraceContext/TraceParent/Accept/Success', None), - ('Supportability/TraceContext/Accept/Success', None)] -TRACEPARENT = '00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01' -TRACESTATE = 'rojo=f06a0ba902b7,congo=t61rcWkgMzE' - - -@pytest.mark.parametrize('traceparent,tracestate,newrelic,metrics', - [(False, False, False, NO_HEADERS_ACCEPTED), - (False, False, True, NEW_RELIC_ACCEPTED), - (False, True, True, NEW_RELIC_ACCEPTED), - (False, True, False, NO_HEADERS_ACCEPTED), - (True, True, True, TRACE_CONTEXT_ACCEPTED), - (True, False, False, TRACE_CONTEXT_ACCEPTED), - (True, False, True, TRACE_CONTEXT_ACCEPTED), - (True, True, False, TRACE_CONTEXT_ACCEPTED)] - ) +NEW_RELIC_ACCEPTED = [ + ("Supportability/DistributedTrace/AcceptPayload/Success", 1), + ("Supportability/TraceContext/Accept/Success", None), + ("Supportability/TraceContext/TraceParent/Accept/Success", None), + ("Supportability/TraceContext/Accept/Success", None), +] +TRACE_CONTEXT_ACCEPTED = [ + ("Supportability/TraceContext/Accept/Success", 1), + ("Supportability/TraceContext/TraceParent/Accept/Success", 1), + ("Supportability/TraceContext/Accept/Success", 1), + ("Supportability/DistributedTrace/AcceptPayload/Success", None), +] +NO_HEADERS_ACCEPTED = [ + ("Supportability/DistributedTrace/AcceptPayload/Success", None), + ("Supportability/TraceContext/Accept/Success", None), + ("Supportability/TraceContext/TraceParent/Accept/Success", None), + ("Supportability/TraceContext/Accept/Success", None), +] +TRACEPARENT = "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01" +TRACESTATE = "rojo=f06a0ba902b7,congo=t61rcWkgMzE" + + +@pytest.mark.parametrize( + "traceparent,tracestate,newrelic,metrics", + [ + (False, False, False, NO_HEADERS_ACCEPTED), + (False, False, True, NEW_RELIC_ACCEPTED), + (False, True, True, NEW_RELIC_ACCEPTED), + (False, True, False, NO_HEADERS_ACCEPTED), + (True, True, True, TRACE_CONTEXT_ACCEPTED), + (True, False, False, TRACE_CONTEXT_ACCEPTED), + (True, False, True, TRACE_CONTEXT_ACCEPTED), + (True, True, False, TRACE_CONTEXT_ACCEPTED), + ], +) @override_application_settings(_override_settings) -def test_distributed_tracing_backwards_compatibility(traceparent, - tracestate, - newrelic, - metrics): - +def test_distributed_tracing_backwards_compatibility(traceparent, tracestate, newrelic, metrics): headers = [] if traceparent: - headers.append(('traceparent', TRACEPARENT)) + headers.append(("traceparent", TRACEPARENT)) if tracestate: - headers.append(('tracestate', TRACESTATE)) + headers.append(("tracestate", TRACESTATE)) if newrelic: - headers.append(('newrelic', json.dumps(payload))) + headers.append(("newrelic", json.dumps(payload))) @validate_transaction_metrics( - "test_distributed_tracing_backwards_compatibility", - background_task=True, - rollup_metrics=metrics) - @background_task(name='test_distributed_tracing_backwards_compatibility') + "test_distributed_tracing_backwards_compatibility", background_task=True, rollup_metrics=metrics + ) + @background_task(name="test_distributed_tracing_backwards_compatibility") def _test(): - transaction = current_transaction() - transaction.accept_distributed_trace_headers(headers) + accept_distributed_trace_headers(headers) _test() -@background_task(name='test_current_trace_id_api_inside_transaction') +@background_task(name="test_current_trace_id_api_inside_transaction") def test_current_trace_id_api_inside_transaction(): trace_id = current_trace_id() assert len(trace_id) == 32 @@ -358,7 +353,7 @@ def test_current_trace_id_api_outside_transaction(): assert trace_id is None -@background_task(name='test_current_span_id_api_inside_transaction') +@background_task(name="test_current_span_id_api_inside_transaction") def test_current_span_id_inside_transaction(): span_id = current_span_id() assert span_id == current_trace().guid @@ -367,3 +362,65 @@ def test_current_span_id_inside_transaction(): def test_current_span_id_outside_transaction(): span_id = current_span_id() assert span_id is None + + +@pytest.mark.parametrize("trusted_account_key", ("1", None), ids=("tk_set", "tk_unset")) +def test_outbound_dt_payload_generation(trusted_account_key): + @override_application_settings( + { + "distributed_tracing.enabled": True, + "account_id": "1", + "trusted_account_key": trusted_account_key, + "primary_application_id": "1", + } + ) + @background_task(name="test_outbound_dt_payload_generation") + def _test_outbound_dt_payload_generation(): + transaction = current_transaction() + payload = ExternalTrace.generate_request_headers(transaction) + if trusted_account_key: + assert payload + # Ensure trusted account key present as vendor + assert dict(payload)["tracestate"].startswith("1@nr=") + else: + assert not payload + + _test_outbound_dt_payload_generation() + + +@pytest.mark.parametrize("trusted_account_key", ("1", None), ids=("tk_set", "tk_unset")) +def test_inbound_dt_payload_acceptance(trusted_account_key): + @override_application_settings( + { + "distributed_tracing.enabled": True, + "account_id": "1", + "trusted_account_key": trusted_account_key, + "primary_application_id": "1", + } + ) + @background_task(name="_test_inbound_dt_payload_acceptance") + def _test_inbound_dt_payload_acceptance(): + transaction = current_transaction() + + payload = { + "v": [0, 1], + "d": { + "ty": "Mobile", + "ac": "1", + "tk": "1", + "ap": "2827902", + "pa": "5e5733a911cfbc73", + "id": "7d3efb1b173fecfa", + "tr": "d6b4ba0c3a712ca", + "ti": 1518469636035, + "tx": "8703ff3d88eefe9d", + }, + } + + result = transaction.accept_distributed_trace_payload(payload) + if trusted_account_key: + assert result + else: + assert not result + + _test_inbound_dt_payload_acceptance() diff --git a/tests/agent_features/test_error_group_callback.py b/tests/agent_features/test_error_group_callback.py index 742391162..2fe2fc68c 100644 --- a/tests/agent_features/test_error_group_callback.py +++ b/tests/agent_features/test_error_group_callback.py @@ -12,35 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import threading import traceback -import sys import pytest - from testing_support.fixtures import ( override_application_settings, reset_core_stats_engine, +) +from testing_support.validators.validate_error_event_attributes import ( validate_error_event_attributes, +) +from testing_support.validators.validate_error_event_attributes_outside_transaction import ( validate_error_event_attributes_outside_transaction, - validate_error_trace_attributes_outside_transaction, ) from testing_support.validators.validate_error_trace_attributes import ( validate_error_trace_attributes, ) +from testing_support.validators.validate_error_trace_attributes_outside_transaction import ( + validate_error_trace_attributes_outside_transaction, +) from newrelic.api.application import application_instance as application from newrelic.api.background_task import background_task +from newrelic.api.settings import set_error_group_callback from newrelic.api.time_trace import notice_error from newrelic.api.transaction import current_transaction -from newrelic.api.settings import set_error_group_callback from newrelic.api.web_transaction import web_transaction from newrelic.common.object_names import callable_name - _callback_called = threading.Event() _truncated_value = "A" * 300 + def error_group_callback(exc, data): _callback_called.set() @@ -64,12 +69,9 @@ def test_clear_error_group_callback(): assert settings.error_collector.error_group_callback is None, "Failed to clear callback." -@pytest.mark.parametrize("callback,accepted", [ - (error_group_callback, True), - (lambda x, y: None, True), - (None, False), - ("string", False) -]) +@pytest.mark.parametrize( + "callback,accepted", [(error_group_callback, True), (lambda x, y: None, True), (None, False), ("string", False)] +) def test_set_error_group_callback(callback, accepted): try: set_error_group_callback(callback) @@ -82,15 +84,19 @@ def test_set_error_group_callback(callback, accepted): set_error_group_callback(None) -@pytest.mark.parametrize("exc_class,group_name,high_security", [ - (ValueError, "value", False), - (ValueError, "value", True), - (TypeError, None, False), - (RuntimeError, None, False), - (IndexError, None, False), - (LookupError, None, False), - (ZeroDivisionError, _truncated_value[:255], False), -], ids=("standard", "high-security", "empty-string", "None-value", "list-type", "int-type", "truncated-value")) +@pytest.mark.parametrize( + "exc_class,group_name,high_security", + [ + (ValueError, "value", False), + (ValueError, "value", True), + (TypeError, None, False), + (RuntimeError, None, False), + (IndexError, None, False), + (LookupError, None, False), + (ZeroDivisionError, _truncated_value[:255], False), + ], + ids=("standard", "high-security", "empty-string", "None-value", "list-type", "int-type", "truncated-value"), +) @reset_core_stats_engine() def test_error_group_name_callback(exc_class, group_name, high_security): _callback_called.clear() @@ -102,9 +108,7 @@ def test_error_group_name_callback(exc_class, group_name, high_security): exact = None forgone = {"user": [], "intrinsic": [], "agent": ["error.group.name"]} - @validate_error_trace_attributes( - callable_name(exc_class), forgone_params=forgone, exact_attrs=exact - ) + @validate_error_trace_attributes(callable_name(exc_class), forgone_params=forgone, exact_attrs=exact) @validate_error_event_attributes(forgone_params=forgone, exact_attrs=exact) @override_application_settings({"high_security": high_security}) @background_task() @@ -124,15 +128,19 @@ def _test(): set_error_group_callback(None) -@pytest.mark.parametrize("exc_class,group_name,high_security", [ - (ValueError, "value", False), - (ValueError, "value", True), - (TypeError, None, False), - (RuntimeError, None, False), - (IndexError, None, False), - (LookupError, None, False), - (ZeroDivisionError, _truncated_value[:255], False), -], ids=("standard", "high-security", "empty-string", "None-value", "list-type", "int-type", "truncated-value")) +@pytest.mark.parametrize( + "exc_class,group_name,high_security", + [ + (ValueError, "value", False), + (ValueError, "value", True), + (TypeError, None, False), + (RuntimeError, None, False), + (IndexError, None, False), + (LookupError, None, False), + (ZeroDivisionError, _truncated_value[:255], False), + ], + ids=("standard", "high-security", "empty-string", "None-value", "list-type", "int-type", "truncated-value"), +) @reset_core_stats_engine() def test_error_group_name_callback_outside_transaction(exc_class, group_name, high_security): _callback_called.clear() @@ -155,7 +163,7 @@ def _test(): except Exception: app = application() notice_error(application=app) - + assert _callback_called.is_set() try: @@ -165,11 +173,22 @@ def _test(): set_error_group_callback(None) -@pytest.mark.parametrize("transaction_decorator", [ - background_task(name="TestBackgroundTask"), - web_transaction(name="TestWebTransaction", host="localhost", port=1234, request_method="GET", request_path="/", headers=[],), - None, -], ids=("background_task", "web_transation", "outside_transaction")) +@pytest.mark.parametrize( + "transaction_decorator", + [ + background_task(name="TestBackgroundTask"), + web_transaction( + name="TestWebTransaction", + host="localhost", + port=1234, + request_method="GET", + request_path="/", + headers=[], + ), + None, + ], + ids=("background_task", "web_transation", "outside_transaction"), +) @reset_core_stats_engine() def test_error_group_name_callback_attributes(transaction_decorator): callback_errors = [] @@ -178,6 +197,7 @@ def test_error_group_name_callback_attributes(transaction_decorator): def callback(error, data): def _callback(): import types + _data.append(data) txn = current_transaction() @@ -191,23 +211,23 @@ def _callback(): # All attributes should always be included, but set to None when not relevant. if txn is None: # Outside transaction assert data["transactionName"] is None - assert data["custom_params"] == {'notice_error_attribute': 1} + assert data["custom_params"] == {"notice_error_attribute": 1} assert data["response.status"] is None assert data["request.method"] is None assert data["request.uri"] is None elif txn.background_task: # Background task assert data["transactionName"] == "TestBackgroundTask" - assert data["custom_params"] == {'notice_error_attribute': 1, 'txn_attribute': 2} + assert data["custom_params"] == {"notice_error_attribute": 1, "txn_attribute": 2} assert data["response.status"] is None assert data["request.method"] is None assert data["request.uri"] is None else: # Web transaction assert data["transactionName"] == "TestWebTransaction" - assert data["custom_params"] == {'notice_error_attribute': 1, 'txn_attribute': 2} + assert data["custom_params"] == {"notice_error_attribute": 1, "txn_attribute": 2} assert data["response.status"] == 200 assert data["request.method"] == "GET" assert data["request.uri"] == "/" - + try: _callback() except Exception: @@ -225,8 +245,11 @@ def _test(): except Exception: app = application() if transaction_decorator is None else None # Only set outside transaction notice_error(application=app, attributes={"notice_error_attribute": 1}) - - assert not callback_errors, "Callback inputs failed to validate.\nerror: %s\ndata: %s" % (traceback.format_exception(*callback_errors[0]), str(_data[0])) + + assert not callback_errors, "Callback inputs failed to validate.\nerror: %s\ndata: %s" % ( + traceback.format_exception(*callback_errors[0]), + str(_data[0]), + ) if transaction_decorator is not None: _test = transaction_decorator(_test) # Manually decorate test function diff --git a/tests/agent_features/test_high_security_mode.py b/tests/agent_features/test_high_security_mode.py index dad7edc29..d2ded9308 100644 --- a/tests/agent_features/test_high_security_mode.py +++ b/tests/agent_features/test_high_security_mode.py @@ -25,7 +25,6 @@ validate_custom_event_count, validate_custom_event_in_application_stats_engine, validate_request_params_omitted, - validate_tt_segment_params, ) from testing_support.validators.validate_custom_parameters import ( validate_custom_parameters, @@ -36,6 +35,9 @@ from testing_support.validators.validate_transaction_errors import ( validate_transaction_errors, ) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) from newrelic.api.application import application_instance as application from newrelic.api.background_task import background_task @@ -77,8 +79,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": False, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": True, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, { "high_security": False, @@ -86,8 +90,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": False, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "message_tracer.segment_parameters_enabled": True, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, { "high_security": False, @@ -95,8 +101,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, { "high_security": False, @@ -104,8 +112,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "off", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "message_tracer.segment_parameters_enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, ] @@ -116,8 +126,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": True, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, { "high_security": True, @@ -125,8 +137,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": True, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, { "high_security": True, @@ -134,8 +148,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": True, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, { "high_security": True, @@ -143,8 +159,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": True, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, { "high_security": True, @@ -152,8 +170,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": True, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, { "high_security": True, @@ -161,8 +181,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "off", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "message_tracer.segment_parameters_enabled": False, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, { "high_security": True, @@ -170,8 +192,10 @@ def test_hsm_configuration_default(): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": False, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "message_tracer.segment_parameters_enabled": False, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, ] @@ -194,8 +218,10 @@ def test_local_config_file_override_hsm_disabled(settings): original_record_sql = settings.transaction_tracer.record_sql original_strip_messages = settings.strip_exception_messages.enabled original_custom_events = settings.custom_insights_events.enabled + original_ml_events = settings.ml_insights_events.enabled original_message_segment_params_enabled = settings.message_tracer.segment_parameters_enabled original_application_logging_forwarding_enabled = settings.application_logging.forwarding.enabled + original_machine_learning_inference_event_value_enabled = settings.machine_learning.inference_events_value.enabled apply_local_high_security_mode_setting(settings) @@ -203,8 +229,13 @@ def test_local_config_file_override_hsm_disabled(settings): assert settings.transaction_tracer.record_sql == original_record_sql assert settings.strip_exception_messages.enabled == original_strip_messages assert settings.custom_insights_events.enabled == original_custom_events + assert settings.ml_insights_events.enabled == original_ml_events assert settings.message_tracer.segment_parameters_enabled == original_message_segment_params_enabled assert settings.application_logging.forwarding.enabled == original_application_logging_forwarding_enabled + assert ( + settings.machine_learning.inference_events_value.enabled + == original_machine_learning_inference_event_value_enabled + ) @parameterize_hsm_local_config(_hsm_local_config_file_settings_enabled) @@ -215,8 +246,10 @@ def test_local_config_file_override_hsm_enabled(settings): assert settings.transaction_tracer.record_sql in ("off", "obfuscated") assert settings.strip_exception_messages.enabled assert settings.custom_insights_events.enabled is False + assert settings.ml_insights_events.enabled is False assert settings.message_tracer.segment_parameters_enabled is False assert settings.application_logging.forwarding.enabled is False + assert settings.machine_learning.inference_events_value.enabled is False _server_side_config_settings_hsm_disabled = [ @@ -227,7 +260,9 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, { "agent_config": { @@ -235,7 +270,9 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": False, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, }, ), @@ -246,7 +283,9 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": False, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, { "agent_config": { @@ -254,7 +293,9 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "off", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, }, ), @@ -268,7 +309,9 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, { "high_security": True, @@ -276,13 +319,17 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, "agent_config": { "capture_params": False, "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, }, ), @@ -293,7 +340,9 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, }, { "high_security": True, @@ -301,13 +350,17 @@ def test_local_config_file_override_hsm_enabled(settings): "transaction_tracer.record_sql": "obfuscated", "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, "application_logging.forwarding.enabled": False, + "machine_learning.inference_events_value.enabled": False, "agent_config": { "capture_params": True, "transaction_tracer.record_sql": "raw", "strip_exception_messages.enabled": False, "custom_insights_events.enabled": True, + "ml_insights_events.enabled": True, "application_logging.forwarding.enabled": True, + "machine_learning.inference_events_value.enabled": True, }, }, ), @@ -327,7 +380,9 @@ def test_remote_config_fixups_hsm_disabled(local_settings, server_settings): original_record_sql = agent_config["transaction_tracer.record_sql"] original_strip_messages = agent_config["strip_exception_messages.enabled"] original_custom_events = agent_config["custom_insights_events.enabled"] + original_ml_events = agent_config["ml_insights_events.enabled"] original_log_forwarding = agent_config["application_logging.forwarding.enabled"] + original_machine_learning_events = agent_config["machine_learning.inference_events_value.enabled"] _settings = global_settings() settings = override_generic_settings(_settings, local_settings)(AgentProtocol._apply_high_security_mode_fixups)( @@ -342,7 +397,9 @@ def test_remote_config_fixups_hsm_disabled(local_settings, server_settings): assert agent_config["transaction_tracer.record_sql"] == original_record_sql assert agent_config["strip_exception_messages.enabled"] == original_strip_messages assert agent_config["custom_insights_events.enabled"] == original_custom_events + assert agent_config["ml_insights_events.enabled"] == original_ml_events assert agent_config["application_logging.forwarding.enabled"] == original_log_forwarding + assert agent_config["machine_learning.inference_events_value.enabled"] == original_machine_learning_events @pytest.mark.parametrize("local_settings,server_settings", _server_side_config_settings_hsm_enabled) @@ -364,13 +421,17 @@ def test_remote_config_fixups_hsm_enabled(local_settings, server_settings): assert "transaction_tracer.record_sql" not in settings assert "strip_exception_messages.enabled" not in settings assert "custom_insights_events.enabled" not in settings + assert "ml_insights_events.enabled" not in settings assert "application_logging.forwarding.enabled" not in settings + assert "machine_learning.inference_events_value.enabled" not in settings assert "capture_params" not in agent_config assert "transaction_tracer.record_sql" not in agent_config assert "strip_exception_messages.enabled" not in agent_config assert "custom_insights_events.enabled" not in agent_config + assert "ml_insights_events.enabled" not in agent_config assert "application_logging.forwarding.enabled" not in agent_config + assert "machine_learning.inference_events_value.enabled" not in agent_config def test_remote_config_hsm_fixups_server_side_disabled(): @@ -395,6 +456,7 @@ def test_remote_config_hsm_fixups_server_side_disabled(): "high_security": True, "strip_exception_messages.enabled": True, "custom_insights_events.enabled": False, + "ml_insights_events.enabled": False, } diff --git a/tests/agent_features/test_ignore_expected_errors.py b/tests/agent_features/test_ignore_expected_errors.py index 93595aa35..ee26245c5 100644 --- a/tests/agent_features/test_ignore_expected_errors.py +++ b/tests/agent_features/test_ignore_expected_errors.py @@ -16,8 +16,12 @@ from testing_support.fixtures import ( override_application_settings, reset_core_stats_engine, - validate_error_event_attributes_outside_transaction, validate_error_event_sample_data, +) +from testing_support.validators.validate_error_event_attributes_outside_transaction import ( + validate_error_event_attributes_outside_transaction, +) +from testing_support.validators.validate_error_trace_attributes_outside_transaction import ( validate_error_trace_attributes_outside_transaction, ) from testing_support.validators.validate_time_metrics_outside_transaction import ( diff --git a/tests/agent_features/test_metric_normalization.py b/tests/agent_features/test_metric_normalization.py new file mode 100644 index 000000000..65f2903ae --- /dev/null +++ b/tests/agent_features/test_metric_normalization.py @@ -0,0 +1,78 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_dimensional_metric_payload import ( + validate_dimensional_metric_payload, +) +from testing_support.validators.validate_metric_payload import validate_metric_payload + +from newrelic.api.application import application_instance +from newrelic.api.background_task import background_task +from newrelic.api.transaction import record_custom_metric, record_dimensional_metric +from newrelic.core.rules_engine import NormalizationRule, RulesEngine + +RULES = [{"match_expression": "(replace)", "replacement": "expected", "ignore": False, "eval_order": 0}] +EXPECTED_TAGS = frozenset({"tag": 1}.items()) + + +def _prepare_rules(test_rules): + # ensure all keys are present, if not present set to an empty string + for rule in test_rules: + for key in NormalizationRule._fields: + rule[key] = rule.get(key, "") + return test_rules + + +@pytest.fixture(scope="session") +def core_app(collector_agent_registration): + app = collector_agent_registration + return app._agent.application(app.name) + + +@pytest.fixture(scope="function") +def rules_engine_fixture(core_app): + rules_engine = core_app._rules_engine + previous_rules = rules_engine["metric"] + + rules_engine["metric"] = RulesEngine(_prepare_rules(RULES)) + yield + rules_engine["metric"] = previous_rules # Restore after test run + + +@validate_dimensional_metric_payload(summary_metrics=[("Metric/expected", EXPECTED_TAGS, 1)]) +@validate_metric_payload([("Metric/expected", 1)]) +@reset_core_stats_engine() +def test_metric_normalization_inside_transaction(core_app, rules_engine_fixture): + @background_task(name="test_record_dimensional_metric_inside_transaction") + def _test(): + record_dimensional_metric("Metric/replace", 1, tags={"tag": 1}) + record_custom_metric("Metric/replace", 1) + + _test() + core_app.harvest() + + +@validate_dimensional_metric_payload(summary_metrics=[("Metric/expected", EXPECTED_TAGS, 1)]) +@validate_metric_payload([("Metric/expected", 1)]) +@reset_core_stats_engine() +def test_metric_normalization_outside_transaction(core_app, rules_engine_fixture): + def _test(): + app = application_instance() + record_dimensional_metric("Metric/replace", 1, tags={"tag": 1}, application=app) + record_custom_metric("Metric/replace", 1, application=app) + + _test() + core_app.harvest() diff --git a/tests/agent_features/test_ml_events.py b/tests/agent_features/test_ml_events.py new file mode 100644 index 000000000..5720224bb --- /dev/null +++ b/tests/agent_features/test_ml_events.py @@ -0,0 +1,199 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest +from testing_support.fixtures import ( # function_not_called,; override_application_settings, + function_not_called, + override_application_settings, + reset_core_stats_engine, +) +from testing_support.validators.validate_ml_event_count import validate_ml_event_count +from testing_support.validators.validate_ml_event_payload import ( + validate_ml_event_payload, +) +from testing_support.validators.validate_ml_events import validate_ml_events +from testing_support.validators.validate_ml_events_outside_transaction import ( + validate_ml_events_outside_transaction, +) + +import newrelic.core.otlp_utils +from newrelic.api.application import application_instance as application +from newrelic.api.background_task import background_task +from newrelic.api.transaction import record_ml_event +from newrelic.core.config import global_settings +from newrelic.packages import six + +try: + # python 2.x + reload +except NameError: + # python 3.x + from importlib import reload + +_now = time.time() + +_intrinsics = { + "type": "LabelEvent", + "timestamp": _now, +} + + +@pytest.fixture(scope="session") +def core_app(collector_agent_registration): + app = collector_agent_registration + return app._agent.application(app.name) + + +@validate_ml_event_payload( + [{"foo": "bar", "real_agent_id": "1234567", "event.domain": "newrelic.ml_events", "event.name": "InferenceEvent"}] +) +@reset_core_stats_engine() +def test_ml_event_payload_inside_transaction(core_app): + @background_task(name="test_ml_event_payload_inside_transaction") + def _test(): + record_ml_event("InferenceEvent", {"foo": "bar"}) + + _test() + core_app.harvest() + + +@validate_ml_event_payload( + [{"foo": "bar", "real_agent_id": "1234567", "event.domain": "newrelic.ml_events", "event.name": "InferenceEvent"}] +) +@reset_core_stats_engine() +def test_ml_event_payload_outside_transaction(core_app): + def _test(): + app = application() + record_ml_event("InferenceEvent", {"foo": "bar"}, application=app) + + _test() + core_app.harvest() + + +@pytest.mark.parametrize( + "params,expected", + [ + ({"foo": "bar"}, [(_intrinsics, {"foo": "bar"})]), + ({"foo": "bar", 123: "bad key"}, [(_intrinsics, {"foo": "bar"})]), + ({"foo": "bar", "*" * 256: "too long"}, [(_intrinsics, {"foo": "bar"})]), + ], + ids=["Valid key/value", "Bad key", "Value too long"], +) +@reset_core_stats_engine() +def test_record_ml_event_inside_transaction(params, expected): + @validate_ml_events(expected) + @background_task() + def _test(): + record_ml_event("LabelEvent", params) + + _test() + + +@pytest.mark.parametrize( + "params,expected", + [ + ({"foo": "bar"}, [(_intrinsics, {"foo": "bar"})]), + ({"foo": "bar", 123: "bad key"}, [(_intrinsics, {"foo": "bar"})]), + ({"foo": "bar", "*" * 256: "too long"}, [(_intrinsics, {"foo": "bar"})]), + ], + ids=["Valid key/value", "Bad key", "Value too long"], +) +@reset_core_stats_engine() +def test_record_ml_event_outside_transaction(params, expected): + @validate_ml_events_outside_transaction(expected) + def _test(): + app = application() + record_ml_event("LabelEvent", params, application=app) + + _test() + + +@reset_core_stats_engine() +@validate_ml_event_count(count=0) +@background_task() +def test_record_ml_event_inside_transaction_bad_event_type(): + record_ml_event("!@#$%^&*()", {"foo": "bar"}) + + +@reset_core_stats_engine() +@validate_ml_event_count(count=0) +def test_record_ml_event_outside_transaction_bad_event_type(): + app = application() + record_ml_event("!@#$%^&*()", {"foo": "bar"}, application=app) + + +@reset_core_stats_engine() +@validate_ml_event_count(count=0) +@background_task() +def test_record_ml_event_inside_transaction_params_not_a_dict(): + record_ml_event("ParamsListEvent", ["not", "a", "dict"]) + + +@reset_core_stats_engine() +@validate_ml_event_count(count=0) +def test_record_ml_event_outside_transaction_params_not_a_dict(): + app = application() + record_ml_event("ParamsListEvent", ["not", "a", "dict"], application=app) + + +# Tests for ML Events configuration settings + +@override_application_settings({"ml_insights_events.enabled": False}) +@reset_core_stats_engine() +@validate_ml_event_count(count=0) +@background_task() +def test_ml_event_settings_check_ml_insights_disabled(): + record_ml_event("FooEvent", {"foo": "bar"}) + + +# Test that record_ml_event() methods will short-circuit. +# +# If the ml_insights_events setting is False, verify that the +# `create_ml_event()` function is not called, in order to avoid the +# event_type and attribute processing. + + +@override_application_settings({"ml_insights_events.enabled": False}) +@reset_core_stats_engine() +@function_not_called("newrelic.api.transaction", "create_custom_event") +@background_task() +def test_transaction_create_ml_event_not_called(): + record_ml_event("FooEvent", {"foo": "bar"}) + + +@override_application_settings({"ml_insights_events.enabled": False}) +@reset_core_stats_engine() +@function_not_called("newrelic.core.application", "create_custom_event") +@background_task() +def test_application_create_ml_event_not_called(): + app = application() + record_ml_event("FooEvent", {"foo": "bar"}, application=app) + + +@pytest.fixture(scope="module", autouse=True, params=["protobuf", "json"]) +def otlp_content_encoding(request): + if six.PY2 and request.param == "protobuf": + pytest.skip("OTLP protos are not compatible with Python 2.") + + _settings = global_settings() + prev = _settings.debug.otlp_content_encoding + _settings.debug.otlp_content_encoding = request.param + reload(newrelic.core.otlp_utils) + assert newrelic.core.otlp_utils.otlp_content_setting == request.param, "Content encoding mismatch." + + yield + + _settings.debug.otlp_content_encoding = prev diff --git a/tests/agent_features/test_profile_trace.py b/tests/agent_features/test_profile_trace.py new file mode 100644 index 000000000..f696b7480 --- /dev/null +++ b/tests/agent_features/test_profile_trace.py @@ -0,0 +1,88 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.api.profile_trace import ProfileTraceWrapper, profile_trace + + +def test_profile_trace_wrapper(): + def _test(): + def nested_fn(): + pass + + nested_fn() + + wrapped_test = ProfileTraceWrapper(_test) + wrapped_test() + + +@validate_transaction_metrics("test_profile_trace:test_profile_trace_empty_args", background_task=True) +@background_task() +def test_profile_trace_empty_args(): + @profile_trace() + def _test(): + pass + + _test() + + +_test_profile_trace_defined_args_scoped_metrics = [("Custom/TestTrace", 1)] + + +@validate_transaction_metrics( + "test_profile_trace:test_profile_trace_defined_args", + scoped_metrics=_test_profile_trace_defined_args_scoped_metrics, + background_task=True, +) +@background_task() +def test_profile_trace_defined_args(): + @profile_trace(name="TestTrace", group="Custom", label="Label", params={"key": "value"}, depth=7) + def _test(): + pass + + _test() + + +_test_profile_trace_callable_args_scoped_metrics = [("Function/TestProfileTrace", 1)] + + +@validate_transaction_metrics( + "test_profile_trace:test_profile_trace_callable_args", + scoped_metrics=_test_profile_trace_callable_args_scoped_metrics, + background_task=True, +) +@background_task() +def test_profile_trace_callable_args(): + def name_callable(): + return "TestProfileTrace" + + def group_callable(): + return "Function" + + def label_callable(): + return "HSM" + + def params_callable(): + return {"account_id": "12345"} + + @profile_trace(name=name_callable, group=group_callable, label=label_callable, params=params_callable, depth=0) + def _test(): + pass + + _test() diff --git a/tests/agent_features/test_serverless_mode.py b/tests/agent_features/test_serverless_mode.py index 75b5f0075..189481f70 100644 --- a/tests/agent_features/test_serverless_mode.py +++ b/tests/agent_features/test_serverless_mode.py @@ -13,7 +13,16 @@ # limitations under the License. import json + import pytest +from testing_support.fixtures import override_generic_settings +from testing_support.validators.validate_serverless_data import validate_serverless_data +from testing_support.validators.validate_serverless_metadata import ( + validate_serverless_metadata, +) +from testing_support.validators.validate_serverless_payload import ( + validate_serverless_payload, +) from newrelic.api.application import application_instance from newrelic.api.background_task import background_task @@ -22,23 +31,14 @@ from newrelic.api.transaction import current_transaction from newrelic.core.config import global_settings -from testing_support.fixtures import override_generic_settings -from testing_support.validators.validate_serverless_data import ( - validate_serverless_data) -from testing_support.validators.validate_serverless_payload import ( - validate_serverless_payload) -from testing_support.validators.validate_serverless_metadata import ( - validate_serverless_metadata) - -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def serverless_application(request): settings = global_settings() orig = settings.serverless_mode.enabled settings.serverless_mode.enabled = True - application_name = 'Python Agent Test (test_serverless_mode:%s)' % ( - request.node.name) + application_name = "Python Agent Test (test_serverless_mode:%s)" % (request.node.name) application = application_instance(application_name) application.activate() @@ -48,17 +48,18 @@ def serverless_application(request): def test_serverless_payload(capsys, serverless_application): - - @override_generic_settings(serverless_application.settings, { - 'distributed_tracing.enabled': True, - }) + @override_generic_settings( + serverless_application.settings, + { + "distributed_tracing.enabled": True, + }, + ) @validate_serverless_data( - expected_methods=('metric_data', 'analytic_event_data'), - forgone_methods=('preconnect', 'connect', 'get_agent_commands')) + expected_methods=("metric_data", "analytic_event_data"), + forgone_methods=("preconnect", "connect", "get_agent_commands"), + ) @validate_serverless_payload() - @background_task( - application=serverless_application, - name='test_serverless_payload') + @background_task(application=serverless_application, name="test_serverless_payload") def _test(): transaction = current_transaction() assert transaction.settings.serverless_mode.enabled @@ -75,17 +76,15 @@ def _test(): def test_no_cat_headers(serverless_application): - @background_task( - application=serverless_application, - name='test_cat_headers') + @background_task(application=serverless_application, name="test_cat_headers") def _test_cat_headers(): transaction = current_transaction() payload = ExternalTrace.generate_request_headers(transaction) assert not payload - trace = ExternalTrace('testlib', 'http://example.com') - response_headers = [('X-NewRelic-App-Data', 'Cookies')] + trace = ExternalTrace("testlib", "http://example.com") + response_headers = [("X-NewRelic-App-Data", "Cookies")] with trace: trace.process_response_headers(response_headers) @@ -94,61 +93,66 @@ def _test_cat_headers(): _test_cat_headers() -def test_dt_outbound(serverless_application): - @override_generic_settings(serverless_application.settings, { - 'distributed_tracing.enabled': True, - 'account_id': '1', - 'trusted_account_key': '1', - 'primary_application_id': '1', - }) - @background_task( - application=serverless_application, - name='test_dt_outbound') - def _test_dt_outbound(): +@pytest.mark.parametrize("trusted_account_key", ("1", None), ids=("tk_set", "tk_unset")) +def test_outbound_dt_payload_generation(serverless_application, trusted_account_key): + @override_generic_settings( + serverless_application.settings, + { + "distributed_tracing.enabled": True, + "account_id": "1", + "trusted_account_key": trusted_account_key, + "primary_application_id": "1", + }, + ) + @background_task(application=serverless_application, name="test_outbound_dt_payload_generation") + def _test_outbound_dt_payload_generation(): transaction = current_transaction() payload = ExternalTrace.generate_request_headers(transaction) assert payload - - _test_dt_outbound() - - -def test_dt_inbound(serverless_application): - @override_generic_settings(serverless_application.settings, { - 'distributed_tracing.enabled': True, - 'account_id': '1', - 'trusted_account_key': '1', - 'primary_application_id': '1', - }) - @background_task( - application=serverless_application, - name='test_dt_inbound') - def _test_dt_inbound(): + # Ensure trusted account key or account ID present as vendor + assert dict(payload)["tracestate"].startswith("1@nr=") + + _test_outbound_dt_payload_generation() + + +@pytest.mark.parametrize("trusted_account_key", ("1", None), ids=("tk_set", "tk_unset")) +def test_inbound_dt_payload_acceptance(serverless_application, trusted_account_key): + @override_generic_settings( + serverless_application.settings, + { + "distributed_tracing.enabled": True, + "account_id": "1", + "trusted_account_key": trusted_account_key, + "primary_application_id": "1", + }, + ) + @background_task(application=serverless_application, name="test_inbound_dt_payload_acceptance") + def _test_inbound_dt_payload_acceptance(): transaction = current_transaction() payload = { - 'v': [0, 1], - 'd': { - 'ty': 'Mobile', - 'ac': '1', - 'tk': '1', - 'ap': '2827902', - 'pa': '5e5733a911cfbc73', - 'id': '7d3efb1b173fecfa', - 'tr': 'd6b4ba0c3a712ca', - 'ti': 1518469636035, - 'tx': '8703ff3d88eefe9d', - } + "v": [0, 1], + "d": { + "ty": "Mobile", + "ac": "1", + "tk": "1", + "ap": "2827902", + "pa": "5e5733a911cfbc73", + "id": "7d3efb1b173fecfa", + "tr": "d6b4ba0c3a712ca", + "ti": 1518469636035, + "tx": "8703ff3d88eefe9d", + }, } result = transaction.accept_distributed_trace_payload(payload) assert result - _test_dt_inbound() + _test_inbound_dt_payload_acceptance() -@pytest.mark.parametrize('arn_set', (True, False)) +@pytest.mark.parametrize("arn_set", (True, False)) def test_payload_metadata_arn(serverless_application, arn_set): - # If the session object gathers the arn from the settings object before the # lambda handler records it there, then this test will fail. @@ -157,17 +161,17 @@ def test_payload_metadata_arn(serverless_application, arn_set): arn = None if arn_set: - arn = 'arrrrrrrrrrRrrrrrrrn' + arn = "arrrrrrrrrrRrrrrrrrn" - settings.aws_lambda_metadata.update({'arn': arn, 'function_version': '$LATEST'}) + settings.aws_lambda_metadata.update({"arn": arn, "function_version": "$LATEST"}) class Context(object): invoked_function_arn = arn - @validate_serverless_metadata(exact_metadata={'arn': arn}) + @validate_serverless_metadata(exact_metadata={"arn": arn}) @lambda_handler(application=serverless_application) def handler(event, context): - assert settings.aws_lambda_metadata['arn'] == arn + assert settings.aws_lambda_metadata["arn"] == arn return {} try: diff --git a/tests/agent_features/test_span_events.py b/tests/agent_features/test_span_events.py index 655efee8c..05e375ff3 100644 --- a/tests/agent_features/test_span_events.py +++ b/tests/agent_features/test_span_events.py @@ -19,7 +19,6 @@ dt_enabled, function_not_called, override_application_settings, - validate_tt_segment_params, ) from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_event_attributes import ( @@ -28,6 +27,9 @@ from testing_support.validators.validate_transaction_metrics import ( validate_transaction_metrics, ) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) from newrelic.api.background_task import background_task from newrelic.api.database_trace import DatabaseTrace @@ -139,7 +141,6 @@ def test_each_span_type(trace_type, args): ) @background_task(name="test_each_span_type") def _test(): - transaction = current_transaction() transaction._sampled = True @@ -305,7 +306,6 @@ def _test(): } ) def test_external_span_limits(kwarg_override, attr_override): - exact_intrinsics = { "type": "Span", "sampled": True, @@ -362,7 +362,6 @@ def _test(): } ) def test_datastore_span_limits(kwarg_override, attribute_override): - exact_intrinsics = { "type": "Span", "sampled": True, @@ -414,10 +413,6 @@ def _test(): @pytest.mark.parametrize("span_events_enabled", (False, True)) def test_collect_span_events_override(collect_span_events, span_events_enabled): spans_expected = collect_span_events and span_events_enabled - # if collect_span_events and span_events_enabled: - # spans_expected = True - # else: - # spans_expected = False span_count = 2 if spans_expected else 0 @@ -507,7 +502,6 @@ def __exit__(self, *args): ) @pytest.mark.parametrize("exclude_attributes", (True, False)) def test_span_event_user_attributes(trace_type, args, exclude_attributes): - _settings = { "distributed_tracing.enabled": True, "span_events.enabled": True, @@ -624,7 +618,6 @@ def _test(): ), ) def test_span_event_error_attributes_notice_error(trace_type, args): - _settings = { "distributed_tracing.enabled": True, "span_events.enabled": True, @@ -672,7 +665,6 @@ def _test(): ), ) def test_span_event_error_attributes_observed(trace_type, args): - error = ValueError("whoops") exact_agents = { @@ -725,7 +717,7 @@ def test_span_event_notice_error_overrides_observed(trace_type, args): raise ERROR except Exception: notice_error() - raise ValueError # pylint: disable + raise ValueError # pylint: disable (Py2/Py3 compatibility) except ValueError: pass diff --git a/tests/agent_features/test_transaction_trace_segments.py b/tests/agent_features/test_transaction_trace_segments.py index b205afc3c..8318c0fca 100644 --- a/tests/agent_features/test_transaction_trace_segments.py +++ b/tests/agent_features/test_transaction_trace_segments.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest -from testing_support.fixtures import ( - override_application_settings, +from testing_support.fixtures import override_application_settings +from testing_support.validators.validate_tt_segment_params import ( validate_tt_segment_params, ) diff --git a/tests/agent_features/test_wsgi_attributes.py b/tests/agent_features/test_wsgi_attributes.py index e90410b6d..db9fc807a 100644 --- a/tests/agent_features/test_wsgi_attributes.py +++ b/tests/agent_features/test_wsgi_attributes.py @@ -13,12 +13,11 @@ # limitations under the License. import webtest -from testing_support.fixtures import ( - dt_enabled, - override_application_settings, +from testing_support.fixtures import dt_enabled, override_application_settings +from testing_support.sample_applications import fully_featured_app +from testing_support.validators.validate_error_event_attributes import ( validate_error_event_attributes, ) -from testing_support.sample_applications import fully_featured_app from testing_support.validators.validate_transaction_error_trace_attributes import ( validate_transaction_error_trace_attributes, ) diff --git a/tests/agent_unittests/test_agent_protocol.py b/tests/agent_unittests/test_agent_protocol.py index ba75358ab..1f0401439 100644 --- a/tests/agent_unittests/test_agent_protocol.py +++ b/tests/agent_unittests/test_agent_protocol.py @@ -565,6 +565,7 @@ def test_ca_bundle_path(monkeypatch, ca_bundle_path): # Pretend CA certificates are not available class DefaultVerifyPaths(object): cafile = None + capath = None def __init__(self, *args, **kwargs): pass diff --git a/tests/agent_unittests/test_harvest_loop.py b/tests/agent_unittests/test_harvest_loop.py index 305622107..15b67a81e 100644 --- a/tests/agent_unittests/test_harvest_loop.py +++ b/tests/agent_unittests/test_harvest_loop.py @@ -32,7 +32,7 @@ from newrelic.core.function_node import FunctionNode from newrelic.core.log_event_node import LogEventNode from newrelic.core.root_node import RootNode -from newrelic.core.stats_engine import CustomMetrics, SampledDataSet +from newrelic.core.stats_engine import CustomMetrics, SampledDataSet, DimensionalMetrics from newrelic.core.transaction_node import TransactionNode from newrelic.network.exceptions import RetryDataForRequest @@ -49,6 +49,11 @@ def transaction_node(request): event = create_custom_event("Custom", {}) custom_events.add(event) + ml_events = SampledDataSet(capacity=num_events) + for _ in range(num_events): + event = create_custom_event("Custom", {}) + ml_events.add(event) + log_events = SampledDataSet(capacity=num_events) for _ in range(num_events): event = LogEventNode(1653609717, "WARNING", "A", {}) @@ -122,10 +127,12 @@ def transaction_node(request): errors=errors, slow_sql=(), custom_events=custom_events, + ml_events=ml_events, log_events=log_events, apdex_t=0.5, suppress_apdex=False, custom_metrics=CustomMetrics(), + dimensional_metrics=DimensionalMetrics(), guid="4485b89db608aece", cpu_time=0.0, suppress_transaction_trace=False, @@ -818,6 +825,7 @@ def test_flexible_events_harvested(allowlist_event): app._stats_engine.log_events.add(LogEventNode(1653609717, "WARNING", "A", {})) app._stats_engine.span_events.add("span event") app._stats_engine.record_custom_metric("CustomMetric/Int", 1) + app._stats_engine.record_dimensional_metric("DimensionalMetric/Int", 1, tags={"tag": "tag"}) assert app._stats_engine.transaction_events.num_seen == 1 assert app._stats_engine.error_events.num_seen == 1 @@ -825,6 +833,7 @@ def test_flexible_events_harvested(allowlist_event): assert app._stats_engine.log_events.num_seen == 1 assert app._stats_engine.span_events.num_seen == 1 assert app._stats_engine.record_custom_metric("CustomMetric/Int", 1) + assert app._stats_engine.record_dimensional_metric("DimensionalMetric/Int", 1, tags={"tag": "tag"}) app.harvest(flexible=True) @@ -844,7 +853,8 @@ def test_flexible_events_harvested(allowlist_event): assert app._stats_engine.span_events.num_seen == num_seen assert ("CustomMetric/Int", "") in app._stats_engine.stats_table - assert app._stats_engine.metrics_count() > 1 + assert ("DimensionalMetric/Int", frozenset({("tag", "tag")})) in app._stats_engine.dimensional_stats_table + assert app._stats_engine.metrics_count() > 3 @pytest.mark.parametrize( diff --git a/tests/agent_unittests/test_http_client.py b/tests/agent_unittests/test_http_client.py index a5c340d6a..df409f932 100644 --- a/tests/agent_unittests/test_http_client.py +++ b/tests/agent_unittests/test_http_client.py @@ -325,7 +325,7 @@ def test_http_payload_compression(server, client_cls, method, threshold): # Verify the compressed payload length is recorded assert internal_metrics["Supportability/Python/Collector/method1/ZLIB/Bytes"][:2] == [1, payload_byte_len] assert internal_metrics["Supportability/Python/Collector/ZLIB/Bytes"][:2] == [2, payload_byte_len*2] - + assert len(internal_metrics) == 8 else: # Verify no ZLIB compression metrics were sent @@ -366,11 +366,14 @@ def test_cert_path(server): def test_default_cert_path(monkeypatch, system_certs_available): if system_certs_available: cert_file = "foo" + ca_path = "/usr/certs" else: cert_file = None + ca_path = None class DefaultVerifyPaths(object): cafile = cert_file + capath = ca_path def __init__(self, *args, **kwargs): pass diff --git a/tests/agent_unittests/test_package_version_utils.py b/tests/agent_unittests/test_package_version_utils.py index d80714d77..376c8c7e0 100644 --- a/tests/agent_unittests/test_package_version_utils.py +++ b/tests/agent_unittests/test_package_version_utils.py @@ -13,22 +13,33 @@ # limitations under the License. import sys +import warnings import pytest +import six from testing_support.validators.validate_function_called import validate_function_called from newrelic.common.package_version_utils import ( NULL_VERSIONS, VERSION_ATTRS, + _get_package_version, get_package_version, get_package_version_tuple, ) +# Notes: +# importlib.metadata was a provisional addition to the std library in PY38 and PY39 +# while pkg_resources was deprecated. +# importlib.metadata is no longer provisional in PY310+. It added some attributes +# such as distribution_packages and removed pkg_resources. + IS_PY38_PLUS = sys.version_info[:2] >= (3, 8) +IS_PY310_PLUS = sys.version_info[:2] >= (3, 10) SKIP_IF_NOT_IMPORTLIB_METADATA = pytest.mark.skipif(not IS_PY38_PLUS, reason="importlib.metadata is not supported.") SKIP_IF_IMPORTLIB_METADATA = pytest.mark.skipif( IS_PY38_PLUS, reason="importlib.metadata is preferred over pkg_resources." ) +SKIP_IF_NOT_PY310_PLUS = pytest.mark.skipif(not IS_PY310_PLUS, reason="These features were added in 3.10+") @pytest.fixture(scope="function", autouse=True) @@ -40,6 +51,14 @@ def patched_pytest_module(monkeypatch): yield pytest +@pytest.fixture(scope="function", autouse=True) +def cleared_package_version_cache(): + """Ensure cache is empty before every test to exercise code paths.""" + _get_package_version.cache_clear() + + +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA @pytest.mark.parametrize( "attr,value,expected_value", ( @@ -49,15 +68,29 @@ def patched_pytest_module(monkeypatch): ("version_tuple", [3, 1, "0b2"], "3.1.0b2"), ), ) -def test_get_package_version(attr, value, expected_value): +def test_get_package_version(monkeypatch, attr, value, expected_value): # There is no file/module here, so we monkeypatch # pytest instead for our purposes - setattr(pytest, attr, value) + monkeypatch.setattr(pytest, attr, value, raising=False) version = get_package_version("pytest") assert version == expected_value - delattr(pytest, attr) +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA +def test_skips_version_callables(monkeypatch): + # There is no file/module here, so we monkeypatch + # pytest instead for our purposes + monkeypatch.setattr(pytest, "version", lambda x: "1.2.3.4", raising=False) + monkeypatch.setattr(pytest, "version_tuple", [3, 1, "0b2"], raising=False) + + version = get_package_version("pytest") + + assert version == "3.1.0b2" + + +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA @pytest.mark.parametrize( "attr,value,expected_value", ( @@ -67,13 +100,12 @@ def test_get_package_version(attr, value, expected_value): ("version_tuple", [3, 1, "0b2"], (3, 1, "0b2")), ), ) -def test_get_package_version_tuple(attr, value, expected_value): +def test_get_package_version_tuple(monkeypatch, attr, value, expected_value): # There is no file/module here, so we monkeypatch # pytest instead for our purposes - setattr(pytest, attr, value) + monkeypatch.setattr(pytest, attr, value, raising=False) version = get_package_version_tuple("pytest") assert version == expected_value - delattr(pytest, attr) @SKIP_IF_NOT_IMPORTLIB_METADATA @@ -83,8 +115,46 @@ def test_importlib_metadata(): assert version not in NULL_VERSIONS, version +@SKIP_IF_NOT_PY310_PLUS +@validate_function_called("importlib.metadata", "packages_distributions") +def test_mapping_import_to_distribution_packages(): + version = get_package_version("pytest") + assert version not in NULL_VERSIONS, version + + @SKIP_IF_IMPORTLIB_METADATA @validate_function_called("pkg_resources", "get_distribution") def test_pkg_resources_metadata(): version = get_package_version("pytest") assert version not in NULL_VERSIONS, version + + +def _getattr_deprecation_warning(attr): + if attr == "__version__": + warnings.warn("Testing deprecation warnings.", DeprecationWarning) + return "3.2.1" + else: + raise NotImplementedError() + + +@pytest.mark.skipif(six.PY2, reason="Can't add Deprecation in __version__ in Python 2.") +def test_deprecation_warning_suppression(monkeypatch, recwarn): + # Add fake module to be deleted later + monkeypatch.setattr(pytest, "__getattr__", _getattr_deprecation_warning, raising=False) + + assert get_package_version("pytest") == "3.2.1" + + assert not recwarn.list, "Warnings not suppressed." + + +def test_version_caching(monkeypatch): + # Add fake module to be deleted later + sys.modules["mymodule"] = sys.modules["pytest"] + monkeypatch.setattr(pytest, "__version__", "1.0.0", raising=False) + version = get_package_version("mymodule") + assert version not in NULL_VERSIONS, version + + # Ensure after deleting that the call to _get_package_version still completes because of caching + del sys.modules["mymodule"] + version = get_package_version("mymodule") + assert version not in NULL_VERSIONS, version diff --git a/tests/agent_unittests/test_utilization_settings.py b/tests/agent_unittests/test_utilization_settings.py index 8af4bcbf1..96cf47669 100644 --- a/tests/agent_unittests/test_utilization_settings.py +++ b/tests/agent_unittests/test_utilization_settings.py @@ -118,6 +118,22 @@ def reset(wrapped, instance, args, kwargs): return reset +@reset_agent_config(INI_FILE_WITHOUT_UTIL_CONF, ENV_WITHOUT_UTIL_CONF) +def test_otlp_host_port_default(): + settings = global_settings() + assert settings.otlp_host == "otlp.nr-data.net" + assert settings.otlp_port == 0 + + +@reset_agent_config( + INI_FILE_WITHOUT_UTIL_CONF, {"NEW_RELIC_OTLP_HOST": "custom-otlp.nr-data.net", "NEW_RELIC_OTLP_PORT": 443} +) +def test_otlp_port_override(): + settings = global_settings() + assert settings.otlp_host == "custom-otlp.nr-data.net" + assert settings.otlp_port == 443 + + @reset_agent_config(INI_FILE_WITHOUT_UTIL_CONF, ENV_WITHOUT_UTIL_CONF) def test_heroku_default(): settings = global_settings() diff --git a/tests/application_gearman/test_gearman.py b/tests/application_gearman/test_gearman.py index 7ddc13fdc..5dda4ef47 100644 --- a/tests/application_gearman/test_gearman.py +++ b/tests/application_gearman/test_gearman.py @@ -20,14 +20,16 @@ import gearman from newrelic.api.background_task import background_task +from testing_support.db_settings import gearman_settings worker_thread = None worker_event = threading.Event() gm_client = None -GEARMAND_HOST = os.environ.get("GEARMAND_PORT_4730_TCP_ADDR", "localhost") -GEARMAND_PORT = os.environ.get("GEARMAND_PORT_4730_TCP_PORT", "4730") +GEARMAND_SETTINGS = gearman_settings()[0] +GEARMAND_HOST = GEARMAND_SETTINGS["host"] +GEARMAND_PORT = GEARMAND_SETTINGS["port"] GEARMAND_ADDR = "%s:%s" % (GEARMAND_HOST, GEARMAND_PORT) diff --git a/tests/component_djangorestframework/test_application.py b/tests/component_djangorestframework/test_application.py index 9ed60aa33..29861dca8 100644 --- a/tests/component_djangorestframework/test_application.py +++ b/tests/component_djangorestframework/test_application.py @@ -12,190 +12,168 @@ # See the License for the specific language governing permissions and # limitations under the License. +import django import pytest import webtest +from testing_support.fixtures import function_not_called, override_generic_settings +from testing_support.validators.validate_code_level_metrics import ( + validate_code_level_metrics, +) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) -from newrelic.packages import six from newrelic.core.config import global_settings +from newrelic.packages import six -from testing_support.fixtures import ( - override_generic_settings, - function_not_called) -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_code_level_metrics import validate_code_level_metrics -import django - -DJANGO_VERSION = tuple(map(int, django.get_version().split('.')[:2])) - +DJANGO_VERSION = tuple(map(int, django.get_version().split(".")[:2])) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def target_application(): from wsgi import application + test_application = webtest.TestApp(application) return test_application if DJANGO_VERSION >= (1, 10): - url_module_path = 'django.urls.resolvers' + url_module_path = "django.urls.resolvers" # Django 1.10 new style middleware removed individual process_* methods. # All middleware in Django 1.10+ is called through the __call__ methods on # middlwares. - process_request_method = '' - process_view_method = '' - process_response_method = '' + process_request_method = "" + process_view_method = "" + process_response_method = "" else: - url_module_path = 'django.core.urlresolvers' - process_request_method = '.process_request' - process_view_method = '.process_view' - process_response_method = '.process_response' + url_module_path = "django.core.urlresolvers" + process_request_method = ".process_request" + process_view_method = ".process_view" + process_response_method = ".process_response" if DJANGO_VERSION >= (2, 0): - url_resolver_cls = 'URLResolver' + url_resolver_cls = "URLResolver" else: - url_resolver_cls = 'RegexURLResolver' + url_resolver_cls = "RegexURLResolver" _scoped_metrics = [ - ('Function/django.core.handlers.wsgi:WSGIHandler.__call__', 1), - ('Python/WSGI/Application', 1), - ('Python/WSGI/Response', 1), - ('Python/WSGI/Finalize', 1), - (('Function/django.middleware.common:' - 'CommonMiddleware' + process_request_method), 1), - (('Function/django.contrib.sessions.middleware:' - 'SessionMiddleware' + process_request_method), 1), - (('Function/django.contrib.auth.middleware:' - 'AuthenticationMiddleware' + process_request_method), 1), - (('Function/django.contrib.messages.middleware:' - 'MessageMiddleware' + process_request_method), 1), - (('Function/%s:' % url_module_path + - '%s.resolve' % url_resolver_cls), 1), - (('Function/django.middleware.csrf:' - 'CsrfViewMiddleware' + process_view_method), 1), - (('Function/django.contrib.messages.middleware:' - 'MessageMiddleware' + process_response_method), 1), - (('Function/django.middleware.csrf:' - 'CsrfViewMiddleware' + process_response_method), 1), - (('Function/django.contrib.sessions.middleware:' - 'SessionMiddleware' + process_response_method), 1), - (('Function/django.middleware.common:' - 'CommonMiddleware' + process_response_method), 1), + ("Function/django.core.handlers.wsgi:WSGIHandler.__call__", 1), + ("Python/WSGI/Application", 1), + ("Python/WSGI/Response", 1), + ("Python/WSGI/Finalize", 1), + (("Function/django.middleware.common:CommonMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.sessions.middleware:SessionMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.auth.middleware:AuthenticationMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.messages.middleware:MessageMiddleware%s" % process_request_method), 1), + (("Function/%s:%s.resolve" % (url_module_path, url_resolver_cls)), 1), + (("Function/django.middleware.csrf:CsrfViewMiddleware%s" % process_view_method), 1), + (("Function/django.contrib.messages.middleware:MessageMiddleware%s" % process_response_method), 1), + (("Function/django.middleware.csrf:CsrfViewMiddleware%s" % process_response_method), 1), + (("Function/django.contrib.sessions.middleware:SessionMiddleware%s" % process_response_method), 1), + (("Function/django.middleware.common:CommonMiddleware%s" % process_response_method), 1), ] _test_application_index_scoped_metrics = list(_scoped_metrics) -_test_application_index_scoped_metrics.append(('Function/views:index', 1)) +_test_application_index_scoped_metrics.append(("Function/views:index", 1)) if DJANGO_VERSION >= (1, 5): - _test_application_index_scoped_metrics.extend([ - ('Function/django.http.response:HttpResponse.close', 1)]) + _test_application_index_scoped_metrics.extend([("Function/django.http.response:HttpResponse.close", 1)]) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics('views:index', - scoped_metrics=_test_application_index_scoped_metrics) +@validate_transaction_metrics("views:index", scoped_metrics=_test_application_index_scoped_metrics) @validate_code_level_metrics("views", "index") def test_application_index(target_application): - response = target_application.get('') - response.mustcontain('INDEX RESPONSE') + response = target_application.get("") + response.mustcontain("INDEX RESPONSE") _test_application_view_scoped_metrics = list(_scoped_metrics) -_test_application_view_scoped_metrics.append(('Function/urls:View.get', 1)) +_test_application_view_scoped_metrics.append(("Function/urls:View.get", 1)) if DJANGO_VERSION >= (1, 5): - _test_application_view_scoped_metrics.extend([ - ('Function/rest_framework.response:Response.close', 1)]) + _test_application_view_scoped_metrics.extend([("Function/rest_framework.response:Response.close", 1)]) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics('urls:View.get', - scoped_metrics=_test_application_view_scoped_metrics) +@validate_transaction_metrics("urls:View.get", scoped_metrics=_test_application_view_scoped_metrics) @validate_code_level_metrics("urls.View", "get") def test_application_view(target_application): - response = target_application.get('/view/') + response = target_application.get("/view/") assert response.status_int == 200 - response.mustcontain('restframework view response') + response.mustcontain("restframework view response") _test_application_view_error_scoped_metrics = list(_scoped_metrics) -_test_application_view_error_scoped_metrics.append( - ('Function/urls:ViewError.get', 1)) +_test_application_view_error_scoped_metrics.append(("Function/urls:ViewError.get", 1)) -@validate_transaction_errors(errors=['urls:Error']) -@validate_transaction_metrics('urls:ViewError.get', - scoped_metrics=_test_application_view_error_scoped_metrics) +@validate_transaction_errors(errors=["urls:Error"]) +@validate_transaction_metrics("urls:ViewError.get", scoped_metrics=_test_application_view_error_scoped_metrics) @validate_code_level_metrics("urls.ViewError", "get") def test_application_view_error(target_application): - target_application.get('/view_error/', status=500) + target_application.get("/view_error/", status=500) _test_application_view_handle_error_scoped_metrics = list(_scoped_metrics) -_test_application_view_handle_error_scoped_metrics.append( - ('Function/urls:ViewHandleError.get', 1)) +_test_application_view_handle_error_scoped_metrics.append(("Function/urls:ViewHandleError.get", 1)) -@pytest.mark.parametrize('status,should_record', [(418, True), (200, False)]) -@pytest.mark.parametrize('use_global_exc_handler', [True, False]) +@pytest.mark.parametrize("status,should_record", [(418, True), (200, False)]) +@pytest.mark.parametrize("use_global_exc_handler", [True, False]) @validate_code_level_metrics("urls.ViewHandleError", "get") -def test_application_view_handle_error(status, should_record, - use_global_exc_handler, target_application): - errors = ['urls:Error'] if should_record else [] +def test_application_view_handle_error(status, should_record, use_global_exc_handler, target_application): + errors = ["urls:Error"] if should_record else [] @validate_transaction_errors(errors=errors) - @validate_transaction_metrics('urls:ViewHandleError.get', - scoped_metrics=_test_application_view_handle_error_scoped_metrics) + @validate_transaction_metrics( + "urls:ViewHandleError.get", scoped_metrics=_test_application_view_handle_error_scoped_metrics + ) def _test(): - response = target_application.get( - '/view_handle_error/%s/%s/' % (status, use_global_exc_handler), - status=status) + response = target_application.get("/view_handle_error/%s/%s/" % (status, use_global_exc_handler), status=status) if use_global_exc_handler: - response.mustcontain('exception was handled global') + response.mustcontain("exception was handled global") else: - response.mustcontain('exception was handled not global') + response.mustcontain("exception was handled not global") _test() -_test_api_view_view_name_get = 'urls:wrapped_view.get' +_test_api_view_view_name_get = "urls:wrapped_view.get" _test_api_view_scoped_metrics_get = list(_scoped_metrics) -_test_api_view_scoped_metrics_get.append( - ('Function/%s' % _test_api_view_view_name_get, 1)) +_test_api_view_scoped_metrics_get.append(("Function/%s" % _test_api_view_view_name_get, 1)) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics(_test_api_view_view_name_get, - scoped_metrics=_test_api_view_scoped_metrics_get) -@validate_code_level_metrics("urls.WrappedAPIView" if six.PY3 else "urls", "wrapped_view") +@validate_transaction_metrics(_test_api_view_view_name_get, scoped_metrics=_test_api_view_scoped_metrics_get) +@validate_code_level_metrics("urls.WrappedAPIView", "wrapped_view", py2_namespace="urls") def test_api_view_get(target_application): - response = target_application.get('/api_view/') - response.mustcontain('wrapped_view response') + response = target_application.get("/api_view/") + response.mustcontain("wrapped_view response") -_test_api_view_view_name_post = 'urls:wrapped_view.http_method_not_allowed' +_test_api_view_view_name_post = "urls:wrapped_view.http_method_not_allowed" _test_api_view_scoped_metrics_post = list(_scoped_metrics) -_test_api_view_scoped_metrics_post.append( - ('Function/%s' % _test_api_view_view_name_post, 1)) +_test_api_view_scoped_metrics_post.append(("Function/%s" % _test_api_view_view_name_post, 1)) -@validate_transaction_errors( - errors=['rest_framework.exceptions:MethodNotAllowed']) -@validate_transaction_metrics(_test_api_view_view_name_post, - scoped_metrics=_test_api_view_scoped_metrics_post) +@validate_transaction_errors(errors=["rest_framework.exceptions:MethodNotAllowed"]) +@validate_transaction_metrics(_test_api_view_view_name_post, scoped_metrics=_test_api_view_scoped_metrics_post) def test_api_view_method_not_allowed(target_application): - target_application.post('/api_view/', status=405) + target_application.post("/api_view/", status=405) def test_application_view_agent_disabled(target_application): settings = global_settings() - @override_generic_settings(settings, {'enabled': False}) - @function_not_called('newrelic.core.stats_engine', - 'StatsEngine.record_transaction') + @override_generic_settings(settings, {"enabled": False}) + @function_not_called("newrelic.core.stats_engine", "StatsEngine.record_transaction") def _test(): - response = target_application.get('/view/') + response = target_application.get("/view/") assert response.status_int == 200 - response.mustcontain('restframework view response') + response.mustcontain("restframework view response") _test() diff --git a/tests/component_flask_rest/test_application.py b/tests/component_flask_rest/test_application.py index d463a0205..67d4825a1 100644 --- a/tests/component_flask_rest/test_application.py +++ b/tests/component_flask_rest/test_application.py @@ -31,8 +31,6 @@ from newrelic.core.config import global_settings from newrelic.packages import six -TEST_APPLICATION_PREFIX = "_test_application.create_app." if six.PY3 else "_test_application" - @pytest.fixture(params=["flask_restful", "flask_restx"]) def application(request): @@ -62,7 +60,7 @@ def application(request): ] -@validate_code_level_metrics(TEST_APPLICATION_PREFIX + ".IndexResource", "get") +@validate_code_level_metrics("_test_application.create_app..IndexResource", "get", py2_namespace="_test_application.IndexResource") @validate_transaction_errors(errors=[]) @validate_transaction_metrics("_test_application:index", scoped_metrics=_test_application_index_scoped_metrics) def test_application_index(application): @@ -88,7 +86,7 @@ def test_application_index(application): ], ) def test_application_raises(exception, status_code, ignore_status_code, propagate_exceptions, application): - @validate_code_level_metrics(TEST_APPLICATION_PREFIX + ".ExceptionResource", "get") + @validate_code_level_metrics("_test_application.create_app..ExceptionResource", "get", py2_namespace="_test_application.ExceptionResource") @validate_transaction_metrics("_test_application:exception", scoped_metrics=_test_application_raises_scoped_metrics) def _test(): try: @@ -118,4 +116,4 @@ def test_application_outside_transaction(application): def _test(): application.get("/exception/werkzeug.exceptions:HTTPException/404", status=404) - _test() + _test() \ No newline at end of file diff --git a/tests/component_graphqlserver/__init__.py b/tests/component_graphqlserver/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/component_graphqlserver/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/component_graphqlserver/_target_schema_async.py b/tests/component_graphqlserver/_target_schema_async.py new file mode 100644 index 000000000..aff587bc8 --- /dev/null +++ b/tests/component_graphqlserver/_target_schema_async.py @@ -0,0 +1,155 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, + GraphQLUnionType, +) + +from ._target_schema_sync import books, libraries, magazines + +storage = [] + + +async def resolve_library(parent, info, index): + return libraries[index] + + +async def resolve_storage_add(parent, info, string): + storage.append(string) + return string + + +async def resolve_storage(parent, info): + return [storage.pop()] + + +async def resolve_search(parent, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +Author = GraphQLObjectType( + "Author", + { + "first_name": GraphQLField(GraphQLString), + "last_name": GraphQLField(GraphQLString), + }, +) + +Book = GraphQLObjectType( + "Book", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "isbn": GraphQLField(GraphQLString), + "author": GraphQLField(Author), + "branch": GraphQLField(GraphQLString), + }, +) + +Magazine = GraphQLObjectType( + "Magazine", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "issue": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + }, +) + + +Library = GraphQLObjectType( + "Library", + { + "id": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + "book": GraphQLField(GraphQLList(Book)), + "magazine": GraphQLField(GraphQLList(Magazine)), + }, +) + +Storage = GraphQLList(GraphQLString) + + +async def resolve_hello(root, info): + return "Hello!" + + +async def resolve_echo(root, info, echo): + return echo + + +async def resolve_error(root, info): + raise RuntimeError("Runtime Error!") + + +hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) +library_field = GraphQLField( + Library, + resolver=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolver=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolver=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolver=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolver=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) + +query = GraphQLObjectType( + name="Query", + fields={ + "hello": hello_field, + "library": library_field, + "search": search_field, + "echo": echo_field, + "storage": storage_field, + "error": error_field, + "error_non_null": error_non_null_field, + "error_middleware": error_middleware_field, + }, +) + +mutation = GraphQLObjectType( + name="Mutation", + fields={ + "storage_add": storage_add_field, + }, +) + +target_schema = GraphQLSchema(query=query, mutation=mutation) diff --git a/tests/component_graphqlserver/_test_graphql.py b/tests/component_graphqlserver/_test_graphql.py index 50b5621f9..7a29b3a8f 100644 --- a/tests/component_graphqlserver/_test_graphql.py +++ b/tests/component_graphqlserver/_test_graphql.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from flask import Flask +from sanic import Sanic import json - import webtest -from flask import Flask -from framework_graphql._target_application import _target_application as schema + +from testing_support.asgi_testing import AsgiTest +from framework_graphql._target_schema_sync import target_schema as schema from graphql_server.flask import GraphQLView as FlaskView from graphql_server.sanic import GraphQLView as SanicView -from sanic import Sanic -from testing_support.asgi_testing import AsgiTest + +# Sanic +target_application = dict() def set_middlware(middleware, view_middleware): @@ -95,5 +98,4 @@ def flask_execute(query, middleware=None): return response - target_application["Flask"] = flask_execute diff --git a/tests/component_graphqlserver/test_graphql.py b/tests/component_graphqlserver/test_graphql.py index e5566047e..098f50970 100644 --- a/tests/component_graphqlserver/test_graphql.py +++ b/tests/component_graphqlserver/test_graphql.py @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. + import importlib import pytest from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.common.object_names import callable_name @@ -36,7 +41,7 @@ def is_graphql_2(): @pytest.fixture(scope="session", params=("Sanic", "Flask")) def target_application(request): - import _test_graphql + from . import _test_graphql framework = request.param version = importlib.import_module(framework.lower()).__version__ @@ -186,7 +191,7 @@ def test_middleware(target_application): _test_middleware_metrics = [ ("GraphQL/operation/GraphQLServer/query//hello", 1), ("GraphQL/resolve/GraphQLServer/hello", 1), - ("Function/test_graphql:example_middleware", 1), + ("Function/component_graphqlserver.test_graphql:example_middleware", 1), ] # Base span count 6: Transaction, View, Operation, Middleware, and 1 Resolver and Resolver function @@ -220,7 +225,7 @@ def test_exception_in_middleware(target_application): _test_exception_rollup_metrics = [ ("Errors/all", 1), ("Errors/allWeb", 1), - ("Errors/WebTransaction/GraphQL/test_graphql:error_middleware", 1), + ("Errors/WebTransaction/GraphQL/component_graphqlserver.test_graphql:error_middleware", 1), ] + _test_exception_scoped_metrics # Attributes @@ -237,7 +242,7 @@ def test_exception_in_middleware(target_application): } @validate_transaction_metrics( - "test_graphql:error_middleware", + "component_graphqlserver.test_graphql:error_middleware", "GraphQL", scoped_metrics=_test_exception_scoped_metrics, rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, @@ -257,7 +262,7 @@ def test_exception_in_resolver(target_application, field): framework, version, target_application = target_application query = "query MyQuery { %s }" % field - txn_name = "framework_graphql._target_application:resolve_error" + txn_name = "framework_graphql._target_schema_sync:resolve_error" # Metrics _test_exception_scoped_metrics = [ @@ -488,7 +493,7 @@ def _test(): def test_deepest_unique_path(target_application, query, expected_path): framework, version, target_application = target_application if expected_path == "/error": - txn_name = "framework_graphql._target_application:resolve_error" + txn_name = "framework_graphql._target_schema_sync:resolve_error" else: txn_name = "query/%s" % expected_path diff --git a/tests/cross_agent/test_agent_attributes.py b/tests/cross_agent/test_agent_attributes.py index c254be772..527b31a75 100644 --- a/tests/cross_agent/test_agent_attributes.py +++ b/tests/cross_agent/test_agent_attributes.py @@ -40,7 +40,8 @@ def _default_settings(): 'browser_monitoring.attributes.exclude': [], } -FIXTURE = os.path.join(os.curdir, 'fixtures', 'attribute_configuration.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'attribute_configuration.json') def _load_tests(): with open(FIXTURE, 'r') as fh: diff --git a/tests/cross_agent/test_cat_map.py b/tests/cross_agent/test_cat_map.py index 67c5ab815..6e7ac63d6 100644 --- a/tests/cross_agent/test_cat_map.py +++ b/tests/cross_agent/test_cat_map.py @@ -18,42 +18,58 @@ can be found in test/framework_tornado_r3/test_cat_map.py """ -import webtest -import pytest import json import os +import pytest +import webtest + try: from urllib2 import urlopen # Py2.X except ImportError: - from urllib.request import urlopen # Py3.X - -from newrelic.packages import six + from urllib.request import urlopen # Py3.X + +from testing_support.fixtures import ( + make_cross_agent_headers, + override_application_name, + override_application_settings, + validate_analytics_catmap_data, +) +from testing_support.mock_external_http_server import ( + MockExternalHTTPHResponseHeadersServer, +) +from testing_support.validators.validate_tt_parameters import validate_tt_parameters from newrelic.api.external_trace import ExternalTrace -from newrelic.api.transaction import (get_browser_timing_header, - set_transaction_name, get_browser_timing_footer, set_background_task, - current_transaction) +from newrelic.api.transaction import ( + current_transaction, + get_browser_timing_footer, + get_browser_timing_header, + set_background_task, + set_transaction_name, +) from newrelic.api.wsgi_application import wsgi_application -from newrelic.common.encoding_utils import obfuscate, json_encode - -from testing_support.fixtures import (override_application_settings, - override_application_name, validate_tt_parameters, - make_cross_agent_headers, validate_analytics_catmap_data) -from testing_support.mock_external_http_server import ( - MockExternalHTTPHResponseHeadersServer) +from newrelic.common.encoding_utils import json_encode, obfuscate +from newrelic.packages import six -ENCODING_KEY = '1234567890123456789012345678901234567890' +ENCODING_KEY = "1234567890123456789012345678901234567890" CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -JSON_DIR = os.path.normpath(os.path.join(CURRENT_DIR, 'fixtures')) +JSON_DIR = os.path.normpath(os.path.join(CURRENT_DIR, "fixtures")) OUTBOUD_REQUESTS = {} -_parameters_list = ["name", "appName", "transactionName", "transactionGuid", - "inboundPayload", "outboundRequests", "expectedIntrinsicFields", - "nonExpectedIntrinsicFields"] +_parameters_list = [ + "name", + "appName", + "transactionName", + "transactionGuid", + "inboundPayload", + "outboundRequests", + "expectedIntrinsicFields", + "nonExpectedIntrinsicFields", +] -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def server(): with MockExternalHTTPHResponseHeadersServer() as _server: yield _server @@ -61,8 +77,8 @@ def server(): def load_tests(): result = [] - path = os.path.join(JSON_DIR, 'cat_map.json') - with open(path, 'r') as fh: + path = os.path.join(JSON_DIR, "cat_map.json") + with open(path, "r") as fh: tests = json.load(fh) for test in tests: @@ -77,57 +93,52 @@ def load_tests(): @wsgi_application() def target_wsgi_application(environ, start_response): - status = '200 OK' + status = "200 OK" - txn_name = environ.get('txn') + txn_name = environ.get("txn") if six.PY2: - txn_name = txn_name.decode('UTF-8') - txn_name = txn_name.split('/', 3) + txn_name = txn_name.decode("UTF-8") + txn_name = txn_name.split("/", 3) - guid = environ.get('guid') - old_cat = environ.get('old_cat') == 'True' + guid = environ.get("guid") + old_cat = environ.get("old_cat") == "True" txn = current_transaction() txn.guid = guid for req in OUTBOUD_REQUESTS: # Change the transaction name before making an outbound call. - outgoing_name = req['outboundTxnName'].split('/', 3) - if outgoing_name[0] != 'WebTransaction': + outgoing_name = req["outboundTxnName"].split("/", 3) + if outgoing_name[0] != "WebTransaction": set_background_task(True) set_transaction_name(outgoing_name[2], group=outgoing_name[1]) - expected_outbound_header = obfuscate( - json_encode(req['expectedOutboundPayload']), ENCODING_KEY) - generated_outbound_header = dict( - ExternalTrace.generate_request_headers(txn)) + expected_outbound_header = obfuscate(json_encode(req["expectedOutboundPayload"]), ENCODING_KEY) + generated_outbound_header = dict(ExternalTrace.generate_request_headers(txn)) # A 500 error is returned because 'assert' statements in the wsgi app # are ignored. if old_cat: - if (expected_outbound_header != - generated_outbound_header['X-NewRelic-Transaction']): - status = '500 Outbound Headers Check Failed.' + if expected_outbound_header != generated_outbound_header["X-NewRelic-Transaction"]: + status = "500 Outbound Headers Check Failed." else: - if 'X-NewRelic-Transaction' in generated_outbound_header: - status = '500 Outbound Headers Check Failed.' - r = urlopen(environ['server_url']) + if "X-NewRelic-Transaction" in generated_outbound_header: + status = "500 Outbound Headers Check Failed." + r = urlopen(environ["server_url"]) # nosec B310 r.read(10) # Set the final transaction name. - if txn_name[0] != 'WebTransaction': + if txn_name[0] != "WebTransaction": set_background_task(True) set_transaction_name(txn_name[2], group=txn_name[1]) - text = '%s

RESPONSE

%s' + text = "%s

RESPONSE

%s" - output = (text % (get_browser_timing_header(), - get_browser_timing_footer())).encode('UTF-8') + output = (text % (get_browser_timing_header(), get_browser_timing_footer())).encode("UTF-8") - response_headers = [('Content-type', 'text/html; charset=utf-8'), - ('Content-Length', str(len(output)))] + response_headers = [("Content-type", "text/html; charset=utf-8"), ("Content-Length", str(len(output)))] start_response(status, response_headers) return [output] @@ -137,26 +148,35 @@ def target_wsgi_application(environ, start_response): @pytest.mark.parametrize(_parameters, load_tests()) -@pytest.mark.parametrize('old_cat', (True, False)) -def test_cat_map(name, appName, transactionName, transactionGuid, - inboundPayload, outboundRequests, expectedIntrinsicFields, - nonExpectedIntrinsicFields, old_cat, server): +@pytest.mark.parametrize("old_cat", (True, False)) +def test_cat_map( + name, + appName, + transactionName, + transactionGuid, + inboundPayload, + outboundRequests, + expectedIntrinsicFields, + nonExpectedIntrinsicFields, + old_cat, + server, +): global OUTBOUD_REQUESTS OUTBOUD_REQUESTS = outboundRequests or {} _custom_settings = { - 'cross_process_id': '1#1', - 'encoding_key': ENCODING_KEY, - 'trusted_account_ids': [1], - 'cross_application_tracer.enabled': True, - 'distributed_tracing.enabled': not old_cat, - 'transaction_tracer.transaction_threshold': 0.0, + "cross_process_id": "1#1", + "encoding_key": ENCODING_KEY, + "trusted_account_ids": [1], + "cross_application_tracer.enabled": True, + "distributed_tracing.enabled": not old_cat, + "transaction_tracer.transaction_threshold": 0.0, } if expectedIntrinsicFields and old_cat: _external_node_params = { - 'path_hash': expectedIntrinsicFields['nr.pathHash'], - 'trip_id': expectedIntrinsicFields['nr.tripId'], + "path_hash": expectedIntrinsicFields["nr.pathHash"], + "trip_id": expectedIntrinsicFields["nr.tripId"], } else: _external_node_params = [] @@ -167,16 +187,16 @@ def test_cat_map(name, appName, transactionName, transactionGuid, expectedIntrinsicFields = {} @validate_tt_parameters(required_params=_external_node_params) - @validate_analytics_catmap_data(transactionName, - expected_attributes=expectedIntrinsicFields, - non_expected_attributes=nonExpectedIntrinsicFields) + @validate_analytics_catmap_data( + transactionName, expected_attributes=expectedIntrinsicFields, non_expected_attributes=nonExpectedIntrinsicFields + ) @override_application_settings(_custom_settings) @override_application_name(appName) def run_cat_test(): if six.PY2: - txn_name = transactionName.encode('UTF-8') - guid = transactionGuid.encode('UTF-8') + txn_name = transactionName.encode("UTF-8") + guid = transactionGuid.encode("UTF-8") else: txn_name = transactionName guid = transactionGuid @@ -185,20 +205,26 @@ def run_cat_test(): # are properly ignoring these headers when the agent is using better # cat. - headers = make_cross_agent_headers(inboundPayload, ENCODING_KEY, '1#1') - response = target_application.get('/', headers=headers, - extra_environ={'txn': txn_name, 'guid': guid, - 'old_cat': str(old_cat), - 'server_url': 'http://localhost:%d' % server.port}) + headers = make_cross_agent_headers(inboundPayload, ENCODING_KEY, "1#1") + response = target_application.get( + "/", + headers=headers, + extra_environ={ + "txn": txn_name, + "guid": guid, + "old_cat": str(old_cat), + "server_url": "http://localhost:%d" % server.port, + }, + ) # Validation of analytic data happens in the decorator. - assert response.status == '200 OK' + assert response.status == "200 OK" content = response.html.html.body.p.string # Validate actual body content as sanity check. - assert content == 'RESPONSE' + assert content == "RESPONSE" run_cat_test() diff --git a/tests/cross_agent/test_datstore_instance.py b/tests/cross_agent/test_datastore_instance.py similarity index 52% rename from tests/cross_agent/test_datstore_instance.py rename to tests/cross_agent/test_datastore_instance.py index aa095400f..79a95e0be 100644 --- a/tests/cross_agent/test_datstore_instance.py +++ b/tests/cross_agent/test_datastore_instance.py @@ -14,34 +14,40 @@ import json import os + import pytest from newrelic.api.background_task import background_task -from newrelic.api.database_trace import (register_database_client, - enable_datastore_instance_feature) +from newrelic.api.database_trace import register_database_client from newrelic.api.transaction import current_transaction from newrelic.core.database_node import DatabaseNode from newrelic.core.stats_engine import StatsEngine -FIXTURE = os.path.join(os.curdir, - 'fixtures', 'datastores', 'datastore_instances.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, "fixtures", "datastores", "datastore_instances.json") -_parameters_list = ['name', 'system_hostname', 'db_hostname', - 'product', 'port', 'unix_socket', 'database_path', - 'expected_instance_metric'] +_parameters_list = [ + "name", + "system_hostname", + "db_hostname", + "product", + "port", + "unix_socket", + "database_path", + "expected_instance_metric", +] -_parameters = ','.join(_parameters_list) +_parameters = ",".join(_parameters_list) def _load_tests(): - with open(FIXTURE, 'r') as fh: + with open(FIXTURE, "r") as fh: js = fh.read() return json.loads(js) def _parametrize_test(test): - return tuple([test.get(f, None if f != 'db_hostname' else 'localhost') - for f in _parameters_list]) + return tuple([test.get(f, None if f != "db_hostname" else "localhost") for f in _parameters_list]) _datastore_tests = [_parametrize_test(t) for t in _load_tests()] @@ -49,45 +55,44 @@ def _parametrize_test(test): @pytest.mark.parametrize(_parameters, _datastore_tests) @background_task() -def test_datastore_instance(name, system_hostname, db_hostname, - product, port, unix_socket, database_path, - expected_instance_metric, monkeypatch): +def test_datastore_instance( + name, system_hostname, db_hostname, product, port, unix_socket, database_path, expected_instance_metric, monkeypatch +): - monkeypatch.setattr('newrelic.common.system_info.gethostname', - lambda: system_hostname) + monkeypatch.setattr("newrelic.common.system_info.gethostname", lambda: system_hostname) - class FakeModule(): + class FakeModule: pass register_database_client(FakeModule, product) - enable_datastore_instance_feature(FakeModule) port_path_or_id = port or database_path or unix_socket - node = DatabaseNode(dbapi2_module=FakeModule, - sql='', - children=[], - start_time=0, - end_time=1, - duration=1, - exclusive=1, - stack_trace=None, - sql_format='obfuscated', - connect_params=None, - cursor_params=None, - sql_parameters=None, - execute_params=None, - host=db_hostname, - port_path_or_id=port_path_or_id, - database_name=database_path, - guid=None, - agent_attributes={}, - user_attributes={}, + node = DatabaseNode( + dbapi2_module=FakeModule, + sql="", + children=[], + start_time=0, + end_time=1, + duration=1, + exclusive=1, + stack_trace=None, + sql_format="obfuscated", + connect_params=None, + cursor_params=None, + sql_parameters=None, + execute_params=None, + host=db_hostname, + port_path_or_id=port_path_or_id, + database_name=database_path, + guid=None, + agent_attributes={}, + user_attributes={}, ) empty_stats = StatsEngine() transaction = current_transaction() - unscoped_scope = '' + unscoped_scope = "" # Check 'Datastore/instance' metric to confirm that: # 1. metric name is reported correctly diff --git a/tests/cross_agent/test_distributed_tracing.py b/tests/cross_agent/test_distributed_tracing.py index 0ff46eea2..060fe8a86 100644 --- a/tests/cross_agent/test_distributed_tracing.py +++ b/tests/cross_agent/test_distributed_tracing.py @@ -14,54 +14,70 @@ import json import os + import pytest import webtest +from testing_support.fixtures import override_application_settings, validate_attributes +from testing_support.validators.validate_error_event_attributes import ( + validate_error_event_attributes, +) +from testing_support.validators.validate_span_events import validate_span_events +from testing_support.validators.validate_transaction_event_attributes import ( + validate_transaction_event_attributes, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.transaction import current_transaction from newrelic.api.wsgi_application import wsgi_application from newrelic.common.encoding_utils import DistributedTracePayload from newrelic.common.object_wrapper import transient_function_wrapper -from testing_support.fixtures import (override_application_settings, - validate_error_event_attributes, validate_attributes) -from testing_support.validators.validate_span_events import ( - validate_span_events) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_transaction_event_attributes import validate_transaction_event_attributes - CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -JSON_DIR = os.path.normpath(os.path.join(CURRENT_DIR, 'fixtures', - 'distributed_tracing')) - -_parameters_list = ['account_id', 'comment', 'expected_metrics', - 'force_sampled_true', 'inbound_payloads', 'intrinsics', - 'major_version', 'minor_version', 'outbound_payloads', - 'raises_exception', 'span_events_enabled', 'test_name', - 'transport_type', 'trusted_account_key', 'web_transaction'] -_parameters = ','.join(_parameters_list) +JSON_DIR = os.path.normpath(os.path.join(CURRENT_DIR, "fixtures", "distributed_tracing")) + +_parameters_list = [ + "account_id", + "comment", + "expected_metrics", + "force_sampled_true", + "inbound_payloads", + "intrinsics", + "major_version", + "minor_version", + "outbound_payloads", + "raises_exception", + "span_events_enabled", + "test_name", + "transport_type", + "trusted_account_key", + "web_transaction", +] +_parameters = ",".join(_parameters_list) def load_tests(): result = [] - path = os.path.join(JSON_DIR, 'distributed_tracing.json') - with open(path, 'r') as fh: + path = os.path.join(JSON_DIR, "distributed_tracing.json") + with open(path, "r") as fh: tests = json.load(fh) for test in tests: values = (test.get(param, None) for param in _parameters_list) - param = pytest.param(*values, id=test.get('test_name')) + param = pytest.param(*values, id=test.get("test_name")) result.append(param) return result def override_compute_sampled(override): - @transient_function_wrapper('newrelic.core.adaptive_sampler', - 'AdaptiveSampler.compute_sampled') + @transient_function_wrapper("newrelic.core.adaptive_sampler", "AdaptiveSampler.compute_sampled") def _override_compute_sampled(wrapped, instance, args, kwargs): if override: return True return wrapped(*args, **kwargs) + return _override_compute_sampled @@ -70,58 +86,54 @@ def assert_payload(payload, payload_assertions, major_version, minor_version): # flatten payload so it matches the test: # payload['d']['ac'] -> payload['d.ac'] - d = payload.pop('d') + d = payload.pop("d") for key, value in d.items(): - payload['d.%s' % key] = value + payload["d.%s" % key] = value - for expected in payload_assertions.get('expected', []): + for expected in payload_assertions.get("expected", []): assert expected in payload - for unexpected in payload_assertions.get('unexpected', []): + for unexpected in payload_assertions.get("unexpected", []): assert unexpected not in payload - for key, value in payload_assertions.get('exact', {}).items(): + for key, value in payload_assertions.get("exact", {}).items(): assert key in payload if isinstance(value, list): value = tuple(value) assert payload[key] == value - assert payload['v'][0] == major_version - assert payload['v'][1] == minor_version + assert payload["v"][0] == major_version + assert payload["v"][1] == minor_version @wsgi_application() def target_wsgi_application(environ, start_response): - status = '200 OK' - output = b'hello world' - response_headers = [('Content-type', 'text/html; charset=utf-8'), - ('Content-Length', str(len(output)))] + status = "200 OK" + output = b"hello world" + response_headers = [("Content-type", "text/html; charset=utf-8"), ("Content-Length", str(len(output)))] txn = current_transaction() - txn.set_transaction_name(test_settings['test_name']) + txn.set_transaction_name(test_settings["test_name"]) - if not test_settings['web_transaction']: + if not test_settings["web_transaction"]: txn.background_task = True - if test_settings['raises_exception']: + if test_settings["raises_exception"]: try: 1 / 0 except ZeroDivisionError: txn.notice_error() - extra_inbound_payloads = test_settings['extra_inbound_payloads'] + extra_inbound_payloads = test_settings["extra_inbound_payloads"] for payload, expected_result in extra_inbound_payloads: - result = txn.accept_distributed_trace_payload(payload, - test_settings['transport_type']) + result = txn.accept_distributed_trace_payload(payload, test_settings["transport_type"]) assert result is expected_result - outbound_payloads = test_settings['outbound_payloads'] + outbound_payloads = test_settings["outbound_payloads"] if outbound_payloads: for payload_assertions in outbound_payloads: payload = txn._create_distributed_trace_payload() - assert_payload(payload, payload_assertions, - test_settings['major_version'], - test_settings['minor_version']) + assert_payload(payload, payload_assertions, test_settings["major_version"], test_settings["minor_version"]) start_response(status, response_headers) return [output] @@ -131,14 +143,26 @@ def target_wsgi_application(environ, start_response): @pytest.mark.parametrize(_parameters, load_tests()) -def test_distributed_tracing(account_id, comment, expected_metrics, - force_sampled_true, inbound_payloads, intrinsics, major_version, - minor_version, outbound_payloads, raises_exception, - span_events_enabled, test_name, transport_type, trusted_account_key, - web_transaction): +def test_distributed_tracing( + account_id, + comment, + expected_metrics, + force_sampled_true, + inbound_payloads, + intrinsics, + major_version, + minor_version, + outbound_payloads, + raises_exception, + span_events_enabled, + test_name, + transport_type, + trusted_account_key, + web_transaction, +): extra_inbound_payloads = [] - if transport_type != 'HTTP': + if transport_type != "HTTP": # Since wsgi_application calls accept_distributed_trace_payload # automatically with transport_type='HTTP', we must defer this call # until we can specify the transport type. @@ -153,78 +177,68 @@ def test_distributed_tracing(account_id, comment, expected_metrics, global test_settings test_settings = { - 'test_name': test_name, - 'web_transaction': web_transaction, - 'raises_exception': raises_exception, - 'extra_inbound_payloads': extra_inbound_payloads, - 'outbound_payloads': outbound_payloads, - 'transport_type': transport_type, - 'major_version': major_version, - 'minor_version': minor_version, + "test_name": test_name, + "web_transaction": web_transaction, + "raises_exception": raises_exception, + "extra_inbound_payloads": extra_inbound_payloads, + "outbound_payloads": outbound_payloads, + "transport_type": transport_type, + "major_version": major_version, + "minor_version": minor_version, } override_settings = { - 'distributed_tracing.enabled': True, - 'span_events.enabled': span_events_enabled, - 'account_id': account_id, - 'trusted_account_key': trusted_account_key, + "distributed_tracing.enabled": True, + "span_events.enabled": span_events_enabled, + "account_id": account_id, + "trusted_account_key": trusted_account_key, } - common_required = intrinsics['common']['expected'] - common_forgone = intrinsics['common']['unexpected'] - common_exact = intrinsics['common'].get('exact', {}) - - txn_intrinsics = intrinsics.get('Transaction', {}) - txn_event_required = {'agent': [], 'user': [], - 'intrinsic': txn_intrinsics.get('expected', [])} - txn_event_required['intrinsic'].extend(common_required) - txn_event_forgone = {'agent': [], 'user': [], - 'intrinsic': txn_intrinsics.get('unexpected', [])} - txn_event_forgone['intrinsic'].extend(common_forgone) - txn_event_exact = {'agent': {}, 'user': {}, - 'intrinsic': txn_intrinsics.get('exact', {})} - txn_event_exact['intrinsic'].update(common_exact) + common_required = intrinsics["common"]["expected"] + common_forgone = intrinsics["common"]["unexpected"] + common_exact = intrinsics["common"].get("exact", {}) + + txn_intrinsics = intrinsics.get("Transaction", {}) + txn_event_required = {"agent": [], "user": [], "intrinsic": txn_intrinsics.get("expected", [])} + txn_event_required["intrinsic"].extend(common_required) + txn_event_forgone = {"agent": [], "user": [], "intrinsic": txn_intrinsics.get("unexpected", [])} + txn_event_forgone["intrinsic"].extend(common_forgone) + txn_event_exact = {"agent": {}, "user": {}, "intrinsic": txn_intrinsics.get("exact", {})} + txn_event_exact["intrinsic"].update(common_exact) headers = {} if inbound_payloads: payload = DistributedTracePayload(inbound_payloads[0]) - headers['newrelic'] = payload.http_safe() - - @validate_transaction_metrics(test_name, - rollup_metrics=expected_metrics, - background_task=not web_transaction) - @validate_transaction_event_attributes( - txn_event_required, txn_event_forgone, txn_event_exact) - @validate_attributes('intrinsic', common_required, common_forgone) + headers["newrelic"] = payload.http_safe() + + @validate_transaction_metrics(test_name, rollup_metrics=expected_metrics, background_task=not web_transaction) + @validate_transaction_event_attributes(txn_event_required, txn_event_forgone, txn_event_exact) + @validate_attributes("intrinsic", common_required, common_forgone) @override_compute_sampled(force_sampled_true) @override_application_settings(override_settings) def _test(): - response = test_application.get('/', headers=headers) - assert 'X-NewRelic-App-Data' not in response.headers + response = test_application.get("/", headers=headers) + assert "X-NewRelic-App-Data" not in response.headers - if 'Span' in intrinsics: - span_intrinsics = intrinsics.get('Span') - span_expected = span_intrinsics.get('expected', []) + if "Span" in intrinsics: + span_intrinsics = intrinsics.get("Span") + span_expected = span_intrinsics.get("expected", []) span_expected.extend(common_required) - span_unexpected = span_intrinsics.get('unexpected', []) + span_unexpected = span_intrinsics.get("unexpected", []) span_unexpected.extend(common_forgone) - span_exact = span_intrinsics.get('exact', {}) + span_exact = span_intrinsics.get("exact", {}) span_exact.update(common_exact) - _test = validate_span_events(exact_intrinsics=span_exact, - expected_intrinsics=span_expected, - unexpected_intrinsics=span_unexpected)(_test) + _test = validate_span_events( + exact_intrinsics=span_exact, expected_intrinsics=span_expected, unexpected_intrinsics=span_unexpected + )(_test) elif not span_events_enabled: _test = validate_span_events(count=0)(_test) if raises_exception: - error_event_required = {'agent': [], 'user': [], - 'intrinsic': common_required} - error_event_forgone = {'agent': [], 'user': [], - 'intrinsic': common_forgone} - error_event_exact = {'agent': {}, 'user': {}, - 'intrinsic': common_exact} - _test = validate_error_event_attributes(error_event_required, - error_event_forgone, error_event_exact)(_test) + error_event_required = {"agent": [], "user": [], "intrinsic": common_required} + error_event_forgone = {"agent": [], "user": [], "intrinsic": common_forgone} + error_event_exact = {"agent": {}, "user": {}, "intrinsic": common_exact} + _test = validate_error_event_attributes(error_event_required, error_event_forgone, error_event_exact)(_test) _test() diff --git a/tests/cross_agent/test_docker.py b/tests/cross_agent/test_docker.py index 9bc1a7363..fd919932b 100644 --- a/tests/cross_agent/test_docker.py +++ b/tests/cross_agent/test_docker.py @@ -19,7 +19,8 @@ import newrelic.common.utilization as u -DOCKER_FIXTURE = os.path.join(os.curdir, 'fixtures', 'docker_container_id') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +DOCKER_FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'docker_container_id') def _load_docker_test_attributes(): diff --git a/tests/cross_agent/test_labels_and_rollups.py b/tests/cross_agent/test_labels_and_rollups.py index d333ec35b..15ebb1e36 100644 --- a/tests/cross_agent/test_labels_and_rollups.py +++ b/tests/cross_agent/test_labels_and_rollups.py @@ -21,7 +21,8 @@ from testing_support.fixtures import override_application_settings -FIXTURE = os.path.join(os.curdir, 'fixtures', 'labels.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'labels.json') def _load_tests(): with open(FIXTURE, 'r') as fh: diff --git a/tests/cross_agent/test_rules.py b/tests/cross_agent/test_rules.py index e37db787c..ce2983c90 100644 --- a/tests/cross_agent/test_rules.py +++ b/tests/cross_agent/test_rules.py @@ -16,23 +16,23 @@ import os import pytest -from newrelic.core.rules_engine import RulesEngine, NormalizationRule +from newrelic.api.application import application_instance +from newrelic.api.background_task import background_task +from newrelic.api.transaction import record_custom_metric +from newrelic.core.rules_engine import RulesEngine + +from testing_support.validators.validate_metric_payload import validate_metric_payload CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) FIXTURE = os.path.normpath(os.path.join( CURRENT_DIR, 'fixtures', 'rules.json')) + def _load_tests(): with open(FIXTURE, 'r') as fh: js = fh.read() return json.loads(js) -def _prepare_rules(test_rules): - # ensure all keys are present, if not present set to an empty string - for rule in test_rules: - for key in NormalizationRule._fields: - rule[key] = rule.get(key, '') - return test_rules def _make_case_insensitive(rules): # lowercase each rule @@ -42,14 +42,14 @@ def _make_case_insensitive(rules): rule['replacement'] = rule['replacement'].lower() return rules + @pytest.mark.parametrize('test_group', _load_tests()) def test_rules_engine(test_group): # FIXME: The test fixture assumes that matching is case insensitive when it # is not. To avoid errors, just lowercase all rules, inputs, and expected # values. - insense_rules = _make_case_insensitive(test_group['rules']) - test_rules = _prepare_rules(insense_rules) + test_rules = _make_case_insensitive(test_group['rules']) rules_engine = RulesEngine(test_rules) for test in test_group['tests']: @@ -66,3 +66,46 @@ def test_rules_engine(test_group): assert expected == '' else: assert result == expected + + +@pytest.mark.parametrize('test_group', _load_tests()) +def test_rules_engine_metric_harvest(test_group): + # FIXME: The test fixture assumes that matching is case insensitive when it + # is not. To avoid errors, just lowercase all rules, inputs, and expected + # values. + test_rules = _make_case_insensitive(test_group['rules']) + rules_engine = RulesEngine(test_rules) + + # Set rules engine on core application + api_application = application_instance(activate=False) + api_name = api_application.name + core_application = api_application._agent.application(api_name) + old_rules = core_application._rules_engine["metric"] # save previoius rules + core_application._rules_engine["metric"] = rules_engine + + def send_metrics(): + # Send all metrics in this test batch in one transaction, then harvest so the normalizer is run. + @background_task(name="send_metrics") + def _test(): + for test in test_group['tests']: + # lowercase each value + input_str = test['input'].lower() + record_custom_metric(input_str, {"count": 1}) + _test() + core_application.harvest() + + try: + # Create a map of all result metrics to validate after harvest + test_metrics = [] + for test in test_group['tests']: + expected = (test['expected'] or '').lower() + if expected == '': # Ignored + test_metrics.append((expected, None)) + else: + test_metrics.append((expected, 1)) + + # Harvest and validate resulting payload + validate_metric_payload(metrics=test_metrics)(send_metrics)() + finally: + # Replace original rules engine + core_application._rules_engine["metric"] = old_rules diff --git a/tests/cross_agent/test_rum_client_config.py b/tests/cross_agent/test_rum_client_config.py index c2a4a465f..5b8da4b84 100644 --- a/tests/cross_agent/test_rum_client_config.py +++ b/tests/cross_agent/test_rum_client_config.py @@ -26,10 +26,11 @@ ) from newrelic.api.wsgi_application import wsgi_application +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, "fixtures", "rum_client_config.json") def _load_tests(): - fixture = os.path.join(os.curdir, "fixtures", "rum_client_config.json") - with open(fixture, "r") as fh: + with open(FIXTURE, "r") as fh: js = fh.read() return json.loads(js) diff --git a/tests/cross_agent/test_w3c_trace_context.py b/tests/cross_agent/test_w3c_trace_context.py index 05f157f7b..893274ce4 100644 --- a/tests/cross_agent/test_w3c_trace_context.py +++ b/tests/cross_agent/test_w3c_trace_context.py @@ -14,88 +14,105 @@ import json import os + import pytest import webtest -from newrelic.packages import six - -from newrelic.api.transaction import current_transaction +from testing_support.fixtures import override_application_settings, validate_attributes +from testing_support.validators.validate_span_events import validate_span_events +from testing_support.validators.validate_transaction_event_attributes import ( + validate_transaction_event_attributes, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.transaction import ( + accept_distributed_trace_headers, + current_transaction, + insert_distributed_trace_headers, +) from newrelic.api.wsgi_application import wsgi_application -from newrelic.common.object_wrapper import transient_function_wrapper -from testing_support.validators.validate_span_events import ( - validate_span_events) -from testing_support.fixtures import (override_application_settings, - validate_attributes) from newrelic.common.encoding_utils import W3CTraceState -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_transaction_event_attributes import validate_transaction_event_attributes +from newrelic.common.object_wrapper import transient_function_wrapper +from newrelic.packages import six CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -JSON_DIR = os.path.normpath(os.path.join(CURRENT_DIR, 'fixtures', - 'distributed_tracing')) - -_parameters_list = ('test_name', 'trusted_account_key', 'account_id', - 'web_transaction', 'raises_exception', 'force_sampled_true', - 'span_events_enabled', 'transport_type', 'inbound_headers', - 'outbound_payloads', 'intrinsics', 'expected_metrics') - -_parameters = ','.join(_parameters_list) +JSON_DIR = os.path.normpath(os.path.join(CURRENT_DIR, "fixtures", "distributed_tracing")) + +_parameters_list = ( + "test_name", + "trusted_account_key", + "account_id", + "web_transaction", + "raises_exception", + "force_sampled_true", + "span_events_enabled", + "transport_type", + "inbound_headers", + "outbound_payloads", + "intrinsics", + "expected_metrics", +) + +_parameters = ",".join(_parameters_list) XFAIL_TESTS = [ - 'spans_disabled_root', - 'missing_traceparent', - 'missing_traceparent_and_tracestate', - 'w3c_and_newrelc_headers_present_error_parsing_traceparent' + "spans_disabled_root", + "missing_traceparent", + "missing_traceparent_and_tracestate", + "w3c_and_newrelc_headers_present_error_parsing_traceparent", ] + def load_tests(): result = [] - path = os.path.join(JSON_DIR, 'trace_context.json') - with open(path, 'r') as fh: + path = os.path.join(JSON_DIR, "trace_context.json") + with open(path, "r") as fh: tests = json.load(fh) for test in tests: values = (test.get(param, None) for param in _parameters_list) - param = pytest.param(*values, id=test.get('test_name')) + param = pytest.param(*values, id=test.get("test_name")) result.append(param) return result ATTR_MAP = { - 'traceparent.version': 0, - 'traceparent.trace_id': 1, - 'traceparent.parent_id': 2, - 'traceparent.trace_flags': 3, - 'tracestate.version': 0, - 'tracestate.parent_type': 1, - 'tracestate.parent_account_id': 2, - 'tracestate.parent_application_id': 3, - 'tracestate.span_id': 4, - 'tracestate.transaction_id': 5, - 'tracestate.sampled': 6, - 'tracestate.priority': 7, - 'tracestate.timestamp': 8, - 'tracestate.tenant_id': None, + "traceparent.version": 0, + "traceparent.trace_id": 1, + "traceparent.parent_id": 2, + "traceparent.trace_flags": 3, + "tracestate.version": 0, + "tracestate.parent_type": 1, + "tracestate.parent_account_id": 2, + "tracestate.parent_application_id": 3, + "tracestate.span_id": 4, + "tracestate.transaction_id": 5, + "tracestate.sampled": 6, + "tracestate.priority": 7, + "tracestate.timestamp": 8, + "tracestate.tenant_id": None, } def validate_outbound_payload(actual, expected, trusted_account_key): - traceparent = '' - tracestate = '' + traceparent = "" + tracestate = "" for key, value in actual: - if key == 'traceparent': - traceparent = value.split('-') - elif key == 'tracestate': + if key == "traceparent": + traceparent = value.split("-") + elif key == "tracestate": vendors = W3CTraceState.decode(value) - nr_entry = vendors.pop(trusted_account_key + '@nr', '') - tracestate = nr_entry.split('-') - exact_values = expected.get('exact', {}) - expected_attrs = expected.get('expected', []) - unexpected_attrs = expected.get('unexpected', []) - expected_vendors = expected.get('vendors', []) + nr_entry = vendors.pop(trusted_account_key + "@nr", "") + tracestate = nr_entry.split("-") + exact_values = expected.get("exact", {}) + expected_attrs = expected.get("expected", []) + unexpected_attrs = expected.get("unexpected", []) + expected_vendors = expected.get("vendors", []) for key, value in exact_values.items(): - header = traceparent if key.startswith('traceparent.') else tracestate + header = traceparent if key.startswith("traceparent.") else tracestate attr = ATTR_MAP[key] if attr is not None: if isinstance(value, bool): @@ -106,13 +123,13 @@ def validate_outbound_payload(actual, expected, trusted_account_key): assert header[attr] == str(value) for key in expected_attrs: - header = traceparent if key.startswith('traceparent.') else tracestate + header = traceparent if key.startswith("traceparent.") else tracestate attr = ATTR_MAP[key] if attr is not None: assert header[attr], key for key in unexpected_attrs: - header = traceparent if key.startswith('traceparent.') else tracestate + header = traceparent if key.startswith("traceparent.") else tracestate attr = ATTR_MAP[key] if attr is not None: assert not header[attr], key @@ -125,127 +142,129 @@ def validate_outbound_payload(actual, expected, trusted_account_key): def target_wsgi_application(environ, start_response): transaction = current_transaction() - if not environ['.web_transaction']: + if not environ[".web_transaction"]: transaction.background_task = True - if environ['.raises_exception']: + if environ[".raises_exception"]: try: raise ValueError("oops") except: transaction.notice_error() - if '.inbound_headers' in environ: - transaction.accept_distributed_trace_headers( - environ['.inbound_headers'], - transport_type=environ['.transport_type'], + if ".inbound_headers" in environ: + accept_distributed_trace_headers( + environ[".inbound_headers"], + transport_type=environ[".transport_type"], ) payloads = [] - for _ in range(environ['.outbound_calls']): + for _ in range(environ[".outbound_calls"]): payloads.append([]) - transaction.insert_distributed_trace_headers(payloads[-1]) + insert_distributed_trace_headers(payloads[-1]) - start_response('200 OK', [('Content-Type', 'application/json')]) - return [json.dumps(payloads).encode('utf-8')] + start_response("200 OK", [("Content-Type", "application/json")]) + return [json.dumps(payloads).encode("utf-8")] test_application = webtest.TestApp(target_wsgi_application) def override_compute_sampled(override): - @transient_function_wrapper('newrelic.core.adaptive_sampler', - 'AdaptiveSampler.compute_sampled') + @transient_function_wrapper("newrelic.core.adaptive_sampler", "AdaptiveSampler.compute_sampled") def _override_compute_sampled(wrapped, instance, args, kwargs): if override: return True return wrapped(*args, **kwargs) + return _override_compute_sampled @pytest.mark.parametrize(_parameters, load_tests()) -def test_trace_context(test_name, trusted_account_key, account_id, - web_transaction, raises_exception, force_sampled_true, - span_events_enabled, transport_type, inbound_headers, - outbound_payloads, intrinsics, expected_metrics): - +def test_trace_context( + test_name, + trusted_account_key, + account_id, + web_transaction, + raises_exception, + force_sampled_true, + span_events_enabled, + transport_type, + inbound_headers, + outbound_payloads, + intrinsics, + expected_metrics, +): if test_name in XFAIL_TESTS: pytest.xfail("Waiting on cross agent tests update.") # Prepare assertions if not intrinsics: intrinsics = {} - common = intrinsics.get('common', {}) - common_required = common.get('expected', []) - common_forgone = common.get('unexpected', []) - common_exact = common.get('exact', {}) - - txn_intrinsics = intrinsics.get('Transaction', {}) - txn_event_required = {'agent': [], 'user': [], - 'intrinsic': txn_intrinsics.get('expected', [])} - txn_event_required['intrinsic'].extend(common_required) - txn_event_forgone = {'agent': [], 'user': [], - 'intrinsic': txn_intrinsics.get('unexpected', [])} - txn_event_forgone['intrinsic'].extend(common_forgone) - txn_event_exact = {'agent': {}, 'user': {}, - 'intrinsic': txn_intrinsics.get('exact', {})} - txn_event_exact['intrinsic'].update(common_exact) + common = intrinsics.get("common", {}) + common_required = common.get("expected", []) + common_forgone = common.get("unexpected", []) + common_exact = common.get("exact", {}) + + txn_intrinsics = intrinsics.get("Transaction", {}) + txn_event_required = {"agent": [], "user": [], "intrinsic": txn_intrinsics.get("expected", [])} + txn_event_required["intrinsic"].extend(common_required) + txn_event_forgone = {"agent": [], "user": [], "intrinsic": txn_intrinsics.get("unexpected", [])} + txn_event_forgone["intrinsic"].extend(common_forgone) + txn_event_exact = {"agent": {}, "user": {}, "intrinsic": txn_intrinsics.get("exact", {})} + txn_event_exact["intrinsic"].update(common_exact) override_settings = { - 'distributed_tracing.enabled': True, - 'span_events.enabled': span_events_enabled, - 'account_id': account_id, - 'trusted_account_key': trusted_account_key, + "distributed_tracing.enabled": True, + "span_events.enabled": span_events_enabled, + "account_id": account_id, + "trusted_account_key": trusted_account_key, } extra_environ = { - '.web_transaction': web_transaction, - '.raises_exception': raises_exception, - '.transport_type': transport_type, - '.outbound_calls': outbound_payloads and len(outbound_payloads) or 0, + ".web_transaction": web_transaction, + ".raises_exception": raises_exception, + ".transport_type": transport_type, + ".outbound_calls": outbound_payloads and len(outbound_payloads) or 0, } inbound_headers = inbound_headers and inbound_headers[0] or None - if transport_type != 'HTTP': - extra_environ['.inbound_headers'] = inbound_headers + if transport_type != "HTTP": + extra_environ[".inbound_headers"] = inbound_headers inbound_headers = None elif six.PY2 and inbound_headers: - inbound_headers = { - k.encode('utf-8'): v.encode('utf-8') - for k, v in inbound_headers.items()} - - @validate_transaction_metrics(test_name, - group="Uri", - rollup_metrics=expected_metrics, - background_task=not web_transaction) - @validate_transaction_event_attributes( - txn_event_required, txn_event_forgone, txn_event_exact) - @validate_attributes('intrinsic', common_required, common_forgone) + inbound_headers = {k.encode("utf-8"): v.encode("utf-8") for k, v in inbound_headers.items()} + + @validate_transaction_metrics( + test_name, group="Uri", rollup_metrics=expected_metrics, background_task=not web_transaction + ) + @validate_transaction_event_attributes(txn_event_required, txn_event_forgone, txn_event_exact) + @validate_attributes("intrinsic", common_required, common_forgone) @override_application_settings(override_settings) @override_compute_sampled(force_sampled_true) def _test(): return test_application.get( - '/' + test_name, + "/" + test_name, headers=inbound_headers, extra_environ=extra_environ, ) - if 'Span' in intrinsics: - span_intrinsics = intrinsics.get('Span') - span_expected = span_intrinsics.get('expected', []) + if "Span" in intrinsics: + span_intrinsics = intrinsics.get("Span") + span_expected = span_intrinsics.get("expected", []) span_expected.extend(common_required) - span_unexpected = span_intrinsics.get('unexpected', []) + span_unexpected = span_intrinsics.get("unexpected", []) span_unexpected.extend(common_forgone) - span_exact = span_intrinsics.get('exact', {}) + span_exact = span_intrinsics.get("exact", {}) span_exact.update(common_exact) - _test = validate_span_events(exact_intrinsics=span_exact, - expected_intrinsics=span_expected, - unexpected_intrinsics=span_unexpected)(_test) + _test = validate_span_events( + exact_intrinsics=span_exact, expected_intrinsics=span_expected, unexpected_intrinsics=span_unexpected + )(_test) elif not span_events_enabled: _test = validate_span_events(count=0)(_test) response = _test() - assert response.status == '200 OK' + assert response.status == "200 OK" payloads = response.json if outbound_payloads: assert len(payloads) == len(outbound_payloads) diff --git a/tests/datastore_aioredis/conftest.py b/tests/datastore_aioredis/conftest.py index d50129255..e1cea4c01 100644 --- a/tests/datastore_aioredis/conftest.py +++ b/tests/datastore_aioredis/conftest.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest from newrelic.common.package_version_utils import get_package_version_tuple @@ -67,3 +68,7 @@ def client(request, loop): pytest.skip("StrictRedis not implemented.") else: raise NotImplementedError() + +@pytest.fixture(scope="session") +def key(): + return "AIOREDIS-TEST-" + str(os.getpid()) diff --git a/tests/datastore_aioredis/test_execute_command.py b/tests/datastore_aioredis/test_execute_command.py index 54851a659..b600abea5 100644 --- a/tests/datastore_aioredis/test_execute_command.py +++ b/tests/datastore_aioredis/test_execute_command.py @@ -13,8 +13,6 @@ # limitations under the License. import pytest - -# import aioredis from conftest import AIOREDIS_VERSION, loop # noqa # pylint: disable=E0611,W0611 from testing_support.db_settings import redis_settings from testing_support.fixtures import override_application_settings diff --git a/tests/datastore_aioredis/test_get_and_set.py b/tests/datastore_aioredis/test_get_and_set.py index 180f32578..cbddf6091 100644 --- a/tests/datastore_aioredis/test_get_and_set.py +++ b/tests/datastore_aioredis/test_get_and_set.py @@ -64,9 +64,9 @@ _disable_rollup_metrics.append((_instance_metric_name, None)) -async def exercise_redis(client): - await client.set("key", "value") - await client.get("key") +async def exercise_redis(client, key): + await client.set(key, "value") + await client.get(key) @override_application_settings(_enable_instance_settings) @@ -77,8 +77,8 @@ async def exercise_redis(client): background_task=True, ) @background_task() -def test_redis_client_operation_enable_instance(client, loop): - loop.run_until_complete(exercise_redis(client)) +def test_redis_client_operation_enable_instance(client, loop, key): + loop.run_until_complete(exercise_redis(client, key)) @override_application_settings(_disable_instance_settings) @@ -89,5 +89,5 @@ def test_redis_client_operation_enable_instance(client, loop): background_task=True, ) @background_task() -def test_redis_client_operation_disable_instance(client, loop): - loop.run_until_complete(exercise_redis(client)) +def test_redis_client_operation_disable_instance(client, loop, key): + loop.run_until_complete(exercise_redis(client, key)) diff --git a/tests/datastore_aioredis/test_transactions.py b/tests/datastore_aioredis/test_transactions.py index 0f84ca684..ced922022 100644 --- a/tests/datastore_aioredis/test_transactions.py +++ b/tests/datastore_aioredis/test_transactions.py @@ -23,42 +23,46 @@ @background_task() @pytest.mark.parametrize("in_transaction", (True, False)) -def test_pipelines_no_harm(client, in_transaction, loop): +def test_pipelines_no_harm(client, in_transaction, loop, key): async def exercise(): if AIOREDIS_VERSION >= (2,): pipe = client.pipeline(transaction=in_transaction) else: pipe = client.pipeline() # Transaction kwarg unsupported - pipe.set("TXN", 1) + pipe.set(key, 1) return await pipe.execute() status = loop.run_until_complete(exercise()) assert status == [True] -def exercise_transaction_sync(pipe): - pipe.set("TXN", 1) +def exercise_transaction_sync(key): + def _run(pipe): + pipe.set(key, 1) + return _run -async def exercise_transaction_async(pipe): - await pipe.set("TXN", 1) +def exercise_transaction_async(key): + async def _run(pipe): + await pipe.set(key, 1) + return _run @SKIPIF_AIOREDIS_V1 @pytest.mark.parametrize("exercise", (exercise_transaction_sync, exercise_transaction_async)) @background_task() -def test_transactions_no_harm(client, loop, exercise): - status = loop.run_until_complete(client.transaction(exercise)) +def test_transactions_no_harm(client, loop, key, exercise): + status = loop.run_until_complete(client.transaction(exercise(key))) assert status == [True] @SKIPIF_AIOREDIS_V2 @background_task() -def test_multi_exec_no_harm(client, loop): +def test_multi_exec_no_harm(client, loop, key): async def exercise(): pipe = client.multi_exec() - pipe.set("key", "value") + pipe.set(key, "value") status = await pipe.execute() assert status == [True] @@ -67,9 +71,7 @@ async def exercise(): @SKIPIF_AIOREDIS_V1 @background_task() -def test_pipeline_immediate_execution_no_harm(client, loop): - key = "TXN_WATCH" - +def test_pipeline_immediate_execution_no_harm(client, loop, key): async def exercise(): await client.set(key, 1) @@ -94,9 +96,7 @@ async def exercise(): @SKIPIF_AIOREDIS_V1 @background_task() -def test_transaction_immediate_execution_no_harm(client, loop): - key = "TXN_WATCH" - +def test_transaction_immediate_execution_no_harm(client, loop, key): async def exercise(): async def exercise_transaction(pipe): value = int(await pipe.get(key)) @@ -119,9 +119,7 @@ async def exercise_transaction(pipe): @SKIPIF_AIOREDIS_V1 @validate_transaction_errors([]) @background_task() -def test_transaction_watch_error_no_harm(client, loop): - key = "TXN_WATCH" - +def test_transaction_watch_error_no_harm(client, loop, key): async def exercise(): async def exercise_transaction(pipe): value = int(await pipe.get(key)) diff --git a/tests/datastore_bmemcached/test_memcache.py b/tests/datastore_bmemcached/test_memcache.py index 68eee0633..2f87da113 100644 --- a/tests/datastore_bmemcached/test_memcache.py +++ b/tests/datastore_bmemcached/test_memcache.py @@ -13,83 +13,94 @@ # limitations under the License. import os -from testing_support.db_settings import memcached_settings + import bmemcached +from testing_support.db_settings import memcached_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task from newrelic.api.transaction import set_background_task from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.db_settings import memcached_settings - DB_SETTINGS = memcached_settings()[0] -MEMCACHED_HOST = DB_SETTINGS['host'] -MEMCACHED_PORT = DB_SETTINGS['port'] +MEMCACHED_HOST = DB_SETTINGS["host"] +MEMCACHED_PORT = DB_SETTINGS["port"] MEMCACHED_NAMESPACE = str(os.getpid()) -MEMCACHED_ADDR = '%s:%s' % (MEMCACHED_HOST, MEMCACHED_PORT) +MEMCACHED_ADDR = "%s:%s" % (MEMCACHED_HOST, MEMCACHED_PORT) _test_bt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_bt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allOther', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allOther", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_bt_set_get_delete', - scoped_metrics=_test_bt_set_get_delete_scoped_metrics, - rollup_metrics=_test_bt_set_get_delete_rollup_metrics, - background_task=True) + "test_memcache:test_bt_set_get_delete", + scoped_metrics=_test_bt_set_get_delete_scoped_metrics, + rollup_metrics=_test_bt_set_get_delete_rollup_metrics, + background_task=True, +) @background_task() def test_bt_set_get_delete(): set_background_task(True) client = bmemcached.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" + _test_wt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_wt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allWeb', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allWeb', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allWeb", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allWeb", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_wt_set_get_delete', - scoped_metrics=_test_wt_set_get_delete_scoped_metrics, - rollup_metrics=_test_wt_set_get_delete_rollup_metrics, - background_task=False) + "test_memcache:test_wt_set_get_delete", + scoped_metrics=_test_wt_set_get_delete_scoped_metrics, + rollup_metrics=_test_wt_set_get_delete_rollup_metrics, + background_task=False, +) @background_task() def test_wt_set_get_delete(): set_background_task(False) client = bmemcached.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" diff --git a/tests/datastore_elasticsearch/test_connection.py b/tests/datastore_elasticsearch/test_connection.py index 2e888af9b..9e8f17b4c 100644 --- a/tests/datastore_elasticsearch/test_connection.py +++ b/tests/datastore_elasticsearch/test_connection.py @@ -36,7 +36,7 @@ def test_connection_default(): else: conn = Connection(**HOST) - assert conn._nr_host_port == ("localhost", ES_SETTINGS["port"]) + assert conn._nr_host_port == (ES_SETTINGS["host"], ES_SETTINGS["port"]) @SKIP_IF_V7 diff --git a/tests/datastore_firestore/conftest.py b/tests/datastore_firestore/conftest.py new file mode 100644 index 000000000..28e138fa2 --- /dev/null +++ b/tests/datastore_firestore/conftest.py @@ -0,0 +1,124 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import uuid + +import pytest + +from google.cloud.firestore import Client +from google.cloud.firestore import Client, AsyncClient + +from testing_support.db_settings import firestore_settings +from testing_support.fixture.event_loop import event_loop as loop # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) + +from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.time_trace import current_trace +from newrelic.common.system_info import LOCALHOST_EQUIVALENTS, gethostname + +DB_SETTINGS = firestore_settings()[0] +FIRESTORE_HOST = DB_SETTINGS["host"] +FIRESTORE_PORT = DB_SETTINGS["port"] + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (datastore_firestore)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) + + +@pytest.fixture() +def instance_info(): + host = gethostname() if FIRESTORE_HOST in LOCALHOST_EQUIVALENTS else FIRESTORE_HOST + return {"host": host, "port_path_or_id": str(FIRESTORE_PORT), "db.instance": "projects/google-cloud-firestore-emulator/databases/(default)"} + + +@pytest.fixture(scope="session") +def client(): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = Client() + # Ensure connection is available + client.collection("healthcheck").document("healthcheck").set( + {}, retry=None, timeout=5 + ) + return client + + +@pytest.fixture(scope="function") +def collection(client): + collection_ = client.collection("firestore_collection_" + str(uuid.uuid4())) + yield collection_ + client.recursive_delete(collection_) + + +@pytest.fixture(scope="session") +def async_client(loop): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = AsyncClient() + loop.run_until_complete(client.collection("healthcheck").document("healthcheck").set({}, retry=None, timeout=5)) # Ensure connection is available + return client + + +@pytest.fixture(scope="function") +def async_collection(async_client, collection): + # Use the same collection name as the collection fixture + yield async_client.collection(collection.id) + + +@pytest.fixture(scope="session") +def assert_trace_for_generator(): + def _assert_trace_for_generator(generator_func, *args, **kwargs): + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + # Check for generator trace on collections + _trace_check = [] + for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_generator + + +@pytest.fixture(scope="session") +def assert_trace_for_async_generator(loop): + def _assert_trace_for_async_generator(generator_func, *args, **kwargs): + _trace_check = [] + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + async def coro(): + # Check for generator trace on collections + async for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + + loop.run_until_complete(coro()) + + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_async_generator diff --git a/tests/datastore_firestore/test_async_batching.py b/tests/datastore_firestore/test_async_batching.py new file mode 100644 index 000000000..39e532a04 --- /dev/null +++ b/tests/datastore_firestore/test_async_batching.py @@ -0,0 +1,73 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_async_write_batch(async_client, async_collection): + async def _exercise_async_write_batch(): + docs = [async_collection.document(str(x)) for x in range(1, 4)] + async_batch = async_client.batch() + for doc in docs: + async_batch.set(doc, {}) + + await async_batch.commit() + + return _exercise_async_write_batch + + +def test_firestore_async_write_batch(loop, exercise_async_write_batch, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_write_batch") + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() + + +def test_firestore_async_write_batch_trace_node_datastore_params(loop, exercise_async_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_client.py b/tests/datastore_firestore/test_async_client.py new file mode 100644 index 000000000..1c7518bf0 --- /dev/null +++ b/tests/datastore_firestore/test_async_client.py @@ -0,0 +1,87 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def existing_document(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_async_client(async_client, existing_document): + async def _exercise_async_client(): + assert len([_ async for _ in async_client.collections()]) >= 1 + doc = [_ async for _ in async_client.get_all([existing_document])][0] + assert doc.to_dict()["x"] == 1 + + return _exercise_async_client + + +def test_firestore_async_client(loop, exercise_async_client, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_client") + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() + + +@background_task() +def test_firestore_async_client_generators(async_client, collection, assert_trace_for_async_generator): + doc = collection.document("test") + doc.set({}) + + assert_trace_for_async_generator(async_client.collections) + assert_trace_for_async_generator(async_client.get_all, [doc]) + + +def test_firestore_async_client_trace_node_datastore_params(loop, exercise_async_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_collections.py b/tests/datastore_firestore/test_async_collections.py new file mode 100644 index 000000000..214ee2939 --- /dev/null +++ b/tests/datastore_firestore/test_async_collections.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_async_collections(async_collection): + async def _exercise_async_collections(): + async_collection.document("DoesNotExist") + await async_collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + await async_collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = await async_collection.get() + assert len(documents_get) == 2 + documents_stream = [_ async for _ in async_collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ async for _ in async_collection.list_documents()] + assert len(documents_list) == 2 + + return _exercise_async_collections + + +def test_firestore_async_collections(loop, exercise_async_collections, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % async_collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collections") + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() + + +@background_task() +def test_firestore_async_collections_generators(collection, async_collection, assert_trace_for_async_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_async_generator(async_collection.stream) + assert_trace_for_async_generator(async_collection.list_documents) + + +def test_firestore_async_collections_trace_node_datastore_params(loop, exercise_async_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_documents.py b/tests/datastore_firestore/test_async_documents.py new file mode 100644 index 000000000..c90693208 --- /dev/null +++ b/tests/datastore_firestore/test_async_documents.py @@ -0,0 +1,108 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_async_documents(async_collection): + async def _exercise_async_documents(): + italy_doc = async_collection.document("Italy") + await italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + await italy_doc.get() + italian_cities = italy_doc.collection("cities") + await italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ async for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = async_collection.document("USA") + await usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + await usa_doc.update({"president": "Joe Biden"}) + + await async_collection.document("USA").delete() + + return _exercise_async_documents + + +def test_firestore_async_documents(loop, exercise_async_documents, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 7), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_documents") + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() + + +@background_task() +def test_firestore_async_documents_generators( + collection, async_collection, assert_trace_for_async_generator, instance_info +): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + async_subcollection = async_collection.document(subcollection_doc.id) + + assert_trace_for_async_generator(async_subcollection.collections) + + +def test_firestore_async_documents_trace_node_datastore_params(loop, exercise_async_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_query.py b/tests/datastore_firestore/test_async_query.py new file mode 100644 index 000000000..1bc579b7f --- /dev/null +++ b/tests/datastore_firestore/test_async_query.py @@ -0,0 +1,249 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== AsyncQuery ===== + + +@pytest.fixture() +def exercise_async_query(async_collection): + async def _exercise_async_query(): + async_query = ( + async_collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + ) + assert len(await async_query.get()) == 3 + assert len([_ async for _ in async_query.stream()]) == 3 + + return _exercise_async_query + + +def test_firestore_async_query(loop, exercise_async_query, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + # @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_query") + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + + +@background_task() +def test_firestore_async_query_generators(async_collection, assert_trace_for_async_generator): + async_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_async_generator(async_query.stream) + + +def test_firestore_async_query_trace_node_datastore_params(loop, exercise_async_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + + +# ===== AsyncAggregationQuery ===== + + +@pytest.fixture() +def exercise_async_aggregation_query(async_collection): + async def _exercise_async_aggregation_query(): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert (await async_aggregation_query.get())[0][0].value == 3 + assert [_ async for _ in async_aggregation_query.stream()][0][0].value == 3 + + return _exercise_async_aggregation_query + + +def test_firestore_async_aggregation_query(loop, exercise_async_aggregation_query, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_aggregation_query") + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +@background_task() +def test_firestore_async_aggregation_query_generators(async_collection, assert_trace_for_async_generator): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_async_generator(async_aggregation_query.stream) + + +def test_firestore_async_aggregation_query_trace_node_datastore_params( + loop, exercise_async_aggregation_query, instance_info +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, async_client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a coroutine that returns an async_generator of Cursor objects. + Each Cursor must point at a valid document path. To test this, we can patch the RPC to return 1 Cursor + which is pointed at any document available. The get_partitions will take that and make 2 QueryPartition + objects out of it, which should be enough to ensure we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + async def mock_partition_query(*args, **kwargs): + async def _mock_partition_query(): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + + return _mock_partition_query() + + monkeypatch.setattr(async_client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_async_collection_group(async_client, async_collection): + async def _exercise_async_collection_group(): + async_collection_group = async_client.collection_group(async_collection.id) + assert len(await async_collection_group.get()) + assert len([d async for d in async_collection_group.stream()]) + + partitions = [p async for p in async_collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(await partitions.pop().query().get()) + assert len(documents) == 6 + + return _exercise_async_collection_group + + +def test_firestore_async_collection_group( + loop, exercise_async_collection_group, async_collection, patch_partition_queries, instance_info +): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collection_group") + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() + + +@background_task() +def test_firestore_async_collection_group_generators( + async_client, async_collection, assert_trace_for_async_generator, patch_partition_queries +): + async_collection_group = async_client.collection_group(async_collection.id) + assert_trace_for_async_generator(async_collection_group.get_partitions, 1) + + +def test_firestore_async_collection_group_trace_node_datastore_params( + loop, exercise_async_collection_group, instance_info, patch_partition_queries +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_transaction.py b/tests/datastore_firestore/test_async_transaction.py new file mode 100644 index 000000000..2b8646ec5 --- /dev/null +++ b/tests/datastore_firestore/test_async_transaction.py @@ -0,0 +1,169 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_async_transaction_commit(async_client, async_collection): + async def _exercise_async_transaction_commit(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # get a DocumentReference + with pytest.raises( + TypeError + ): # get is currently broken. It attempts to await an async_generator instead of consuming it. + [_ async for _ in async_transaction.get(async_collection.document("doc1"))] + + # get a Query + with pytest.raises( + TypeError + ): # get is currently broken. It attempts to await an async_generator instead of consuming it. + async_query = async_collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ async for _ in async_transaction.get(async_query)]) == 1 + + # get_all on a list of DocumentReferences + with pytest.raises( + TypeError + ): # get_all is currently broken. It attempts to await an async_generator instead of consuming it. + all_docs = async_transaction.get_all([async_collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ async for _ in all_docs]) == 3 + + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 0}) + async_transaction.delete(async_collection.document("doc3")) + + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 2 + + return _exercise_async_transaction_commit + + +@pytest.fixture() +def exercise_async_transaction_rollback(async_client, async_collection): + async def _exercise_async_transaction_rollback(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 99}) + async_transaction.delete(async_collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 3 + + return _exercise_async_transaction_rollback + + +def test_firestore_async_transaction_commit(loop, exercise_async_transaction_commit, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + # ("Datastore/operation/Firestore/get_all", 2), + # ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + # ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), # Should be 5 if not for broken APIs + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback( + loop, exercise_async_transaction_rollback, async_collection, instance_info +): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() + + +def test_firestore_async_transaction_commit_trace_node_datastore_params( + loop, exercise_async_transaction_commit, instance_info +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback_trace_node_datastore_params( + loop, exercise_async_transaction_rollback, instance_info +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() diff --git a/tests/datastore_firestore/test_batching.py b/tests/datastore_firestore/test_batching.py new file mode 100644 index 000000000..07964338c --- /dev/null +++ b/tests/datastore_firestore/test_batching.py @@ -0,0 +1,127 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + +# ===== WriteBatch ===== + + +@pytest.fixture() +def exercise_write_batch(client, collection): + def _exercise_write_batch(): + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = client.batch() + for doc in docs: + batch.set(doc, {}) + + batch.commit() + + return _exercise_write_batch + + +def test_firestore_write_batch(exercise_write_batch, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_write_batch") + def _test(): + exercise_write_batch() + + _test() + + +def test_firestore_write_batch_trace_node_datastore_params(exercise_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_write_batch() + + _test() + + +# ===== BulkWriteBatch ===== + + +@pytest.fixture() +def exercise_bulk_write_batch(client, collection): + def _exercise_bulk_write_batch(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = BulkWriteBatch(client) + for doc in docs: + batch.set(doc, {}) + + batch.commit() + + return _exercise_bulk_write_batch + + +def test_firestore_bulk_write_batch(exercise_bulk_write_batch, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_bulk_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_bulk_write_batch") + def _test(): + exercise_bulk_write_batch() + + _test() + + +def test_firestore_bulk_write_batch_trace_node_datastore_params(exercise_bulk_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_bulk_write_batch() + + _test() diff --git a/tests/datastore_firestore/test_client.py b/tests/datastore_firestore/test_client.py new file mode 100644 index 000000000..81fbd181c --- /dev/null +++ b/tests/datastore_firestore/test_client.py @@ -0,0 +1,83 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def sample_data(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_client(client, sample_data): + def _exercise_client(): + assert len([_ for _ in client.collections()]) + doc = [_ for _ in client.get_all([sample_data])][0] + assert doc.to_dict()["x"] == 1 + + return _exercise_client + + +def test_firestore_client(exercise_client, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_client") + def _test(): + exercise_client() + + _test() + + +@background_task() +def test_firestore_client_generators(client, sample_data, assert_trace_for_generator): + assert_trace_for_generator(client.collections) + assert_trace_for_generator(client.get_all, [sample_data]) + + +def test_firestore_client_trace_node_datastore_params(exercise_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_client() + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_collections.py b/tests/datastore_firestore/test_collections.py new file mode 100644 index 000000000..2e58bbe95 --- /dev/null +++ b/tests/datastore_firestore/test_collections.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_collections(collection): + def _exercise_collections(): + collection.document("DoesNotExist") + collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = collection.get() + assert len(documents_get) == 2 + documents_stream = [_ for _ in collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ for _ in collection.list_documents()] + assert len(documents_list) == 2 + + return _exercise_collections + + +def test_firestore_collections(exercise_collections, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collections") + def _test(): + exercise_collections() + + _test() + + +@background_task() +def test_firestore_collections_generators(collection, assert_trace_for_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_generator(collection.stream) + assert_trace_for_generator(collection.list_documents) + + +def test_firestore_collections_trace_node_datastore_params(exercise_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collections() + + _test() diff --git a/tests/datastore_firestore/test_documents.py b/tests/datastore_firestore/test_documents.py new file mode 100644 index 000000000..ae6b94edd --- /dev/null +++ b/tests/datastore_firestore/test_documents.py @@ -0,0 +1,104 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_documents(collection): + def _exercise_documents(): + italy_doc = collection.document("Italy") + italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + italy_doc.get() + italian_cities = italy_doc.collection("cities") + italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = collection.document("USA") + usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + usa_doc.update({"president": "Joe Biden"}) + + collection.document("USA").delete() + + return _exercise_documents + + +def test_firestore_documents(exercise_documents, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 7), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_documents") + def _test(): + exercise_documents() + + _test() + + +@background_task() +def test_firestore_documents_generators(collection, assert_trace_for_generator): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + assert_trace_for_generator(subcollection_doc.collections) + + +def test_firestore_documents_trace_node_datastore_params(exercise_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_documents() + + _test() diff --git a/tests/datastore_firestore/test_query.py b/tests/datastore_firestore/test_query.py new file mode 100644 index 000000000..6f1643c5b --- /dev/null +++ b/tests/datastore_firestore/test_query.py @@ -0,0 +1,236 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== Query ===== + + +@pytest.fixture() +def exercise_query(collection): + def _exercise_query(): + query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(query.get()) == 3 + assert len([_ for _ in query.stream()]) == 3 + + return _exercise_query + + +def test_firestore_query(exercise_query, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_query") + def _test(): + exercise_query() + + _test() + + +@background_task() +def test_firestore_query_generators(collection, assert_trace_for_generator): + query = collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_generator(query.stream) + + +def test_firestore_query_trace_node_datastore_params(exercise_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_query() + + _test() + + +# ===== AggregationQuery ===== + + +@pytest.fixture() +def exercise_aggregation_query(collection): + def _exercise_aggregation_query(): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert aggregation_query.get()[0][0].value == 3 + assert [_ for _ in aggregation_query.stream()][0][0].value == 3 + + return _exercise_aggregation_query + + +def test_firestore_aggregation_query(exercise_aggregation_query, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_aggregation_query") + def _test(): + exercise_aggregation_query() + + _test() + + +@background_task() +def test_firestore_aggregation_query_generators(collection, assert_trace_for_generator): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_generator(aggregation_query.stream) + + +def test_firestore_aggregation_query_trace_node_datastore_params(exercise_aggregation_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_aggregation_query() + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a generator of Cursor objects. Each Cursor must point at a valid document path. + To test this, we can patch the RPC to return 1 Cursor which is pointed at any document available. + The get_partitions will take that and make 2 QueryPartition objects out of it, which should be enough to ensure + we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + def mock_partition_query(*args, **kwargs): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + + monkeypatch.setattr(client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_collection_group(client, collection, patch_partition_queries): + def _exercise_collection_group(): + collection_group = client.collection_group(collection.id) + assert len(collection_group.get()) + assert len([d for d in collection_group.stream()]) + + partitions = [p for p in collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(partitions.pop().query().get()) + assert len(documents) == 6 + + return _exercise_collection_group + + +def test_firestore_collection_group(exercise_collection_group, client, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collection_group") + def _test(): + exercise_collection_group() + + _test() + + +@background_task() +def test_firestore_collection_group_generators(client, collection, assert_trace_for_generator, patch_partition_queries): + collection_group = client.collection_group(collection.id) + assert_trace_for_generator(collection_group.get_partitions, 1) + + +def test_firestore_collection_group_trace_node_datastore_params(exercise_collection_group, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collection_group() + + _test() diff --git a/tests/datastore_firestore/test_transaction.py b/tests/datastore_firestore/test_transaction.py new file mode 100644 index 000000000..59d496a00 --- /dev/null +++ b/tests/datastore_firestore/test_transaction.py @@ -0,0 +1,153 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_transaction_commit(client, collection): + def _exercise_transaction_commit(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # get a DocumentReference + [_ for _ in transaction.get(collection.document("doc1"))] + + # get a Query + query = collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ for _ in transaction.get(query)]) == 1 + + # get_all on a list of DocumentReferences + all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ for _ in all_docs]) == 3 + + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 0}) + transaction.delete(collection.document("doc3")) + + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 2 + + return _exercise_transaction_commit + + +@pytest.fixture() +def exercise_transaction_rollback(client, collection): + def _exercise_transaction_rollback(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 99}) + transaction.delete(collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 3 + + return _exercise_transaction_rollback + + +def test_firestore_transaction_commit(exercise_transaction_commit, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ("Datastore/operation/Firestore/get_all", 2), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback(exercise_transaction_rollback, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_rollback() + + _test() + + +def test_firestore_transaction_commit_trace_node_datastore_params(exercise_transaction_commit, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback_trace_node_datastore_params(exercise_transaction_rollback, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_rollback() + + _test() diff --git a/tests/datastore_mysql/test_database.py b/tests/datastore_mysql/test_database.py index 2fc8ca129..8f8641903 100644 --- a/tests/datastore_mysql/test_database.py +++ b/tests/datastore_mysql/test_database.py @@ -13,11 +13,15 @@ # limitations under the License. import mysql.connector - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_database_trace_inputs import validate_database_trace_inputs - from testing_support.db_settings import mysql_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_database_trace_inputs import ( + validate_database_trace_inputs, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + from newrelic.api.background_task import background_task DB_SETTINGS = mysql_settings() @@ -27,80 +31,95 @@ mysql_version = tuple(int(x) for x in mysql.connector.__version__.split(".")[:3]) if mysql_version >= (8, 0, 30): - _connector_metric_name = 'Function/mysql.connector.pooling:connect' + _connector_metric_name = "Function/mysql.connector.pooling:connect" else: - _connector_metric_name = 'Function/mysql.connector:connect' + _connector_metric_name = "Function/mysql.connector:connect" _test_execute_via_cursor_scoped_metrics = [ - (_connector_metric_name, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + (_connector_metric_name, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_execute_via_cursor_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] - -@validate_transaction_metrics('test_database:test_execute_via_cursor', - scoped_metrics=_test_execute_via_cursor_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 12), +] + + +@validate_transaction_metrics( + "test_database:test_execute_via_cursor", + scoped_metrics=_test_execute_via_cursor_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=dict) @background_task() def test_execute_via_cursor(table_name): - connection = mysql.connector.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = mysql.connector.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) cursor = connection.cursor() cursor.execute("""drop table if exists `%s`""" % table_name) - cursor.execute("""create table %s """ - """(a integer, b real, c text)""" % table_name) + cursor.execute("""create table %s """ """(a integer, b real, c text)""" % table_name) - cursor.executemany("""insert into `%s` """ % table_name + - """values (%(a)s, %(b)s, %(c)s)""", [dict(a=1, b=1.0, c='1.0'), - dict(a=2, b=2.2, c='2.2'), dict(a=3, b=3.3, c='3.3')]) + cursor.executemany( + """insert into `%s` """ % table_name + """values (%(a)s, %(b)s, %(c)s)""", + [dict(a=1, b=1.0, c="1.0"), dict(a=2, b=2.2, c="2.2"), dict(a=3, b=3.3, c="3.3")], + ) cursor.execute("""select * from %s""" % table_name) - for row in cursor: pass + for row in cursor: + pass - cursor.execute("""update `%s` """ % table_name + - """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", - dict(a=4, b=4.0, c='4.0', old_a=1)) + cursor.execute( + """update `%s` """ % table_name + """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", + dict(a=4, b=4.0, c="4.0", old_a=1), + ) cursor.execute("""delete from `%s` where a=2""" % table_name) cursor.execute("""drop procedure if exists %s""" % DB_PROCEDURE) - cursor.execute("""CREATE PROCEDURE %s() + cursor.execute( + """CREATE PROCEDURE %s() BEGIN SELECT 'Hello World!'; - END""" % DB_PROCEDURE) + END""" + % DB_PROCEDURE + ) cursor.callproc("%s" % DB_PROCEDURE) @@ -108,76 +127,92 @@ def test_execute_via_cursor(table_name): connection.rollback() connection.commit() + _test_connect_using_alias_scoped_metrics = [ - (_connector_metric_name, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + (_connector_metric_name, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_connect_using_alias_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] - -@validate_transaction_metrics('test_database:test_connect_using_alias', - scoped_metrics=_test_connect_using_alias_scoped_metrics, - rollup_metrics=_test_connect_using_alias_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 12), +] + + +@validate_transaction_metrics( + "test_database:test_connect_using_alias", + scoped_metrics=_test_connect_using_alias_scoped_metrics, + rollup_metrics=_test_connect_using_alias_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=dict) @background_task() def test_connect_using_alias(table_name): - connection = mysql.connector.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = mysql.connector.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) cursor = connection.cursor() cursor.execute("""drop table if exists `%s`""" % table_name) - cursor.execute("""create table %s """ - """(a integer, b real, c text)""" % table_name) + cursor.execute("""create table %s """ """(a integer, b real, c text)""" % table_name) - cursor.executemany("""insert into `%s` """ % table_name + - """values (%(a)s, %(b)s, %(c)s)""", [dict(a=1, b=1.0, c='1.0'), - dict(a=2, b=2.2, c='2.2'), dict(a=3, b=3.3, c='3.3')]) + cursor.executemany( + """insert into `%s` """ % table_name + """values (%(a)s, %(b)s, %(c)s)""", + [dict(a=1, b=1.0, c="1.0"), dict(a=2, b=2.2, c="2.2"), dict(a=3, b=3.3, c="3.3")], + ) cursor.execute("""select * from %s""" % table_name) - for row in cursor: pass + for row in cursor: + pass - cursor.execute("""update `%s` """ % table_name + - """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", - dict(a=4, b=4.0, c='4.0', old_a=1)) + cursor.execute( + """update `%s` """ % table_name + """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", + dict(a=4, b=4.0, c="4.0", old_a=1), + ) cursor.execute("""delete from `%s` where a=2""" % table_name) cursor.execute("""drop procedure if exists %s""" % DB_PROCEDURE) - cursor.execute("""CREATE PROCEDURE %s() + cursor.execute( + """CREATE PROCEDURE %s() BEGIN SELECT 'Hello World!'; - END""" % DB_PROCEDURE) + END""" + % DB_PROCEDURE + ) cursor.callproc("%s" % DB_PROCEDURE) diff --git a/tests/datastore_postgresql/conftest.py b/tests/datastore_postgresql/conftest.py index 624fb4726..4a25f2574 100644 --- a/tests/datastore_postgresql/conftest.py +++ b/tests/datastore_postgresql/conftest.py @@ -12,21 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { - 'transaction_tracer.explain_threshold': 0.0, - 'transaction_tracer.transaction_threshold': 0.0, - 'transaction_tracer.stack_trace_threshold': 0.0, - 'debug.log_data_collector_payloads': True, - 'debug.record_transaction_failure': True, - 'debug.log_explain_plan_queries': True + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, } collector_agent_registration = collector_agent_registration_fixture( - app_name='Python Agent Test (datastore_postgresql)', - default_settings=_default_settings, - linked_applications=['Python Agent Test (datastore)']) + app_name="Python Agent Test (datastore_postgresql)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) diff --git a/tests/datastore_postgresql/test_database.py b/tests/datastore_postgresql/test_database.py index 2ea930b05..cf432d174 100644 --- a/tests/datastore_postgresql/test_database.py +++ b/tests/datastore_postgresql/test_database.py @@ -13,15 +13,14 @@ # limitations under the License. import postgresql.driver.dbapi20 - - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics - +from testing_support.db_settings import postgresql_settings +from testing_support.util import instance_hostname from testing_support.validators.validate_database_trace_inputs import ( validate_database_trace_inputs, ) - -from testing_support.db_settings import postgresql_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task @@ -41,13 +40,14 @@ ("Datastore/operation/Postgres/create", 1), ("Datastore/operation/Postgres/commit", 3), ("Datastore/operation/Postgres/rollback", 1), + ("Datastore/operation/Postgres/other", 1), ] _test_execute_via_cursor_rollup_metrics = [ - ("Datastore/all", 13), - ("Datastore/allOther", 13), - ("Datastore/Postgres/all", 13), - ("Datastore/Postgres/allOther", 13), + ("Datastore/all", 14), + ("Datastore/allOther", 14), + ("Datastore/Postgres/all", 14), + ("Datastore/Postgres/allOther", 14), ("Datastore/operation/Postgres/select", 1), ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), ("Datastore/operation/Postgres/insert", 1), @@ -63,6 +63,11 @@ ("Datastore/operation/Postgres/call", 2), ("Datastore/operation/Postgres/commit", 3), ("Datastore/operation/Postgres/rollback", 1), + ("Datastore/operation/Postgres/other", 1), + ("Datastore/instance/Postgres/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 13), + ("Function/postgresql.driver.dbapi20:connect", 1), + ("Function/postgresql.driver.dbapi20:Connection.__enter__", 1), + ("Function/postgresql.driver.dbapi20:Connection.__exit__", 1), ] @@ -82,30 +87,27 @@ def test_execute_via_cursor(): host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], ) as connection: - cursor = connection.cursor() cursor.execute("""drop table if exists %s""" % DB_SETTINGS["table_name"]) - cursor.execute( - """create table %s """ % DB_SETTINGS["table_name"] - + """(a integer, b real, c text)""" - ) + cursor.execute("""create table %s """ % DB_SETTINGS["table_name"] + """(a integer, b real, c text)""") cursor.executemany( - """insert into %s """ % DB_SETTINGS["table_name"] - + """values (%s, %s, %s)""", + """insert into %s """ % DB_SETTINGS["table_name"] + """values (%s, %s, %s)""", [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")], ) cursor.execute("""select * from %s""" % DB_SETTINGS["table_name"]) - for row in cursor: - pass + cursor.execute( + """with temporaryTable (averageValue) as (select avg(b) from %s) """ % DB_SETTINGS["table_name"] + + """select * from %s,temporaryTable """ % DB_SETTINGS["table_name"] + + """where %s.b > temporaryTable.averageValue""" % DB_SETTINGS["table_name"] + ) cursor.execute( - """update %s """ % DB_SETTINGS["table_name"] - + """set a=%s, b=%s, c=%s where a=%s""", + """update %s """ % DB_SETTINGS["table_name"] + """set a=%s, b=%s, c=%s where a=%s""", (4, 4.0, "4.0", 1), ) @@ -152,7 +154,6 @@ def test_rollback_on_exception(): host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], ): - raise RuntimeError("error") except RuntimeError: diff --git a/tests/datastore_psycopg2cffi/test_database.py b/tests/datastore_psycopg2cffi/test_database.py index 54ff6ad09..939c5cabc 100644 --- a/tests/datastore_psycopg2cffi/test_database.py +++ b/tests/datastore_psycopg2cffi/test_database.py @@ -15,166 +15,190 @@ import psycopg2cffi import psycopg2cffi.extensions import psycopg2cffi.extras - -from testing_support.fixtures import validate_stats_engine_explain_plan_output_is_none -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_transaction_slow_sql_count import \ - validate_transaction_slow_sql_count -from testing_support.validators.validate_database_trace_inputs import validate_database_trace_inputs - from testing_support.db_settings import postgresql_settings +from testing_support.fixtures import validate_stats_engine_explain_plan_output_is_none +from testing_support.util import instance_hostname +from testing_support.validators.validate_database_trace_inputs import ( + validate_database_trace_inputs, +) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_transaction_slow_sql_count import ( + validate_transaction_slow_sql_count, +) from newrelic.api.background_task import background_task DB_SETTINGS = postgresql_settings()[0] _test_execute_via_cursor_scoped_metrics = [ - ('Function/psycopg2cffi:connect', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__enter__', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__exit__', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/update' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/delete' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/now/call', 1), - ('Datastore/statement/Postgres/pg_sleep/call', 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1), - ('Datastore/operation/Postgres/commit', 3), - ('Datastore/operation/Postgres/rollback', 1)] + ("Function/psycopg2cffi:connect", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__enter__", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__exit__", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/update" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/delete" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/now/call", 1), + ("Datastore/statement/Postgres/pg_sleep/call", 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), + ("Datastore/operation/Postgres/commit", 3), + ("Datastore/operation/Postgres/rollback", 1), +] _test_execute_via_cursor_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/Postgres/all', 13), - ('Datastore/Postgres/allOther', 13), - ('Datastore/operation/Postgres/select', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/insert', 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/update', 1), - ('Datastore/statement/Postgres/%s/update' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/delete', 1), - ('Datastore/statement/Postgres/%s/delete' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1), - ('Datastore/statement/Postgres/now/call', 1), - ('Datastore/statement/Postgres/pg_sleep/call', 1), - ('Datastore/operation/Postgres/call', 2), - ('Datastore/operation/Postgres/commit', 3), - ('Datastore/operation/Postgres/rollback', 1)] - - -@validate_transaction_metrics('test_database:test_execute_via_cursor', - scoped_metrics=_test_execute_via_cursor_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/Postgres/all", 13), + ("Datastore/Postgres/allOther", 13), + ("Datastore/operation/Postgres/select", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/insert", 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/update", 1), + ("Datastore/statement/Postgres/%s/update" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/delete", 1), + ("Datastore/statement/Postgres/%s/delete" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), + ("Datastore/statement/Postgres/now/call", 1), + ("Datastore/statement/Postgres/pg_sleep/call", 1), + ("Datastore/operation/Postgres/call", 2), + ("Datastore/operation/Postgres/commit", 3), + ("Datastore/operation/Postgres/rollback", 1), + ("Datastore/instance/Postgres/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 12), +] + + +@validate_transaction_metrics( + "test_database:test_execute_via_cursor", + scoped_metrics=_test_execute_via_cursor_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_execute_via_cursor(): with psycopg2cffi.connect( - database=DB_SETTINGS['name'], user=DB_SETTINGS['user'], - password=DB_SETTINGS['password'], host=DB_SETTINGS['host'], - port=DB_SETTINGS['port']) as connection: + database=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + password=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) as connection: cursor = connection.cursor() psycopg2cffi.extensions.register_type(psycopg2cffi.extensions.UNICODE) - psycopg2cffi.extensions.register_type( - psycopg2cffi.extensions.UNICODE, - connection) - psycopg2cffi.extensions.register_type( - psycopg2cffi.extensions.UNICODE, - cursor) + psycopg2cffi.extensions.register_type(psycopg2cffi.extensions.UNICODE, connection) + psycopg2cffi.extensions.register_type(psycopg2cffi.extensions.UNICODE, cursor) cursor.execute("""drop table if exists %s""" % DB_SETTINGS["table_name"]) - cursor.execute("""create table %s """ % DB_SETTINGS["table_name"] + - """(a integer, b real, c text)""") + cursor.execute("""create table %s """ % DB_SETTINGS["table_name"] + """(a integer, b real, c text)""") - cursor.executemany("""insert into %s """ % DB_SETTINGS["table_name"] + - """values (%s, %s, %s)""", [(1, 1.0, '1.0'), - (2, 2.2, '2.2'), (3, 3.3, '3.3')]) + cursor.executemany( + """insert into %s """ % DB_SETTINGS["table_name"] + """values (%s, %s, %s)""", + [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")], + ) cursor.execute("""select * from %s""" % DB_SETTINGS["table_name"]) for row in cursor: pass - cursor.execute("""update %s""" % DB_SETTINGS["table_name"] + """ set a=%s, b=%s, """ - """c=%s where a=%s""", (4, 4.0, '4.0', 1)) + cursor.execute( + """update %s""" % DB_SETTINGS["table_name"] + """ set a=%s, b=%s, """ """c=%s where a=%s""", + (4, 4.0, "4.0", 1), + ) cursor.execute("""delete from %s where a=2""" % DB_SETTINGS["table_name"]) connection.commit() - cursor.callproc('now') - cursor.callproc('pg_sleep', (0,)) + cursor.callproc("now") + cursor.callproc("pg_sleep", (0,)) connection.rollback() connection.commit() _test_rollback_on_exception_scoped_metrics = [ - ('Function/psycopg2cffi:connect', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__enter__', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__exit__', 1), - ('Datastore/operation/Postgres/rollback', 1)] + ("Function/psycopg2cffi:connect", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__enter__", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__exit__", 1), + ("Datastore/operation/Postgres/rollback", 1), +] _test_rollback_on_exception_rollup_metrics = [ - ('Datastore/all', 2), - ('Datastore/allOther', 2), - ('Datastore/Postgres/all', 2), - ('Datastore/Postgres/allOther', 2)] - - -@validate_transaction_metrics('test_database:test_rollback_on_exception', - scoped_metrics=_test_rollback_on_exception_scoped_metrics, - rollup_metrics=_test_rollback_on_exception_rollup_metrics, - background_task=True) + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/Postgres/all", 2), + ("Datastore/Postgres/allOther", 2), +] + + +@validate_transaction_metrics( + "test_database:test_rollback_on_exception", + scoped_metrics=_test_rollback_on_exception_scoped_metrics, + rollup_metrics=_test_rollback_on_exception_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_rollback_on_exception(): try: with psycopg2cffi.connect( - database=DB_SETTINGS['name'], user=DB_SETTINGS['user'], - password=DB_SETTINGS['password'], host=DB_SETTINGS['host'], - port=DB_SETTINGS['port']): - - raise RuntimeError('error') + database=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + password=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ): + + raise RuntimeError("error") except RuntimeError: pass _test_async_mode_scoped_metrics = [ - ('Function/psycopg2cffi:connect', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1)] + ("Function/psycopg2cffi:connect", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), +] _test_async_mode_rollup_metrics = [ - ('Datastore/all', 5), - ('Datastore/allOther', 5), - ('Datastore/Postgres/all', 5), - ('Datastore/Postgres/allOther', 5), - ('Datastore/operation/Postgres/select', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/insert', 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1)] + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/Postgres/all", 5), + ("Datastore/Postgres/allOther", 5), + ("Datastore/operation/Postgres/select", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/insert", 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), + ("Datastore/instance/Postgres/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 4), +] @validate_stats_engine_explain_plan_output_is_none() @validate_transaction_slow_sql_count(num_slow_sql=4) @validate_database_trace_inputs(sql_parameters_type=tuple) -@validate_transaction_metrics('test_database:test_async_mode', - scoped_metrics=_test_async_mode_scoped_metrics, - rollup_metrics=_test_async_mode_rollup_metrics, - background_task=True) +@validate_transaction_metrics( + "test_database:test_async_mode", + scoped_metrics=_test_async_mode_scoped_metrics, + rollup_metrics=_test_async_mode_rollup_metrics, + background_task=True, +) @validate_transaction_errors(errors=[]) @background_task() def test_async_mode(): @@ -182,16 +206,19 @@ def test_async_mode(): wait = psycopg2cffi.extras.wait_select kwargs = {} - version = tuple(int(_) for _ in psycopg2cffi.__version__.split('.')) + version = tuple(int(_) for _ in psycopg2cffi.__version__.split(".")) if version >= (2, 8): - kwargs['async_'] = 1 + kwargs["async_"] = 1 else: - kwargs['async'] = 1 + kwargs["async"] = 1 async_conn = psycopg2cffi.connect( - database=DB_SETTINGS['name'], user=DB_SETTINGS['user'], - password=DB_SETTINGS['password'], host=DB_SETTINGS['host'], - port=DB_SETTINGS['port'], **kwargs + database=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + password=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + **kwargs ) wait(async_conn) async_cur = async_conn.cursor() @@ -199,12 +226,10 @@ def test_async_mode(): async_cur.execute("""drop table if exists %s""" % DB_SETTINGS["table_name"]) wait(async_cur.connection) - async_cur.execute("""create table %s """ % DB_SETTINGS["table_name"] + - """(a integer, b real, c text)""") + async_cur.execute("""create table %s """ % DB_SETTINGS["table_name"] + """(a integer, b real, c text)""") wait(async_cur.connection) - async_cur.execute("""insert into %s """ % DB_SETTINGS["table_name"] + - """values (%s, %s, %s)""", (1, 1.0, '1.0')) + async_cur.execute("""insert into %s """ % DB_SETTINGS["table_name"] + """values (%s, %s, %s)""", (1, 1.0, "1.0")) wait(async_cur.connection) async_cur.execute("""select * from %s""" % DB_SETTINGS["table_name"]) diff --git a/tests/datastore_pylibmc/test_memcache.py b/tests/datastore_pylibmc/test_memcache.py index 769f3b483..64da33416 100644 --- a/tests/datastore_pylibmc/test_memcache.py +++ b/tests/datastore_pylibmc/test_memcache.py @@ -12,85 +12,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pylibmc - from testing_support.db_settings import memcached_settings -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task from newrelic.api.transaction import set_background_task - DB_SETTINGS = memcached_settings()[0] MEMCACHED_HOST = DB_SETTINGS["host"] MEMCACHED_PORT = DB_SETTINGS["port"] MEMCACHED_NAMESPACE = DB_SETTINGS["namespace"] -MEMCACHED_ADDR = '%s:%s' % (MEMCACHED_HOST, MEMCACHED_PORT) +MEMCACHED_ADDR = "%s:%s" % (MEMCACHED_HOST, MEMCACHED_PORT) _test_bt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_bt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allOther', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allOther", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_bt_set_get_delete', - scoped_metrics=_test_bt_set_get_delete_scoped_metrics, - rollup_metrics=_test_bt_set_get_delete_rollup_metrics, - background_task=True) + "test_memcache:test_bt_set_get_delete", + scoped_metrics=_test_bt_set_get_delete_scoped_metrics, + rollup_metrics=_test_bt_set_get_delete_rollup_metrics, + background_task=True, +) @background_task() def test_bt_set_get_delete(): set_background_task(True) client = pylibmc.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" + _test_wt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_wt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allWeb', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allWeb', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allWeb", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allWeb", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_wt_set_get_delete', - scoped_metrics=_test_wt_set_get_delete_scoped_metrics, - rollup_metrics=_test_wt_set_get_delete_rollup_metrics, - background_task=False) + "test_memcache:test_wt_set_get_delete", + scoped_metrics=_test_wt_set_get_delete_scoped_metrics, + rollup_metrics=_test_wt_set_get_delete_rollup_metrics, + background_task=False, +) @background_task() def test_wt_set_get_delete(): set_background_task(False) client = pylibmc.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" diff --git a/tests/datastore_pymemcache/test_memcache.py b/tests/datastore_pymemcache/test_memcache.py index 9aeea4d54..3100db5b7 100644 --- a/tests/datastore_pymemcache/test_memcache.py +++ b/tests/datastore_pymemcache/test_memcache.py @@ -15,9 +15,10 @@ import os import pymemcache.client - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import memcached_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task from newrelic.api.transaction import set_background_task @@ -31,65 +32,74 @@ MEMCACHED_ADDR = (MEMCACHED_HOST, int(MEMCACHED_PORT)) _test_bt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_bt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allOther', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allOther", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_bt_set_get_delete', - scoped_metrics=_test_bt_set_get_delete_scoped_metrics, - rollup_metrics=_test_bt_set_get_delete_rollup_metrics, - background_task=True) + "test_memcache:test_bt_set_get_delete", + scoped_metrics=_test_bt_set_get_delete_scoped_metrics, + rollup_metrics=_test_bt_set_get_delete_rollup_metrics, + background_task=True, +) @background_task() def test_bt_set_get_delete(): set_background_task(True) client = pymemcache.client.Client(MEMCACHED_ADDR) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, b'value') + client.set(key, b"value") value = client.get(key) client.delete(key) - assert value == b'value' + assert value == b"value" + _test_wt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_wt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allWeb', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allWeb', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allWeb", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allWeb", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_wt_set_get_delete', - scoped_metrics=_test_wt_set_get_delete_scoped_metrics, - rollup_metrics=_test_wt_set_get_delete_rollup_metrics, - background_task=False) + "test_memcache:test_wt_set_get_delete", + scoped_metrics=_test_wt_set_get_delete_scoped_metrics, + rollup_metrics=_test_wt_set_get_delete_rollup_metrics, + background_task=False, +) @background_task() def test_wt_set_get_delete(): set_background_task(False) client = pymemcache.client.Client(MEMCACHED_ADDR) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, b'value') + client.set(key, b"value") value = client.get(key) client.delete(key) - assert value == b'value' + assert value == b"value" diff --git a/tests/datastore_pymssql/conftest.py b/tests/datastore_pymssql/conftest.py new file mode 100644 index 000000000..a6584cdff --- /dev/null +++ b/tests/datastore_pymssql/conftest.py @@ -0,0 +1,36 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.fixtures import ( + collector_agent_registration_fixture, + collector_available_fixture, +) # noqa: F401; pylint: disable=W0611 + + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (datastore_pymssql)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) diff --git a/tests/datastore_pymssql/test_database.py b/tests/datastore_pymssql/test_database.py new file mode 100644 index 000000000..bdbf75c15 --- /dev/null +++ b/tests/datastore_pymssql/test_database.py @@ -0,0 +1,115 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pymssql + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from testing_support.validators.validate_database_trace_inputs import validate_database_trace_inputs + +from testing_support.db_settings import mssql_settings + +from newrelic.api.background_task import background_task + +DB_SETTINGS = mssql_settings()[0] +TABLE_NAME = "datastore_pymssql_" + DB_SETTINGS["namespace"] +PROCEDURE_NAME = "hello_" + DB_SETTINGS["namespace"] + + +def execute_db_calls_with_cursor(cursor): + cursor.execute("""drop table if exists %s""" % TABLE_NAME) + + cursor.execute("""create table %s """ % TABLE_NAME + """(a integer, b real, c text)""") + + cursor.executemany( + """insert into %s """ % TABLE_NAME + """values (%s, %s, %s)""", + [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")], + ) + + cursor.execute("""select * from %s""" % TABLE_NAME) + + for row in cursor: + pass + + cursor.execute("""update %s""" % TABLE_NAME + """ set a=%s, b=%s, """ """c=%s where a=%s""", (4, 4.0, "4.0", 1)) + + cursor.execute("""delete from %s where a=2""" % TABLE_NAME) + cursor.execute("""drop procedure if exists %s""" % PROCEDURE_NAME) + cursor.execute( + """CREATE PROCEDURE %s AS + BEGIN + SELECT 'Hello World!'; + END""" + % PROCEDURE_NAME + ) + + cursor.callproc(PROCEDURE_NAME) + + +_test_scoped_metrics = [ + ("Function/pymssql._pymssql:connect", 1), + ("Datastore/statement/MSSQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MSSQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MSSQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MSSQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MSSQL/drop", 2), + ("Datastore/operation/MSSQL/create", 2), + ("Datastore/statement/MSSQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MSSQL/commit", 2), + ("Datastore/operation/MSSQL/rollback", 1), +] + +_test_rollup_metrics = [ + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MSSQL/all", 13), + ("Datastore/MSSQL/allOther", 13), + ("Datastore/statement/MSSQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MSSQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MSSQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MSSQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MSSQL/select", 1), + ("Datastore/operation/MSSQL/insert", 1), + ("Datastore/operation/MSSQL/update", 1), + ("Datastore/operation/MSSQL/delete", 1), + ("Datastore/statement/MSSQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MSSQL/call", 1), + ("Datastore/operation/MSSQL/drop", 2), + ("Datastore/operation/MSSQL/create", 2), + ("Datastore/operation/MSSQL/commit", 2), + ("Datastore/operation/MSSQL/rollback", 1), +] + + +@validate_transaction_metrics( + "test_database:test_execute_via_cursor_context_manager", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, +) +@validate_database_trace_inputs(sql_parameters_type=tuple) +@background_task() +def test_execute_via_cursor_context_manager(): + connection = pymssql.connect( + user=DB_SETTINGS["user"], password=DB_SETTINGS["password"], host=DB_SETTINGS["host"], port=DB_SETTINGS["port"] + ) + + with connection: + cursor = connection.cursor() + + with cursor: + execute_db_calls_with_cursor(cursor) + + connection.commit() + connection.rollback() + connection.commit() diff --git a/tests/datastore_pymysql/test_database.py b/tests/datastore_pymysql/test_database.py index 5943b1266..ad4db1d9c 100644 --- a/tests/datastore_pymysql/test_database.py +++ b/tests/datastore_pymysql/test_database.py @@ -13,11 +13,14 @@ # limitations under the License. import pymysql - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_database_trace_inputs import validate_database_trace_inputs - from testing_support.db_settings import mysql_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_database_trace_inputs import ( + validate_database_trace_inputs, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task @@ -25,76 +28,92 @@ TABLE_NAME = "datastore_pymysql_" + DB_SETTINGS["namespace"] PROCEDURE_NAME = "hello_" + DB_SETTINGS["namespace"] +HOST = instance_hostname(DB_SETTINGS["host"]) +PORT = DB_SETTINGS["port"] + def execute_db_calls_with_cursor(cursor): cursor.execute("""drop table if exists %s""" % TABLE_NAME) - cursor.execute("""create table %s """ % TABLE_NAME + - """(a integer, b real, c text)""") + cursor.execute("""create table %s """ % TABLE_NAME + """(a integer, b real, c text)""") - cursor.executemany("""insert into %s """ % TABLE_NAME + - """values (%s, %s, %s)""", [(1, 1.0, '1.0'), - (2, 2.2, '2.2'), (3, 3.3, '3.3')]) + cursor.executemany( + """insert into %s """ % TABLE_NAME + """values (%s, %s, %s)""", + [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")], + ) cursor.execute("""select * from %s""" % TABLE_NAME) - for row in cursor: pass + for row in cursor: + pass - cursor.execute("""update %s""" % TABLE_NAME + """ set a=%s, b=%s, """ - """c=%s where a=%s""", (4, 4.0, '4.0', 1)) + cursor.execute("""update %s""" % TABLE_NAME + """ set a=%s, b=%s, """ """c=%s where a=%s""", (4, 4.0, "4.0", 1)) cursor.execute("""delete from %s where a=2""" % TABLE_NAME) cursor.execute("""drop procedure if exists %s""" % PROCEDURE_NAME) - cursor.execute("""CREATE PROCEDURE %s() + cursor.execute( + """CREATE PROCEDURE %s() BEGIN SELECT 'Hello World!'; - END""" % PROCEDURE_NAME) + END""" + % PROCEDURE_NAME + ) cursor.callproc(PROCEDURE_NAME) _test_execute_via_cursor_scoped_metrics = [ - ('Function/pymysql:Connect', 1), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + ("Function/pymysql:Connect", 1), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_execute_via_cursor_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] - -@validate_transaction_metrics('test_database:test_execute_via_cursor', - scoped_metrics=_test_execute_via_cursor_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (HOST, PORT), 12), +] + + +@validate_transaction_metrics( + "test_database:test_execute_via_cursor", + scoped_metrics=_test_execute_via_cursor_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_execute_via_cursor(): - connection = pymysql.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = pymysql.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) with connection.cursor() as cursor: execute_db_calls_with_cursor(cursor) @@ -105,49 +124,57 @@ def test_execute_via_cursor(): _test_execute_via_cursor_context_mangaer_scoped_metrics = [ - ('Function/pymysql:Connect', 1), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + ("Function/pymysql:Connect", 1), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_execute_via_cursor_context_mangaer_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (HOST, PORT), 12), +] @validate_transaction_metrics( - 'test_database:test_execute_via_cursor_context_manager', - scoped_metrics=_test_execute_via_cursor_context_mangaer_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_context_mangaer_rollup_metrics, - background_task=True) + "test_database:test_execute_via_cursor_context_manager", + scoped_metrics=_test_execute_via_cursor_context_mangaer_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_context_mangaer_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_execute_via_cursor_context_manager(): - connection = pymysql.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = pymysql.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) cursor = connection.cursor() with cursor: diff --git a/tests/datastore_pyodbc/test_pyodbc.py b/tests/datastore_pyodbc/test_pyodbc.py index 119908e4d..5a810be5f 100644 --- a/tests/datastore_pyodbc/test_pyodbc.py +++ b/tests/datastore_pyodbc/test_pyodbc.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest from testing_support.db_settings import postgresql_settings +from testing_support.util import instance_hostname from testing_support.validators.validate_database_trace_inputs import ( validate_database_trace_inputs, ) diff --git a/tests/datastore_pysolr/test_solr.py b/tests/datastore_pysolr/test_solr.py index a987a29ac..e17117117 100644 --- a/tests/datastore_pysolr/test_solr.py +++ b/tests/datastore_pysolr/test_solr.py @@ -13,16 +13,19 @@ # limitations under the License. from pysolr import Solr - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import solr_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task DB_SETTINGS = solr_settings()[0] SOLR_HOST = DB_SETTINGS["host"] SOLR_PORT = DB_SETTINGS["port"] -SOLR_URL = 'http://%s:%s/solr/collection' % (DB_SETTINGS["host"], DB_SETTINGS["port"]) +SOLR_URL = "http://%s:%s/solr/collection" % (DB_SETTINGS["host"], DB_SETTINGS["port"]) + def _exercise_solr(solr): # Construct document names within namespace @@ -31,30 +34,36 @@ def _exercise_solr(solr): solr.add([{"id": x} for x in documents]) - solr.search('id:%s' % documents[0]) + solr.search("id:%s" % documents[0]) solr.delete(id=documents[0]) # Delete all documents. - solr.delete(q='id:*_%s' % DB_SETTINGS["namespace"]) + solr.delete(q="id:*_%s" % DB_SETTINGS["namespace"]) + _test_solr_search_scoped_metrics = [ - ('Datastore/operation/Solr/add', 1), - ('Datastore/operation/Solr/delete', 2), - ('Datastore/operation/Solr/search', 1)] + ("Datastore/operation/Solr/add", 1), + ("Datastore/operation/Solr/delete", 2), + ("Datastore/operation/Solr/search", 1), +] _test_solr_search_rollup_metrics = [ - ('Datastore/all', 4), - ('Datastore/allOther', 4), - ('Datastore/Solr/all', 4), - ('Datastore/Solr/allOther', 4), - ('Datastore/operation/Solr/add', 1), - ('Datastore/operation/Solr/search', 1), - ('Datastore/operation/Solr/delete', 2)] - -@validate_transaction_metrics('test_solr:test_solr_search', + ("Datastore/all", 4), + ("Datastore/allOther", 4), + ("Datastore/Solr/all", 4), + ("Datastore/Solr/allOther", 4), + ("Datastore/operation/Solr/add", 1), + ("Datastore/operation/Solr/search", 1), + ("Datastore/operation/Solr/delete", 2), +] + + +@validate_transaction_metrics( + "test_solr:test_solr_search", scoped_metrics=_test_solr_search_scoped_metrics, rollup_metrics=_test_solr_search_rollup_metrics, - background_task=True) + background_task=True, +) @background_task() def test_solr_search(): s = Solr(SOLR_URL) diff --git a/tests/datastore_redis/conftest.py b/tests/datastore_redis/conftest.py index 53ff2658d..6747039b4 100644 --- a/tests/datastore_redis/conftest.py +++ b/tests/datastore_redis/conftest.py @@ -15,6 +15,7 @@ import pytest from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 +from testing_support.fixture.event_loop import event_loop as loop # noqa: F401; pylint: disable=W0611 _default_settings = { diff --git a/tests/datastore_redis/test_asyncio.py b/tests/datastore_redis/test_asyncio.py new file mode 100644 index 000000000..f46e8515e --- /dev/null +++ b/tests/datastore_redis/test_asyncio.py @@ -0,0 +1,160 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest +from testing_support.db_settings import redis_settings +from testing_support.fixture.event_loop import event_loop as loop # noqa: F401 +from testing_support.util import instance_hostname +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple + +# Settings + +DB_SETTINGS = redis_settings()[0] +REDIS_PY_VERSION = get_package_version_tuple("redis") + +# Metrics for publish test + +datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 + +_base_scoped_metrics = [("Datastore/operation/Redis/publish", 3)] + +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append( + ("Datastore/operation/Redis/client_setinfo", 2), + ) + +_base_rollup_metrics = [ + ("Datastore/all", datastore_all_metric_count), + ("Datastore/allOther", datastore_all_metric_count), + ("Datastore/Redis/all", datastore_all_metric_count), + ("Datastore/Redis/allOther", datastore_all_metric_count), + ("Datastore/operation/Redis/publish", 3), + ( + "Datastore/instance/Redis/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), + datastore_all_metric_count, + ), +] +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append( + ("Datastore/operation/Redis/client_setinfo", 2), + ) + + +# Metrics for connection pool test + +_base_pool_scoped_metrics = [ + ("Datastore/operation/Redis/get", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/client_list", 1), +] + +_base_pool_rollup_metrics = [ + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Redis/all", 3), + ("Datastore/Redis/allOther", 3), + ("Datastore/operation/Redis/get", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/client_list", 1), + ("Datastore/instance/Redis/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 3), +] + + +# Tests + + +@pytest.fixture() +def client(loop): # noqa + import redis.asyncio + + return loop.run_until_complete(redis.asyncio.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0)) + + +@pytest.fixture() +def client_pool(loop): # noqa + import redis.asyncio + + connection_pool = redis.asyncio.ConnectionPool(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + return loop.run_until_complete(redis.asyncio.Redis(connection_pool=connection_pool)) + + +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") +@validate_transaction_metrics( + "test_asyncio:test_async_connection_pool", + scoped_metrics=_base_pool_scoped_metrics, + rollup_metrics=_base_pool_rollup_metrics, + background_task=True, +) +@background_task() +def test_async_connection_pool(client_pool, loop): # noqa + async def _test_async_pool(client_pool): + await client_pool.set("key1", "value1") + await client_pool.get("key1") + await client_pool.execute_command("CLIENT", "LIST") + + loop.run_until_complete(_test_async_pool(client_pool)) + + +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") +@validate_transaction_metrics("test_asyncio:test_async_pipeline", background_task=True) +@background_task() +def test_async_pipeline(client, loop): # noqa + async def _test_pipeline(client): + async with client.pipeline(transaction=True) as pipe: + await pipe.set("key1", "value1") + await pipe.execute() + + loop.run_until_complete(_test_pipeline(client)) + + +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") +@validate_transaction_metrics( + "test_asyncio:test_async_pubsub", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_base_rollup_metrics, + background_task=True, +) +@background_task() +def test_async_pubsub(client, loop): # noqa + messages_received = [] + + async def reader(pubsub): + while True: + message = await pubsub.get_message(ignore_subscribe_messages=True) + if message: + messages_received.append(message["data"].decode()) + if message["data"].decode() == "NOPE": + break + + async def _test_pubsub(): + async with client.pubsub() as pubsub: + await pubsub.psubscribe("channel:*") + + future = asyncio.create_task(reader(pubsub)) + + await client.publish("channel:1", "Hello") + await client.publish("channel:2", "World") + await client.publish("channel:1", "NOPE") + + await future + + loop.run_until_complete(_test_pubsub()) + assert messages_received == ["Hello", "World", "NOPE"] diff --git a/tests/datastore_redis/test_custom_conn_pool.py b/tests/datastore_redis/test_custom_conn_pool.py index 156c9ce31..b16a77f48 100644 --- a/tests/datastore_redis/test_custom_conn_pool.py +++ b/tests/datastore_redis/test_custom_conn_pool.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -''' The purpose of these tests is to confirm that using a non-standard +""" The purpose of these tests is to confirm that using a non-standard connection pool that does not have a `connection_kwargs` attribute will not result in an error. -''' +""" import pytest import redis - -from newrelic.api.background_task import background_task - -from testing_support.fixtures import override_application_settings -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import redis_settings +from testing_support.fixtures import override_application_settings from testing_support.util import instance_hostname +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple DB_SETTINGS = redis_settings()[0] -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") class FakeConnectionPool(object): @@ -43,112 +45,120 @@ def get_connection(self, name, *keys, **options): def release(self, connection): self.connection.disconnect() + def disconnect(self): + self.connection.disconnect() + + # Settings _enable_instance_settings = { - 'datastore_tracer.instance_reporting.enabled': True, + "datastore_tracer.instance_reporting.enabled": True, } _disable_instance_settings = { - 'datastore_tracer.instance_reporting.enabled': False, + "datastore_tracer.instance_reporting.enabled": False, } # Metrics # We don't record instance metrics when using redis blaster, # so we just check for base metrics. - -_base_scoped_metrics = ( - ('Datastore/operation/Redis/get', 1), - ('Datastore/operation/Redis/set', 1), - ('Datastore/operation/Redis/client_list', 1), -) - -_base_rollup_metrics = ( - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Redis/all', 3), - ('Datastore/Redis/allOther', 3), - ('Datastore/operation/Redis/get', 1), - ('Datastore/operation/Redis/set', 1), - ('Datastore/operation/Redis/client_list', 1), -) - -_disable_scoped_metrics = list(_base_scoped_metrics) -_disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) -_enable_rollup_metrics = list(_base_rollup_metrics) - -_host = instance_hostname(DB_SETTINGS['host']) -_port = DB_SETTINGS['port'] - -_instance_metric_name = 'Datastore/instance/Redis/%s/%s' % (_host, _port) - -_enable_rollup_metrics.append( - (_instance_metric_name, 3) -) - -_disable_rollup_metrics.append( - (_instance_metric_name, None) -) +datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 + +_base_scoped_metrics = [ + ("Datastore/operation/Redis/get", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/client_list", 1), +] +# client_setinfo was introduced in v5.0.0 and assigns info displayed in client_list output +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append( + ("Datastore/operation/Redis/client_setinfo", 2), + ) + +_base_rollup_metrics = [ + ("Datastore/all", datastore_all_metric_count), + ("Datastore/allOther", datastore_all_metric_count), + ("Datastore/Redis/all", datastore_all_metric_count), + ("Datastore/Redis/allOther", datastore_all_metric_count), + ("Datastore/operation/Redis/get", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/client_list", 1), +] +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append( + ("Datastore/operation/Redis/client_setinfo", 2), + ) + +_host = instance_hostname(DB_SETTINGS["host"]) +_port = DB_SETTINGS["port"] + +_instance_metric_name = "Datastore/instance/Redis/%s/%s" % (_host, _port) + +instance_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 + +_enable_rollup_metrics = _base_rollup_metrics.append((_instance_metric_name, instance_metric_count)) + +_disable_rollup_metrics = _base_rollup_metrics.append((_instance_metric_name, None)) # Operations + def exercise_redis(client): - client.set('key', 'value') - client.get('key') - client.execute_command('CLIENT', 'LIST', parse='LIST') + client.set("key", "value") + client.get("key") + client.execute_command("CLIENT", "LIST", parse="LIST") + # Tests -@pytest.mark.skipif(REDIS_PY_VERSION < (2, 7), - reason='Client list command introduced in 2.7') + +@pytest.mark.skipif(REDIS_PY_VERSION < (2, 7), reason="Client list command introduced in 2.7") @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( - 'test_custom_conn_pool:test_fake_conn_pool_enable_instance', - scoped_metrics=_enable_scoped_metrics, - rollup_metrics=_enable_rollup_metrics, - background_task=True) + "test_custom_conn_pool:test_fake_conn_pool_enable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_enable_rollup_metrics, + background_task=True, +) @background_task() def test_fake_conn_pool_enable_instance(): - client = redis.StrictRedis(host=DB_SETTINGS['host'], - port=DB_SETTINGS['port'], db=0) + client = redis.StrictRedis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) # Get a real connection - conn = client.connection_pool.get_connection('GET') + conn = client.connection_pool.get_connection("GET") # Replace the original connection pool with one that doesn't # have the `connection_kwargs` attribute. fake_pool = FakeConnectionPool(conn) client.connection_pool = fake_pool - assert not hasattr(client.connection_pool, 'connection_kwargs') + assert not hasattr(client.connection_pool, "connection_kwargs") exercise_redis(client) -@pytest.mark.skipif(REDIS_PY_VERSION < (2, 7), - reason='Client list command introduced in 2.7') + +@pytest.mark.skipif(REDIS_PY_VERSION < (2, 7), reason="Client list command introduced in 2.7") @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( - 'test_custom_conn_pool:test_fake_conn_pool_disable_instance', - scoped_metrics=_disable_scoped_metrics, - rollup_metrics=_disable_rollup_metrics, - background_task=True) + "test_custom_conn_pool:test_fake_conn_pool_disable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_disable_rollup_metrics, + background_task=True, +) @background_task() def test_fake_conn_pool_disable_instance(): - client = redis.StrictRedis(host=DB_SETTINGS['host'], - port=DB_SETTINGS['port'], db=0) + client = redis.StrictRedis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) # Get a real connection - conn = client.connection_pool.get_connection('GET') + conn = client.connection_pool.get_connection("GET") # Replace the original connection pool with one that doesn't # have the `connection_kwargs` attribute. fake_pool = FakeConnectionPool(conn) client.connection_pool = fake_pool - assert not hasattr(client.connection_pool, 'connection_kwargs') + assert not hasattr(client.connection_pool, "connection_kwargs") exercise_redis(client) diff --git a/tests/datastore_redis/test_execute_command.py b/tests/datastore_redis/test_execute_command.py index 747588072..741bc5034 100644 --- a/tests/datastore_redis/test_execute_command.py +++ b/tests/datastore_redis/test_execute_command.py @@ -16,6 +16,7 @@ import redis from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics @@ -23,7 +24,8 @@ from testing_support.util import instance_hostname DB_SETTINGS = redis_settings()[0] -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") + # Settings @@ -36,34 +38,34 @@ # Metrics -_base_scoped_metrics = ( +_base_scoped_metrics = [ ('Datastore/operation/Redis/client_list', 1), -) - -_base_rollup_metrics = ( - ('Datastore/all', 1), - ('Datastore/allOther', 1), - ('Datastore/Redis/all', 1), - ('Datastore/Redis/allOther', 1), +] +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) + +_base_rollup_metrics = [ + ('Datastore/all', 3), + ('Datastore/allOther', 3), + ('Datastore/Redis/all', 3), + ('Datastore/Redis/allOther', 3), ('Datastore/operation/Redis/client_list', 1), -) - -_disable_scoped_metrics = list(_base_scoped_metrics) -_disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) -_enable_rollup_metrics = list(_base_rollup_metrics) +] +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) _host = instance_hostname(DB_SETTINGS['host']) _port = DB_SETTINGS['port'] _instance_metric_name = 'Datastore/instance/Redis/%s/%s' % (_host, _port) -_enable_rollup_metrics.append( - (_instance_metric_name, 1) +instance_metric_count = 3 if REDIS_PY_VERSION >= (5, 0) else 1 + +_enable_rollup_metrics = _base_rollup_metrics.append( + (_instance_metric_name, instance_metric_count) ) -_disable_rollup_metrics.append( +_disable_rollup_metrics = _base_rollup_metrics.append( (_instance_metric_name, None) ) @@ -76,7 +78,7 @@ def exercise_redis_single_arg(client): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_two_args_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -88,7 +90,7 @@ def test_strict_redis_execute_command_two_args_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_two_args_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -100,7 +102,7 @@ def test_strict_redis_execute_command_two_args_disabled(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_two_args_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -112,7 +114,7 @@ def test_redis_execute_command_two_args_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_two_args_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -126,7 +128,7 @@ def test_redis_execute_command_two_args_disabled(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_as_one_arg_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -140,7 +142,7 @@ def test_strict_redis_execute_command_as_one_arg_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_strict_redis_execute_command_as_one_arg_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -154,7 +156,7 @@ def test_strict_redis_execute_command_as_one_arg_disabled(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_as_one_arg_enable', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -168,7 +170,7 @@ def test_redis_execute_command_as_one_arg_enable(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_execute_command:test_redis_execute_command_as_one_arg_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_generators.py b/tests/datastore_redis/test_generators.py new file mode 100644 index 000000000..f747838e1 --- /dev/null +++ b/tests/datastore_redis/test_generators.py @@ -0,0 +1,258 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import redis +from testing_support.db_settings import redis_settings +from testing_support.fixtures import override_application_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.time_trace import current_trace +from newrelic.common.package_version_utils import get_package_version_tuple + +DB_SETTINGS = redis_settings()[0] +REDIS_PY_VERSION = get_package_version_tuple("redis") + +# Settings + +_enable_instance_settings = { + "datastore_tracer.instance_reporting.enabled": True, +} +_disable_instance_settings = { + "datastore_tracer.instance_reporting.enabled": False, +} + +# Metrics + +_base_scoped_metrics = ( + ("Datastore/operation/Redis/scan_iter", 1), + ("Datastore/operation/Redis/sscan_iter", 1), + ("Datastore/operation/Redis/zscan_iter", 1), + ("Datastore/operation/Redis/hscan_iter", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/sadd", 1), + ("Datastore/operation/Redis/zadd", 1), + ("Datastore/operation/Redis/hset", 1), +) + +_base_rollup_metrics = ( + ("Datastore/all", 8), + ("Datastore/allOther", 8), + ("Datastore/Redis/all", 8), + ("Datastore/Redis/allOther", 8), + ("Datastore/operation/Redis/scan_iter", 1), + ("Datastore/operation/Redis/sscan_iter", 1), + ("Datastore/operation/Redis/zscan_iter", 1), + ("Datastore/operation/Redis/hscan_iter", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/sadd", 1), + ("Datastore/operation/Redis/zadd", 1), + ("Datastore/operation/Redis/hset", 1), +) + +_disable_rollup_metrics = list(_base_rollup_metrics) +_enable_rollup_metrics = list(_base_rollup_metrics) + +_host = instance_hostname(DB_SETTINGS["host"]) +_port = DB_SETTINGS["port"] + +_instance_metric_name = "Datastore/instance/Redis/%s/%s" % (_host, _port) + +_enable_rollup_metrics.append((_instance_metric_name, 8)) + +_disable_rollup_metrics.append((_instance_metric_name, None)) + +# Operations + + +def exercise_redis(client): + """ + Exercise client generators by iterating on various methods and ensuring they are + non-empty, and that traces are started and stopped with the generator. + """ + + # Set existing values + client.set("scan-key", "value") + client.sadd("sscan-key", "value") + client.zadd("zscan-key", {"value": 1}) + client.hset("hscan-key", "field", "value") + + # Check generators + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + for k in client.scan_iter("scan-*"): + assert k == b"scan-key" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + for k in client.sscan_iter("sscan-key"): + assert k == b"value" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + for k, _ in client.zscan_iter("zscan-key"): + assert k == b"value" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + for f, v in client.hscan_iter("hscan-key"): + assert f == b"field" + assert v == b"value" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + +async def exercise_redis_async(client): + """ + Exercise client generators by iterating on various methods and ensuring they are + non-empty, and that traces are started and stopped with the generator. + """ + + # Set existing values + await client.set("scan-key", "value") + await client.sadd("sscan-key", "value") + await client.zadd("zscan-key", {"value": 1}) + await client.hset("hscan-key", "field", "value") + + # Check generators + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + async for k in client.scan_iter("scan-*"): + assert k == b"scan-key" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + async for k in client.sscan_iter("sscan-key"): + assert k == b"value" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + async for k, _ in client.zscan_iter("zscan-key"): + assert k == b"value" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + flag = False + assert not isinstance(current_trace(), DatastoreTrace) # Assert no active DatastoreTrace + async for f, v in client.hscan_iter("hscan-key"): + assert f == b"field" + assert v == b"value" + assert isinstance(current_trace(), DatastoreTrace) # Assert DatastoreTrace now active + flag = True + assert flag + + +# Tests + + +@override_application_settings(_enable_instance_settings) +@validate_transaction_metrics( + "test_generators:test_strict_redis_generator_enable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_enable_rollup_metrics, + background_task=True, +) +@background_task() +def test_strict_redis_generator_enable_instance(): + client = redis.StrictRedis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + exercise_redis(client) + + +@override_application_settings(_disable_instance_settings) +@validate_transaction_metrics( + "test_generators:test_strict_redis_generator_disable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_disable_rollup_metrics, + background_task=True, +) +@background_task() +def test_strict_redis_generator_disable_instance(): + client = redis.StrictRedis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + exercise_redis(client) + + +@override_application_settings(_enable_instance_settings) +@validate_transaction_metrics( + "test_generators:test_redis_generator_enable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_enable_rollup_metrics, + background_task=True, +) +@background_task() +def test_redis_generator_enable_instance(): + client = redis.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + exercise_redis(client) + + +@override_application_settings(_disable_instance_settings) +@validate_transaction_metrics( + "test_generators:test_redis_generator_disable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_disable_rollup_metrics, + background_task=True, +) +@background_task() +def test_redis_generator_disable_instance(): + client = redis.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + exercise_redis(client) + + +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="Redis.asyncio was not added until v4.2") +@override_application_settings(_enable_instance_settings) +@validate_transaction_metrics( + "test_generators:test_redis_async_generator_enable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_enable_rollup_metrics, + background_task=True, +) +@background_task() +def test_redis_async_generator_enable_instance(loop): + client = redis.asyncio.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + loop.run_until_complete(exercise_redis_async(client)) + + +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="Redis.asyncio was not added until v4.2") +@override_application_settings(_disable_instance_settings) +@validate_transaction_metrics( + "test_generators:test_redis_async_generator_disable_instance", + scoped_metrics=_base_scoped_metrics, + rollup_metrics=_disable_rollup_metrics, + background_task=True, +) +@background_task() +def test_redis_async_generator_disable_instance(loop): + client = redis.asyncio.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + loop.run_until_complete(exercise_redis_async(client)) diff --git a/tests/datastore_redis/test_get_and_set.py b/tests/datastore_redis/test_get_and_set.py index 0e2df4bb1..720433ae3 100644 --- a/tests/datastore_redis/test_get_and_set.py +++ b/tests/datastore_redis/test_get_and_set.py @@ -48,10 +48,7 @@ ('Datastore/operation/Redis/set', 1), ) -_disable_scoped_metrics = list(_base_scoped_metrics) _disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) _enable_rollup_metrics = list(_base_rollup_metrics) _host = instance_hostname(DB_SETTINGS['host']) @@ -78,7 +75,7 @@ def exercise_redis(client): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_strict_redis_operation_enable_instance', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -90,7 +87,7 @@ def test_strict_redis_operation_enable_instance(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_strict_redis_operation_disable_instance', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() @@ -102,7 +99,7 @@ def test_strict_redis_operation_disable_instance(): @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_redis_operation_enable_instance', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -114,7 +111,7 @@ def test_redis_operation_enable_instance(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_get_and_set:test_redis_operation_disable_instance', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_instance_info.py b/tests/datastore_redis/test_instance_info.py index b3e9a0d5d..211e96169 100644 --- a/tests/datastore_redis/test_instance_info.py +++ b/tests/datastore_redis/test_instance_info.py @@ -15,9 +15,10 @@ import pytest import redis +from newrelic.common.package_version_utils import get_package_version_tuple from newrelic.hooks.datastore_redis import _conn_attrs_to_dict, _instance_info -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") _instance_info_tests = [ ((), {}, ("localhost", "6379", "0")), diff --git a/tests/datastore_redis/test_multiple_dbs.py b/tests/datastore_redis/test_multiple_dbs.py index 15777cc38..9a5e299f0 100644 --- a/tests/datastore_redis/test_multiple_dbs.py +++ b/tests/datastore_redis/test_multiple_dbs.py @@ -16,6 +16,7 @@ import redis from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics @@ -23,6 +24,8 @@ from testing_support.util import instance_hostname DB_MULTIPLE_SETTINGS = redis_settings() +REDIS_PY_VERSION = get_package_version_tuple("redis") + # Settings @@ -35,27 +38,31 @@ # Metrics -_base_scoped_metrics = ( +_base_scoped_metrics = [ ('Datastore/operation/Redis/get', 1), ('Datastore/operation/Redis/set', 1), ('Datastore/operation/Redis/client_list', 1), -) - -_base_rollup_metrics = ( - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Redis/all', 3), - ('Datastore/Redis/allOther', 3), +] +# client_setinfo was introduced in v5.0.0 and assigns info displayed in client_list output +if REDIS_PY_VERSION >= (5, 0): + _base_scoped_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) + +datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 + +_base_rollup_metrics = [ + ('Datastore/all', datastore_all_metric_count), + ('Datastore/allOther', datastore_all_metric_count), + ('Datastore/Redis/all', datastore_all_metric_count), + ('Datastore/Redis/allOther', datastore_all_metric_count), ('Datastore/operation/Redis/get', 1), ('Datastore/operation/Redis/set', 1), ('Datastore/operation/Redis/client_list', 1), -) +] -_disable_scoped_metrics = list(_base_scoped_metrics) -_disable_rollup_metrics = list(_base_rollup_metrics) +# client_setinfo was introduced in v5.0.0 and assigns info displayed in client_list output +if REDIS_PY_VERSION >= (5, 0): + _base_rollup_metrics.append(('Datastore/operation/Redis/client_setinfo', 2),) -_enable_scoped_metrics = list(_base_scoped_metrics) -_enable_rollup_metrics = list(_base_rollup_metrics) if len(DB_MULTIPLE_SETTINGS) > 1: redis_1 = DB_MULTIPLE_SETTINGS[0] @@ -70,16 +77,20 @@ instance_metric_name_1 = 'Datastore/instance/Redis/%s/%s' % (host_1, port_1) instance_metric_name_2 = 'Datastore/instance/Redis/%s/%s' % (host_2, port_2) - _enable_rollup_metrics.extend([ - (instance_metric_name_1, 2), - (instance_metric_name_2, 1), + instance_metric_name_1_count = 2 if REDIS_PY_VERSION >= (5, 0) else 2 + instance_metric_name_2_count = 3 if REDIS_PY_VERSION >= (5, 0) else 1 + + _enable_rollup_metrics = _base_rollup_metrics.extend([ + (instance_metric_name_1, instance_metric_name_1_count), + (instance_metric_name_2, instance_metric_name_2_count), ]) - _disable_rollup_metrics.extend([ + _disable_rollup_metrics = _base_rollup_metrics.extend([ (instance_metric_name_1, None), (instance_metric_name_2, None), ]) + def exercise_redis(client_1, client_2): client_1.set('key', 'value') client_1.get('key') @@ -90,7 +101,7 @@ def exercise_redis(client_1, client_2): reason='Test environment not configured with multiple databases.') @override_application_settings(_enable_instance_settings) @validate_transaction_metrics('test_multiple_dbs:test_multiple_datastores_enabled', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -106,7 +117,7 @@ def test_multiple_datastores_enabled(): reason='Test environment not configured with multiple databases.') @override_application_settings(_disable_instance_settings) @validate_transaction_metrics('test_multiple_dbs:test_multiple_datastores_disabled', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_rb.py b/tests/datastore_redis/test_rb.py index 5678c2787..3b25593be 100644 --- a/tests/datastore_redis/test_rb.py +++ b/tests/datastore_redis/test_rb.py @@ -23,6 +23,7 @@ import six from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version_tuple from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics @@ -30,7 +31,7 @@ from testing_support.util import instance_hostname DB_SETTINGS = redis_settings()[0] -REDIS_PY_VERSION = redis.VERSION +REDIS_PY_VERSION = get_package_version_tuple("redis") # Settings @@ -61,10 +62,7 @@ ('Datastore/operation/Redis/set', 1), ) -_disable_scoped_metrics = list(_base_scoped_metrics) _disable_rollup_metrics = list(_base_rollup_metrics) - -_enable_scoped_metrics = list(_base_scoped_metrics) _enable_rollup_metrics = list(_base_rollup_metrics) _host = instance_hostname(DB_SETTINGS['host']) @@ -80,25 +78,26 @@ (_instance_metric_name, None) ) -# Operations +# Operations def exercise_redis(routing_client): routing_client.set('key', 'value') routing_client.get('key') + def exercise_fanout(cluster): with cluster.fanout(hosts='all') as client: client.execute_command('CLIENT', 'LIST') -# Tests +# Tests @pytest.mark.skipif(six.PY3, reason='Redis Blaster is Python 2 only.') @pytest.mark.skipif(REDIS_PY_VERSION < (2, 10, 2), reason='Redis Blaster requires redis>=2.10.2') @override_application_settings(_enable_instance_settings) @validate_transaction_metrics( 'test_rb:test_redis_blaster_operation_enable_instance', - scoped_metrics=_enable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_enable_rollup_metrics, background_task=True) @background_task() @@ -121,7 +120,7 @@ def test_redis_blaster_operation_enable_instance(): @override_application_settings(_disable_instance_settings) @validate_transaction_metrics( 'test_rb:test_redis_blaster_operation_disable_instance', - scoped_metrics=_disable_scoped_metrics, + scoped_metrics=_base_scoped_metrics, rollup_metrics=_disable_rollup_metrics, background_task=True) @background_task() diff --git a/tests/datastore_redis/test_uninstrumented_methods.py b/tests/datastore_redis/test_uninstrumented_methods.py index ccf5a096d..c0be684b2 100644 --- a/tests/datastore_redis/test_uninstrumented_methods.py +++ b/tests/datastore_redis/test_uninstrumented_methods.py @@ -39,6 +39,7 @@ "append_no_scale", "append_values_and_weights", "append_weights", + "auto_close_connection_pool", "batch_indexer", "BatchIndexer", "bulk", @@ -55,6 +56,7 @@ "edges", "execute_command", "flush", + "from_pool", "from_url", "get_connection_kwargs", "get_encoder", @@ -63,7 +65,6 @@ "get_property", "get_relation", "get_retry", - "hscan_iter", "index_name", "labels", "list_keys", diff --git a/tests/datastore_rediscluster/conftest.py b/tests/datastore_rediscluster/conftest.py new file mode 100644 index 000000000..fe53f1fe2 --- /dev/null +++ b/tests/datastore_rediscluster/conftest.py @@ -0,0 +1,32 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (datastore_redis)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) diff --git a/tests/datastore_rediscluster/test_uninstrumented_rediscluster_methods.py b/tests/datastore_rediscluster/test_uninstrumented_rediscluster_methods.py new file mode 100644 index 000000000..ae211aa31 --- /dev/null +++ b/tests/datastore_rediscluster/test_uninstrumented_rediscluster_methods.py @@ -0,0 +1,168 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import redis +from testing_support.db_settings import redis_cluster_settings + +DB_CLUSTER_SETTINGS = redis_cluster_settings()[0] + +# Set socket_timeout to 5s for fast fail, otherwise the default is to wait forever. +client = redis.RedisCluster(host=DB_CLUSTER_SETTINGS["host"], port=DB_CLUSTER_SETTINGS["port"], socket_timeout=5) + +IGNORED_METHODS = { + "MODULE_CALLBACKS", + "MODULE_VERSION", + "NAME", + "add_edge", + "add_node", + "append_bucket_size", + "append_capacity", + "append_error", + "append_expansion", + "append_items_and_increments", + "append_items", + "append_max_iterations", + "append_no_create", + "append_no_scale", + "append_values_and_weights", + "append_weights", + "batch_indexer", + "BatchIndexer", + "bulk", + "call_procedure", + "client_tracking_off", + "client_tracking_on", + "client", + "close", + "commandmixin", + "connection_pool", + "connection", + "debug_segfault", + "edges", + "execute_command", + "flush", + "from_url", + "get_connection_kwargs", + "get_encoder", + "get_label", + "get_params_args", + "get_property", + "get_relation", + "get_retry", + "hscan_iter", + "index_name", + "labels", + "list_keys", + "load_document", + "load_external_module", + "lock", + "name", + "nodes", + "parse_response", + "pipeline", + "property_keys", + "register_script", + "relationship_types", + "response_callbacks", + "RESPONSE_CALLBACKS", + "sentinel", + "set_file", + "set_path", + "set_response_callback", + "set_retry", + "transaction", + "version", + "ALL_NODES", + "CLUSTER_COMMANDS_RESPONSE_CALLBACKS", + "COMMAND_FLAGS", + "DEFAULT_NODE", + "ERRORS_ALLOW_RETRY", + "NODE_FLAGS", + "PRIMARIES", + "RANDOM", + "REPLICAS", + "RESULT_CALLBACKS", + "RedisClusterRequestTTL", + "SEARCH_COMMANDS", + "client_no_touch", + "cluster_addslotsrange", + "cluster_bumpepoch", + "cluster_delslotsrange", + "cluster_error_retry_attempts", + "cluster_flushslots", + "cluster_links", + "cluster_myid", + "cluster_myshardid", + "cluster_replicas", + "cluster_response_callbacks", + "cluster_setslot_stable", + "cluster_shards", + "command_flags", + "commands_parser", + "determine_slot", + "disconnect_connection_pools", + "encoder", + "get_default_node", + "get_node", + "get_node_from_key", + "get_nodes", + "get_primaries", + "get_random_node", + "get_redis_connection", + "get_replicas", + "keyslot", + "mget_nonatomic", + "monitor", + "mset_nonatomic", + "node_flags", + "nodes_manager", + "on_connect", + "pubsub", + "read_from_replicas", + "reinitialize_counter", + "reinitialize_steps", + "replace_default_node", + "result_callbacks", + "set_default_node", + "user_on_connect_func", +} + +REDIS_MODULES = { + "bf", + "cf", + "cms", + "ft", + "graph", + "json", + "tdigest", + "topk", + "ts", +} + +IGNORED_METHODS |= REDIS_MODULES + + +def test_uninstrumented_methods(): + methods = {m for m in dir(client) if not m[0] == "_"} + is_wrapped = lambda m: hasattr(getattr(client, m), "__wrapped__") + uninstrumented = {m for m in methods - IGNORED_METHODS if not is_wrapped(m)} + + for module in REDIS_MODULES: + if hasattr(client, module): + module_client = getattr(client, module)() + module_methods = {m for m in dir(module_client) if not m[0] == "_"} + is_wrapped = lambda m: hasattr(getattr(module_client, m), "__wrapped__") + uninstrumented |= {m for m in module_methods - IGNORED_METHODS if not is_wrapped(m)} + + assert not uninstrumented, "Uninstrumented methods: %s" % sorted(uninstrumented) diff --git a/tests/datastore_solrpy/test_solr.py b/tests/datastore_solrpy/test_solr.py index ee1a7e91e..56dcce62b 100644 --- a/tests/datastore_solrpy/test_solr.py +++ b/tests/datastore_solrpy/test_solr.py @@ -13,16 +13,19 @@ # limitations under the License. from solr import SolrConnection - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import solr_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task DB_SETTINGS = solr_settings()[0] SOLR_HOST = DB_SETTINGS["host"] SOLR_PORT = DB_SETTINGS["port"] -SOLR_URL = 'http://%s:%s/solr/collection' % (DB_SETTINGS["host"], DB_SETTINGS["port"]) +SOLR_URL = "http://%s:%s/solr/collection" % (DB_SETTINGS["host"], DB_SETTINGS["port"]) + def _exercise_solr(solr): # Construct document names within namespace @@ -31,30 +34,37 @@ def _exercise_solr(solr): solr.add_many([{"id": x} for x in documents]) solr.commit() - solr.query('id:%s' % documents[0]).results - solr.delete('id:*_%s' % DB_SETTINGS["namespace"]) + solr.query("id:%s" % documents[0]).results + solr.delete("id:*_%s" % DB_SETTINGS["namespace"]) solr.commit() + _test_solr_search_scoped_metrics = [ - ('Datastore/operation/Solr/add_many', 1), - ('Datastore/operation/Solr/delete', 1), - ('Datastore/operation/Solr/commit', 2), - ('Datastore/operation/Solr/query', 1)] + ("Datastore/operation/Solr/add_many", 1), + ("Datastore/operation/Solr/delete", 1), + ("Datastore/operation/Solr/commit", 2), + ("Datastore/operation/Solr/query", 1), +] _test_solr_search_rollup_metrics = [ - ('Datastore/all', 5), - ('Datastore/allOther', 5), - ('Datastore/Solr/all', 5), - ('Datastore/Solr/allOther', 5), - ('Datastore/operation/Solr/add_many', 1), - ('Datastore/operation/Solr/query', 1), - ('Datastore/operation/Solr/commit', 2), - ('Datastore/operation/Solr/delete', 1)] - -@validate_transaction_metrics('test_solr:test_solr_search', + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/Solr/all", 5), + ("Datastore/Solr/allOther", 5), + ("Datastore/instance/Solr/%s/%s" % (instance_hostname(SOLR_HOST), SOLR_PORT), 3), + ("Datastore/operation/Solr/add_many", 1), + ("Datastore/operation/Solr/query", 1), + ("Datastore/operation/Solr/commit", 2), + ("Datastore/operation/Solr/delete", 1), +] + + +@validate_transaction_metrics( + "test_solr:test_solr_search", scoped_metrics=_test_solr_search_scoped_metrics, rollup_metrics=_test_solr_search_rollup_metrics, - background_task=True) + background_task=True, +) @background_task() def test_solr_search(): s = SolrConnection(SOLR_URL) diff --git a/tests/external_boto3/test_boto3_iam.py b/tests/external_boto3/test_boto3_iam.py deleted file mode 100644 index ac49214f4..000000000 --- a/tests/external_boto3/test_boto3_iam.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import uuid - -import boto3 -import moto - -from newrelic.api.background_task import background_task -from testing_support.fixtures import ( - validate_tt_segment_params, override_application_settings) -from testing_support.validators.validate_span_events import ( - validate_span_events) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics - -MOTO_VERSION = tuple(int(v) for v in moto.__version__.split('.')[:3]) - -# patch earlier versions of moto to support py37 -if sys.version_info >= (3, 7) and MOTO_VERSION <= (1, 3, 1): - import re - moto.packages.responses.responses.re._pattern_type = re.Pattern - -AWS_ACCESS_KEY_ID = 'AAAAAAAAAAAACCESSKEY' -AWS_SECRET_ACCESS_KEY = 'AAAAAASECRETKEY' - -TEST_USER = 'python-agent-test-%s' % uuid.uuid4() - -_iam_scoped_metrics = [ - ('External/iam.amazonaws.com/botocore/POST', 3), -] - -_iam_rollup_metrics = [ - ('External/all', 3), - ('External/allOther', 3), - ('External/iam.amazonaws.com/all', 3), - ('External/iam.amazonaws.com/botocore/POST', 3), -] - - -@override_application_settings({'distributed_tracing.enabled': True}) -@validate_span_events( - exact_agents={'http.url': 'https://iam.amazonaws.com/'}, count=3) -@validate_span_events(expected_agents=('aws.requestId',), count=3) -@validate_span_events(exact_agents={'aws.operation': 'CreateUser'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'GetUser'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'DeleteUser'}, count=1) -@validate_tt_segment_params(present_params=('aws.requestId',)) -@validate_transaction_metrics( - 'test_boto3_iam:test_iam', - scoped_metrics=_iam_scoped_metrics, - rollup_metrics=_iam_rollup_metrics, - background_task=True) -@background_task() -@moto.mock_iam -def test_iam(): - iam = boto3.client( - 'iam', - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - ) - - # Create user - resp = iam.create_user(UserName=TEST_USER) - assert resp['ResponseMetadata']['HTTPStatusCode'] == 200 - - # Get the user - resp = iam.get_user(UserName=TEST_USER) - assert resp['ResponseMetadata']['HTTPStatusCode'] == 200 - assert resp['User']['UserName'] == TEST_USER - - # Delete the user - resp = iam.delete_user(UserName=TEST_USER) - assert resp['ResponseMetadata']['HTTPStatusCode'] == 200 diff --git a/tests/external_boto3/test_boto3_sns.py b/tests/external_boto3/test_boto3_sns.py deleted file mode 100644 index 3718d5292..000000000 --- a/tests/external_boto3/test_boto3_sns.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import boto3 -import moto -import pytest - -from newrelic.api.background_task import background_task -from testing_support.fixtures import ( - validate_tt_segment_params, override_application_settings) -from testing_support.validators.validate_span_events import validate_span_events -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics - -MOTO_VERSION = tuple(int(v) for v in moto.__version__.split('.')[:3]) - -# patch earlier versions of moto to support py37 -if sys.version_info >= (3, 7) and MOTO_VERSION <= (1, 3, 1): - import re - moto.packages.responses.responses.re._pattern_type = re.Pattern - -AWS_ACCESS_KEY_ID = 'AAAAAAAAAAAACCESSKEY' -AWS_SECRET_ACCESS_KEY = 'AAAAAASECRETKEY' -AWS_REGION_NAME = 'us-east-1' -SNS_URL = 'sns-us-east-1.amazonaws.com' -TOPIC = 'arn:aws:sns:us-east-1:123456789012:some-topic' -sns_metrics = [ - ('MessageBroker/SNS/Topic' - '/Produce/Named/%s' % TOPIC, 1)] -sns_metrics_phone = [ - ('MessageBroker/SNS/Topic' - '/Produce/Named/PhoneNumber', 1)] - - -@override_application_settings({'distributed_tracing.enabled': True}) -@validate_span_events(expected_agents=('aws.requestId',), count=2) -@validate_span_events(exact_agents={'aws.operation': 'CreateTopic'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'Publish'}, count=1) -@validate_tt_segment_params(present_params=('aws.requestId',)) -@pytest.mark.parametrize('topic_argument', ('TopicArn', 'TargetArn')) -@validate_transaction_metrics('test_boto3_sns:test_publish_to_sns_topic', - scoped_metrics=sns_metrics, rollup_metrics=sns_metrics, - background_task=True) -@background_task() -@moto.mock_sns -def test_publish_to_sns_topic(topic_argument): - conn = boto3.client('sns', - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION_NAME) - - topic_arn = conn.create_topic(Name='some-topic')['TopicArn'] - - kwargs = {topic_argument: topic_arn} - published_message = conn.publish(Message='my msg', **kwargs) - assert 'MessageId' in published_message - - -@override_application_settings({'distributed_tracing.enabled': True}) -@validate_span_events(expected_agents=('aws.requestId',), count=3) -@validate_span_events(exact_agents={'aws.operation': 'CreateTopic'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'Subscribe'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'Publish'}, count=1) -@validate_tt_segment_params(present_params=('aws.requestId',)) -@validate_transaction_metrics('test_boto3_sns:test_publish_to_sns_phone', - scoped_metrics=sns_metrics_phone, rollup_metrics=sns_metrics_phone, - background_task=True) -@background_task() -@moto.mock_sns -def test_publish_to_sns_phone(): - conn = boto3.client('sns', - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION_NAME) - - topic_arn = conn.create_topic(Name='some-topic')['TopicArn'] - conn.subscribe(TopicArn=topic_arn, Protocol='sms', Endpoint='5555555555') - - published_message = conn.publish( - PhoneNumber='5555555555', Message='my msg') - assert 'MessageId' in published_message diff --git a/tests/external_botocore/conftest.py b/tests/external_botocore/conftest.py index e5cf15533..fb703c85e 100644 --- a/tests/external_botocore/conftest.py +++ b/tests/external_botocore/conftest.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { - 'transaction_tracer.explain_threshold': 0.0, - 'transaction_tracer.transaction_threshold': 0.0, - 'transaction_tracer.stack_trace_threshold': 0.0, - 'debug.log_data_collector_payloads': True, - 'debug.record_transaction_failure': True, + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, } collector_agent_registration = collector_agent_registration_fixture( - app_name='Python Agent Test (external_botocore)', - default_settings=_default_settings) + app_name="Python Agent Test (external_botocore)", default_settings=_default_settings +) diff --git a/tests/external_botocore/test_boto3_iam.py b/tests/external_botocore/test_boto3_iam.py new file mode 100644 index 000000000..3d672f375 --- /dev/null +++ b/tests/external_botocore/test_boto3_iam.py @@ -0,0 +1,89 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import uuid + +import boto3 +import moto +from testing_support.fixtures import dt_enabled +from testing_support.validators.validate_span_events import validate_span_events +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) + +from newrelic.api.background_task import background_task + +MOTO_VERSION = tuple(int(v) for v in moto.__version__.split(".")[:3]) + +# patch earlier versions of moto to support py37 +if sys.version_info >= (3, 7) and MOTO_VERSION <= (1, 3, 1): + import re + + moto.packages.responses.responses.re._pattern_type = re.Pattern + +AWS_ACCESS_KEY_ID = "AAAAAAAAAAAACCESSKEY" +AWS_SECRET_ACCESS_KEY = "AAAAAASECRETKEY" # nosec (This is fine for testing purposes) + +TEST_USER = "python-agent-test-%s" % uuid.uuid4() + +_iam_scoped_metrics = [ + ("External/iam.amazonaws.com/botocore/POST", 3), +] + +_iam_rollup_metrics = [ + ("External/all", 3), + ("External/allOther", 3), + ("External/iam.amazonaws.com/all", 3), + ("External/iam.amazonaws.com/botocore/POST", 3), +] + + +@dt_enabled +@validate_span_events(exact_agents={"http.url": "https://iam.amazonaws.com/"}, count=3) +@validate_span_events(expected_agents=("aws.requestId",), count=3) +@validate_span_events(exact_agents={"aws.operation": "CreateUser"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "GetUser"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "DeleteUser"}, count=1) +@validate_tt_segment_params(present_params=("aws.requestId",)) +@validate_transaction_metrics( + "test_boto3_iam:test_iam", + scoped_metrics=_iam_scoped_metrics, + rollup_metrics=_iam_rollup_metrics, + background_task=True, +) +@background_task() +@moto.mock_iam +def test_iam(): + iam = boto3.client( + "iam", + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) + + # Create user + resp = iam.create_user(UserName=TEST_USER) + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + # Get the user + resp = iam.get_user(UserName=TEST_USER) + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert resp["User"]["UserName"] == TEST_USER + + # Delete the user + resp = iam.delete_user(UserName=TEST_USER) + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 diff --git a/tests/external_boto3/test_boto3_s3.py b/tests/external_botocore/test_boto3_s3.py similarity index 97% rename from tests/external_boto3/test_boto3_s3.py rename to tests/external_botocore/test_boto3_s3.py index a7ecf034a..b6299d9f6 100644 --- a/tests/external_boto3/test_boto3_s3.py +++ b/tests/external_botocore/test_boto3_s3.py @@ -18,7 +18,7 @@ import boto3 import botocore import moto -from testing_support.fixtures import override_application_settings +from testing_support.fixtures import dt_enabled from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_metrics import ( validate_transaction_metrics, @@ -73,7 +73,7 @@ ] -@override_application_settings({"distributed_tracing.enabled": True}) +@dt_enabled @validate_span_events(exact_agents={"aws.operation": "CreateBucket"}, count=1) @validate_span_events(exact_agents={"aws.operation": "PutObject"}, count=1) @validate_span_events(exact_agents={"aws.operation": "ListObjects"}, count=1) diff --git a/tests/external_botocore/test_boto3_sns.py b/tests/external_botocore/test_boto3_sns.py new file mode 100644 index 000000000..5e6c7c4b4 --- /dev/null +++ b/tests/external_botocore/test_boto3_sns.py @@ -0,0 +1,103 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import boto3 +import moto +import pytest +from testing_support.fixtures import dt_enabled +from testing_support.validators.validate_span_events import validate_span_events +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) + +from newrelic.api.background_task import background_task + +MOTO_VERSION = tuple(int(v) for v in moto.__version__.split(".")[:3]) + +# patch earlier versions of moto to support py37 +if sys.version_info >= (3, 7) and MOTO_VERSION <= (1, 3, 1): + import re + + moto.packages.responses.responses.re._pattern_type = re.Pattern + +AWS_ACCESS_KEY_ID = "AAAAAAAAAAAACCESSKEY" +AWS_SECRET_ACCESS_KEY = "AAAAAASECRETKEY" # nosec (This is fine for testing purposes) +AWS_REGION_NAME = "us-east-1" +SNS_URL = "sns-us-east-1.amazonaws.com" +TOPIC = "arn:aws:sns:us-east-1:123456789012:some-topic" +sns_metrics = [("MessageBroker/SNS/Topic" "/Produce/Named/%s" % TOPIC, 1)] +sns_metrics_phone = [("MessageBroker/SNS/Topic" "/Produce/Named/PhoneNumber", 1)] + + +@dt_enabled +@validate_span_events(expected_agents=("aws.requestId",), count=2) +@validate_span_events(exact_agents={"aws.operation": "CreateTopic"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "Publish"}, count=1) +@validate_tt_segment_params(present_params=("aws.requestId",)) +@pytest.mark.parametrize("topic_argument", ("TopicArn", "TargetArn")) +@validate_transaction_metrics( + "test_boto3_sns:test_publish_to_sns_topic", + scoped_metrics=sns_metrics, + rollup_metrics=sns_metrics, + background_task=True, +) +@background_task() +@moto.mock_sns +def test_publish_to_sns_topic(topic_argument): + conn = boto3.client( + "sns", + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + region_name=AWS_REGION_NAME, + ) + + topic_arn = conn.create_topic(Name="some-topic")["TopicArn"] + + kwargs = {topic_argument: topic_arn} + published_message = conn.publish(Message="my msg", **kwargs) + assert "MessageId" in published_message + + +@dt_enabled +@validate_span_events(expected_agents=("aws.requestId",), count=3) +@validate_span_events(exact_agents={"aws.operation": "CreateTopic"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "Subscribe"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "Publish"}, count=1) +@validate_tt_segment_params(present_params=("aws.requestId",)) +@validate_transaction_metrics( + "test_boto3_sns:test_publish_to_sns_phone", + scoped_metrics=sns_metrics_phone, + rollup_metrics=sns_metrics_phone, + background_task=True, +) +@background_task() +@moto.mock_sns +def test_publish_to_sns_phone(): + conn = boto3.client( + "sns", + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + region_name=AWS_REGION_NAME, + ) + + topic_arn = conn.create_topic(Name="some-topic")["TopicArn"] + conn.subscribe(TopicArn=topic_arn, Protocol="sms", Endpoint="5555555555") + + published_message = conn.publish(PhoneNumber="5555555555", Message="my msg") + assert "MessageId" in published_message diff --git a/tests/external_botocore/test_botocore_dynamodb.py b/tests/external_botocore/test_botocore_dynamodb.py index 44862d827..932fb1743 100644 --- a/tests/external_botocore/test_botocore_dynamodb.py +++ b/tests/external_botocore/test_botocore_dynamodb.py @@ -17,91 +17,96 @@ import botocore.session import moto +from testing_support.fixtures import dt_enabled +from testing_support.validators.validate_span_events import validate_span_events +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) from newrelic.api.background_task import background_task -from testing_support.fixtures import ( - validate_tt_segment_params, override_application_settings) -from testing_support.validators.validate_span_events import ( - validate_span_events) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -MOTO_VERSION = tuple(int(v) for v in moto.__version__.split('.')[:3]) +MOTO_VERSION = tuple(int(v) for v in moto.__version__.split(".")[:3]) # patch earlier versions of moto to support py37 if sys.version_info >= (3, 7) and MOTO_VERSION <= (1, 3, 1): import re + moto.packages.responses.responses.re._pattern_type = re.Pattern -AWS_ACCESS_KEY_ID = 'AAAAAAAAAAAACCESSKEY' -AWS_SECRET_ACCESS_KEY = 'AAAAAASECRETKEY' -AWS_REGION = 'us-east-1' +AWS_ACCESS_KEY_ID = "AAAAAAAAAAAACCESSKEY" +AWS_SECRET_ACCESS_KEY = "AAAAAASECRETKEY" # nosec (This is fine for testing purposes) +AWS_REGION = "us-east-1" -TEST_TABLE = 'python-agent-test-%s' % uuid.uuid4() +TEST_TABLE = "python-agent-test-%s" % uuid.uuid4() _dynamodb_scoped_metrics = [ - ('Datastore/statement/DynamoDB/%s/create_table' % TEST_TABLE, 1), - ('Datastore/statement/DynamoDB/%s/put_item' % TEST_TABLE, 1), - ('Datastore/statement/DynamoDB/%s/get_item' % TEST_TABLE, 1), - ('Datastore/statement/DynamoDB/%s/update_item' % TEST_TABLE, 1), - ('Datastore/statement/DynamoDB/%s/query' % TEST_TABLE, 1), - ('Datastore/statement/DynamoDB/%s/scan' % TEST_TABLE, 1), - ('Datastore/statement/DynamoDB/%s/delete_item' % TEST_TABLE, 1), - ('Datastore/statement/DynamoDB/%s/delete_table' % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/create_table" % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/put_item" % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/get_item" % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/update_item" % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/query" % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/scan" % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/delete_item" % TEST_TABLE, 1), + ("Datastore/statement/DynamoDB/%s/delete_table" % TEST_TABLE, 1), ] _dynamodb_rollup_metrics = [ - ('Datastore/all', 8), - ('Datastore/allOther', 8), - ('Datastore/DynamoDB/all', 8), - ('Datastore/DynamoDB/allOther', 8), + ("Datastore/all", 8), + ("Datastore/allOther", 8), + ("Datastore/DynamoDB/all", 8), + ("Datastore/DynamoDB/allOther", 8), ] -@override_application_settings({'distributed_tracing.enabled': True}) -@validate_span_events(expected_agents=('aws.requestId',), count=8) -@validate_span_events(exact_agents={'aws.operation': 'PutItem'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'GetItem'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'DeleteItem'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'CreateTable'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'DeleteTable'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'Query'}, count=1) -@validate_span_events(exact_agents={'aws.operation': 'Scan'}, count=1) -@validate_tt_segment_params(present_params=('aws.requestId',)) +@dt_enabled +@validate_span_events(expected_agents=("aws.requestId",), count=8) +@validate_span_events(exact_agents={"aws.operation": "PutItem"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "GetItem"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "DeleteItem"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "CreateTable"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "DeleteTable"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "Query"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "Scan"}, count=1) +@validate_tt_segment_params(present_params=("aws.requestId",)) @validate_transaction_metrics( - 'test_botocore_dynamodb:test_dynamodb', - scoped_metrics=_dynamodb_scoped_metrics, - rollup_metrics=_dynamodb_rollup_metrics, - background_task=True) + "test_botocore_dynamodb:test_dynamodb", + scoped_metrics=_dynamodb_scoped_metrics, + rollup_metrics=_dynamodb_rollup_metrics, + background_task=True, +) @background_task() -@moto.mock_dynamodb2 +@moto.mock_dynamodb def test_dynamodb(): session = botocore.session.get_session() client = session.create_client( - 'dynamodb', - region_name=AWS_REGION, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY + "dynamodb", + region_name=AWS_REGION, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, ) # Create table resp = client.create_table( - TableName=TEST_TABLE, - AttributeDefinitions=[ - {'AttributeName': 'Id', 'AttributeType': 'N'}, - {'AttributeName': 'Foo', 'AttributeType': 'S'}, - ], - KeySchema=[ - {'AttributeName': 'Id', 'KeyType': 'HASH'}, - {'AttributeName': 'Foo', 'KeyType': 'RANGE'}, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5, - }, + TableName=TEST_TABLE, + AttributeDefinitions=[ + {"AttributeName": "Id", "AttributeType": "N"}, + {"AttributeName": "Foo", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "Id", "KeyType": "HASH"}, + {"AttributeName": "Foo", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={ + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, + }, ) - assert resp['TableDescription']['TableName'] == TEST_TABLE + assert resp["TableDescription"]["TableName"] == TEST_TABLE # moto response is ACTIVE, AWS response is CREATING # assert resp['TableDescription']['TableStatus'] == 'ACTIVE' @@ -111,73 +116,70 @@ def test_dynamodb(): # Put item resp = client.put_item( - TableName=TEST_TABLE, - Item={ - 'Id': {'N': '101'}, - 'Foo': {'S': 'hello_world'}, - 'SomeValue': {'S': 'some_random_attribute'}, - } + TableName=TEST_TABLE, + Item={ + "Id": {"N": "101"}, + "Foo": {"S": "hello_world"}, + "SomeValue": {"S": "some_random_attribute"}, + }, ) # No checking response, due to inconsistent return values. # moto returns resp['Attributes']. AWS returns resp['ResponseMetadata'] # Get item resp = client.get_item( - TableName=TEST_TABLE, - Key={ - 'Id': {'N': '101'}, - 'Foo': {'S': 'hello_world'}, - 'SomeValue': {'S': 'some_random_attribute'}, - } + TableName=TEST_TABLE, + Key={ + "Id": {"N": "101"}, + "Foo": {"S": "hello_world"}, + "SomeValue": {"S": "some_random_attribute"}, + }, ) - assert resp['Item']['SomeValue']['S'] == 'some_random_attribute' + assert resp["Item"]["SomeValue"]["S"] == "some_random_attribute" # Update item resp = client.update_item( - TableName=TEST_TABLE, - Key={ - 'Id': {'N': '101'}, - 'Foo': {'S': 'hello_world'}, - 'SomeValue': {'S': 'some_random_attribute'}, - }, - AttributeUpdates={ - 'Foo2': { - 'Value': {'S': 'hello_world2'}, - 'Action': 'PUT' - }, - }, - ReturnValues='ALL_NEW', + TableName=TEST_TABLE, + Key={ + "Id": {"N": "101"}, + "Foo": {"S": "hello_world"}, + "SomeValue": {"S": "some_random_attribute"}, + }, + AttributeUpdates={ + "Foo2": {"Value": {"S": "hello_world2"}, "Action": "PUT"}, + }, + ReturnValues="ALL_NEW", ) - assert resp['Attributes']['Foo2'] + assert resp["Attributes"]["Foo2"] # Query for item resp = client.query( - TableName=TEST_TABLE, - Select='ALL_ATTRIBUTES', - KeyConditionExpression='#Id = :v_id', - ExpressionAttributeNames={'#Id': 'Id'}, - ExpressionAttributeValues={':v_id': {'N': '101'}}, + TableName=TEST_TABLE, + Select="ALL_ATTRIBUTES", + KeyConditionExpression="#Id = :v_id", + ExpressionAttributeNames={"#Id": "Id"}, + ExpressionAttributeValues={":v_id": {"N": "101"}}, ) - assert len(resp['Items']) == 1 - assert resp['Items'][0]['SomeValue']['S'] == 'some_random_attribute' + assert len(resp["Items"]) == 1 + assert resp["Items"][0]["SomeValue"]["S"] == "some_random_attribute" # Scan resp = client.scan(TableName=TEST_TABLE) - assert len(resp['Items']) == 1 + assert len(resp["Items"]) == 1 # Delete item resp = client.delete_item( - TableName=TEST_TABLE, - Key={ - 'Id': {'N': '101'}, - 'Foo': {'S': 'hello_world'}, - }, + TableName=TEST_TABLE, + Key={ + "Id": {"N": "101"}, + "Foo": {"S": "hello_world"}, + }, ) # No checking response, due to inconsistent return values. # moto returns resp['Attributes']. AWS returns resp['ResponseMetadata'] # Delete table resp = client.delete_table(TableName=TEST_TABLE) - assert resp['TableDescription']['TableName'] == TEST_TABLE + assert resp["TableDescription"]["TableName"] == TEST_TABLE # moto response is ACTIVE, AWS response is DELETING # assert resp['TableDescription']['TableStatus'] == 'DELETING' diff --git a/tests/external_botocore/test_botocore_ec2.py b/tests/external_botocore/test_botocore_ec2.py index 0cfd09b6f..3cb83e318 100644 --- a/tests/external_botocore/test_botocore_ec2.py +++ b/tests/external_botocore/test_botocore_ec2.py @@ -17,81 +17,81 @@ import botocore.session import moto +from testing_support.fixtures import dt_enabled +from testing_support.validators.validate_span_events import validate_span_events +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) from newrelic.api.background_task import background_task -from testing_support.fixtures import ( - validate_tt_segment_params, override_application_settings) -from testing_support.validators.validate_span_events import ( - validate_span_events) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -MOTO_VERSION = tuple(int(v) for v in moto.__version__.split('.')[:3]) +MOTO_VERSION = tuple(int(v) for v in moto.__version__.split(".")[:3]) # patch earlier versions of moto to support py37 if sys.version_info >= (3, 7) and MOTO_VERSION <= (1, 3, 1): import re + moto.packages.responses.responses.re._pattern_type = re.Pattern -AWS_ACCESS_KEY_ID = 'AAAAAAAAAAAACCESSKEY' -AWS_SECRET_ACCESS_KEY = 'AAAAAASECRETKEY' -AWS_REGION = 'us-east-1' -UBUNTU_14_04_PARAVIRTUAL_AMI = 'ami-c65be9ae' +AWS_ACCESS_KEY_ID = "AAAAAAAAAAAACCESSKEY" +AWS_SECRET_ACCESS_KEY = "AAAAAASECRETKEY" # nosec (This is fine for testing purposes) +AWS_REGION = "us-east-1" +UBUNTU_14_04_PARAVIRTUAL_AMI = "ami-c65be9ae" -TEST_INSTANCE = 'python-agent-test-%s' % uuid.uuid4() +TEST_INSTANCE = "python-agent-test-%s" % uuid.uuid4() _ec2_scoped_metrics = [ - ('External/ec2.us-east-1.amazonaws.com/botocore/POST', 3), + ("External/ec2.us-east-1.amazonaws.com/botocore/POST", 3), ] _ec2_rollup_metrics = [ - ('External/all', 3), - ('External/allOther', 3), - ('External/ec2.us-east-1.amazonaws.com/all', 3), - ('External/ec2.us-east-1.amazonaws.com/botocore/POST', 3), + ("External/all", 3), + ("External/allOther", 3), + ("External/ec2.us-east-1.amazonaws.com/all", 3), + ("External/ec2.us-east-1.amazonaws.com/botocore/POST", 3), ] -@override_application_settings({'distributed_tracing.enabled': True}) -@validate_span_events(expected_agents=('aws.requestId',), count=3) -@validate_span_events(exact_agents={'aws.operation': 'RunInstances'}, count=1) -@validate_span_events( - exact_agents={'aws.operation': 'DescribeInstances'}, count=1) -@validate_span_events( - exact_agents={'aws.operation': 'TerminateInstances'}, count=1) -@validate_tt_segment_params(present_params=('aws.requestId',)) +@dt_enabled +@validate_span_events(expected_agents=("aws.requestId",), count=3) +@validate_span_events(exact_agents={"aws.operation": "RunInstances"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "DescribeInstances"}, count=1) +@validate_span_events(exact_agents={"aws.operation": "TerminateInstances"}, count=1) +@validate_tt_segment_params(present_params=("aws.requestId",)) @validate_transaction_metrics( - 'test_botocore_ec2:test_ec2', - scoped_metrics=_ec2_scoped_metrics, - rollup_metrics=_ec2_rollup_metrics, - background_task=True) + "test_botocore_ec2:test_ec2", + scoped_metrics=_ec2_scoped_metrics, + rollup_metrics=_ec2_rollup_metrics, + background_task=True, +) @background_task() @moto.mock_ec2 def test_ec2(): session = botocore.session.get_session() client = session.create_client( - 'ec2', - region_name=AWS_REGION, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY + "ec2", region_name=AWS_REGION, aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY ) # Create instance resp = client.run_instances( - ImageId=UBUNTU_14_04_PARAVIRTUAL_AMI, - InstanceType='m1.small', - MinCount=1, - MaxCount=1, + ImageId=UBUNTU_14_04_PARAVIRTUAL_AMI, + InstanceType="m1.small", + MinCount=1, + MaxCount=1, ) - assert resp['ResponseMetadata']['HTTPStatusCode'] == 200 - assert len(resp['Instances']) == 1 - instance_id = resp['Instances'][0]['InstanceId'] + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert len(resp["Instances"]) == 1 + instance_id = resp["Instances"][0]["InstanceId"] # Describe instance resp = client.describe_instances(InstanceIds=[instance_id]) - assert resp['ResponseMetadata']['HTTPStatusCode'] == 200 - assert resp['Reservations'][0]['Instances'][0]['InstanceId'] == instance_id + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert resp["Reservations"][0]["Instances"][0]["InstanceId"] == instance_id # Delete instance resp = client.terminate_instances(InstanceIds=[instance_id]) - assert resp['ResponseMetadata']['HTTPStatusCode'] == 200 - assert resp['TerminatingInstances'][0]['InstanceId'] == instance_id + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert resp["TerminatingInstances"][0]["InstanceId"] == instance_id diff --git a/tests/external_botocore/test_botocore_s3.py b/tests/external_botocore/test_botocore_s3.py index 1984d8103..ea0c22539 100644 --- a/tests/external_botocore/test_botocore_s3.py +++ b/tests/external_botocore/test_botocore_s3.py @@ -18,7 +18,7 @@ import botocore import botocore.session import moto -from testing_support.fixtures import override_application_settings +from testing_support.fixtures import dt_enabled from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_metrics import ( validate_transaction_metrics, @@ -67,7 +67,7 @@ ] -@override_application_settings({"distributed_tracing.enabled": True}) +@dt_enabled @validate_span_events(exact_agents={"aws.operation": "CreateBucket"}, count=1) @validate_span_events(exact_agents={"aws.operation": "PutObject"}, count=1) @validate_span_events(exact_agents={"aws.operation": "ListObjects"}, count=1) diff --git a/tests/external_botocore/test_botocore_sqs.py b/tests/external_botocore/test_botocore_sqs.py index 3f7d8c022..63f15801b 100644 --- a/tests/external_botocore/test_botocore_sqs.py +++ b/tests/external_botocore/test_botocore_sqs.py @@ -18,7 +18,7 @@ import botocore.session import moto import pytest -from testing_support.fixtures import override_application_settings +from testing_support.fixtures import dt_enabled from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_metrics import ( validate_transaction_metrics, @@ -70,7 +70,7 @@ ] -@override_application_settings({"distributed_tracing.enabled": True}) +@dt_enabled @validate_span_events(exact_agents={"aws.operation": "CreateQueue"}, count=1) @validate_span_events(exact_agents={"aws.operation": "SendMessage"}, count=1) @validate_span_events(exact_agents={"aws.operation": "ReceiveMessage"}, count=1) @@ -124,7 +124,7 @@ def test_sqs(): assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 -@override_application_settings({"distributed_tracing.enabled": True}) +@dt_enabled @validate_transaction_metrics( "test_botocore_sqs:test_sqs_malformed", scoped_metrics=_sqs_scoped_metrics_malformed, diff --git a/tests/external_httplib/test_httplib.py b/tests/external_httplib/test_httplib.py index c7747f8ff..f67e68dc2 100644 --- a/tests/external_httplib/test_httplib.py +++ b/tests/external_httplib/test_httplib.py @@ -23,12 +23,7 @@ cache_outgoing_headers, insert_incoming_headers, ) -from testing_support.fixtures import ( - cat_enabled, - override_application_settings, - validate_tt_segment_params, -) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from testing_support.fixtures import cat_enabled, override_application_settings from testing_support.validators.validate_cross_process_headers import ( validate_cross_process_headers, ) @@ -36,6 +31,12 @@ validate_external_node_params, ) from testing_support.validators.validate_span_events import validate_span_events +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) from newrelic.api.background_task import background_task from newrelic.common.encoding_utils import DistributedTracePayload @@ -104,7 +105,8 @@ def test_httplib_https_request(server): ) @background_task(name="test_httplib:test_httplib_https_request") def _test(): - connection = httplib.HTTPSConnection("localhost", server.port) + # fix HTTPSConnection: https://wiki.openstack.org/wiki/OSSN/OSSN-0033 + connection = httplib.HTTPSConnection("localhost", server.port) # nosec # It doesn't matter that a SSL exception is raised here because the # agent still records this as an external request try: diff --git a/tests/external_httpx/test_client.py b/tests/external_httpx/test_client.py index 87a1bc7d0..b4760a38f 100644 --- a/tests/external_httpx/test_client.py +++ b/tests/external_httpx/test_client.py @@ -19,7 +19,6 @@ dt_enabled, override_application_settings, override_generic_settings, - validate_tt_segment_params, ) from testing_support.mock_external_http_server import ( MockExternalHTTPHResponseHeadersServer, @@ -28,8 +27,15 @@ validate_cross_process_headers, ) from testing_support.validators.validate_span_events import validate_span_events -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_segment_params import ( + validate_tt_segment_params, +) from newrelic.api.background_task import background_task from newrelic.api.time_trace import current_trace diff --git a/tests/framework_aiohttp/test_server_cat.py b/tests/framework_aiohttp/test_server_cat.py index 44b5c7217..28af90d8d 100644 --- a/tests/framework_aiohttp/test_server_cat.py +++ b/tests/framework_aiohttp/test_server_cat.py @@ -37,7 +37,7 @@ def record_aiohttp1_raw_headers(raw_headers): try: - import aiohttp.protocol # noqa: F401 + import aiohttp.protocol # noqa: F401, pylint: disable=W0611 except ImportError: def pass_through(function): diff --git a/tests/framework_ariadne/__init__.py b/tests/framework_ariadne/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_ariadne/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_ariadne/_target_application.py b/tests/framework_ariadne/_target_application.py index 94bc0710f..fef782608 100644 --- a/tests/framework_ariadne/_target_application.py +++ b/tests/framework_ariadne/_target_application.py @@ -12,140 +12,125 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -from ariadne import ( - MutationType, - QueryType, - UnionType, - load_schema_from_path, - make_executable_schema, + +import asyncio +import json + +from framework_ariadne._target_schema_async import ( + target_asgi_application as target_asgi_application_async, +) +from framework_ariadne._target_schema_async import target_schema as target_schema_async +from framework_ariadne._target_schema_sync import ( + target_asgi_application as target_asgi_application_sync, +) +from framework_ariadne._target_schema_sync import target_schema as target_schema_sync +from framework_ariadne._target_schema_sync import ( + target_wsgi_application as target_wsgi_application_sync, ) -from ariadne.asgi import GraphQL as GraphQLASGI -from ariadne.wsgi import GraphQL as GraphQLWSGI +from framework_ariadne._target_schema_sync import ariadne_version_tuple +from graphql import MiddlewareManager -schema_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.graphql") -type_defs = load_schema_from_path(schema_file) - - -authors = [ - { - "first_name": "New", - "last_name": "Relic", - }, - { - "first_name": "Bob", - "last_name": "Smith", - }, - { - "first_name": "Leslie", - "last_name": "Jones", - }, -] -books = [ - { - "id": 1, - "name": "Python Agent: The Book", - "isbn": "a-fake-isbn", - "author": authors[0], - "branch": "riverside", - }, - { - "id": 2, - "name": "Ollies for O11y: A Sk8er's Guide to Observability", - "isbn": "a-second-fake-isbn", - "author": authors[1], - "branch": "downtown", - }, - { - "id": 3, - "name": "[Redacted]", - "isbn": "a-third-fake-isbn", - "author": authors[2], - "branch": "riverside", - }, -] -magazines = [ - {"id": 1, "name": "Reli Updates Weekly", "issue": 1, "branch": "riverside"}, - {"id": 2, "name": "Reli Updates Weekly", "issue": 2, "branch": "downtown"}, - {"id": 3, "name": "Node Weekly", "issue": 1, "branch": "riverside"}, -] +def check_response(query, success, response): + if isinstance(query, str) and "error" not in query: + assert success and "errors" not in response, response + assert response.get("data", None), response + else: + assert "errors" in response, response -libraries = ["riverside", "downtown"] -libraries = [ - { - "id": i + 1, - "branch": branch, - "magazine": [m for m in magazines if m["branch"] == branch], - "book": [b for b in books if b["branch"] == branch], - } - for i, branch in enumerate(libraries) -] +def run_sync(schema): + def _run_sync(query, middleware=None): + from ariadne import graphql_sync -storage = [] + if ariadne_version_tuple < (0, 18): + if middleware: + middleware = MiddlewareManager(*middleware) + success, response = graphql_sync(schema, {"query": query}, middleware=middleware) + check_response(query, success, response) -mutation = MutationType() + return response.get("data", {}) + return _run_sync -@mutation.field("storage_add") -def mutate(self, info, string): - storage.append(string) - return {"string": string} +def run_async(schema): + def _run_async(query, middleware=None): + from ariadne import graphql -item = UnionType("Item") + #Later versions of ariadne directly accept a list of middleware while older versions require the MiddlewareManager + if ariadne_version_tuple < (0, 18): + if middleware: + middleware = MiddlewareManager(*middleware) + loop = asyncio.get_event_loop() + success, response = loop.run_until_complete(graphql(schema, {"query": query}, middleware=middleware)) + check_response(query, success, response) -@item.type_resolver -def resolve_type(obj, *args): - if "isbn" in obj: - return "Book" - elif "issue" in obj: # pylint: disable=R1705 - return "Magazine" + return response.get("data", {}) - return None + return _run_async -query = QueryType() +def run_wsgi(app): + def _run_asgi(query, middleware=None): + if not isinstance(query, str) or "error" in query: + expect_errors = True + else: + expect_errors = False + app.app.middleware = middleware -@query.field("library") -def resolve_library(self, info, index): - return libraries[index] + response = app.post( + "/", json.dumps({"query": query}), headers={"Content-Type": "application/json"}, expect_errors=expect_errors + ) + body = json.loads(response.body.decode("utf-8")) + if expect_errors: + assert body["errors"] + else: + assert "errors" not in body or not body["errors"] -@query.field("storage") -def resolve_storage(self, info): - return storage + return body.get("data", {}) + return _run_asgi -@query.field("search") -def resolve_search(self, info, contains): - search_books = [b for b in books if contains in b["name"]] - search_magazines = [m for m in magazines if contains in m["name"]] - return search_books + search_magazines +def run_asgi(app): + def _run_asgi(query, middleware=None): + if ariadne_version_tuple < (0, 16): + app.asgi_application.middleware = middleware -@query.field("hello") -def resolve_hello(self, info): - return "Hello!" + #In ariadne v0.16.0, the middleware attribute was removed from the GraphQL class in favor of the http_handler + elif ariadne_version_tuple >= (0, 16): + app.asgi_application.http_handler.middleware = middleware + response = app.make_request( + "POST", "/", body=json.dumps({"query": query}), headers={"Content-Type": "application/json"} + ) + body = json.loads(response.body.decode("utf-8")) -@query.field("echo") -def resolve_echo(self, info, echo): - return echo + if not isinstance(query, str) or "error" in query: + try: + assert response.status != 200 + except AssertionError: + assert body["errors"] + else: + assert response.status == 200 + assert "errors" not in body or not body["errors"] + return body.get("data", {}) -@query.field("error_non_null") -@query.field("error") -def resolve_error(self, info): - raise RuntimeError("Runtime Error!") + return _run_asgi -_target_application = make_executable_schema(type_defs, query, mutation, item) -_target_asgi_application = GraphQLASGI(_target_application) -_target_wsgi_application = GraphQLWSGI(_target_application) +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), + "wsgi-sync": run_wsgi(target_wsgi_application_sync), + "asgi-sync": run_asgi(target_asgi_application_sync), + "asgi-async": run_asgi(target_asgi_application_async), +} diff --git a/tests/framework_ariadne/_target_schema_async.py b/tests/framework_ariadne/_target_schema_async.py new file mode 100644 index 000000000..076475628 --- /dev/null +++ b/tests/framework_ariadne/_target_schema_async.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ariadne import ( + MutationType, + QueryType, + UnionType, + load_schema_from_path, + make_executable_schema, +) +from ariadne.asgi import GraphQL as GraphQLASGI +from framework_graphql._target_schema_sync import books, magazines, libraries + +from testing_support.asgi_testing import AsgiTest + +schema_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.graphql") +type_defs = load_schema_from_path(schema_file) + +storage = [] + +mutation = MutationType() + + +@mutation.field("storage_add") +async def resolve_storage_add(self, info, string): + storage.append(string) + return string + + +item = UnionType("Item") + + +@item.type_resolver +async def resolve_type(obj, *args): + if "isbn" in obj: + return "Book" + elif "issue" in obj: # pylint: disable=R1705 + return "Magazine" + + return None + + +query = QueryType() + + +@query.field("library") +async def resolve_library(self, info, index): + return libraries[index] + + +@query.field("storage") +async def resolve_storage(self, info): + return [storage.pop()] + + +@query.field("search") +async def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +@query.field("hello") +@query.field("error_middleware") +async def resolve_hello(self, info): + return "Hello!" + + +@query.field("echo") +async def resolve_echo(self, info, echo): + return echo + + +@query.field("error_non_null") +@query.field("error") +async def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + + +target_schema = make_executable_schema(type_defs, query, mutation, item) +target_asgi_application = AsgiTest(GraphQLASGI(target_schema)) diff --git a/tests/framework_ariadne/_target_schema_sync.py b/tests/framework_ariadne/_target_schema_sync.py new file mode 100644 index 000000000..8860e71ac --- /dev/null +++ b/tests/framework_ariadne/_target_schema_sync.py @@ -0,0 +1,106 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import webtest + +from ariadne import ( + MutationType, + QueryType, + UnionType, + load_schema_from_path, + make_executable_schema, +) +from ariadne.wsgi import GraphQL as GraphQLWSGI +from framework_graphql._target_schema_sync import books, magazines, libraries + +from testing_support.asgi_testing import AsgiTest +from framework_ariadne.test_application import ARIADNE_VERSION + +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) + +if ariadne_version_tuple < (0, 16): + from ariadne.asgi import GraphQL as GraphQLASGI +elif ariadne_version_tuple >= (0, 16): + from ariadne.asgi.graphql import GraphQL as GraphQLASGI + + +schema_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.graphql") +type_defs = load_schema_from_path(schema_file) + +storage = [] + +mutation = MutationType() + + + +@mutation.field("storage_add") +def resolve_storage_add(self, info, string): + storage.append(string) + return string + + +item = UnionType("Item") + + +@item.type_resolver +def resolve_type(obj, *args): + if "isbn" in obj: + return "Book" + elif "issue" in obj: # pylint: disable=R1705 + return "Magazine" + + return None + + +query = QueryType() + + +@query.field("library") +def resolve_library(self, info, index): + return libraries[index] + + +@query.field("storage") +def resolve_storage(self, info): + return [storage.pop()] + + +@query.field("search") +def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +@query.field("hello") +@query.field("error_middleware") +def resolve_hello(self, info): + return "Hello!" + + +@query.field("echo") +def resolve_echo(self, info, echo): + return echo + + +@query.field("error_non_null") +@query.field("error") +def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + + +target_schema = make_executable_schema(type_defs, query, mutation, item) +target_asgi_application = AsgiTest(GraphQLASGI(target_schema)) +target_wsgi_application = webtest.TestApp(GraphQLWSGI(target_schema)) \ No newline at end of file diff --git a/tests/framework_ariadne/conftest.py b/tests/framework_ariadne/conftest.py index 93623a685..42b08faba 100644 --- a/tests/framework_ariadne/conftest.py +++ b/tests/framework_ariadne/conftest.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import six -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -31,12 +32,5 @@ ) -@pytest.fixture(scope="session") -def app(): - from _target_application import _target_application - - return _target_application - - if six.PY2: collect_ignore = ["test_application_async.py"] diff --git a/tests/framework_ariadne/schema.graphql b/tests/framework_ariadne/schema.graphql index 4c76e0b88..8bf64af51 100644 --- a/tests/framework_ariadne/schema.graphql +++ b/tests/framework_ariadne/schema.graphql @@ -33,7 +33,7 @@ type Magazine { } type Mutation { - storage_add(string: String!): StorageAdd + storage_add(string: String!): String } type Query { @@ -44,8 +44,5 @@ type Query { echo(echo: String!): String error: String error_non_null: String! -} - -type StorageAdd { - string: String + error_middleware: String } diff --git a/tests/framework_ariadne/test_application.py b/tests/framework_ariadne/test_application.py index cf8501a7a..0b7bf2489 100644 --- a/tests/framework_ariadne/test_application.py +++ b/tests/framework_ariadne/test_application.py @@ -11,526 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest -from testing_support.fixtures import dt_enabled, override_application_settings -from testing_support.validators.validate_span_events import validate_span_events -from testing_support.validators.validate_transaction_count import ( - validate_transaction_count, -) -from testing_support.validators.validate_transaction_errors import ( - validate_transaction_errors, -) -from testing_support.validators.validate_transaction_metrics import ( - validate_transaction_metrics, -) - -from newrelic.api.background_task import background_task -from newrelic.common.object_names import callable_name -from newrelic.common.package_version_utils import get_package_version_tuple - - -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - - major_version = int(version.split(".")[0]) - return major_version == 2 - - -@pytest.fixture(scope="session") -def graphql_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, query, *args, **kwargs): - from ariadne import graphql_sync - - return graphql_sync(schema, {"query": query}, *args, **kwargs) - - return execute - - -def to_graphql_source(query): - def delay_import(): - try: - from graphql import Source - except ImportError: - # Fallback if Source is not implemented - return query - - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - - return Source(query) - - return delay_import - - -def example_middleware(next, root, info, **args): # pylint: disable=W0622 - return_value = next(root, info, **args) - return return_value - - -def error_middleware(next, root, info, **args): # pylint: disable=W0622 - raise RuntimeError("Runtime Error!") - - -_runtime_error_name = callable_name(RuntimeError) -_test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Ariadne/all", 1), - ("GraphQL/Ariadne/allOther", 1), -] - - -def test_basic(app, graphql_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - ok, response = graphql_run(app, "{ hello }") - assert ok and not response.get("errors") - - _test() - - -@dt_enabled -def test_query_and_mutation(app, graphql_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Ariadne/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Ariadne/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - ok, response = graphql_run(app, 'mutation { storage_add(string: "abc") { string } }') - assert ok and not response.get("errors") - ok, response = graphql_run(app, "query { storage }") - assert ok and not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response["data"]) - assert "abc" in str(response["data"]) - - _test() - - -@dt_enabled -def test_middleware(app, graphql_run, is_graphql_2): - _test_middleware_metrics = [ - ("GraphQL/operation/Ariadne/query//hello", 1), - ("GraphQL/resolve/Ariadne/hello", 1), - ("Function/test_application:example_middleware", 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_middleware_metrics, - rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and Resolver function - @validate_span_events(count=5) - @background_task() - def _test(): - from graphql import MiddlewareManager - - middleware = ( - [example_middleware] - if get_package_version_tuple("ariadne") >= (0, 18) - else MiddlewareManager(example_middleware) - ) +from framework_graphql.test_application import * - ok, response = graphql_run(app, "{ hello }", middleware=middleware) - assert ok and not response.get("errors") - assert "Hello!" in str(response["data"]) +from newrelic.common.package_version_utils import get_package_version - _test() +ARIADNE_VERSION = get_package_version("ariadne") +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) -@dt_enabled -def test_exception_in_middleware(app, graphql_run): - query = "query MyQuery { hello }" - field = "hello" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Ariadne/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Ariadne/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/test_application:error_middleware", 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - "test_application:error_middleware", - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - from graphql import MiddlewareManager - - middleware = ( - [error_middleware] - if get_package_version_tuple("ariadne") >= (0, 18) - else MiddlewareManager(error_middleware) - ) - - _, response = graphql_run(app, query, middleware=middleware) - assert response["errors"] - - _test() - - -@pytest.mark.parametrize("field", ("error", "error_non_null")) -@dt_enabled -def test_exception_in_resolver(app, graphql_run, field): - query = "query MyQuery { %s }" % field - txn_name = "_target_application:resolve_error" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Ariadne/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Ariadne/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String!" if "non_null" in field else "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - _, response = graphql_run(app, query) - assert response["errors"] - - _test() - - -@dt_enabled -@pytest.mark.parametrize( - "query,exc_class", - [ - ("query MyQuery { missing_field }", "GraphQLError"), - ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), - ], +@pytest.fixture( + scope="session", params=["sync-sync", "async-sync", "async-async", "wsgi-sync", "asgi-sync", "asgi-async"] ) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): - if "syntax" in query: - txn_name = "graphql.language.parser:parse" - else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" - - # Import path differs between versions - if exc_class == "GraphQLError": - from graphql.error import GraphQLError - - exc_class = callable_name(GraphQLError) - - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Ariadne///", 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_operation_attributes = { - "graphql.operation.type": "", - "graphql.operation.name": "", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_transaction_errors(errors=[exc_class]) - @background_task() - def _test(): - _, response = graphql_run(app, query) - assert response["errors"] - - _test() - - -@dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/Ariadne/query/MyQuery/library", 1)] - operation_attrs = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - } - - @validate_transaction_metrics( - "query/MyQuery/library", - "GraphQL", - scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) - @validate_span_events(exact_agents=operation_attrs) - @background_task() - def _test(): - ok, response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert ok and not response.get("errors") - - _test() - - -@dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/Ariadne/hello", 1)] - graphql_attrs = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - ok, response = graphql_run(app, "{ hello }") - assert ok and not response.get("errors") - assert "Hello!" in str(response["data"]) - - _test() - - -_test_queries = [ - ("{ hello }", "{ hello }"), # Basic query extraction - ("{ error }", "{ error }"), # Extract query on field error - ("{ library(index: 0) { branch } }", "{ library(index: ?) { branch } }"), # Integers - ('{ echo(echo: "123") }', "{ echo(echo: ?) }"), # Strings with numerics - ('{ echo(echo: "test") }', "{ echo(echo: ?) }"), # Strings - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Aliases - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Variables - ( # Fragments - '{ ...MyFragment } fragment MyFragment on Query { echo(echo: "test") }', - "{ ...MyFragment } fragment MyFragment on Query { echo(echo: ?) }", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): - graphql_attrs = {"graphql.operation.query": obfuscated} - - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - ok, response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert ok and not response.get("errors") - - _test() - - -_test_queries = [ - ("{ hello }", "/hello"), # Basic query - ("{ error }", "/error"), # Extract deepest path on field error - ('{ echo(echo: "test") }', "/echo"), # Fields with arguments - ( - "{ library(index: 0) { branch, book { isbn branch } } }", - "/library", - ), # Complex Example, 1 level - ( - "{ library(index: 0) { book { author { first_name }} } }", - "/library.book.author.first_name", - ), # Complex Example, 2 levels - ("{ library(index: 0) { id, book { name } } }", "/library.book.name"), # Filtering - ('{ TestEcho: echo(echo: "test") }', "/echo"), # Aliases - ( - '{ search(contains: "A") { __typename ... on Book { name } } }', - "/search.name", - ), # InlineFragment - ( - '{ hello echo(echo: "test") }', - "", - ), # Multiple root selections. (need to decide on final behavior) - # FragmentSpread - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { name id }", # Fragment filtering - "/library.book.name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { author { first_name } }", - "/library.book.author.first_name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } magazine { ...MagFragment } } } fragment MyFragment on Book { author { first_name } } fragment MagFragment on Magazine { name }", - "/library", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): - if expected_path == "/error": - txn_name = "_target_application:resolve_error" - else: - txn_name = "query/%s" % expected_path - - @validate_transaction_metrics( - txn_name, - "GraphQL", - background_task=True, - ) - @background_task() - def _test(): - ok, response = graphql_run(app, query) - if "error" not in query: - assert ok and not response.get("errors") - - _test() - +def target_application(request): + from ._target_application import target_application -@pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): - txn_ct = 1 if capture_introspection_setting else 0 + target_application = target_application[request.param] - @override_application_settings( - {"instrumentation.graphql.capture_introspection_queries": capture_introspection_setting} - ) - @validate_transaction_count(txn_ct) - @background_task() - def _test(): - ok, response = graphql_run(app, "{ __schema { types { name } } }") - assert ok and not response.get("errors") + param = request.param.split("-") + is_background = param[0] not in {"wsgi", "asgi"} + schema_type = param[1] + extra_spans = 4 if param[0] == "wsgi" else 0 - _test() + assert ARIADNE_VERSION is not None + return "Ariadne", ARIADNE_VERSION, target_application, is_background, schema_type, extra_spans diff --git a/tests/framework_ariadne/test_application_async.py b/tests/framework_ariadne/test_application_async.py deleted file mode 100644 index ada34ffad..000000000 --- a/tests/framework_ariadne/test_application_async.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio - -import pytest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - -from newrelic.api.background_task import background_task - - -@pytest.fixture(scope="session") -def graphql_run_async(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, query, *args, **kwargs): - from ariadne import graphql - - return graphql(schema, {"query": query}, *args, **kwargs) - - return execute - - -@dt_enabled -def test_query_and_mutation_async(app, graphql_run_async): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Ariadne/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Ariadne/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - async def coro(): - ok, response = await graphql_run_async(app, 'mutation { storage_add(string: "abc") { string } }') - assert ok and not response.get("errors") - ok, response = await graphql_run_async(app, "query { storage }") - assert ok and not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - loop = asyncio.new_event_loop() - loop.run_until_complete(coro()) - - _test() diff --git a/tests/framework_ariadne/test_asgi.py b/tests/framework_ariadne/test_asgi.py deleted file mode 100644 index 861f2aa93..000000000 --- a/tests/framework_ariadne/test_asgi.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from testing_support.asgi_testing import AsgiTest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -@pytest.fixture(scope="session") -def graphql_asgi_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - from _target_application import _target_asgi_application - - app = AsgiTest(_target_asgi_application) - - def execute(query): - return app.make_request( - "POST", "/", headers={"Content-Type": "application/json"}, body=json.dumps({"query": query}) - ) - - return execute - - -@dt_enabled -def test_query_and_mutation_asgi(graphql_asgi_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_query_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ] - _test_unscoped_metrics = [ - ("WebTransaction", 1), - ("GraphQL/all", 1), - ("GraphQL/Ariadne/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Ariadne/allWeb", 1), - ] - _test_mutation_unscoped_metrics = _test_unscoped_metrics + _test_mutation_scoped_metrics - _test_query_unscoped_metrics = _test_unscoped_metrics + _test_query_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_query_scoped_metrics, - rollup_metrics=_test_query_unscoped_metrics + FRAMEWORK_METRICS, - ) - @validate_transaction_metrics( - "mutation//storage_add.string", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - index=-2, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes, index=-2) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes, index=-2) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - def _test(): - response = graphql_asgi_run('mutation { storage_add(string: "abc") { string } }') - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - response = graphql_asgi_run("query { storage }") - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - _test() diff --git a/tests/framework_ariadne/test_wsgi.py b/tests/framework_ariadne/test_wsgi.py deleted file mode 100644 index 9ce2373d4..000000000 --- a/tests/framework_ariadne/test_wsgi.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import webtest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -@pytest.fixture(scope="session") -def graphql_wsgi_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - from _target_application import _target_wsgi_application - - app = webtest.TestApp(_target_wsgi_application) - - def execute(query): - return app.post_json("/", {"query": query}) - - return execute - - -@dt_enabled -def test_query_and_mutation_wsgi(graphql_wsgi_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_query_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ] - _test_unscoped_metrics = [ - ("WebTransaction", 1), - ("Python/WSGI/Response", 1), - ("GraphQL/all", 1), - ("GraphQL/Ariadne/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Ariadne/allWeb", 1), - ] - _test_mutation_unscoped_metrics = _test_unscoped_metrics + _test_mutation_scoped_metrics - _test_query_unscoped_metrics = _test_unscoped_metrics + _test_query_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_query_scoped_metrics, - rollup_metrics=_test_query_unscoped_metrics + FRAMEWORK_METRICS, - ) - @validate_transaction_metrics( - "mutation//storage_add.string", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - index=-2, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes, index=-2) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes, index=-2) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - def _test(): - response = graphql_wsgi_run('mutation { storage_add(string: "abc") { string } }') - assert response.status_code == 200 - response = response.json_body - assert not response.get("errors") - - response = graphql_wsgi_run("query { storage }") - assert response.status_code == 200 - response = response.json_body - assert not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - _test() diff --git a/tests/framework_graphene/__init__.py b/tests/framework_graphene/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_graphene/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_graphene/_target_application.py b/tests/framework_graphene/_target_application.py index 50acc776f..3f4b23e57 100644 --- a/tests/framework_graphene/_target_application.py +++ b/tests/framework_graphene/_target_application.py @@ -11,150 +11,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from graphene import Field, Int, List -from graphene import Mutation as GrapheneMutation -from graphene import NonNull, ObjectType, Schema, String, Union +from ._target_schema_async import target_schema as target_schema_async +from ._target_schema_sync import target_schema as target_schema_sync +from framework_graphene.test_application import GRAPHENE_VERSION -class Author(ObjectType): - first_name = String() - last_name = String() +def check_response(query, response): + if isinstance(query, str) and "error" not in query: + assert not response.errors, response + assert response.data + else: + assert response.errors, response -class Book(ObjectType): - id = Int() - name = String() - isbn = String() - author = Field(Author) - branch = String() +def run_sync(schema): + def _run_sync(query, middleware=None): + response = schema.execute(query, middleware=middleware) + check_response(query, response) + return response.data -class Magazine(ObjectType): - id = Int() - name = String() - issue = Int() - branch = String() + return _run_sync -class Item(Union): - class Meta: - types = (Book, Magazine) +def run_async(schema): + import asyncio + def _run_async(query, middleware=None): + loop = asyncio.get_event_loop() + response = loop.run_until_complete(schema.execute_async(query, middleware=middleware)) + check_response(query, response) -class Library(ObjectType): - id = Int() - branch = String() - magazine = Field(List(Magazine)) - book = Field(List(Book)) + return response.data + return _run_async -Storage = List(String) +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), + } -authors = [ - Author( - first_name="New", - last_name="Relic", - ), - Author( - first_name="Bob", - last_name="Smith", - ), - Author( - first_name="Leslie", - last_name="Jones", - ), -] - -books = [ - Book( - id=1, - name="Python Agent: The Book", - isbn="a-fake-isbn", - author=authors[0], - branch="riverside", - ), - Book( - id=2, - name="Ollies for O11y: A Sk8er's Guide to Observability", - isbn="a-second-fake-isbn", - author=authors[1], - branch="downtown", - ), - Book( - id=3, - name="[Redacted]", - isbn="a-third-fake-isbn", - author=authors[2], - branch="riverside", - ), -] - -magazines = [ - Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), - Magazine(id=2, name="Reli Updates Weekly", issue=2, branch="downtown"), - Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), -] - - -libraries = ["riverside", "downtown"] -libraries = [ - Library( - id=i + 1, - branch=branch, - magazine=[m for m in magazines if m.branch == branch], - book=[b for b in books if b.branch == branch], - ) - for i, branch in enumerate(libraries) -] - -storage = [] - - -class StorageAdd(GrapheneMutation): - class Arguments: - string = String(required=True) - - string = String() - - def mutate(self, info, string): - storage.append(string) - return String(string=string) - - -class Query(ObjectType): - library = Field(Library, index=Int(required=True)) - hello = String() - search = Field(List(Item), contains=String(required=True)) - echo = Field(String, echo=String(required=True)) - storage = Storage - error = String() - - def resolve_library(self, info, index): - return libraries[index] - - def resolve_storage(self, info): - return storage - - def resolve_search(self, info, contains): - search_books = [b for b in books if contains in b.name] - search_magazines = [m for m in magazines if contains in m.name] - return search_books + search_magazines - - def resolve_hello(self, info): - return "Hello!" - - def resolve_echo(self, info, echo): - return echo - - def resolve_error(self, info): - raise RuntimeError("Runtime Error!") - - error_non_null = Field(NonNull(String), resolver=resolve_error) - - -class Mutation(ObjectType): - storage_add = StorageAdd.Field() - - -_target_application = Schema(query=Query, mutation=Mutation, auto_camelcase=False) diff --git a/tests/framework_graphene/_target_schema_async.py b/tests/framework_graphene/_target_schema_async.py new file mode 100644 index 000000000..39905f2f9 --- /dev/null +++ b/tests/framework_graphene/_target_schema_async.py @@ -0,0 +1,72 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from graphene import Field, Int, List +from graphene import Mutation as GrapheneMutation +from graphene import NonNull, ObjectType, Schema, String, Union + +from ._target_schema_sync import Author, Book, Magazine, Item, Library, Storage, authors, books, magazines, libraries + + +storage = [] + + +async def resolve_library(self, info, index): + return libraries[index] + +async def resolve_storage(self, info): + return [storage.pop()] + +async def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + +async def resolve_hello(self, info): + return "Hello!" + +async def resolve_echo(self, info, echo): + return echo + +async def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + +async def resolve_storage_add(self, info, string): + storage.append(string) + return StorageAdd(string=string) + + +class StorageAdd(GrapheneMutation): + class Arguments: + string = String(required=True) + + string = String() + mutate = resolve_storage_add + + +class Query(ObjectType): + library = Field(Library, index=Int(required=True), resolver=resolve_library) + hello = String(resolver=resolve_hello) + search = Field(List(Item), contains=String(required=True), resolver=resolve_search) + echo = Field(String, echo=String(required=True), resolver=resolve_echo) + storage = Field(Storage, resolver=resolve_storage) + error = String(resolver=resolve_error) + error_non_null = Field(NonNull(String), resolver=resolve_error) + error_middleware = String(resolver=resolve_hello) + + +class Mutation(ObjectType): + storage_add = StorageAdd.Field() + + +target_schema = Schema(query=Query, mutation=Mutation, auto_camelcase=False) diff --git a/tests/framework_graphene/_target_schema_sync.py b/tests/framework_graphene/_target_schema_sync.py new file mode 100644 index 000000000..b59179065 --- /dev/null +++ b/tests/framework_graphene/_target_schema_sync.py @@ -0,0 +1,162 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from graphene import Field, Int, List +from graphene import Mutation as GrapheneMutation +from graphene import NonNull, ObjectType, Schema, String, Union + + +class Author(ObjectType): + first_name = String() + last_name = String() + + +class Book(ObjectType): + id = Int() + name = String() + isbn = String() + author = Field(Author) + branch = String() + + +class Magazine(ObjectType): + id = Int() + name = String() + issue = Int() + branch = String() + + +class Item(Union): + class Meta: + types = (Book, Magazine) + + +class Library(ObjectType): + id = Int() + branch = String() + magazine = Field(List(Magazine)) + book = Field(List(Book)) + + +Storage = List(String) + + +authors = [ + Author( + first_name="New", + last_name="Relic", + ), + Author( + first_name="Bob", + last_name="Smith", + ), + Author( + first_name="Leslie", + last_name="Jones", + ), +] + +books = [ + Book( + id=1, + name="Python Agent: The Book", + isbn="a-fake-isbn", + author=authors[0], + branch="riverside", + ), + Book( + id=2, + name="Ollies for O11y: A Sk8er's Guide to Observability", + isbn="a-second-fake-isbn", + author=authors[1], + branch="downtown", + ), + Book( + id=3, + name="[Redacted]", + isbn="a-third-fake-isbn", + author=authors[2], + branch="riverside", + ), +] + +magazines = [ + Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), + Magazine(id=2, name="Reli Updates Weekly", issue=2, branch="downtown"), + Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), +] + + +libraries = ["riverside", "downtown"] +libraries = [ + Library( + id=i + 1, + branch=branch, + magazine=[m for m in magazines if m.branch == branch], + book=[b for b in books if b.branch == branch], + ) + for i, branch in enumerate(libraries) +] + +storage = [] + + +def resolve_library(self, info, index): + return libraries[index] + +def resolve_storage(self, info): + return [storage.pop()] + +def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + +def resolve_hello(self, info): + return "Hello!" + +def resolve_echo(self, info, echo): + return echo + +def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + +def resolve_storage_add(self, info, string): + storage.append(string) + return StorageAdd(string=string) + + +class StorageAdd(GrapheneMutation): + class Arguments: + string = String(required=True) + + string = String() + mutate = resolve_storage_add + + +class Query(ObjectType): + library = Field(Library, index=Int(required=True), resolver=resolve_library) + hello = String(resolver=resolve_hello) + search = Field(List(Item), contains=String(required=True), resolver=resolve_search) + echo = Field(String, echo=String(required=True), resolver=resolve_echo) + storage = Field(Storage, resolver=resolve_storage) + error = String(resolver=resolve_error) + error_non_null = Field(NonNull(String), resolver=resolve_error) + error_middleware = String(resolver=resolve_hello) + + +class Mutation(ObjectType): + storage_add = StorageAdd.Field() + + +target_schema = Schema(query=Query, mutation=Mutation, auto_camelcase=False) diff --git a/tests/framework_graphene/test_application.py b/tests/framework_graphene/test_application.py index fd02d992a..838f3b515 100644 --- a/tests/framework_graphene/test_application.py +++ b/tests/framework_graphene/test_application.py @@ -13,518 +13,25 @@ # limitations under the License. import pytest -import six -from testing_support.fixtures import dt_enabled, override_application_settings -from testing_support.validators.validate_span_events import validate_span_events -from testing_support.validators.validate_transaction_count import ( - validate_transaction_count, -) -from testing_support.validators.validate_transaction_errors import ( - validate_transaction_errors, -) -from testing_support.validators.validate_transaction_metrics import ( - validate_transaction_metrics, -) -from newrelic.api.background_task import background_task -from newrelic.common.object_names import callable_name +from framework_graphql.test_application import * +from newrelic.common.package_version_utils import get_package_version +GRAPHENE_VERSION = get_package_version("graphene") -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - major_version = int(version.split(".")[0]) - return major_version == 2 +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async"]) +def target_application(request): + from ._target_application import target_application + target_application = target_application.get(request.param, None) + if target_application is None: + pytest.skip("Unsupported combination.") + return -@pytest.fixture(scope="session") -def graphql_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, *args, **kwargs): - return schema.execute(*args, **kwargs) - - return execute - - -def to_graphql_source(query): - def delay_import(): - try: - from graphql import Source - except ImportError: - # Fallback if Source is not implemented - return query - - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - - return Source(query) - - return delay_import - - -def example_middleware(next, root, info, **args): # pylint: disable=W0622 - return_value = next(root, info, **args) - return return_value - - -def error_middleware(next, root, info, **args): # pylint: disable=W0622 - raise RuntimeError("Runtime Error!") - - -_runtime_error_name = callable_name(RuntimeError) -_test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Graphene/all", 1), - ("GraphQL/Graphene/allOther", 1), -] - - -def test_basic(app, graphql_run): - from graphql import __version__ as version - - from newrelic.hooks.framework_graphene import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Graphene/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - - _test() - - -@dt_enabled -def test_query_and_mutation(app, graphql_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Graphene/storage", 1), - ("GraphQL/resolve/Graphene/storage_add", 1), - ("GraphQL/operation/Graphene/query//storage", 1), - ("GraphQL/operation/Graphene/mutation//storage_add.string", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Graphene/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Graphene/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - response = graphql_run(app, 'mutation { storage_add(string: "abc") { string } }') - assert not response.errors - response = graphql_run(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - _test() - - -@dt_enabled -def test_middleware(app, graphql_run, is_graphql_2): - _test_middleware_metrics = [ - ("GraphQL/operation/Graphene/query//hello", 1), - ("GraphQL/resolve/Graphene/hello", 1), - ("Function/test_application:example_middleware", 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_middleware_metrics, - rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and 1 Resolver Function - @validate_span_events(count=5) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }", middleware=[example_middleware]) - assert not response.errors - assert "Hello!" in str(response.data) - - _test() - - -@dt_enabled -def test_exception_in_middleware(app, graphql_run): - query = "query MyQuery { hello }" - field = "hello" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Graphene/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Graphene/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/test_application:error_middleware", 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - "test_application:error_middleware", - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - response = graphql_run(app, query, middleware=[error_middleware]) - assert response.errors - - _test() - - -@pytest.mark.parametrize("field", ("error", "error_non_null")) -@dt_enabled -def test_exception_in_resolver(app, graphql_run, field): - query = "query MyQuery { %s }" % field - - if six.PY2: - txn_name = "_target_application:resolve_error" - else: - txn_name = "_target_application:Query.resolve_error" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Graphene/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Graphene/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String!" if "non_null" in field else "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -@pytest.mark.parametrize( - "query,exc_class", - [ - ("query MyQuery { missing_field }", "GraphQLError"), - ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), - ], -) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): - if "syntax" in query: - txn_name = "graphql.language.parser:parse" - else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" - - # Import path differs between versions - if exc_class == "GraphQLError": - from graphql.error import GraphQLError - - exc_class = callable_name(GraphQLError) - - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Graphene///", 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_operation_attributes = { - "graphql.operation.type": "", - "graphql.operation.name": "", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_transaction_errors(errors=[exc_class]) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/Graphene/query/MyQuery/library", 1)] - operation_attrs = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - } - - @validate_transaction_metrics( - "query/MyQuery/library", - "GraphQL", - scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) - @validate_span_events(exact_agents=operation_attrs) - @background_task() - def _test(): - response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert not response.errors - - _test() - - -@dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/Graphene/hello", 1)] - graphql_attrs = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - assert "Hello!" in str(response.data) - - _test() - - -_test_queries = [ - ("{ hello }", "{ hello }"), # Basic query extraction - ("{ error }", "{ error }"), # Extract query on field error - (to_graphql_source("{ hello }"), "{ hello }"), # Extract query from Source objects - ("{ library(index: 0) { branch } }", "{ library(index: ?) { branch } }"), # Integers - ('{ echo(echo: "123") }', "{ echo(echo: ?) }"), # Strings with numerics - ('{ echo(echo: "test") }', "{ echo(echo: ?) }"), # Strings - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Aliases - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Variables - ( # Fragments - '{ ...MyFragment } fragment MyFragment on Query { echo(echo: "test") }', - "{ ...MyFragment } fragment MyFragment on Query { echo(echo: ?) }", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): - graphql_attrs = {"graphql.operation.query": obfuscated} - - if callable(query): - query = query() - - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert not response.errors - - _test() - - -_test_queries = [ - ("{ hello }", "/hello"), # Basic query - ("{ error }", "/error"), # Extract deepest path on field error - ('{ echo(echo: "test") }', "/echo"), # Fields with arguments - ( - "{ library(index: 0) { branch, book { isbn branch } } }", - "/library", - ), # Complex Example, 1 level - ( - "{ library(index: 0) { book { author { first_name }} } }", - "/library.book.author.first_name", - ), # Complex Example, 2 levels - ("{ library(index: 0) { id, book { name } } }", "/library.book.name"), # Filtering - ('{ TestEcho: echo(echo: "test") }', "/echo"), # Aliases - ( - '{ search(contains: "A") { __typename ... on Book { name } } }', - "/search.name", - ), # InlineFragment - ( - '{ hello echo(echo: "test") }', - "", - ), # Multiple root selections. (need to decide on final behavior) - # FragmentSpread - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { name id }", # Fragment filtering - "/library.book.name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { author { first_name } }", - "/library.book.author.first_name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } magazine { ...MagFragment } } } fragment MyFragment on Book { author { first_name } } fragment MagFragment on Magazine { name }", - "/library", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): - if expected_path == "/error": - if six.PY2: - txn_name = "_target_application:resolve_error" - else: - txn_name = "_target_application:Query.resolve_error" - else: - txn_name = "query/%s" % expected_path - - @validate_transaction_metrics( - txn_name, - "GraphQL", - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, query) - if "error" not in query: - assert not response.errors - - _test() - - -@pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): - txn_ct = 1 if capture_introspection_setting else 0 - - @override_application_settings( - {"instrumentation.graphql.capture_introspection_queries": capture_introspection_setting} - ) - @validate_transaction_count(txn_ct) - @background_task() - def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors - - _test() + param = request.param.split("-") + is_background = param[0] not in {"wsgi", "asgi"} + schema_type = param[1] + extra_spans = 4 if param[0] == "wsgi" else 0 + assert GRAPHENE_VERSION is not None + return "Graphene", GRAPHENE_VERSION, target_application, is_background, schema_type, extra_spans diff --git a/tests/framework_graphql/__init__.py b/tests/framework_graphql/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_graphql/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_graphql/_target_application.py b/tests/framework_graphql/_target_application.py index 7bef5e975..91da5d767 100644 --- a/tests/framework_graphql/_target_application.py +++ b/tests/framework_graphql/_target_application.py @@ -12,228 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from graphql import ( - GraphQLArgument, - GraphQLField, - GraphQLInt, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLSchema, - GraphQLString, - GraphQLUnionType, -) - -authors = [ - { - "first_name": "New", - "last_name": "Relic", - }, - { - "first_name": "Bob", - "last_name": "Smith", - }, - { - "first_name": "Leslie", - "last_name": "Jones", - }, -] - -books = [ - { - "id": 1, - "name": "Python Agent: The Book", - "isbn": "a-fake-isbn", - "author": authors[0], - "branch": "riverside", - }, - { - "id": 2, - "name": "Ollies for O11y: A Sk8er's Guide to Observability", - "isbn": "a-second-fake-isbn", - "author": authors[1], - "branch": "downtown", - }, - { - "id": 3, - "name": "[Redacted]", - "isbn": "a-third-fake-isbn", - "author": authors[2], - "branch": "riverside", - }, -] - -magazines = [ - {"id": 1, "name": "Reli Updates Weekly", "issue": 1, "branch": "riverside"}, - {"id": 2, "name": "Reli Updates Weekly", "issue": 2, "branch": "downtown"}, - {"id": 3, "name": "Node Weekly", "issue": 1, "branch": "riverside"}, -] - - -libraries = ["riverside", "downtown"] -libraries = [ - { - "id": i + 1, - "branch": branch, - "magazine": [m for m in magazines if m["branch"] == branch], - "book": [b for b in books if b["branch"] == branch], - } - for i, branch in enumerate(libraries) -] - -storage = [] - - -def resolve_library(parent, info, index): - return libraries[index] - - -def resolve_storage_add(parent, info, string): - storage.append(string) - return string - - -def resolve_storage(parent, info): - return storage - - -def resolve_search(parent, info, contains): - search_books = [b for b in books if contains in b["name"]] - search_magazines = [m for m in magazines if contains in m["name"]] - return search_books + search_magazines - - -Author = GraphQLObjectType( - "Author", - { - "first_name": GraphQLField(GraphQLString), - "last_name": GraphQLField(GraphQLString), - }, -) - -Book = GraphQLObjectType( - "Book", - { - "id": GraphQLField(GraphQLInt), - "name": GraphQLField(GraphQLString), - "isbn": GraphQLField(GraphQLString), - "author": GraphQLField(Author), - "branch": GraphQLField(GraphQLString), - }, -) - -Magazine = GraphQLObjectType( - "Magazine", - { - "id": GraphQLField(GraphQLInt), - "name": GraphQLField(GraphQLString), - "issue": GraphQLField(GraphQLInt), - "branch": GraphQLField(GraphQLString), - }, -) - - -Library = GraphQLObjectType( - "Library", - { - "id": GraphQLField(GraphQLInt), - "branch": GraphQLField(GraphQLString), - "book": GraphQLField(GraphQLList(Book)), - "magazine": GraphQLField(GraphQLList(Magazine)), - }, -) - -Storage = GraphQLList(GraphQLString) - - -def resolve_hello(root, info): - return "Hello!" - - -def resolve_echo(root, info, echo): - return echo - - -def resolve_error(root, info): - raise RuntimeError("Runtime Error!") - - -try: - hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) - library_field = GraphQLField( - Library, - resolver=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolver=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolver=resolve_storage, - ) - storage_add_field = GraphQLField( - Storage, - resolver=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolver=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) -except TypeError: - hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) - library_field = GraphQLField( - Library, - resolve=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolve=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolve=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolve=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolve=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) - -query = GraphQLObjectType( - name="Query", - fields={ - "hello": hello_field, - "library": library_field, - "search": search_field, - "echo": echo_field, - "storage": storage_field, - "error": error_field, - "error_non_null": error_non_null_field, - "error_middleware": error_middleware_field, - }, -) - -mutation = GraphQLObjectType( - name="Mutation", - fields={ - "storage_add": storage_add_field, - }, -) - -_target_application = GraphQLSchema(query=query, mutation=mutation) +from graphql.language.source import Source + +from ._target_schema_async import target_schema as target_schema_async +from ._target_schema_sync import target_schema as target_schema_sync + + +def check_response(query, response): + if isinstance(query, str) and "error" not in query or isinstance(query, Source) and "error" not in query.body: + assert not response.errors, response.errors + assert response.data + else: + assert response.errors + + +def run_sync(schema): + def _run_sync(query, middleware=None): + try: + from graphql import graphql_sync as graphql + except ImportError: + from graphql import graphql + + response = graphql(schema, query, middleware=middleware) + + check_response(query, response) + + return response.data + + return _run_sync + + +def run_async(schema): + import asyncio + + from graphql import graphql + + def _run_async(query, middleware=None): + coro = graphql(schema, query, middleware=middleware) + loop = asyncio.get_event_loop() + response = loop.run_until_complete(coro) + + check_response(query, response) + + return response.data + + return _run_async + + +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), +} diff --git a/tests/framework_graphql/_target_schema_async.py b/tests/framework_graphql/_target_schema_async.py new file mode 100644 index 000000000..aad4eb271 --- /dev/null +++ b/tests/framework_graphql/_target_schema_async.py @@ -0,0 +1,155 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, + GraphQLUnionType, +) + +from ._target_schema_sync import books, libraries, magazines + +storage = [] + + +async def resolve_library(parent, info, index): + return libraries[index] + + +async def resolve_storage_add(parent, info, string): + storage.append(string) + return string + + +async def resolve_storage(parent, info): + return [storage.pop()] + + +async def resolve_search(parent, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +Author = GraphQLObjectType( + "Author", + { + "first_name": GraphQLField(GraphQLString), + "last_name": GraphQLField(GraphQLString), + }, +) + +Book = GraphQLObjectType( + "Book", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "isbn": GraphQLField(GraphQLString), + "author": GraphQLField(Author), + "branch": GraphQLField(GraphQLString), + }, +) + +Magazine = GraphQLObjectType( + "Magazine", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "issue": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + }, +) + + +Library = GraphQLObjectType( + "Library", + { + "id": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + "book": GraphQLField(GraphQLList(Book)), + "magazine": GraphQLField(GraphQLList(Magazine)), + }, +) + +Storage = GraphQLList(GraphQLString) + + +async def resolve_hello(root, info): + return "Hello!" + + +async def resolve_echo(root, info, echo): + return echo + + +async def resolve_error(root, info): + raise RuntimeError("Runtime Error!") + + +hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) +library_field = GraphQLField( + Library, + resolve=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolve=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolve=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolve=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolve=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) + +query = GraphQLObjectType( + name="Query", + fields={ + "hello": hello_field, + "library": library_field, + "search": search_field, + "echo": echo_field, + "storage": storage_field, + "error": error_field, + "error_non_null": error_non_null_field, + "error_middleware": error_middleware_field, + }, +) + +mutation = GraphQLObjectType( + name="Mutation", + fields={ + "storage_add": storage_add_field, + }, +) + +target_schema = GraphQLSchema(query=query, mutation=mutation) diff --git a/tests/framework_graphql/_target_schema_sync.py b/tests/framework_graphql/_target_schema_sync.py new file mode 100644 index 000000000..302a6c66e --- /dev/null +++ b/tests/framework_graphql/_target_schema_sync.py @@ -0,0 +1,210 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applic`ab`le law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, + GraphQLUnionType, +) + +authors = [ + { + "first_name": "New", + "last_name": "Relic", + }, + { + "first_name": "Bob", + "last_name": "Smith", + }, + { + "first_name": "Leslie", + "last_name": "Jones", + }, +] + +books = [ + { + "id": 1, + "name": "Python Agent: The Book", + "isbn": "a-fake-isbn", + "author": authors[0], + "branch": "riverside", + }, + { + "id": 2, + "name": "Ollies for O11y: A Sk8er's Guide to Observability", + "isbn": "a-second-fake-isbn", + "author": authors[1], + "branch": "downtown", + }, + { + "id": 3, + "name": "[Redacted]", + "isbn": "a-third-fake-isbn", + "author": authors[2], + "branch": "riverside", + }, +] + +magazines = [ + {"id": 1, "name": "Reli Updates Weekly", "issue": 1, "branch": "riverside"}, + {"id": 2, "name": "Reli Updates Weekly", "issue": 2, "branch": "downtown"}, + {"id": 3, "name": "Node Weekly", "issue": 1, "branch": "riverside"}, +] + + +libraries = ["riverside", "downtown"] +libraries = [ + { + "id": i + 1, + "branch": branch, + "magazine": [m for m in magazines if m["branch"] == branch], + "book": [b for b in books if b["branch"] == branch], + } + for i, branch in enumerate(libraries) +] + +storage = [] + + +def resolve_library(parent, info, index): + return libraries[index] + + +def resolve_storage_add(parent, info, string): + storage.append(string) + return string + + +def resolve_storage(parent, info): + return [storage.pop()] + + +def resolve_search(parent, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +Author = GraphQLObjectType( + "Author", + { + "first_name": GraphQLField(GraphQLString), + "last_name": GraphQLField(GraphQLString), + }, +) + +Book = GraphQLObjectType( + "Book", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "isbn": GraphQLField(GraphQLString), + "author": GraphQLField(Author), + "branch": GraphQLField(GraphQLString), + }, +) + +Magazine = GraphQLObjectType( + "Magazine", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "issue": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + }, +) + + +Library = GraphQLObjectType( + "Library", + { + "id": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + "book": GraphQLField(GraphQLList(Book)), + "magazine": GraphQLField(GraphQLList(Magazine)), + }, +) + +Storage = GraphQLList(GraphQLString) + + +def resolve_hello(root, info): + return "Hello!" + + +def resolve_echo(root, info, echo): + return echo + + +def resolve_error(root, info): + raise RuntimeError("Runtime Error!") + + +hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) +library_field = GraphQLField( + Library, + resolve=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolve=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolve=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolve=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolve=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) + +query = GraphQLObjectType( + name="Query", + fields={ + "hello": hello_field, + "library": library_field, + "search": search_field, + "echo": echo_field, + "storage": storage_field, + "error": error_field, + "error_non_null": error_non_null_field, + "error_middleware": error_middleware_field, + }, +) + +mutation = GraphQLObjectType( + name="Mutation", + fields={ + "storage_add": storage_add_field, + }, +) + +target_schema = GraphQLSchema(query=query, mutation=mutation) diff --git a/tests/framework_graphql/conftest.py b/tests/framework_graphql/conftest.py index 4d9e06758..5302da2b8 100644 --- a/tests/framework_graphql/conftest.py +++ b/tests/framework_graphql/conftest.py @@ -13,10 +13,12 @@ # limitations under the License. import pytest -import six - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) +from newrelic.packages import six _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -32,11 +34,16 @@ ) -@pytest.fixture(scope="session") -def app(): - from _target_application import _target_application +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async"]) +def target_application(request): + from ._target_application import target_application + + app = target_application.get(request.param, None) + if app is None: + pytest.skip("Unsupported combination.") + return - return _target_application + return "GraphQL", None, app, True, request.param.split("-")[1], 0 if six.PY2: diff --git a/tests/framework_graphql/test_application.py b/tests/framework_graphql/test_application.py index dd49ee37f..b5d78699d 100644 --- a/tests/framework_graphql/test_application.py +++ b/tests/framework_graphql/test_application.py @@ -13,6 +13,10 @@ # limitations under the License. import pytest +from framework_graphql.test_application_async import ( + error_middleware_async, + example_middleware_async, +) from testing_support.fixtures import dt_enabled, override_application_settings from testing_support.validators.validate_code_level_metrics import ( validate_code_level_metrics, @@ -30,24 +34,18 @@ from newrelic.api.background_task import background_task from newrelic.common.object_names import callable_name +from newrelic.common.package_version_utils import get_package_version -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - - major_version = int(version.split(".")[0]) - return major_version == 2 +graphql_version = get_package_version("graphql-core") +def conditional_decorator(decorator, condition): + def _conditional_decorator(func): + if not condition: + return func + return decorator(func) -@pytest.fixture(scope="session") -def graphql_run(): - try: - from graphql import graphql_sync as graphql - except ImportError: - from graphql import graphql - - return graphql + return _conditional_decorator def to_graphql_source(query): @@ -58,13 +56,6 @@ def delay_import(): # Fallback if Source is not implemented return query - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - return Source(query) return delay_import @@ -79,66 +70,86 @@ def error_middleware(next, root, info, **args): raise RuntimeError("Runtime Error!") -def test_no_harm_no_transaction(app, graphql_run): +def test_no_harm_no_transaction(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors + response = target_application("{ __schema { types { name } } }") _test() +example_middleware = [example_middleware] +error_middleware = [error_middleware] + +example_middleware.append(example_middleware_async) +error_middleware.append(error_middleware_async) + _runtime_error_name = callable_name(RuntimeError) _test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/GraphQL/all", 1), - ("GraphQL/GraphQL/allOther", 1), -] -def test_basic(app, graphql_run): - from graphql import __version__ as version +def _graphql_base_rollup_metrics(framework, version, background_task=True): + graphql_version = get_package_version("graphql-core") - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), + metrics = [ + ("Python/Framework/GraphQL/%s" % graphql_version, 1), + ("GraphQL/all", 1), + ("GraphQL/%s/all" % framework, 1), ] + if background_task: + metrics.extend( + [ + ("GraphQL/allOther", 1), + ("GraphQL/%s/allOther" % framework, 1), + ] + ) + else: + metrics.extend( + [ + ("GraphQL/allWeb", 1), + ("GraphQL/%s/allWeb" % framework, 1), + ] + ) + + if framework != "GraphQL": + metrics.append(("Python/Framework/%s/%s" % (framework, version), 1)) + + return metrics + + +def test_basic(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application @validate_transaction_metrics( "query//hello", "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, + rollup_metrics=_graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors + response = target_application("{ hello }") + assert response["hello"] == "Hello!" _test() @dt_enabled -def test_query_and_mutation(app, graphql_run, is_graphql_2): - from graphql import __version__ as version +def test_query_and_mutation(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + + mutation_path = "storage_add" if framework != "Graphene" else "storage_add.string" + type_annotation = "!" if framework == "Strawberry" else "" - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), - ] _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/GraphQL/storage", 1), - ("GraphQL/resolve/GraphQL/storage_add", 1), - ("GraphQL/operation/GraphQL/query//storage", 1), - ("GraphQL/operation/GraphQL/mutation//storage_add", 1), + ("GraphQL/resolve/%s/storage_add" % framework, 1), + ("GraphQL/operation/%s/mutation//%s" % (framework, mutation_path), 1), + ] + _test_query_scoped_metrics = [ + ("GraphQL/resolve/%s/storage" % framework, 1), + ("GraphQL/operation/%s/query//storage" % framework, 1), ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/GraphQL/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/GraphQL/allOther", 2), - ] + _test_mutation_scoped_metrics _expected_mutation_operation_attributes = { "graphql.operation.type": "mutation", @@ -148,7 +159,7 @@ def test_query_and_mutation(app, graphql_run, is_graphql_2): "graphql.field.name": "storage_add", "graphql.field.parentType": "Mutation", "graphql.field.path": "storage_add", - "graphql.field.returnType": "[String]" if is_graphql_2 else "String", + "graphql.field.returnType": ("String" if framework != "Graphene" else "StorageAdd") + type_annotation, } _expected_query_operation_attributes = { "graphql.operation.type": "query", @@ -158,78 +169,108 @@ def test_query_and_mutation(app, graphql_run, is_graphql_2): "graphql.field.name": "storage", "graphql.field.parentType": "Query", "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", + "graphql.field.returnType": "[String%s]%s" % (type_annotation, type_annotation), } - @validate_code_level_metrics("_target_application", "resolve_storage") - @validate_code_level_metrics("_target_application", "resolve_storage_add") + @validate_code_level_metrics( + "framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_storage_add" + ) + @validate_span_events(exact_agents=_expected_mutation_operation_attributes) + @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) @validate_transaction_metrics( - "query//storage", + "mutation//%s" % mutation_path, "GraphQL", scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, + rollup_metrics=_test_mutation_scoped_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) + @conditional_decorator(background_task(), is_bg) + def _mutation(): + if framework == "Graphene": + query = 'mutation { storage_add(string: "abc") { string } }' + else: + query = 'mutation { storage_add(string: "abc") }' + response = target_application(query) + assert response["storage_add"] == "abc" or response["storage_add"]["string"] == "abc" + + @validate_code_level_metrics("framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_storage") @validate_span_events(exact_agents=_expected_query_operation_attributes) @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - response = graphql_run(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = graphql_run(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) + @validate_transaction_metrics( + "query//storage", + "GraphQL", + scoped_metrics=_test_query_scoped_metrics, + rollup_metrics=_test_query_scoped_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, + ) + @conditional_decorator(background_task(), is_bg) + def _query(): + response = target_application("query { storage }") + assert response["storage"] == ["abc"] - _test() + _mutation() + _query() +@pytest.mark.parametrize("middleware", example_middleware) @dt_enabled -def test_middleware(app, graphql_run, is_graphql_2): +def test_middleware(target_application, middleware): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + + name = "%s:%s" % (middleware.__module__, middleware.__name__) + if "async" in name: + if schema_type != "async": + pytest.skip("Async middleware not supported in sync applications.") + _test_middleware_metrics = [ - ("GraphQL/operation/GraphQL/query//hello", 1), - ("GraphQL/resolve/GraphQL/hello", 1), - ("Function/test_application:example_middleware", 1), + ("GraphQL/operation/%s/query//hello" % framework, 1), + ("GraphQL/resolve/%s/hello" % framework, 1), + ("Function/%s" % name, 1), ] - @validate_code_level_metrics("test_application", "example_middleware") - @validate_code_level_metrics("_target_application", "resolve_hello") + # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and Resolver Function + span_count = 5 + extra_spans + + @validate_code_level_metrics(*name.split(":")) + @validate_code_level_metrics("framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_hello") + @validate_span_events(count=span_count) @validate_transaction_metrics( "query//hello", "GraphQL", scoped_metrics=_test_middleware_metrics, - rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and Resolver Function - @validate_span_events(count=5) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "{ hello }", middleware=[example_middleware]) - assert not response.errors - assert "Hello!" in str(response.data) + response = target_application("{ hello }", middleware=[middleware]) + assert response["hello"] == "Hello!" _test() +@pytest.mark.parametrize("middleware", error_middleware) @dt_enabled -def test_exception_in_middleware(app, graphql_run): - query = "query MyQuery { hello }" - field = "hello" +def test_exception_in_middleware(target_application, middleware): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + query = "query MyQuery { error_middleware }" + field = "error_middleware" + + name = "%s:%s" % (middleware.__module__, middleware.__name__) + if "async" in name: + if schema_type != "async": + pytest.skip("Async middleware not supported in sync applications.") # Metrics _test_exception_scoped_metrics = [ - ("GraphQL/operation/GraphQL/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/GraphQL/%s" % field, 1), + ("GraphQL/operation/%s/query/MyQuery/%s" % (framework, field), 1), + ("GraphQL/resolve/%s/%s" % (framework, field), 1), + ("Function/%s" % name, 1), ] _test_exception_rollup_metrics = [ ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/test_application:error_middleware", 1), + ("Errors/all%s" % ("Other" if is_bg else "Web"), 1), + ("Errors/%sTransaction/GraphQL/%s" % ("Other" if is_bg else "Web", name), 1), ] + _test_exception_scoped_metrics # Attributes @@ -246,39 +287,39 @@ def test_exception_in_middleware(app, graphql_run): } @validate_transaction_metrics( - "test_application:error_middleware", + name, "GraphQL", scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) @validate_span_events(exact_agents=_expected_exception_operation_attributes) @validate_span_events(exact_agents=_expected_exception_resolver_attributes) @validate_transaction_errors(errors=_test_runtime_error) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query, middleware=[error_middleware]) - assert response.errors + response = target_application(query, middleware=[middleware]) _test() @pytest.mark.parametrize("field", ("error", "error_non_null")) @dt_enabled -def test_exception_in_resolver(app, graphql_run, field): +def test_exception_in_resolver(target_application, field): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application query = "query MyQuery { %s }" % field - txn_name = "_target_application:resolve_error" + txn_name = "framework_%s._target_schema_%s:resolve_error" % (framework.lower(), schema_type) # Metrics _test_exception_scoped_metrics = [ - ("GraphQL/operation/GraphQL/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/GraphQL/%s" % field, 1), + ("GraphQL/operation/%s/query/MyQuery/%s" % (framework, field), 1), + ("GraphQL/resolve/%s/%s" % (framework, field), 1), ] _test_exception_rollup_metrics = [ ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), + ("Errors/all%s" % ("Other" if is_bg else "Web"), 1), + ("Errors/%sTransaction/GraphQL/%s" % ("Other" if is_bg else "Web", txn_name), 1), ] + _test_exception_scoped_metrics # Attributes @@ -298,16 +339,15 @@ def test_exception_in_resolver(app, graphql_run, field): txn_name, "GraphQL", scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) @validate_span_events(exact_agents=_expected_exception_operation_attributes) @validate_span_events(exact_agents=_expected_exception_resolver_attributes) @validate_transaction_errors(errors=_test_runtime_error) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - assert response.errors + response = target_application(query) _test() @@ -316,18 +356,16 @@ def _test(): @pytest.mark.parametrize( "query,exc_class", [ - ("query MyQuery { missing_field }", "GraphQLError"), + ("query MyQuery { error_missing_field }", "GraphQLError"), ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), ], ) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): +def test_exception_in_validation(target_application, query, exc_class): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application if "syntax" in query: txn_name = "graphql.language.parser:parse" else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" + txn_name = "graphql.validation.validate:validate" # Import path differs between versions if exc_class == "GraphQLError": @@ -336,12 +374,12 @@ def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_clas exc_class = callable_name(GraphQLError) _test_exception_scoped_metrics = [ - # ('GraphQL/operation/GraphQL///', 1), + ("GraphQL/operation/%s///" % framework, 1), ] _test_exception_rollup_metrics = [ ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), + ("Errors/all%s" % ("Other" if is_bg else "Web"), 1), + ("Errors/%sTransaction/GraphQL/%s" % ("Other" if is_bg else "Web", txn_name), 1), ] + _test_exception_scoped_metrics # Attributes @@ -355,72 +393,77 @@ def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_clas txn_name, "GraphQL", scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) @validate_span_events(exact_agents=_expected_exception_operation_attributes) @validate_transaction_errors(errors=[exc_class]) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - assert response.errors + response = target_application(query) _test() @dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/GraphQL/query/MyQuery/library", 1)] +def test_operation_metrics_and_attrs(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + operation_metrics = [("GraphQL/operation/%s/query/MyQuery/library" % framework, 1)] operation_attrs = { "graphql.operation.type": "query", "graphql.operation.name": "MyQuery", } + # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions + # library, library.name, library.book + # library.book.name and library.book.id for each book resolved (in this case 2) + span_count = 16 + extra_spans # WSGI may add 4 spans, other frameworks may add other amounts + @validate_transaction_metrics( "query/MyQuery/library", "GraphQL", scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=operation_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) + @validate_span_events(count=span_count) @validate_span_events(exact_agents=operation_attrs) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert not response.errors + response = target_application("query MyQuery { library(index: 0) { branch, book { id, name } } }") _test() @dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/GraphQL/hello", 1)] +def test_field_resolver_metrics_and_attrs(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + field_resolver_metrics = [("GraphQL/resolve/%s/hello" % framework, 1)] + + type_annotation = "!" if framework == "Strawberry" else "" graphql_attrs = { "graphql.field.name": "hello", "graphql.field.parentType": "Query", "graphql.field.path": "hello", - "graphql.field.returnType": "String", + "graphql.field.returnType": "String" + type_annotation, } + # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function + span_count = 4 + extra_spans # WSGI may add 4 spans, other frameworks may add other amounts + @validate_transaction_metrics( "query//hello", "GraphQL", scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) + @validate_span_events(count=span_count) @validate_span_events(exact_agents=graphql_attrs) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - assert "Hello!" in str(response.data) + response = target_application("{ hello }") + assert response["hello"] == "Hello!" _test() @@ -443,18 +486,19 @@ def _test(): @dt_enabled @pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): +def test_query_obfuscation(target_application, query, obfuscated): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application graphql_attrs = {"graphql.operation.query": obfuscated} if callable(query): + if framework != "GraphQL": + pytest.skip("Source query objects not tested outside of graphql-core") query = query() @validate_span_events(exact_agents=graphql_attrs) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert not response.errors + response = target_application(query) _test() @@ -499,28 +543,28 @@ def _test(): @dt_enabled @pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): +def test_deepest_unique_path(target_application, query, expected_path): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application if expected_path == "/error": - txn_name = "_target_application:resolve_error" + txn_name = "framework_%s._target_schema_%s:resolve_error" % (framework.lower(), schema_type) else: txn_name = "query/%s" % expected_path @validate_transaction_metrics( txn_name, "GraphQL", - background_task=True, + background_task=is_bg, ) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - if "error" not in query: - assert not response.errors + response = target_application(query) _test() @pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): +def test_introspection_transactions(target_application, capture_introspection_setting): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application txn_ct = 1 if capture_introspection_setting else 0 @override_application_settings( @@ -529,7 +573,6 @@ def test_introspection_transactions(app, graphql_run, capture_introspection_sett @validate_transaction_count(txn_ct) @background_task() def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors + response = target_application("{ __schema { types { name } } }") _test() diff --git a/tests/framework_graphql/test_application_async.py b/tests/framework_graphql/test_application_async.py index 28b435c43..39c1871ef 100644 --- a/tests/framework_graphql/test_application_async.py +++ b/tests/framework_graphql/test_application_async.py @@ -12,99 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +from inspect import isawaitable -import pytest -from test_application import is_graphql_2 -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events -from newrelic.api.background_task import background_task +# Async Functions not allowed in Py2 +async def example_middleware_async(next, root, info, **args): + return_value = next(root, info, **args) + if isawaitable(return_value): + return await return_value + return return_value -@pytest.fixture(scope="session") -def graphql_run_async(): - from graphql import __version__ as version - from graphql import graphql - - major_version = int(version.split(".")[0]) - if major_version == 2: - - def graphql_run(*args, **kwargs): - return graphql(*args, return_promise=True, **kwargs) - - return graphql_run - else: - return graphql - - -@dt_enabled -def test_query_and_mutation_async(app, graphql_run_async, is_graphql_2): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/GraphQL/storage", 1), - ("GraphQL/resolve/GraphQL/storage_add", 1), - ("GraphQL/operation/GraphQL/query//storage", 1), - ("GraphQL/operation/GraphQL/mutation//storage_add", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/GraphQL/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/GraphQL/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "[String]" if is_graphql_2 else "String", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - async def coro(): - response = await graphql_run_async(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = await graphql_run_async(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - loop = asyncio.new_event_loop() - loop.run_until_complete(coro()) - - _test() +async def error_middleware_async(next, root, info, **args): + raise RuntimeError("Runtime Error!") diff --git a/tests/framework_starlette/test_application.py b/tests/framework_starlette/test_application.py index 7d36d66cc..bd89bb9a9 100644 --- a/tests/framework_starlette/test_application.py +++ b/tests/framework_starlette/test_application.py @@ -17,13 +17,21 @@ import pytest import starlette from testing_support.fixtures import override_ignore_status_codes +from testing_support.validators.validate_code_level_metrics import ( + validate_code_level_metrics, +) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.common.object_names import callable_name -from testing_support.validators.validate_code_level_metrics import validate_code_level_metrics -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.common.package_version_utils import get_package_version_tuple + +starlette_version = get_package_version_tuple("starlette")[:3] -starlette_version = tuple(int(x) for x in starlette.__version__.split(".")) @pytest.fixture(scope="session") def target_application(): @@ -78,6 +86,7 @@ def test_application_non_async(target_application, app_name): response = app.get("/non_async") assert response.status == 200 + # Starting in Starlette v0.20.1, the ExceptionMiddleware class # has been moved to the starlette.middleware.exceptions from # starlette.exceptions @@ -96,8 +105,10 @@ def test_application_non_async(target_application, app_name): ), ) + @pytest.mark.parametrize( - "app_name, transaction_name", middleware_test, + "app_name, transaction_name", + middleware_test, ) def test_application_nonexistent_route(target_application, app_name, transaction_name): @validate_transaction_metrics( @@ -117,10 +128,6 @@ def _test(): def test_exception_in_middleware(target_application, app_name): app = target_application[app_name] - from starlette import __version__ as version - - starlette_version = tuple(int(v) for v in version.split(".")) - # Starlette >=0.15 and <0.17 raises an exception group instead of reraising the ValueError # This only occurs on Python versions >=3.8 if sys.version_info[0:2] > (3, 7) and starlette_version >= (0, 15, 0) and starlette_version < (0, 17, 0): @@ -272,9 +279,8 @@ def _test(): ), ) -@pytest.mark.parametrize( - "app_name,scoped_metrics", middleware_test_exception -) + +@pytest.mark.parametrize("app_name,scoped_metrics", middleware_test_exception) def test_starlette_http_exception(target_application, app_name, scoped_metrics): @validate_transaction_errors(errors=["starlette.exceptions:HTTPException"]) @validate_transaction_metrics( diff --git a/tests/framework_starlette/test_bg_tasks.py b/tests/framework_starlette/test_bg_tasks.py index 07a70131b..9ad8fe61b 100644 --- a/tests/framework_starlette/test_bg_tasks.py +++ b/tests/framework_starlette/test_bg_tasks.py @@ -15,7 +15,6 @@ import sys import pytest -from starlette import __version__ from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) @@ -23,7 +22,9 @@ validate_transaction_metrics, ) -starlette_version = tuple(int(x) for x in __version__.split(".")) +from newrelic.common.package_version_utils import get_package_version_tuple + +starlette_version = get_package_version_tuple("starlette")[:3] try: from starlette.middleware import Middleware # noqa: F401 @@ -89,11 +90,20 @@ def _test(): # The bug was fixed in version 0.21.0 but re-occured in 0.23.1. # The bug was also not present on 0.20.1 to 0.23.1 if using Python3.7. - BUG_COMPLETELY_FIXED = (0, 21, 0) <= starlette_version < (0, 23, 1) or ( - (0, 20, 1) <= starlette_version < (0, 23, 1) and sys.version_info[:2] > (3, 7) + # The bug was fixed again in version 0.29.0 + BUG_COMPLETELY_FIXED = any( + ( + (0, 21, 0) <= starlette_version < (0, 23, 1), + (0, 20, 1) <= starlette_version < (0, 23, 1) and sys.version_info[:2] > (3, 7), + starlette_version >= (0, 29, 0), + ) + ) + BUG_PARTIALLY_FIXED = any( + ( + (0, 20, 1) <= starlette_version < (0, 21, 0), + (0, 23, 1) <= starlette_version < (0, 29, 0), + ) ) - BUG_PARTIALLY_FIXED = (0, 20, 1) <= starlette_version < (0, 21, 0) or starlette_version >= (0, 23, 1) - if BUG_COMPLETELY_FIXED: # Assert both web transaction and background task transactions are present. _test = validate_transaction_metrics( diff --git a/tests/framework_starlette/test_graphql.py b/tests/framework_starlette/test_graphql.py deleted file mode 100644 index 24ec3ab38..000000000 --- a/tests/framework_starlette/test_graphql.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -def get_starlette_version(): - import starlette - - version = getattr(starlette, "__version__", "0.0.0").split(".") - return tuple(int(x) for x in version) - - -@pytest.fixture(scope="session") -def target_application(): - import _test_graphql - - return _test_graphql.target_application - - -@dt_enabled -@pytest.mark.parametrize("endpoint", ("/async", "/sync")) -@pytest.mark.skipif(get_starlette_version() >= (0, 17), reason="Starlette GraphQL support dropped in v0.17.0") -def test_graphql_metrics_and_attrs(target_application, endpoint): - from graphql import __version__ as version - - from newrelic.hooks.framework_graphene import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Graphene/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_scoped_metrics = [ - ("GraphQL/resolve/Graphene/hello", 1), - ("GraphQL/operation/Graphene/query//hello", 1), - ] - _test_unscoped_metrics = [ - ("GraphQL/all", 1), - ("GraphQL/Graphene/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Graphene/allWeb", 1), - ] + _test_scoped_metrics - - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - "graphql.operation.query": "{ hello }", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_scoped_metrics, - rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS, - ) - def _test(): - response = target_application.make_request( - "POST", endpoint, body=json.dumps({"query": "{ hello }"}), headers={"Content-Type": "application/json"} - ) - assert response.status == 200 - assert "Hello!" in response.body.decode("utf-8") - - _test() diff --git a/tests/framework_strawberry/__init__.py b/tests/framework_strawberry/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_strawberry/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_strawberry/_target_application.py b/tests/framework_strawberry/_target_application.py index e032fc27a..afba04873 100644 --- a/tests/framework_strawberry/_target_application.py +++ b/tests/framework_strawberry/_target_application.py @@ -12,185 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union - -import strawberry.mutation -import strawberry.type -from strawberry import Schema, field -from strawberry.asgi import GraphQL -from strawberry.schema.config import StrawberryConfig -from strawberry.types.types import Optional - - -@strawberry.type -class Author: - first_name: str - last_name: str - - -@strawberry.type -class Book: - id: int - name: str - isbn: str - author: Author - branch: str - - -@strawberry.type -class Magazine: - id: int - name: str - issue: int - branch: str - -@strawberry.type -class Library: - id: int - branch: str - magazine: List[Magazine] - book: List[Book] +import asyncio +import json +import pytest +from framework_strawberry._target_schema_async import ( + target_asgi_application as target_asgi_application_async, +) +from framework_strawberry._target_schema_async import ( + target_schema as target_schema_async, +) +from framework_strawberry._target_schema_sync import ( + target_asgi_application as target_asgi_application_sync, +) +from framework_strawberry._target_schema_sync import target_schema as target_schema_sync -Item = Union[Book, Magazine] -Storage = List[str] +def run_sync(schema): + def _run_sync(query, middleware=None): + from graphql.language.source import Source -authors = [ - Author( - first_name="New", - last_name="Relic", - ), - Author( - first_name="Bob", - last_name="Smith", - ), - Author( - first_name="Leslie", - last_name="Jones", - ), -] - -books = [ - Book( - id=1, - name="Python Agent: The Book", - isbn="a-fake-isbn", - author=authors[0], - branch="riverside", - ), - Book( - id=2, - name="Ollies for O11y: A Sk8er's Guide to Observability", - isbn="a-second-fake-isbn", - author=authors[1], - branch="downtown", - ), - Book( - id=3, - name="[Redacted]", - isbn="a-third-fake-isbn", - author=authors[2], - branch="riverside", - ), -] + if middleware is not None: + pytest.skip("Middleware not supported in Strawberry.") -magazines = [ - Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), - Magazine(id=2, name="Reli: The Forgotten Years", issue=2, branch="downtown"), - Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), -] + response = schema.execute_sync(query) + if isinstance(query, str) and "error" not in query or isinstance(query, Source) and "error" not in query.body: + assert not response.errors + else: + assert response.errors -libraries = ["riverside", "downtown"] -libraries = [ - Library( - id=i + 1, - branch=branch, - magazine=[m for m in magazines if m.branch == branch], - book=[b for b in books if b.branch == branch], - ) - for i, branch in enumerate(libraries) -] + return response.data -storage = [] + return _run_sync -def resolve_hello(): - return "Hello!" +def run_async(schema): + def _run_async(query, middleware=None): + from graphql.language.source import Source + if middleware is not None: + pytest.skip("Middleware not supported in Strawberry.") -async def resolve_hello_async(): - return "Hello!" + loop = asyncio.get_event_loop() + response = loop.run_until_complete(schema.execute(query)) + if isinstance(query, str) and "error" not in query or isinstance(query, Source) and "error" not in query.body: + assert not response.errors + else: + assert response.errors -def resolve_echo(echo: str): - return echo + return response.data + return _run_async -def resolve_library(index: int): - return libraries[index] +def run_asgi(app): + def _run_asgi(query, middleware=None): + if middleware is not None: + pytest.skip("Middleware not supported in Strawberry.") -def resolve_storage_add(string: str): - storage.add(string) - return storage + response = app.make_request( + "POST", "/", body=json.dumps({"query": query}), headers={"Content-Type": "application/json"} + ) + body = json.loads(response.body.decode("utf-8")) + if not isinstance(query, str) or "error" in query: + try: + assert response.status != 200 + except AssertionError: + assert body["errors"] + else: + assert response.status == 200 + assert "errors" not in body or not body["errors"] -def resolve_storage(): - return storage + return body["data"] + return _run_asgi -def resolve_error(): - raise RuntimeError("Runtime Error!") - -def resolve_search(contains: str): - search_books = [b for b in books if contains in b.name] - search_magazines = [m for m in magazines if contains in m.name] - return search_books + search_magazines - - -@strawberry.type -class Query: - library: Library = field(resolver=resolve_library) - hello: str = field(resolver=resolve_hello) - hello_async: str = field(resolver=resolve_hello_async) - search: List[Item] = field(resolver=resolve_search) - echo: str = field(resolver=resolve_echo) - storage: Storage = field(resolver=resolve_storage) - error: Optional[str] = field(resolver=resolve_error) - error_non_null: str = field(resolver=resolve_error) - - def resolve_library(self, info, index): - return libraries[index] - - def resolve_storage(self, info): - return storage - - def resolve_search(self, info, contains): - search_books = [b for b in books if contains in b.name] - search_magazines = [m for m in magazines if contains in m.name] - return search_books + search_magazines - - def resolve_hello(self, info): - return "Hello!" - - def resolve_echo(self, info, echo): - return echo - - def resolve_error(self, info) -> str: - raise RuntimeError("Runtime Error!") - - -@strawberry.type -class Mutation: - @strawberry.mutation - def storage_add(self, string: str) -> str: - storage.append(string) - return str(string) - - -_target_application = Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) -_target_asgi_application = GraphQL(_target_application) +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "asgi-sync": run_asgi(target_asgi_application_sync), + "async-async": run_async(target_schema_async), + "asgi-async": run_asgi(target_asgi_application_async), +} diff --git a/tests/framework_strawberry/_target_schema_async.py b/tests/framework_strawberry/_target_schema_async.py new file mode 100644 index 000000000..373cef537 --- /dev/null +++ b/tests/framework_strawberry/_target_schema_async.py @@ -0,0 +1,84 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import strawberry.mutation +import strawberry.type +from framework_strawberry._target_schema_sync import ( + Item, + Library, + Storage, + books, + libraries, + magazines, +) +from strawberry import Schema, field +from strawberry.asgi import GraphQL +from strawberry.schema.config import StrawberryConfig +from strawberry.types.types import Optional +from testing_support.asgi_testing import AsgiTest + +storage = [] + + +async def resolve_hello(): + return "Hello!" + + +async def resolve_echo(echo: str): + return echo + + +async def resolve_library(index: int): + return libraries[index] + + +async def resolve_storage_add(string: str): + storage.append(string) + return string + + +async def resolve_storage(): + return [storage.pop()] + + +async def resolve_error(): + raise RuntimeError("Runtime Error!") + + +async def resolve_search(contains: str): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + + +@strawberry.type +class Query: + library: Library = field(resolver=resolve_library) + hello: str = field(resolver=resolve_hello) + search: List[Item] = field(resolver=resolve_search) + echo: str = field(resolver=resolve_echo) + storage: Storage = field(resolver=resolve_storage) + error: Optional[str] = field(resolver=resolve_error) + error_non_null: str = field(resolver=resolve_error) + + +@strawberry.type +class Mutation: + storage_add: str = strawberry.mutation(resolver=resolve_storage_add) + + +target_schema = Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) +target_asgi_application = AsgiTest(GraphQL(target_schema)) diff --git a/tests/framework_strawberry/_target_schema_sync.py b/tests/framework_strawberry/_target_schema_sync.py new file mode 100644 index 000000000..34bff75b9 --- /dev/null +++ b/tests/framework_strawberry/_target_schema_sync.py @@ -0,0 +1,169 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import strawberry.mutation +import strawberry.type +from strawberry import Schema, field +from strawberry.asgi import GraphQL +from strawberry.schema.config import StrawberryConfig +from strawberry.types.types import Optional +from testing_support.asgi_testing import AsgiTest + + +@strawberry.type +class Author: + first_name: str + last_name: str + + +@strawberry.type +class Book: + id: int + name: str + isbn: str + author: Author + branch: str + + +@strawberry.type +class Magazine: + id: int + name: str + issue: int + branch: str + + +@strawberry.type +class Library: + id: int + branch: str + magazine: List[Magazine] + book: List[Book] + + +Item = Union[Book, Magazine] +Storage = List[str] + + +authors = [ + Author( + first_name="New", + last_name="Relic", + ), + Author( + first_name="Bob", + last_name="Smith", + ), + Author( + first_name="Leslie", + last_name="Jones", + ), +] + +books = [ + Book( + id=1, + name="Python Agent: The Book", + isbn="a-fake-isbn", + author=authors[0], + branch="riverside", + ), + Book( + id=2, + name="Ollies for O11y: A Sk8er's Guide to Observability", + isbn="a-second-fake-isbn", + author=authors[1], + branch="downtown", + ), + Book( + id=3, + name="[Redacted]", + isbn="a-third-fake-isbn", + author=authors[2], + branch="riverside", + ), +] + +magazines = [ + Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), + Magazine(id=2, name="Reli: The Forgotten Years", issue=2, branch="downtown"), + Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), +] + + +libraries = ["riverside", "downtown"] +libraries = [ + Library( + id=i + 1, + branch=branch, + magazine=[m for m in magazines if m.branch == branch], + book=[b for b in books if b.branch == branch], + ) + for i, branch in enumerate(libraries) +] + +storage = [] + + +def resolve_hello(): + return "Hello!" + + +def resolve_echo(echo: str): + return echo + + +def resolve_library(index: int): + return libraries[index] + + +def resolve_storage_add(string: str): + storage.append(string) + return string + + +def resolve_storage(): + return [storage.pop()] + + +def resolve_error(): + raise RuntimeError("Runtime Error!") + + +def resolve_search(contains: str): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + + +@strawberry.type +class Query: + library: Library = field(resolver=resolve_library) + hello: str = field(resolver=resolve_hello) + search: List[Item] = field(resolver=resolve_search) + echo: str = field(resolver=resolve_echo) + storage: Storage = field(resolver=resolve_storage) + error: Optional[str] = field(resolver=resolve_error) + error_non_null: str = field(resolver=resolve_error) + + +@strawberry.type +class Mutation: + storage_add: str = strawberry.mutation(resolver=resolve_storage_add) + + +target_schema = Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) +target_asgi_application = AsgiTest(GraphQL(target_schema)) diff --git a/tests/framework_strawberry/conftest.py b/tests/framework_strawberry/conftest.py index 130866bcb..6345b3033 100644 --- a/tests/framework_strawberry/conftest.py +++ b/tests/framework_strawberry/conftest.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import six - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -30,14 +29,3 @@ app_name="Python Agent Test (framework_strawberry)", default_settings=_default_settings, ) - - -@pytest.fixture(scope="session") -def app(): - from _target_application import _target_application - - return _target_application - - -if six.PY2: - collect_ignore = ["test_application_async.py"] diff --git a/tests/framework_strawberry/test_application.py b/tests/framework_strawberry/test_application.py index ac60a33e0..5a3f579ba 100644 --- a/tests/framework_strawberry/test_application.py +++ b/tests/framework_strawberry/test_application.py @@ -11,437 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest -from testing_support.fixtures import dt_enabled, override_application_settings -from testing_support.validators.validate_span_events import validate_span_events +from framework_graphql.test_application import * +from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) -from testing_support.validators.validate_transaction_errors import ( - validate_transaction_errors, -) -from testing_support.validators.validate_transaction_metrics import ( - validate_transaction_metrics, -) from newrelic.api.background_task import background_task -from newrelic.common.object_names import callable_name - - -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - - major_version = int(version.split(".")[0]) - return major_version == 2 - - -@pytest.fixture(scope="session") -def graphql_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, *args, **kwargs): - return schema.execute_sync(*args, **kwargs) - - return execute - - -def to_graphql_source(query): - def delay_import(): - try: - from graphql import Source - except ImportError: - # Fallback if Source is not implemented - return query - - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - - return Source(query) - - return delay_import - - -def example_middleware(next, root, info, **args): # pylint: disable=W0622 - return_value = next(root, info, **args) - return return_value - - -def error_middleware(next, root, info, **args): # pylint: disable=W0622 - raise RuntimeError("Runtime Error!") - - -_runtime_error_name = callable_name(RuntimeError) -_test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Strawberry/all", 1), - ("GraphQL/Strawberry/allOther", 1), -] - - -def test_basic(app, graphql_run): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details +from newrelic.common.package_version_utils import get_package_version - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] +STRAWBERRY_VERSION = get_package_version("strawberry-graphql") - @validate_transaction_metrics( - "query//hello", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - - _test() - - -@dt_enabled -def test_query_and_mutation(app, graphql_run): - from graphql import __version__ as version - from newrelic.hooks.framework_strawberry import framework_details +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async", "asgi-sync", "asgi-async"]) +def target_application(request): + from ._target_application import target_application - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage", 1), - ("GraphQL/resolve/Strawberry/storage_add", 1), - ("GraphQL/operation/Strawberry/query//storage", 1), - ("GraphQL/operation/Strawberry/mutation//storage_add", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Strawberry/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Strawberry/allOther", 2), - ] + _test_mutation_scoped_metrics + target_application = target_application[request.param] - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "String!", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String!]!", - } + is_asgi = "asgi" in request.param + schema_type = request.param.split("-")[1] - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - response = graphql_run(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = graphql_run(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - _test() - - -@pytest.mark.parametrize("field", ("error", "error_non_null")) -@dt_enabled -def test_exception_in_resolver(app, graphql_run, field): - query = "query MyQuery { %s }" % field - - txn_name = "_target_application:resolve_error" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Strawberry/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Strawberry/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String!" if "non_null" in field else "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -@pytest.mark.parametrize( - "query,exc_class", - [ - ("query MyQuery { missing_field }", "GraphQLError"), - ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), - ], -) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): - if "syntax" in query: - txn_name = "graphql.language.parser:parse" - else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" - - # Import path differs between versions - if exc_class == "GraphQLError": - from graphql.error import GraphQLError - - exc_class = callable_name(GraphQLError) - - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Strawberry///", 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_operation_attributes = { - "graphql.operation.type": "", - "graphql.operation.name": "", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_transaction_errors(errors=[exc_class]) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/Strawberry/query/MyQuery/library", 1)] - operation_attrs = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - } - - @validate_transaction_metrics( - "query/MyQuery/library", - "GraphQL", - scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) - @validate_span_events(exact_agents=operation_attrs) - @background_task() - def _test(): - response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert not response.errors - - _test() - - -@dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/Strawberry/hello", 1)] - graphql_attrs = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String!", - } - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - assert "Hello!" in str(response.data) - - _test() - - -_test_queries = [ - ("{ hello }", "{ hello }"), # Basic query extraction - ("{ error }", "{ error }"), # Extract query on field error - (to_graphql_source("{ hello }"), "{ hello }"), # Extract query from Source objects - ( - "{ library(index: 0) { branch } }", - "{ library(index: ?) { branch } }", - ), # Integers - ('{ echo(echo: "123") }', "{ echo(echo: ?) }"), # Strings with numerics - ('{ echo(echo: "test") }', "{ echo(echo: ?) }"), # Strings - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Aliases - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Variables - ( # Fragments - '{ ...MyFragment } fragment MyFragment on Query { echo(echo: "test") }', - "{ ...MyFragment } fragment MyFragment on Query { echo(echo: ?) }", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): - graphql_attrs = {"graphql.operation.query": obfuscated} - - if callable(query): - query = query() - - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert not response.errors - - _test() - - -_test_queries = [ - ("{ hello }", "/hello"), # Basic query - ("{ error }", "/error"), # Extract deepest path on field error - ('{ echo(echo: "test") }', "/echo"), # Fields with arguments - ( - "{ library(index: 0) { branch, book { isbn branch } } }", - "/library", - ), # Complex Example, 1 level - ( - "{ library(index: 0) { book { author { first_name }} } }", - "/library.book.author.first_name", - ), # Complex Example, 2 levels - ("{ library(index: 0) { id, book { name } } }", "/library.book.name"), # Filtering - ('{ TestEcho: echo(echo: "test") }', "/echo"), # Aliases - ( - '{ search(contains: "A") { __typename ... on Book { name } } }', - "/search.name", - ), # InlineFragment - ( - '{ hello echo(echo: "test") }', - "", - ), # Multiple root selections. (need to decide on final behavior) - # FragmentSpread - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { name id }", # Fragment filtering - "/library.book.name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { author { first_name } }", - "/library.book.author.first_name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } magazine { ...MagFragment } } } fragment MyFragment on Book { author { first_name } } fragment MagFragment on Magazine { name }", - "/library", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): - if expected_path == "/error": - txn_name = "_target_application:resolve_error" - else: - txn_name = "query/%s" % expected_path - - @validate_transaction_metrics( - txn_name, - "GraphQL", - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, query) - if "error" not in query: - assert not response.errors - - _test() + assert STRAWBERRY_VERSION is not None + return "Strawberry", STRAWBERRY_VERSION, target_application, not is_asgi, schema_type, 0 @pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): +def test_introspection_transactions(target_application, capture_introspection_setting): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + txn_ct = 1 if capture_introspection_setting else 0 @override_application_settings( @@ -450,7 +49,6 @@ def test_introspection_transactions(app, graphql_run, capture_introspection_sett @validate_transaction_count(txn_ct) @background_task() def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors + response = target_application("{ __schema { types { name } } }") _test() diff --git a/tests/framework_strawberry/test_application_async.py b/tests/framework_strawberry/test_application_async.py deleted file mode 100644 index 1354c4c01..000000000 --- a/tests/framework_strawberry/test_application_async.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio - -import pytest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - -from newrelic.api.background_task import background_task - - -@pytest.fixture(scope="session") -def graphql_run_async(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, *args, **kwargs): - return schema.execute(*args, **kwargs) - - return execute - - -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Strawberry/all", 1), - ("GraphQL/Strawberry/allOther", 1), -] - - -loop = asyncio.new_event_loop() - - -def test_basic(app, graphql_run_async): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - - @validate_transaction_metrics( - "query//hello_async", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - async def coro(): - response = await graphql_run_async(app, "{ hello_async }") - assert not response.errors - - loop.run_until_complete(coro()) - - _test() - - -@dt_enabled -def test_query_and_mutation_async(app, graphql_run_async): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage", 1), - ("GraphQL/resolve/Strawberry/storage_add", 1), - ("GraphQL/operation/Strawberry/query//storage", 1), - ("GraphQL/operation/Strawberry/mutation//storage_add", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Strawberry/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Strawberry/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "String!", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String!]!", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - async def coro(): - response = await graphql_run_async(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = await graphql_run_async(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - loop.run_until_complete(coro()) - - _test() diff --git a/tests/framework_strawberry/test_asgi.py b/tests/framework_strawberry/test_asgi.py deleted file mode 100644 index 8acbaedfb..000000000 --- a/tests/framework_strawberry/test_asgi.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from testing_support.asgi_testing import AsgiTest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -@pytest.fixture(scope="session") -def graphql_asgi_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - from _target_application import _target_asgi_application - - app = AsgiTest(_target_asgi_application) - - def execute(query): - return app.make_request( - "POST", - "/", - headers={"Content-Type": "application/json"}, - body=json.dumps({"query": query}), - ) - - return execute - - -@dt_enabled -def test_query_and_mutation_asgi(graphql_asgi_run): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage_add", 1), - ("GraphQL/operation/Strawberry/mutation//storage_add", 1), - ] - _test_query_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage", 1), - ("GraphQL/operation/Strawberry/query//storage", 1), - ] - _test_unscoped_metrics = [ - ("WebTransaction", 1), - ("GraphQL/all", 1), - ("GraphQL/Strawberry/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Strawberry/allWeb", 1), - ] - _test_mutation_unscoped_metrics = _test_unscoped_metrics + _test_mutation_scoped_metrics - _test_query_unscoped_metrics = _test_unscoped_metrics + _test_query_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "String!", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String!]!", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_query_scoped_metrics, - rollup_metrics=_test_query_unscoped_metrics + FRAMEWORK_METRICS, - ) - @validate_transaction_metrics( - "mutation//storage_add", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - index=-2, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes, index=-2) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes, index=-2) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - def _test(): - response = graphql_asgi_run('mutation { storage_add(string: "abc") }') - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - response = graphql_asgi_run("query { storage }") - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - _test() diff --git a/tests/logger_structlog/conftest.py b/tests/logger_structlog/conftest.py new file mode 100644 index 000000000..05a86d8a7 --- /dev/null +++ b/tests/logger_structlog/conftest.py @@ -0,0 +1,143 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import pytest +from structlog import DropEvent, PrintLogger +from newrelic.api.time_trace import current_trace +from newrelic.api.transaction import current_transaction +from testing_support.fixtures import ( + collector_agent_registration_fixture, + collector_available_fixture, +) + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "application_logging.enabled": True, + "application_logging.forwarding.enabled": True, + "application_logging.metrics.enabled": True, + "application_logging.local_decorating.enabled": True, + "event_harvest_config.harvest_limits.log_event_data": 100000, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (logger_structlog)", + default_settings=_default_settings, +) + + +class StructLogCapLog(PrintLogger): + def __init__(self, caplog): + self.caplog = caplog if caplog is not None else [] + + def msg(self, event, **kwargs): + self.caplog.append(event) + return + + log = debug = info = warn = warning = msg + fatal = failure = err = error = critical = exception = msg + + def __repr__(self): + return "" % str(id(self)) + + __str__ = __repr__ + + +@pytest.fixture +def set_trace_ids(): + def _set(): + txn = current_transaction() + if txn: + txn._trace_id = "abcdefgh12345678" + trace = current_trace() + if trace: + trace.guid = "abcdefgh" + return _set + +def drop_event_processor(logger, method_name, event_dict): + if method_name == "info": + raise DropEvent + else: + return event_dict + + +@pytest.fixture(scope="function") +def structlog_caplog(): + return list() + + +@pytest.fixture(scope="function") +def logger(structlog_caplog): + import structlog + structlog.configure(processors=[], logger_factory=lambda *args, **kwargs: StructLogCapLog(structlog_caplog)) + _logger = structlog.get_logger() + return _logger + + +@pytest.fixture(scope="function") +def filtering_logger(structlog_caplog): + import structlog + structlog.configure(processors=[drop_event_processor], logger_factory=lambda *args, **kwargs: StructLogCapLog(structlog_caplog)) + _filtering_logger = structlog.get_logger() + return _filtering_logger + + +@pytest.fixture +def exercise_logging_multiple_lines(set_trace_ids, logger, structlog_caplog): + def _exercise(): + set_trace_ids() + + logger.msg("Cat", a=42) + logger.error("Dog") + logger.critical("Elephant") + + assert len(structlog_caplog) == 3 + + assert "Cat" in structlog_caplog[0] + assert "Dog" in structlog_caplog[1] + assert "Elephant" in structlog_caplog[2] + + return _exercise + + +@pytest.fixture +def exercise_filtering_logging_multiple_lines(set_trace_ids, filtering_logger, structlog_caplog): + def _exercise(): + set_trace_ids() + + filtering_logger.msg("Cat", a=42) + filtering_logger.error("Dog") + filtering_logger.critical("Elephant") + + assert len(structlog_caplog) == 2 + + assert "Cat" not in structlog_caplog[0] + assert "Dog" in structlog_caplog[0] + assert "Elephant" in structlog_caplog[1] + + return _exercise + + +@pytest.fixture +def exercise_logging_single_line(set_trace_ids, logger, structlog_caplog): + def _exercise(): + set_trace_ids() + logger.error("A", key="value") + assert len(structlog_caplog) == 1 + + return _exercise diff --git a/tests/logger_structlog/test_attribute_forwarding.py b/tests/logger_structlog/test_attribute_forwarding.py new file mode 100644 index 000000000..eb555cca1 --- /dev/null +++ b/tests/logger_structlog/test_attribute_forwarding.py @@ -0,0 +1,49 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.background_task import background_task +from testing_support.fixtures import override_application_settings, reset_core_stats_engine +from testing_support.validators.validate_log_event_count import validate_log_event_count +from testing_support.validators.validate_log_event_count_outside_transaction import validate_log_event_count_outside_transaction +from testing_support.validators.validate_log_events import validate_log_events +from testing_support.validators.validate_log_events_outside_transaction import validate_log_events_outside_transaction + + +_event_attributes = {"message": "A"} + + +@override_application_settings({ + "application_logging.forwarding.context_data.enabled": True, +}) +def test_attributes_inside_transaction(exercise_logging_single_line): + @validate_log_events([_event_attributes]) + @validate_log_event_count(1) + @background_task() + def test(): + exercise_logging_single_line() + + test() + + +@reset_core_stats_engine() +@override_application_settings({ + "application_logging.forwarding.context_data.enabled": True, +}) +def test_attributes_outside_transaction(exercise_logging_single_line): + @validate_log_events_outside_transaction([_event_attributes]) + @validate_log_event_count_outside_transaction(1) + def test(): + exercise_logging_single_line() + + test() diff --git a/tests/logger_structlog/test_local_decorating.py b/tests/logger_structlog/test_local_decorating.py new file mode 100644 index 000000000..7b58d4a0c --- /dev/null +++ b/tests/logger_structlog/test_local_decorating.py @@ -0,0 +1,54 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +from newrelic.api.application import application_settings +from newrelic.api.background_task import background_task +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_log_event_count import validate_log_event_count +from testing_support.validators.validate_log_event_count_outside_transaction import validate_log_event_count_outside_transaction + + +def get_metadata_string(log_message, is_txn): + host = platform.uname()[1] + assert host + entity_guid = application_settings().entity_guid + if is_txn: + metadata_string = "".join(('NR-LINKING|', entity_guid, '|', host, '|abcdefgh12345678|abcdefgh|Python%20Agent%20Test%20%28logger_structlog%29|')) + else: + metadata_string = "".join(('NR-LINKING|', entity_guid, '|', host, '|||Python%20Agent%20Test%20%28logger_structlog%29|')) + formatted_string = log_message + " " + metadata_string + return formatted_string + + +@reset_core_stats_engine() +def test_local_log_decoration_inside_transaction(exercise_logging_single_line, structlog_caplog): + @validate_log_event_count(1) + @background_task() + def test(): + exercise_logging_single_line() + assert get_metadata_string('A', True) in structlog_caplog[0] + + test() + + +@reset_core_stats_engine() +def test_local_log_decoration_outside_transaction(exercise_logging_single_line, structlog_caplog): + @validate_log_event_count_outside_transaction(1) + def test(): + exercise_logging_single_line() + assert get_metadata_string('A', False) in structlog_caplog[0] + + test() diff --git a/tests/logger_structlog/test_log_forwarding.py b/tests/logger_structlog/test_log_forwarding.py new file mode 100644 index 000000000..e5a5e670f --- /dev/null +++ b/tests/logger_structlog/test_log_forwarding.py @@ -0,0 +1,88 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.background_task import background_task +from testing_support.fixtures import override_application_settings, reset_core_stats_engine +from testing_support.validators.validate_log_event_count import validate_log_event_count +from testing_support.validators.validate_log_event_count_outside_transaction import \ + validate_log_event_count_outside_transaction +from testing_support.validators.validate_log_events import validate_log_events +from testing_support.validators.validate_log_events_outside_transaction import validate_log_events_outside_transaction + + +_common_attributes_service_linking = {"timestamp": None, "hostname": None, + "entity.name": "Python Agent Test (logger_structlog)", "entity.guid": None} + +_common_attributes_trace_linking = {"span.id": "abcdefgh", "trace.id": "abcdefgh12345678", + **_common_attributes_service_linking} + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_inside_transaction(exercise_logging_multiple_lines): + @validate_log_events([ + {"message": "Cat", "level": "INFO", **_common_attributes_trace_linking}, + {"message": "Dog", "level": "ERROR", **_common_attributes_trace_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_trace_linking}, + ]) + @validate_log_event_count(3) + @background_task() + def test(): + exercise_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_filtering_inside_transaction(exercise_filtering_logging_multiple_lines): + @validate_log_events([ + {"message": "Dog", "level": "ERROR", **_common_attributes_trace_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_trace_linking}, + ]) + @validate_log_event_count(2) + @background_task() + def test(): + exercise_filtering_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_outside_transaction(exercise_logging_multiple_lines): + @validate_log_events_outside_transaction([ + {"message": "Cat", "level": "INFO", **_common_attributes_service_linking}, + {"message": "Dog", "level": "ERROR", **_common_attributes_service_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_service_linking}, + ]) + @validate_log_event_count_outside_transaction(3) + def test(): + exercise_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_filtering_outside_transaction(exercise_filtering_logging_multiple_lines): + @validate_log_events_outside_transaction([ + {"message": "Dog", "level": "ERROR", **_common_attributes_service_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_service_linking}, + ]) + @validate_log_event_count_outside_transaction(2) + def test(): + exercise_filtering_logging_multiple_lines() + + test() diff --git a/tests/logger_structlog/test_metrics.py b/tests/logger_structlog/test_metrics.py new file mode 100644 index 000000000..48f7204e8 --- /dev/null +++ b/tests/logger_structlog/test_metrics.py @@ -0,0 +1,73 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.packages import six +from newrelic.api.background_task import background_task +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from testing_support.validators.validate_custom_metrics_outside_transaction import validate_custom_metrics_outside_transaction + + +_test_logging_unscoped_metrics = [ + ("Logging/lines", 3), + ("Logging/lines/INFO", 1), + ("Logging/lines/ERROR", 1), + ("Logging/lines/CRITICAL", 1), +] + + +@reset_core_stats_engine() +def test_logging_metrics_inside_transaction(exercise_logging_multiple_lines): + txn_name = "test_metrics:test_logging_metrics_inside_transaction..test" if six.PY3 else "test_metrics:test" + @validate_transaction_metrics( + txn_name, + custom_metrics=_test_logging_unscoped_metrics, + background_task=True, + ) + @background_task() + def test(): + exercise_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +def test_logging_metrics_outside_transaction(exercise_logging_multiple_lines): + @validate_custom_metrics_outside_transaction(_test_logging_unscoped_metrics) + def test(): + exercise_logging_multiple_lines() + + test() + + +_test_logging_unscoped_filtering_metrics = [ + ("Logging/lines", 2), + ("Logging/lines/ERROR", 1), + ("Logging/lines/CRITICAL", 1), +] + + +@reset_core_stats_engine() +def test_filtering_logging_metrics_inside_transaction(exercise_filtering_logging_multiple_lines): + txn_name = "test_metrics:test_filtering_logging_metrics_inside_transaction..test" if six.PY3 else "test_metrics:test" + @validate_transaction_metrics( + txn_name, + custom_metrics=_test_logging_unscoped_filtering_metrics, + background_task=True, + ) + @background_task() + def test(): + exercise_filtering_logging_multiple_lines() + + test() diff --git a/tests/messagebroker_confluentkafka/conftest.py b/tests/messagebroker_confluentkafka/conftest.py index e29596d55..fa86b6b3c 100644 --- a/tests/messagebroker_confluentkafka/conftest.py +++ b/tests/messagebroker_confluentkafka/conftest.py @@ -84,7 +84,7 @@ def producer(topic, client_type, json_serializer): @pytest.fixture(scope="function") -def consumer(topic, producer, client_type, json_deserializer): +def consumer(group_id, topic, producer, client_type, json_deserializer): from confluent_kafka import Consumer, DeserializingConsumer if client_type == "cimpl": @@ -93,7 +93,7 @@ def consumer(topic, producer, client_type, json_deserializer): "bootstrap.servers": BROKER, "auto.offset.reset": "earliest", "heartbeat.interval.ms": 1000, - "group.id": "test", + "group.id": group_id, } ) elif client_type == "serializer_function": @@ -102,7 +102,7 @@ def consumer(topic, producer, client_type, json_deserializer): "bootstrap.servers": BROKER, "auto.offset.reset": "earliest", "heartbeat.interval.ms": 1000, - "group.id": "test", + "group.id": group_id, "value.deserializer": lambda v, c: json.loads(v.decode("utf-8")), "key.deserializer": lambda v, c: json.loads(v.decode("utf-8")) if v is not None else None, } @@ -113,7 +113,7 @@ def consumer(topic, producer, client_type, json_deserializer): "bootstrap.servers": BROKER, "auto.offset.reset": "earliest", "heartbeat.interval.ms": 1000, - "group.id": "test", + "group.id": group_id, "value.deserializer": json_deserializer, "key.deserializer": json_deserializer, } @@ -181,6 +181,11 @@ def topic(): admin.delete_topics(new_topics) +@pytest.fixture(scope="session") +def group_id(): + return str(uuid.uuid4()) + + @pytest.fixture() def send_producer_message(topic, producer, serialize, client_type): callback_called = [] diff --git a/tests/messagebroker_confluentkafka/test_consumer.py b/tests/messagebroker_confluentkafka/test_consumer.py index 5478b7c80..31f9478b3 100644 --- a/tests/messagebroker_confluentkafka/test_consumer.py +++ b/tests/messagebroker_confluentkafka/test_consumer.py @@ -14,14 +14,13 @@ import pytest from conftest import cache_kafka_consumer_headers -from testing_support.fixtures import ( - reset_core_stats_engine, - validate_attributes, - validate_error_event_attributes_outside_transaction, -) +from testing_support.fixtures import reset_core_stats_engine, validate_attributes from testing_support.validators.validate_distributed_trace_accepted import ( validate_distributed_trace_accepted, ) +from testing_support.validators.validate_error_event_attributes_outside_transaction import ( + validate_error_event_attributes_outside_transaction, +) from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) diff --git a/tests/messagebroker_kafkapython/conftest.py b/tests/messagebroker_kafkapython/conftest.py index becef31a0..de12f5830 100644 --- a/tests/messagebroker_kafkapython/conftest.py +++ b/tests/messagebroker_kafkapython/conftest.py @@ -86,7 +86,7 @@ def producer(client_type, json_serializer, json_callable_serializer): @pytest.fixture(scope="function") -def consumer(topic, producer, client_type, json_deserializer, json_callable_deserializer): +def consumer(group_id, topic, producer, client_type, json_deserializer, json_callable_deserializer): if client_type == "no_serializer": consumer = kafka.KafkaConsumer( topic, @@ -94,7 +94,7 @@ def consumer(topic, producer, client_type, json_deserializer, json_callable_dese auto_offset_reset="earliest", consumer_timeout_ms=100, heartbeat_interval_ms=1000, - group_id="test", + group_id=group_id, ) elif client_type == "serializer_function": consumer = kafka.KafkaConsumer( @@ -105,7 +105,7 @@ def consumer(topic, producer, client_type, json_deserializer, json_callable_dese auto_offset_reset="earliest", consumer_timeout_ms=100, heartbeat_interval_ms=1000, - group_id="test", + group_id=group_id, ) elif client_type == "callable_object": consumer = kafka.KafkaConsumer( @@ -116,7 +116,7 @@ def consumer(topic, producer, client_type, json_deserializer, json_callable_dese auto_offset_reset="earliest", consumer_timeout_ms=100, heartbeat_interval_ms=1000, - group_id="test", + group_id=group_id, ) elif client_type == "serializer_object": consumer = kafka.KafkaConsumer( @@ -127,7 +127,7 @@ def consumer(topic, producer, client_type, json_deserializer, json_callable_dese auto_offset_reset="earliest", consumer_timeout_ms=100, heartbeat_interval_ms=1000, - group_id="test", + group_id=group_id, ) yield consumer @@ -202,6 +202,11 @@ def topic(): admin.delete_topics([topic]) +@pytest.fixture(scope="session") +def group_id(): + return str(uuid.uuid4()) + + @pytest.fixture() def send_producer_message(topic, producer, serialize): def _test(): diff --git a/tests/messagebroker_kafkapython/test_consumer.py b/tests/messagebroker_kafkapython/test_consumer.py index 47e42d6c9..78ba086c6 100644 --- a/tests/messagebroker_kafkapython/test_consumer.py +++ b/tests/messagebroker_kafkapython/test_consumer.py @@ -14,14 +14,13 @@ import pytest from conftest import cache_kafka_consumer_headers -from testing_support.fixtures import ( - reset_core_stats_engine, - validate_attributes, - validate_error_event_attributes_outside_transaction, -) +from testing_support.fixtures import reset_core_stats_engine, validate_attributes from testing_support.validators.validate_distributed_trace_accepted import ( validate_distributed_trace_accepted, ) +from testing_support.validators.validate_error_event_attributes_outside_transaction import ( + validate_error_event_attributes_outside_transaction, +) from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) diff --git a/tests/messagebroker_kafkapython/test_serialization.py b/tests/messagebroker_kafkapython/test_serialization.py index f58d082ec..0b2bee74d 100644 --- a/tests/messagebroker_kafkapython/test_serialization.py +++ b/tests/messagebroker_kafkapython/test_serialization.py @@ -15,8 +15,8 @@ import json import pytest -from testing_support.fixtures import ( - reset_core_stats_engine, +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_error_event_attributes_outside_transaction import ( validate_error_event_attributes_outside_transaction, ) from testing_support.validators.validate_transaction_errors import ( diff --git a/tests/messagebroker_pika/test_pika_async_connection_consume.py b/tests/messagebroker_pika/test_pika_async_connection_consume.py index 4e44c7ed7..29b9d8ea4 100644 --- a/tests/messagebroker_pika/test_pika_async_connection_consume.py +++ b/tests/messagebroker_pika/test_pika_async_connection_consume.py @@ -49,20 +49,20 @@ from newrelic.api.background_task import background_task + DB_SETTINGS = rabbitmq_settings()[0] _message_broker_tt_params = { - "queue_name": QUEUE, - "routing_key": QUEUE, - "correlation_id": CORRELATION_ID, - "reply_to": REPLY_TO, - "headers": HEADERS.copy(), + 'queue_name': QUEUE, + 'routing_key': QUEUE, + 'correlation_id': CORRELATION_ID, + 'reply_to': REPLY_TO, + 'headers': HEADERS.copy(), } # Tornado's IO loop is not configurable in versions 5.x and up try: - class MyIOLoop(tornado.ioloop.IOLoop.configured_class()): def handle_callback_exception(self, *args, **kwargs): raise @@ -73,44 +73,38 @@ def handle_callback_exception(self, *args, **kwargs): connection_classes = [pika.SelectConnection, TornadoConnection] -parametrized_connection = pytest.mark.parametrize("ConnectionClass", connection_classes) +parametrized_connection = pytest.mark.parametrize('ConnectionClass', + connection_classes) _test_select_conn_basic_get_inside_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, 1), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, 1), ] if six.PY3: _test_select_conn_basic_get_inside_txn_metrics.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_get_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn.' + '.on_message'), 1)) else: - _test_select_conn_basic_get_inside_txn_metrics.append(("Function/test_pika_async_connection_consume:on_message", 1)) + _test_select_conn_basic_get_inside_txn_metrics.append( + ('Function/test_pika_async_connection_consume:on_message', 1)) @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_async_connection_consume" + (".test_async_connection_basic_get_inside_txn." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('callback_as_partial', [True, False]) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_get_inside_txn.", "on_message", py2_namespace="test_pika_async_connection_consume") @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_inside_txn"), - scoped_metrics=_test_select_conn_basic_get_inside_txn_metrics, - rollup_metrics=_test_select_conn_basic_get_inside_txn_metrics, - background_task=True, -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn'), + scoped_metrics=_test_select_conn_basic_get_inside_txn_metrics, + rollup_metrics=_test_select_conn_basic_get_inside_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_inside_txn(producer, ConnectionClass, callback_as_partial): +def test_async_connection_basic_get_inside_txn(producer, ConnectionClass, + callback_as_partial): def on_message(channel, method_frame, header_frame, body): assert method_frame assert body == BODY @@ -128,7 +122,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -139,8 +135,9 @@ def on_open_connection(connection): @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) -def test_select_connection_basic_get_outside_txn(producer, ConnectionClass, callback_as_partial): +@pytest.mark.parametrize('callback_as_partial', [True, False]) +def test_select_connection_basic_get_outside_txn(producer, ConnectionClass, + callback_as_partial): metrics_list = [] @capture_transaction_metrics(metrics_list) @@ -163,8 +160,8 @@ def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) connection = ConnectionClass( - pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection - ) + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -181,24 +178,25 @@ def on_open_connection(connection): _test_select_conn_basic_get_inside_txn_no_callback_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @pytest.mark.skipif( - condition=pika_version_info[0] > 0, reason="pika 1.0 removed the ability to use basic_get with callback=None" -) + condition=pika_version_info[0] > 0, + reason='pika 1.0 removed the ability to use basic_get with callback=None') @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_inside_txn_no_callback"), + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn_no_callback'), scoped_metrics=_test_select_conn_basic_get_inside_txn_no_callback_metrics, rollup_metrics=_test_select_conn_basic_get_inside_txn_no_callback_metrics, - background_task=True, -) + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_inside_txn_no_callback(producer, ConnectionClass): +def test_async_connection_basic_get_inside_txn_no_callback(producer, + ConnectionClass): def on_open_channel(channel): channel.basic_get(callback=None, queue=QUEUE) channel.close() @@ -208,7 +206,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -219,26 +219,27 @@ def on_open_connection(connection): _test_async_connection_basic_get_empty_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) +@pytest.mark.parametrize('callback_as_partial', [True, False]) @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_empty"), - scoped_metrics=_test_async_connection_basic_get_empty_metrics, - rollup_metrics=_test_async_connection_basic_get_empty_metrics, - background_task=True, -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_empty'), + scoped_metrics=_test_async_connection_basic_get_empty_metrics, + rollup_metrics=_test_async_connection_basic_get_empty_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_empty(ConnectionClass, callback_as_partial): - QUEUE = "test_async_empty" +def test_async_connection_basic_get_empty(ConnectionClass, + callback_as_partial): + QUEUE = 'test_async_empty' def on_message(channel, method_frame, header_frame, body): - assert False, body.decode("UTF-8") + assert False, body.decode('UTF-8') if callback_as_partial: on_message = functools.partial(on_message) @@ -252,7 +253,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -263,42 +266,33 @@ def on_open_connection(connection): _test_select_conn_basic_consume_in_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_select_conn_basic_consume_in_txn_metrics.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_inside_txn.' + '.on_message'), 1)) else: - _test_select_conn_basic_consume_in_txn_metrics.append(("Function/test_pika_async_connection_consume:on_message", 1)) + _test_select_conn_basic_consume_in_txn_metrics.append( + ('Function/test_pika_async_connection_consume:on_message', 1)) @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_consume_inside_txn"), - scoped_metrics=_test_select_conn_basic_consume_in_txn_metrics, - rollup_metrics=_test_select_conn_basic_consume_in_txn_metrics, - background_task=True, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_inside_txn." if six.PY3 else ""), - "on_message", -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_inside_txn'), + scoped_metrics=_test_select_conn_basic_consume_in_txn_metrics, + rollup_metrics=_test_select_conn_basic_consume_in_txn_metrics, + background_task=True) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_inside_txn.", "on_message", py2_namespace="test_pika_async_connection_consume") @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_async_connection_basic_consume_inside_txn(producer, ConnectionClass): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -311,7 +305,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -322,67 +318,46 @@ def on_open_connection(connection): _test_select_conn_basic_consume_two_exchanges = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE_2, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE_2, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE_2, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE_2, None), ] if six.PY3: _test_select_conn_basic_consume_two_exchanges.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_two_exchanges." - ".on_message_1" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges.' + '.on_message_1'), 1)) _test_select_conn_basic_consume_two_exchanges.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_two_exchanges." - ".on_message_2" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges.' + '.on_message_2'), 1)) else: _test_select_conn_basic_consume_two_exchanges.append( - ("Function/test_pika_async_connection_consume:on_message_1", 1) - ) + ('Function/test_pika_async_connection_consume:on_message_1', 1)) _test_select_conn_basic_consume_two_exchanges.append( - ("Function/test_pika_async_connection_consume:on_message_2", 1) - ) + ('Function/test_pika_async_connection_consume:on_message_2', 1)) @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_consume_two_exchanges"), - scoped_metrics=_test_select_conn_basic_consume_two_exchanges, - rollup_metrics=_test_select_conn_basic_consume_two_exchanges, - background_task=True, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_two_exchanges." if six.PY3 else ""), - "on_message_1", -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_two_exchanges." if six.PY3 else ""), - "on_message_2", -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges'), + scoped_metrics=_test_select_conn_basic_consume_two_exchanges, + rollup_metrics=_test_select_conn_basic_consume_two_exchanges, + background_task=True) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_two_exchanges.", "on_message_1", py2_namespace="test_pika_async_connection_consume") +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_two_exchanges.", "on_message_2", py2_namespace="test_pika_async_connection_consume") @background_task() -def test_async_connection_basic_consume_two_exchanges(producer, producer_2, ConnectionClass): +def test_async_connection_basic_consume_two_exchanges(producer, producer_2, + ConnectionClass): global events_received events_received = 0 def on_message_1(channel, method_frame, header_frame, body): channel.basic_ack(method_frame.delivery_tag) - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY global events_received @@ -395,7 +370,7 @@ def on_message_1(channel, method_frame, header_frame, body): def on_message_2(channel, method_frame, header_frame, body): channel.basic_ack(method_frame.delivery_tag) - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY global events_received @@ -413,7 +388,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -424,11 +401,12 @@ def on_open_connection(connection): # This should not create a transaction -@function_not_called("newrelic.core.stats_engine", "StatsEngine.record_transaction") -@override_application_settings({"debug.record_transaction_failure": True}) +@function_not_called('newrelic.core.stats_engine', + 'StatsEngine.record_transaction') +@override_application_settings({'debug.record_transaction_failure': True}) def test_tornado_connection_basic_consume_outside_transaction(producer): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -441,7 +419,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = TornadoConnection(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = TornadoConnection( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -452,44 +432,31 @@ def on_open_connection(connection): if six.PY3: - _txn_name = ( - "test_pika_async_connection_consume:" - "test_select_connection_basic_consume_outside_transaction." - ".on_message" - ) + _txn_name = ('test_pika_async_connection_consume:' + 'test_select_connection_basic_consume_outside_transaction.' + '.on_message') _test_select_connection_consume_outside_txn_metrics = [ - ( - ( - "Function/test_pika_async_connection_consume:" - "test_select_connection_basic_consume_outside_transaction." - ".on_message" - ), - None, - ) - ] + (('Function/test_pika_async_connection_consume:' + 'test_select_connection_basic_consume_outside_transaction.' + '.on_message'), None)] else: - _txn_name = "test_pika_async_connection_consume:on_message" + _txn_name = ( + 'test_pika_async_connection_consume:on_message') _test_select_connection_consume_outside_txn_metrics = [ - ("Function/test_pika_async_connection_consume:on_message", None) - ] + ('Function/test_pika_async_connection_consume:on_message', None)] # This should create a transaction @validate_transaction_metrics( - _txn_name, - scoped_metrics=_test_select_connection_consume_outside_txn_metrics, - rollup_metrics=_test_select_connection_consume_outside_txn_metrics, - background_task=True, - group="Message/RabbitMQ/Exchange/%s" % EXCHANGE, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_select_connection_basic_consume_outside_transaction." if six.PY3 else ""), - "on_message", -) + _txn_name, + scoped_metrics=_test_select_connection_consume_outside_txn_metrics, + rollup_metrics=_test_select_connection_consume_outside_txn_metrics, + background_task=True, + group='Message/RabbitMQ/Exchange/%s' % EXCHANGE) +@validate_code_level_metrics("test_pika_async_connection_consume.test_select_connection_basic_consume_outside_transaction.", "on_message", py2_namespace="test_pika_async_connection_consume") def test_select_connection_basic_consume_outside_transaction(producer): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -503,8 +470,8 @@ def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) connection = pika.SelectConnection( - pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection - ) + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() diff --git a/tests/messagebroker_pika/test_pika_blocking_connection_consume.py b/tests/messagebroker_pika/test_pika_blocking_connection_consume.py index 7b41674a2..e097cfbe9 100644 --- a/tests/messagebroker_pika/test_pika_blocking_connection_consume.py +++ b/tests/messagebroker_pika/test_pika_blocking_connection_consume.py @@ -38,30 +38,32 @@ DB_SETTINGS = rabbitmq_settings()[0] _message_broker_tt_params = { - "queue_name": QUEUE, - "routing_key": QUEUE, - "correlation_id": CORRELATION_ID, - "reply_to": REPLY_TO, - "headers": HEADERS.copy(), + 'queue_name': QUEUE, + 'routing_key': QUEUE, + 'correlation_id': CORRELATION_ID, + 'reply_to': REPLY_TO, + 'headers': HEADERS.copy(), } _test_blocking_connection_basic_get_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, 1), - (("Function/pika.adapters.blocking_connection:" "_CallbackResult.set_value_once"), 1), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, 1), + (('Function/pika.adapters.blocking_connection:' + '_CallbackResult.set_value_once'), 1) ] @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_get"), - scoped_metrics=_test_blocking_connection_basic_get_metrics, - rollup_metrics=_test_blocking_connection_basic_get_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_get'), + scoped_metrics=_test_blocking_connection_basic_get_metrics, + rollup_metrics=_test_blocking_connection_basic_get_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_get(producer): - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() method_frame, _, _ = channel.basic_get(QUEUE) assert method_frame @@ -69,22 +71,23 @@ def test_blocking_connection_basic_get(producer): _test_blocking_connection_basic_get_empty_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_get_empty"), - scoped_metrics=_test_blocking_connection_basic_get_empty_metrics, - rollup_metrics=_test_blocking_connection_basic_get_empty_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_get_empty'), + scoped_metrics=_test_blocking_connection_basic_get_empty_metrics, + rollup_metrics=_test_blocking_connection_basic_get_empty_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_get_empty(): - QUEUE = "test_blocking_empty-%s" % os.getpid() - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + QUEUE = 'test_blocking_empty-%s' % os.getpid() + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() channel.queue_declare(queue=QUEUE) @@ -100,7 +103,8 @@ def test_blocking_connection_basic_get_outside_transaction(producer): @capture_transaction_metrics(metrics_list) def test_basic_get(): - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() channel.queue_declare(queue=QUEUE) @@ -116,57 +120,46 @@ def test_basic_get(): _test_blocking_conn_basic_consume_no_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: - _txn_name = ( - "test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_outside_transaction." - ".on_message" - ) + _txn_name = ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_outside_transaction.' + '.on_message') _test_blocking_conn_basic_consume_no_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_outside_transaction." - ".on_message" - ), - None, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_outside_transaction.' + '.on_message'), None)) else: - _txn_name = "test_pika_blocking_connection_consume:" "on_message" + _txn_name = ('test_pika_blocking_connection_consume:' + 'on_message') _test_blocking_conn_basic_consume_no_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", None) - ) + ('Function/test_pika_blocking_connection_consume:on_message', None)) -@pytest.mark.parametrize("as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_blocking_connection_consume" - + (".test_blocking_connection_basic_consume_outside_transaction." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('as_partial', [True, False]) +@validate_code_level_metrics("test_pika_blocking_connection_consume.test_blocking_connection_basic_consume_outside_transaction.", "on_message", py2_namespace="test_pika_blocking_connection_consume") @validate_transaction_metrics( - _txn_name, - scoped_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, - background_task=True, - group="Message/RabbitMQ/Exchange/%s" % EXCHANGE, -) + _txn_name, + scoped_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, + background_task=True, + group='Message/RabbitMQ/Exchange/%s' % EXCHANGE) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) -def test_blocking_connection_basic_consume_outside_transaction(producer, as_partial): +def test_blocking_connection_basic_consume_outside_transaction(producer, + as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) @@ -178,51 +171,41 @@ def on_message(channel, method_frame, header_frame, body): _test_blocking_conn_basic_consume_in_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_blocking_conn_basic_consume_in_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_inside_txn.' + '.on_message'), 1)) else: _test_blocking_conn_basic_consume_in_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", 1) - ) + ('Function/test_pika_blocking_connection_consume:on_message', 1)) -@pytest.mark.parametrize("as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_blocking_connection_consume" - + (".test_blocking_connection_basic_consume_inside_txn." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('as_partial', [True, False]) +@validate_code_level_metrics("test_pika_blocking_connection_consume.test_blocking_connection_basic_consume_inside_txn.", "on_message", py2_namespace="test_pika_blocking_connection_consume") @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_consume_inside_txn"), - scoped_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_inside_txn'), + scoped_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_consume_inside_txn(producer, as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) try: @@ -233,40 +216,33 @@ def on_message(channel, method_frame, header_frame, body): _test_blocking_conn_basic_consume_stopped_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), - ("OtherTransaction/Message/RabbitMQ/Exchange/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), + ('OtherTransaction/Message/RabbitMQ/Exchange/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_blocking_conn_basic_consume_stopped_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_stopped_txn." - ".on_message" - ), - None, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_stopped_txn.' + '.on_message'), None)) else: _test_blocking_conn_basic_consume_stopped_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", None) - ) + ('Function/test_pika_blocking_connection_consume:on_message', None)) -@pytest.mark.parametrize("as_partial", [True, False]) +@pytest.mark.parametrize('as_partial', [True, False]) @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_consume_stopped_txn"), - scoped_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_stopped_txn'), + scoped_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_consume_stopped_txn(producer, as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() @@ -275,7 +251,8 @@ def on_message(channel, method_frame, header_frame, body): if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) try: diff --git a/tests/mlmodel_sklearn/conftest.py b/tests/mlmodel_sklearn/conftest.py new file mode 100644 index 000000000..d91eb549a --- /dev/null +++ b/tests/mlmodel_sklearn/conftest.py @@ -0,0 +1,34 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from testing_support.fixtures import ( # noqa: F401, pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "machine_learning.enabled": True, + "machine_learning.inference_events_value.enabled": True, + "ml_insights_events.enabled": True +} +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (mlmodel_sklearn)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (mlmodel_sklearn)"], +) diff --git a/tests/mlmodel_sklearn/test_calibration_models.py b/tests/mlmodel_sklearn/test_calibration_models.py new file mode 100644 index 000000000..39ac34cb2 --- /dev/null +++ b/tests/mlmodel_sklearn/test_calibration_models.py @@ -0,0 +1,76 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +def test_model_methods_wrapped_in_function_trace(calibration_model_name, run_calibration_model): + expected_scoped_metrics = { + "CalibratedClassifierCV": [ + ("Function/MLModel/Sklearn/Named/CalibratedClassifierCV.fit", 1), + ("Function/MLModel/Sklearn/Named/CalibratedClassifierCV.predict", 1), + ("Function/MLModel/Sklearn/Named/CalibratedClassifierCV.predict_proba", 2), + ], + } + + expected_transaction_name = "test_calibration_models:_test" + if six.PY3: + expected_transaction_name = ( + "test_calibration_models:test_model_methods_wrapped_in_function_trace.._test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[calibration_model_name], + rollup_metrics=expected_scoped_metrics[calibration_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_calibration_model() + + _test() + + +@pytest.fixture(params=["CalibratedClassifierCV"]) +def calibration_model_name(request): + return request.param + + +@pytest.fixture +def run_calibration_model(calibration_model_name): + def _run(): + import sklearn.calibration + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.calibration, calibration_model_name)() + + model = clf.fit(x_train, y_train) + model.predict(x_test) + + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_cluster_models.py b/tests/mlmodel_sklearn/test_cluster_models.py new file mode 100644 index 000000000..906995c22 --- /dev/null +++ b/tests/mlmodel_sklearn/test_cluster_models.py @@ -0,0 +1,186 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn import __version__ # noqa: this is needed for get_package_version +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "cluster_model_name", + [ + "AffinityPropagation", + "AgglomerativeClustering", + "Birch", + "DBSCAN", + "FeatureAgglomeration", + "KMeans", + "MeanShift", + "MiniBatchKMeans", + "SpectralBiclustering", + "SpectralCoclustering", + "SpectralClustering", + ], +) +def test_below_v1_1_model_methods_wrapped_in_function_trace(cluster_model_name, run_cluster_model): + expected_scoped_metrics = { + "AffinityPropagation": [ + ("Function/MLModel/Sklearn/Named/AffinityPropagation.fit", 2), + ("Function/MLModel/Sklearn/Named/AffinityPropagation.predict", 1), + ("Function/MLModel/Sklearn/Named/AffinityPropagation.fit_predict", 1), + ], + "AgglomerativeClustering": [ + ("Function/MLModel/Sklearn/Named/AgglomerativeClustering.fit", 2), + ("Function/MLModel/Sklearn/Named/AgglomerativeClustering.fit_predict", 1), + ], + "Birch": [ + ("Function/MLModel/Sklearn/Named/Birch.fit", 2), + ( + "Function/MLModel/Sklearn/Named/Birch.predict", + 1 if SKLEARN_VERSION >= (1, 0, 0) else 3, + ), + ("Function/MLModel/Sklearn/Named/Birch.fit_predict", 1), + ("Function/MLModel/Sklearn/Named/Birch.transform", 1), + ], + "DBSCAN": [ + ("Function/MLModel/Sklearn/Named/DBSCAN.fit", 2), + ("Function/MLModel/Sklearn/Named/DBSCAN.fit_predict", 1), + ], + "FeatureAgglomeration": [ + ("Function/MLModel/Sklearn/Named/FeatureAgglomeration.fit", 1), + ("Function/MLModel/Sklearn/Named/FeatureAgglomeration.transform", 1), + ], + "KMeans": [ + ("Function/MLModel/Sklearn/Named/KMeans.fit", 2), + ("Function/MLModel/Sklearn/Named/KMeans.predict", 1), + ("Function/MLModel/Sklearn/Named/KMeans.fit_predict", 1), + ("Function/MLModel/Sklearn/Named/KMeans.transform", 1), + ], + "MeanShift": [ + ("Function/MLModel/Sklearn/Named/MeanShift.fit", 2), + ("Function/MLModel/Sklearn/Named/MeanShift.predict", 1), + ("Function/MLModel/Sklearn/Named/MeanShift.fit_predict", 1), + ], + "MiniBatchKMeans": [ + ("Function/MLModel/Sklearn/Named/MiniBatchKMeans.fit", 2), + ("Function/MLModel/Sklearn/Named/MiniBatchKMeans.predict", 1), + ("Function/MLModel/Sklearn/Named/MiniBatchKMeans.fit_predict", 1), + ], + "SpectralBiclustering": [ + ("Function/MLModel/Sklearn/Named/SpectralBiclustering.fit", 1), + ], + "SpectralCoclustering": [ + ("Function/MLModel/Sklearn/Named/SpectralCoclustering.fit", 1), + ], + "SpectralClustering": [ + ("Function/MLModel/Sklearn/Named/SpectralClustering.fit", 2), + ("Function/MLModel/Sklearn/Named/SpectralClustering.fit_predict", 1), + ], + } + expected_transaction_name = "test_cluster_models:_test" + if six.PY3: + expected_transaction_name = ( + "test_cluster_models:test_below_v1_1_model_methods_wrapped_in_function_trace.._test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[cluster_model_name], + rollup_metrics=expected_scoped_metrics[cluster_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_cluster_model(cluster_model_name) + + _test() + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 1, 0), reason="Requires sklearn > 1.1") +@pytest.mark.parametrize( + "cluster_model_name", + [ + "BisectingKMeans", + "OPTICS", + ], +) +def test_above_v1_1_model_methods_wrapped_in_function_trace(cluster_model_name, run_cluster_model): + expected_scoped_metrics = { + "BisectingKMeans": [ + ("Function/MLModel/Sklearn/Named/BisectingKMeans.fit", 2), + ("Function/MLModel/Sklearn/Named/BisectingKMeans.predict", 1), + ("Function/MLModel/Sklearn/Named/BisectingKMeans.fit_predict", 1), + ], + "OPTICS": [ + ("Function/MLModel/Sklearn/Named/OPTICS.fit", 2), + ("Function/MLModel/Sklearn/Named/OPTICS.fit_predict", 1), + ], + } + expected_transaction_name = "test_cluster_models:_test" + if six.PY3: + expected_transaction_name = ( + "test_cluster_models:test_above_v1_1_model_methods_wrapped_in_function_trace.._test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[cluster_model_name], + rollup_metrics=expected_scoped_metrics[cluster_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_cluster_model(cluster_model_name) + + _test() + + +@pytest.fixture +def run_cluster_model(): + def _run(cluster_model_name): + import sklearn.cluster + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.cluster, cluster_model_name)() + + model = clf.fit(x_train, y_train) + + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "fit_predict"): + model.fit_predict(x_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_compose_models.py b/tests/mlmodel_sklearn/test_compose_models.py new file mode 100644 index 000000000..eab076fc3 --- /dev/null +++ b/tests/mlmodel_sklearn/test_compose_models.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn.linear_model import LinearRegression +from sklearn.preprocessing import Normalizer +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +@pytest.mark.parametrize( + "compose_model_name", + [ + "ColumnTransformer", + "TransformedTargetRegressor", + ], +) +def test_model_methods_wrapped_in_function_trace(compose_model_name, run_compose_model): + expected_scoped_metrics = { + "ColumnTransformer": [ + ("Function/MLModel/Sklearn/Named/ColumnTransformer.fit", 1), + ("Function/MLModel/Sklearn/Named/ColumnTransformer.transform", 1), + ], + "TransformedTargetRegressor": [ + ("Function/MLModel/Sklearn/Named/TransformedTargetRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/TransformedTargetRegressor.predict", 1), + ], + } + + expected_transaction_name = ( + "test_compose_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_compose_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[compose_model_name], + rollup_metrics=expected_scoped_metrics[compose_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_compose_model(compose_model_name) + + _test() + + +@pytest.fixture +def run_compose_model(): + def _run(compose_model_name): + import numpy as np + import sklearn.compose + + if compose_model_name == "TransformedTargetRegressor": + kwargs = {"regressor": LinearRegression()} + X = np.arange(4).reshape(-1, 1) + y = np.exp(2 * X).ravel() + else: + X = [[0.0, 1.0, 2.0, 2.0], [1.0, 1.0, 0.0, 1.0]] + y = None + kwargs = { + "transformers": [ + ("norm1", Normalizer(norm="l1"), [0, 1]), + ("norm2", Normalizer(norm="l1"), slice(2, 4)), + ] + } + + clf = getattr(sklearn.compose, compose_model_name)(**kwargs) + + model = clf.fit(X, y) + if hasattr(model, "predict"): + model.predict(X) + if hasattr(model, "transform"): + model.transform(X) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_covariance_models.py b/tests/mlmodel_sklearn/test_covariance_models.py new file mode 100644 index 000000000..afa5c31c2 --- /dev/null +++ b/tests/mlmodel_sklearn/test_covariance_models.py @@ -0,0 +1,110 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +@pytest.mark.parametrize( + "covariance_model_name", + [ + "EllipticEnvelope", + "EmpiricalCovariance", + "GraphicalLasso", + "GraphicalLassoCV", + "MinCovDet", + "ShrunkCovariance", + "LedoitWolf", + "OAS", + ], +) +def test_model_methods_wrapped_in_function_trace(covariance_model_name, run_covariance_model): + expected_scoped_metrics = { + "EllipticEnvelope": [ + ("Function/MLModel/Sklearn/Named/EllipticEnvelope.fit", 1), + ("Function/MLModel/Sklearn/Named/EllipticEnvelope.predict", 2), + ("Function/MLModel/Sklearn/Named/EllipticEnvelope.score", 1), + ], + "EmpiricalCovariance": [ + ("Function/MLModel/Sklearn/Named/EmpiricalCovariance.fit", 1), + ("Function/MLModel/Sklearn/Named/EmpiricalCovariance.score", 1), + ], + "GraphicalLasso": [ + ("Function/MLModel/Sklearn/Named/GraphicalLasso.fit", 1), + ], + "GraphicalLassoCV": [ + ("Function/MLModel/Sklearn/Named/GraphicalLassoCV.fit", 1), + ], + "MinCovDet": [ + ("Function/MLModel/Sklearn/Named/MinCovDet.fit", 1), + ], + "ShrunkCovariance": [ + ("Function/MLModel/Sklearn/Named/ShrunkCovariance.fit", 1), + ], + "LedoitWolf": [ + ("Function/MLModel/Sklearn/Named/LedoitWolf.fit", 1), + ], + "OAS": [ + ("Function/MLModel/Sklearn/Named/OAS.fit", 1), + ], + } + expected_transaction_name = ( + "test_covariance_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_covariance_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[covariance_model_name], + rollup_metrics=expected_scoped_metrics[covariance_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_covariance_model(covariance_model_name) + + _test() + + +@pytest.fixture +def run_covariance_model(): + def _run(covariance_model_name): + import sklearn.covariance + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + kwargs = {} + if covariance_model_name in ["EllipticEnvelope", "MinCovDet"]: + kwargs = {"random_state": 0} + + clf = getattr(sklearn.covariance, covariance_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_cross_decomposition_models.py b/tests/mlmodel_sklearn/test_cross_decomposition_models.py new file mode 100644 index 000000000..6a053350f --- /dev/null +++ b/tests/mlmodel_sklearn/test_cross_decomposition_models.py @@ -0,0 +1,81 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +@pytest.mark.parametrize( + "cross_decomposition_model_name", + [ + "PLSRegression", + "PLSSVD", + ], +) +def test_model_methods_wrapped_in_function_trace(cross_decomposition_model_name, run_cross_decomposition_model): + expected_scoped_metrics = { + "PLSRegression": [ + ("Function/MLModel/Sklearn/Named/PLSRegression.fit", 1), + ], + "PLSSVD": [ + ("Function/MLModel/Sklearn/Named/PLSSVD.fit", 1), + ("Function/MLModel/Sklearn/Named/PLSSVD.transform", 1), + ], + } + expected_transaction_name = ( + "test_cross_decomposition_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_cross_decomposition_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[cross_decomposition_model_name], + rollup_metrics=expected_scoped_metrics[cross_decomposition_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_cross_decomposition_model(cross_decomposition_model_name) + + _test() + + +@pytest.fixture +def run_cross_decomposition_model(): + def _run(cross_decomposition_model_name): + import sklearn.cross_decomposition + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, _ = train_test_split(X, y, stratify=y, random_state=0) + + kwargs = {} + if cross_decomposition_model_name == "PLSSVD": + kwargs = {"n_components": 1} + clf = getattr(sklearn.cross_decomposition, cross_decomposition_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_discriminant_analysis_models.py b/tests/mlmodel_sklearn/test_discriminant_analysis_models.py new file mode 100644 index 000000000..de1182696 --- /dev/null +++ b/tests/mlmodel_sklearn/test_discriminant_analysis_models.py @@ -0,0 +1,91 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +@pytest.mark.parametrize( + "discriminant_analysis_model_name", + [ + "LinearDiscriminantAnalysis", + "QuadraticDiscriminantAnalysis", + ], +) +def test_model_methods_wrapped_in_function_trace(discriminant_analysis_model_name, run_discriminant_analysis_model): + expected_scoped_metrics = { + "LinearDiscriminantAnalysis": [ + ("Function/MLModel/Sklearn/Named/LinearDiscriminantAnalysis.fit", 1), + ("Function/MLModel/Sklearn/Named/LinearDiscriminantAnalysis.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/LinearDiscriminantAnalysis.predict_proba", 2), + ("Function/MLModel/Sklearn/Named/LinearDiscriminantAnalysis.transform", 1), + ], + "QuadraticDiscriminantAnalysis": [ + ("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.fit", 1), + ("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict", 1), + ("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict_proba", 2), + ("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict_log_proba", 1), + ], + } + + expected_transaction_name = ( + "test_discriminant_analysis_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_discriminant_analysis_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[discriminant_analysis_model_name], + rollup_metrics=expected_scoped_metrics[discriminant_analysis_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_discriminant_analysis_model(discriminant_analysis_model_name) + + _test() + + +@pytest.fixture +def run_discriminant_analysis_model(): + def _run(discriminant_analysis_model_name): + import sklearn.discriminant_analysis + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + kwargs = {} + clf = getattr(sklearn.discriminant_analysis, discriminant_analysis_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_dummy_models.py b/tests/mlmodel_sklearn/test_dummy_models.py new file mode 100644 index 000000000..d1059add1 --- /dev/null +++ b/tests/mlmodel_sklearn/test_dummy_models.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn import __init__ # noqa: needed for get_package_version +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "dummy_model_name", + [ + "DummyClassifier", + "DummyRegressor", + ], +) +def test_model_methods_wrapped_in_function_trace(dummy_model_name, run_dummy_model): + expected_scoped_metrics = { + "DummyClassifier": [ + ("Function/MLModel/Sklearn/Named/DummyClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/DummyClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/DummyClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/DummyClassifier.predict_proba", 2 if SKLEARN_VERSION > (1, 0, 0) else 4), + ("Function/MLModel/Sklearn/Named/DummyClassifier.score", 1), + ], + "DummyRegressor": [ + ("Function/MLModel/Sklearn/Named/DummyRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/DummyRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/DummyRegressor.score", 1), + ], + } + + expected_transaction_name = ( + "test_dummy_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_dummy_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[dummy_model_name], + rollup_metrics=expected_scoped_metrics[dummy_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_dummy_model(dummy_model_name) + + _test() + + +@pytest.fixture +def run_dummy_model(): + def _run(dummy_model_name): + import sklearn.dummy + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.dummy, dummy_model_name)() + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_ensemble_models.py b/tests/mlmodel_sklearn/test_ensemble_models.py new file mode 100644 index 000000000..4093edf76 --- /dev/null +++ b/tests/mlmodel_sklearn/test_ensemble_models.py @@ -0,0 +1,303 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "ensemble_model_name", + [ + "AdaBoostClassifier", + "AdaBoostRegressor", + "BaggingClassifier", + "BaggingRegressor", + "ExtraTreesClassifier", + "ExtraTreesRegressor", + "GradientBoostingClassifier", + "GradientBoostingRegressor", + "IsolationForest", + "RandomForestClassifier", + "RandomForestRegressor", + "RandomTreesEmbedding", + "VotingClassifier", + ], +) +def test_below_v1_0_model_methods_wrapped_in_function_trace(ensemble_model_name, run_ensemble_model): + expected_scoped_metrics = { + "AdaBoostClassifier": [ + ("Function/MLModel/Sklearn/Named/AdaBoostClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/AdaBoostClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/AdaBoostClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/AdaBoostClassifier.predict_proba", 2), + ("Function/MLModel/Sklearn/Named/AdaBoostClassifier.score", 1), + ], + "AdaBoostRegressor": [ + ("Function/MLModel/Sklearn/Named/AdaBoostRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/AdaBoostRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/AdaBoostRegressor.score", 1), + ], + "BaggingClassifier": [ + ("Function/MLModel/Sklearn/Named/BaggingClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/BaggingClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/BaggingClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/BaggingClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/BaggingClassifier.predict_proba", 3), + ], + "BaggingRegressor": [ + ("Function/MLModel/Sklearn/Named/BaggingRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/BaggingRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/BaggingRegressor.score", 1), + ], + "ExtraTreesClassifier": [ + ("Function/MLModel/Sklearn/Named/ExtraTreesClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreesClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/ExtraTreesClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreesClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreesClassifier.predict_proba", 4), + ], + "ExtraTreesRegressor": [ + ("Function/MLModel/Sklearn/Named/ExtraTreesRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreesRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/ExtraTreesRegressor.score", 1), + ], + "GradientBoostingClassifier": [ + ("Function/MLModel/Sklearn/Named/GradientBoostingClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/GradientBoostingClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/GradientBoostingClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/GradientBoostingClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/GradientBoostingClassifier.predict_proba", 2), + ], + "GradientBoostingRegressor": [ + ("Function/MLModel/Sklearn/Named/GradientBoostingRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/GradientBoostingRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/GradientBoostingRegressor.score", 1), + ], + "IsolationForest": [ + ("Function/MLModel/Sklearn/Named/IsolationForest.fit", 1), + ("Function/MLModel/Sklearn/Named/IsolationForest.predict", 1), + ], + "RandomForestClassifier": [ + ("Function/MLModel/Sklearn/Named/RandomForestClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/RandomForestClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/RandomForestClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/RandomForestClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/RandomForestClassifier.predict_proba", 4), + ], + "RandomForestRegressor": [ + ("Function/MLModel/Sklearn/Named/RandomForestRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/RandomForestRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/RandomForestRegressor.score", 1), + ], + "RandomTreesEmbedding": [ + ("Function/MLModel/Sklearn/Named/RandomTreesEmbedding.fit", 1), + ("Function/MLModel/Sklearn/Named/RandomTreesEmbedding.transform", 1), + ], + "VotingClassifier": [ + ("Function/MLModel/Sklearn/Named/VotingClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/VotingClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/VotingClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/VotingClassifier.transform", 1), + ("Function/MLModel/Sklearn/Named/VotingClassifier.predict_proba", 3), + ], + } + + expected_transaction_name = ( + "test_ensemble_models:test_below_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_ensemble_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[ensemble_model_name], + rollup_metrics=expected_scoped_metrics[ensemble_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_ensemble_model(ensemble_model_name) + + _test() + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 0, 0) or SKLEARN_VERSION >= (1, 1, 0), reason="Requires 1.0 <= sklearn < 1.1") +@pytest.mark.parametrize( + "ensemble_model_name", + [ + "HistGradientBoostingClassifier", + "HistGradientBoostingRegressor", + "StackingClassifier", + "StackingRegressor", + "VotingRegressor", + ], +) +def test_between_v1_0_and_v1_1_model_methods_wrapped_in_function_trace(ensemble_model_name, run_ensemble_model): + expected_scoped_metrics = { + "HistGradientBoostingClassifier": [ + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.predict_proba", 3), + ], + "HistGradientBoostingRegressor": [ + ("Function/MLModel/Sklearn/Named/HistGradientBoostingRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingRegressor.score", 1), + ], + "StackingClassifier": [ + ("Function/MLModel/Sklearn/Named/StackingClassifier.fit", 1), + ], + "StackingRegressor": [ + ("Function/MLModel/Sklearn/Named/StackingRegressor.fit", 1), + ], + "VotingRegressor": [ + ("Function/MLModel/Sklearn/Named/VotingRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/VotingRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/VotingRegressor.score", 1), + ("Function/MLModel/Sklearn/Named/VotingRegressor.transform", 1), + ], + } + expected_transaction_name = ( + "test_ensemble_models:test_between_v1_0_and_v1_1_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_ensemble_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[ensemble_model_name], + rollup_metrics=expected_scoped_metrics[ensemble_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_ensemble_model(ensemble_model_name) + + _test() + + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 1, 0), reason="Requires sklearn >= 1.1") +@pytest.mark.parametrize( + "ensemble_model_name", + [ + "HistGradientBoostingClassifier", + "HistGradientBoostingRegressor", + "StackingClassifier", + "StackingRegressor", + "VotingRegressor", + ], +) +def test_above_v1_1_model_methods_wrapped_in_function_trace(ensemble_model_name, run_ensemble_model): + expected_scoped_metrics = { + "StackingClassifier": [ + ("Function/MLModel/Sklearn/Named/StackingClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/StackingClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/StackingClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/StackingClassifier.predict_proba", 1), + ("Function/MLModel/Sklearn/Named/StackingClassifier.transform", 4), + ], + "StackingRegressor": [ + ("Function/MLModel/Sklearn/Named/StackingRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/StackingRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/StackingRegressor.score", 1), + ], + "VotingRegressor": [ + ("Function/MLModel/Sklearn/Named/VotingRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/VotingRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/VotingRegressor.score", 1), + ("Function/MLModel/Sklearn/Named/VotingRegressor.transform", 1), + ], + "HistGradientBoostingClassifier": [ + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingClassifier.predict_proba", 3), + ], + "HistGradientBoostingRegressor": [ + ("Function/MLModel/Sklearn/Named/HistGradientBoostingRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/HistGradientBoostingRegressor.score", 1), + ], + } + expected_transaction_name = ( + "test_ensemble_models:test_above_v1_1_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_ensemble_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[ensemble_model_name], + rollup_metrics=expected_scoped_metrics[ensemble_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_ensemble_model(ensemble_model_name) + + _test() + + +@pytest.fixture +def run_ensemble_model(): + def _run(ensemble_model_name): + import sklearn.ensemble + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + kwargs = {"random_state": 0} + if ensemble_model_name == "StackingClassifier": + kwargs = {"estimators": [("rf", RandomForestClassifier())], "final_estimator": RandomForestClassifier()} + elif ensemble_model_name == "VotingClassifier": + kwargs = { + "estimators": [("rf", RandomForestClassifier())], + "voting": "soft", + } + elif ensemble_model_name == "VotingRegressor": + x_train = x_test = [[1, 1]] + y_train = y_test = [0] + kwargs = {"estimators": [("rf", RandomForestRegressor())]} + elif ensemble_model_name == "StackingRegressor": + kwargs = {"estimators": [("rf", RandomForestRegressor())]} + clf = getattr(sklearn.ensemble, ensemble_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_feature_selection_models.py b/tests/mlmodel_sklearn/test_feature_selection_models.py new file mode 100644 index 000000000..f4d601d32 --- /dev/null +++ b/tests/mlmodel_sklearn/test_feature_selection_models.py @@ -0,0 +1,138 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn.ensemble import AdaBoostClassifier +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "feature_selection_model_name", + [ + "VarianceThreshold", + "RFE", + "RFECV", + "SelectFromModel", + ], +) +def test_below_v1_0_model_methods_wrapped_in_function_trace(feature_selection_model_name, run_feature_selection_model): + expected_scoped_metrics = { + "VarianceThreshold": [ + ("Function/MLModel/Sklearn/Named/VarianceThreshold.fit", 1), + ], + "RFE": [ + ("Function/MLModel/Sklearn/Named/RFE.fit", 1), + ("Function/MLModel/Sklearn/Named/RFE.predict", 1), + ("Function/MLModel/Sklearn/Named/RFE.score", 1), + ("Function/MLModel/Sklearn/Named/RFE.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/RFE.predict_proba", 1), + ], + "RFECV": [ + ("Function/MLModel/Sklearn/Named/RFECV.fit", 1), + ], + "SelectFromModel": [ + ("Function/MLModel/Sklearn/Named/SelectFromModel.fit", 1), + ], + } + + expected_transaction_name = ( + "test_feature_selection_models:test_below_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_feature_selection_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[feature_selection_model_name], + rollup_metrics=expected_scoped_metrics[feature_selection_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_feature_selection_model(feature_selection_model_name) + + _test() + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 0, 0), reason="Requires sklearn >= 1.0") +@pytest.mark.parametrize( + "feature_selection_model_name", + [ + "SequentialFeatureSelector", + ], +) +def test_above_v1_0_model_methods_wrapped_in_function_trace(feature_selection_model_name, run_feature_selection_model): + expected_scoped_metrics = { + "SequentialFeatureSelector": [ + ("Function/MLModel/Sklearn/Named/SequentialFeatureSelector.fit", 1), + ], + } + expected_transaction_name = ( + "test_feature_selection_models:test_above_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_feature_selection_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[feature_selection_model_name], + rollup_metrics=expected_scoped_metrics[feature_selection_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_feature_selection_model(feature_selection_model_name) + + _test() + + +@pytest.fixture +def run_feature_selection_model(): + def _run(feature_selection_model_name): + import sklearn.feature_selection + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + kwargs = {} + if feature_selection_model_name in ["RFE", "SequentialFeatureSelector", "SelectFromModel", "RFECV"]: + # This is an example of a model that has all the available attributes + # We could have choosen any estimator that has predict, score, + # predict_log_proba, and predict_proba + kwargs = {"estimator": AdaBoostClassifier()} + clf = getattr(sklearn.feature_selection, feature_selection_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_gaussian_process_models.py b/tests/mlmodel_sklearn/test_gaussian_process_models.py new file mode 100644 index 000000000..7a78fc703 --- /dev/null +++ b/tests/mlmodel_sklearn/test_gaussian_process_models.py @@ -0,0 +1,83 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +@pytest.mark.parametrize( + "gaussian_process_model_name", + [ + "GaussianProcessClassifier", + "GaussianProcessRegressor", + ], +) +def test_model_methods_wrapped_in_function_trace(gaussian_process_model_name, run_gaussian_process_model): + expected_scoped_metrics = { + "GaussianProcessClassifier": [ + ("Function/MLModel/Sklearn/Named/GaussianProcessClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/GaussianProcessClassifier.predict", 1), + ("Function/MLModel/Sklearn/Named/GaussianProcessClassifier.predict_proba", 1), + ], + "GaussianProcessRegressor": [ + ("Function/MLModel/Sklearn/Named/GaussianProcessRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/GaussianProcessRegressor.predict", 1), + ], + } + + expected_transaction_name = ( + "test_gaussian_process_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_gaussian_process_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[gaussian_process_model_name], + rollup_metrics=expected_scoped_metrics[gaussian_process_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_gaussian_process_model(gaussian_process_model_name) + + _test() + + +@pytest.fixture +def run_gaussian_process_model(): + def _run(gaussian_process_model_name): + import sklearn.gaussian_process + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.gaussian_process, gaussian_process_model_name)(random_state=0) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_inference_events.py b/tests/mlmodel_sklearn/test_inference_events.py new file mode 100644 index 000000000..0a3677019 --- /dev/null +++ b/tests/mlmodel_sklearn/test_inference_events.py @@ -0,0 +1,429 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import numpy as np +import pandas +from testing_support.fixtures import ( + override_application_settings, + reset_core_stats_engine, +) +from testing_support.fixtures import override_application_settings +from testing_support.validators.validate_ml_event_count import validate_ml_event_count +from testing_support.validators.validate_ml_events import validate_ml_events + +from newrelic.api.background_task import background_task + +pandas_df_category_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "DecisionTreeClassifier", + "model_version": "0.0.0", + "feature.col1": 2.0, + "feature.col2": 4.0, + "label.0": "27.0", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_pandas_df_categorical_feature_event(): + @validate_ml_events(pandas_df_category_recorded_custom_events) + @validate_ml_event_count(count=1) + @background_task() + def _test(): + import sklearn.tree + + clf = getattr(sklearn.tree, "DecisionTreeClassifier")(random_state=0) + model = clf.fit( + pandas.DataFrame({"col1": [27.0, 24.0], "col2": [23.0, 25.0]}, dtype="category"), + pandas.DataFrame({"label": [27.0, 28.0]}), + ) + + labels = model.predict(pandas.DataFrame({"col1": [2.0], "col2": [4.0]}, dtype="category")) + return model + + _test() + + +label_type = "bool" if sys.version_info < (3, 8) else "numeric" +true_label_value = "True" if sys.version_info < (3, 8) else "1.0" +false_label_value = "False" if sys.version_info < (3, 8) else "0.0" +pandas_df_bool_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "DecisionTreeClassifier", + "model_version": "0.0.0", + "feature.col1": True, + "feature.col2": True, + "label.0": true_label_value, + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_pandas_df_bool_feature_event(): + @validate_ml_events(pandas_df_bool_recorded_custom_events) + @validate_ml_event_count(count=1) + @background_task() + def _test(): + import sklearn.tree + + dtype_name = "bool" if sys.version_info < (3, 8) else "boolean" + x_train = pandas.DataFrame({"col1": [True, False], "col2": [True, False]}, dtype=dtype_name) + y_train = pandas.DataFrame({"label": [True, False]}, dtype=dtype_name) + x_test = pandas.DataFrame({"col1": [True], "col2": [True]}, dtype=dtype_name) + + clf = getattr(sklearn.tree, "DecisionTreeClassifier")(random_state=0) + model = clf.fit(x_train, y_train) + + labels = model.predict(x_test) + return model + + _test() + + +pandas_df_float_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "DecisionTreeRegressor", + "model_version": "0.0.0", + "feature.col1": 100.0, + "feature.col2": 300.0, + "label.0": "345.6", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_pandas_df_float_feature_event(): + @validate_ml_events(pandas_df_float_recorded_custom_events) + @validate_ml_event_count(count=1) + @background_task() + def _test(): + import sklearn.tree + + x_train = pandas.DataFrame({"col1": [120.0, 254.0], "col2": [236.9, 234.5]}, dtype="float64") + y_train = pandas.DataFrame({"label": [345.6, 456.7]}, dtype="float64") + x_test = pandas.DataFrame({"col1": [100.0], "col2": [300.0]}, dtype="float64") + + clf = getattr(sklearn.tree, "DecisionTreeRegressor")(random_state=0) + + model = clf.fit(x_train, y_train) + labels = model.predict(x_test) + + return model + + _test() + + +int_list_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "ExtraTreeRegressor", + "model_version": "0.0.0", + "feature.0": 1, + "feature.1": 2, + "label.0": "1.0", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_int_list(): + @validate_ml_events(int_list_recorded_custom_events) + @validate_ml_event_count(count=1) + @background_task() + def _test(): + import sklearn.tree + + x_train = [[0, 0], [1, 1]] + y_train = [0, 1] + x_test = [[1, 2]] + + clf = getattr(sklearn.tree, "ExtraTreeRegressor")(random_state=0) + model = clf.fit(x_train, y_train) + + labels = model.predict(x_test) + return model + + _test() + + +numpy_int_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "ExtraTreeRegressor", + "model_version": "0.0.0", + "feature.0": 12, + "feature.1": 13, + "label.0": "11.0", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_numpy_int_array(): + @validate_ml_events(numpy_int_recorded_custom_events) + @validate_ml_event_count(count=1) + @background_task() + def _test(): + import sklearn.tree + + x_train = np.array([[10, 10], [11, 11]], dtype="int") + y_train = np.array([10, 11], dtype="int") + x_test = np.array([[12, 13]], dtype="int") + + clf = getattr(sklearn.tree, "ExtraTreeRegressor")(random_state=0) + model = clf.fit(x_train, y_train) + + labels = model.predict(x_test) + return model + + _test() + + +numpy_str_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "DecisionTreeClassifier", + "model_version": "0.0.0", + "feature.0": "20", + "feature.1": "21", + "label.0": "21", + "new_relic_data_schema_version": 2, + }, + ), + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "DecisionTreeClassifier", + "model_version": "0.0.0", + "feature.0": "22", + "feature.1": "23", + "label.0": "21", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_numpy_str_array_multiple_features(): + @validate_ml_events(numpy_str_recorded_custom_events) + @validate_ml_event_count(count=2) + @background_task() + def _test(): + import sklearn.tree + + x_train = np.array([[20, 20], [21, 21]], dtype="._test" + if six.PY3 + else "test_kernel_ridge_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[kernel_ridge_model_name], + rollup_metrics=expected_scoped_metrics[kernel_ridge_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_kernel_ridge_model(kernel_ridge_model_name) + + _test() + + +@pytest.fixture +def run_kernel_ridge_model(): + def _run(kernel_ridge_model_name): + import sklearn.kernel_ridge + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, _ = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.kernel_ridge, kernel_ridge_model_name)() + + model = clf.fit(x_train, y_train) + model.predict(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_linear_models.py b/tests/mlmodel_sklearn/test_linear_models.py new file mode 100644 index 000000000..582a4750e --- /dev/null +++ b/tests/mlmodel_sklearn/test_linear_models.py @@ -0,0 +1,335 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "linear_model_name", + [ + "ARDRegression", + "BayesianRidge", + "ElasticNet", + "ElasticNetCV", + "HuberRegressor", + "Lars", + "LarsCV", + "Lasso", + "LassoCV", + "LassoLars", + "LassoLarsCV", + "LassoLarsIC", + "LinearRegression", + "LogisticRegression", + "LogisticRegressionCV", + "MultiTaskElasticNet", + "MultiTaskElasticNetCV", + "MultiTaskLasso", + "MultiTaskLassoCV", + "OrthogonalMatchingPursuit", + "OrthogonalMatchingPursuitCV", + "PassiveAggressiveClassifier", + "PassiveAggressiveRegressor", + "Perceptron", + "Ridge", + "RidgeCV", + "RidgeClassifier", + "RidgeClassifierCV", + "TheilSenRegressor", + "RANSACRegressor", + ], +) +def test_model_methods_wrapped_in_function_trace(linear_model_name, run_linear_model): + expected_scoped_metrics = { + "ARDRegression": [ + ("Function/MLModel/Sklearn/Named/ARDRegression.fit", 1), + ("Function/MLModel/Sklearn/Named/ARDRegression.predict", 2), + ("Function/MLModel/Sklearn/Named/ARDRegression.score", 1), + ], + "BayesianRidge": [ + ("Function/MLModel/Sklearn/Named/BayesianRidge.fit", 1), + ("Function/MLModel/Sklearn/Named/BayesianRidge.predict", 2), + ("Function/MLModel/Sklearn/Named/BayesianRidge.score", 1), + ], + "ElasticNet": [ + ("Function/MLModel/Sklearn/Named/ElasticNet.fit", 1), + ("Function/MLModel/Sklearn/Named/ElasticNet.predict", 2), + ("Function/MLModel/Sklearn/Named/ElasticNet.score", 1), + ], + "ElasticNetCV": [ + ("Function/MLModel/Sklearn/Named/ElasticNetCV.fit", 1), + ("Function/MLModel/Sklearn/Named/ElasticNetCV.predict", 2), + ("Function/MLModel/Sklearn/Named/ElasticNetCV.score", 1), + ], + "HuberRegressor": [ + ("Function/MLModel/Sklearn/Named/HuberRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/HuberRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/HuberRegressor.score", 1), + ], + "Lars": [ + ("Function/MLModel/Sklearn/Named/Lars.fit", 1), + ("Function/MLModel/Sklearn/Named/Lars.predict", 2), + ("Function/MLModel/Sklearn/Named/Lars.score", 1), + ], + "LarsCV": [ + ("Function/MLModel/Sklearn/Named/LarsCV.fit", 1), + ("Function/MLModel/Sklearn/Named/LarsCV.predict", 2), + ("Function/MLModel/Sklearn/Named/LarsCV.score", 1), + ], + "Lasso": [ + ("Function/MLModel/Sklearn/Named/Lasso.fit", 1), + ("Function/MLModel/Sklearn/Named/Lasso.predict", 2), + ("Function/MLModel/Sklearn/Named/Lasso.score", 1), + ], + "LassoCV": [ + ("Function/MLModel/Sklearn/Named/LassoCV.fit", 1), + ("Function/MLModel/Sklearn/Named/LassoCV.predict", 2), + ("Function/MLModel/Sklearn/Named/LassoCV.score", 1), + ], + "LassoLars": [ + ("Function/MLModel/Sklearn/Named/LassoLars.fit", 1), + ("Function/MLModel/Sklearn/Named/LassoLars.predict", 2), + ("Function/MLModel/Sklearn/Named/LassoLars.score", 1), + ], + "LassoLarsCV": [ + ("Function/MLModel/Sklearn/Named/LassoLarsCV.fit", 1), + ("Function/MLModel/Sklearn/Named/LassoLarsCV.predict", 2), + ("Function/MLModel/Sklearn/Named/LassoLarsCV.score", 1), + ], + "LassoLarsIC": [ + ("Function/MLModel/Sklearn/Named/LassoLarsIC.fit", 1), + ("Function/MLModel/Sklearn/Named/LassoLarsIC.predict", 2), + ("Function/MLModel/Sklearn/Named/LassoLarsIC.score", 1), + ], + "LinearRegression": [ + ("Function/MLModel/Sklearn/Named/LinearRegression.fit", 1), + ("Function/MLModel/Sklearn/Named/LinearRegression.predict", 2), + ("Function/MLModel/Sklearn/Named/LinearRegression.score", 1), + ], + "LogisticRegression": [ + ("Function/MLModel/Sklearn/Named/LogisticRegression.fit", 1), + ("Function/MLModel/Sklearn/Named/LogisticRegression.predict", 2), + ("Function/MLModel/Sklearn/Named/LogisticRegression.score", 1), + ], + "LogisticRegressionCV": [ + ("Function/MLModel/Sklearn/Named/LogisticRegressionCV.fit", 1), + ("Function/MLModel/Sklearn/Named/LogisticRegressionCV.predict", 2), + ("Function/MLModel/Sklearn/Named/LogisticRegressionCV.score", 1), + ], + "MultiTaskElasticNet": [ + ("Function/MLModel/Sklearn/Named/MultiTaskElasticNet.fit", 1), + ("Function/MLModel/Sklearn/Named/MultiTaskElasticNet.predict", 2), + ("Function/MLModel/Sklearn/Named/MultiTaskElasticNet.score", 1), + ], + "MultiTaskElasticNetCV": [ + ("Function/MLModel/Sklearn/Named/MultiTaskElasticNetCV.fit", 1), + ("Function/MLModel/Sklearn/Named/MultiTaskElasticNetCV.predict", 2), + ("Function/MLModel/Sklearn/Named/MultiTaskElasticNetCV.score", 1), + ], + "MultiTaskLasso": [ + ("Function/MLModel/Sklearn/Named/MultiTaskLasso.fit", 1), + ("Function/MLModel/Sklearn/Named/MultiTaskLasso.predict", 2), + ("Function/MLModel/Sklearn/Named/MultiTaskLasso.score", 1), + ], + "MultiTaskLassoCV": [ + ("Function/MLModel/Sklearn/Named/MultiTaskLassoCV.fit", 1), + ("Function/MLModel/Sklearn/Named/MultiTaskLassoCV.predict", 2), + ("Function/MLModel/Sklearn/Named/MultiTaskLassoCV.score", 1), + ], + "OrthogonalMatchingPursuit": [ + ("Function/MLModel/Sklearn/Named/OrthogonalMatchingPursuit.fit", 1), + ("Function/MLModel/Sklearn/Named/OrthogonalMatchingPursuit.predict", 2), + ("Function/MLModel/Sklearn/Named/OrthogonalMatchingPursuit.score", 1), + ], + "OrthogonalMatchingPursuitCV": [ + ("Function/MLModel/Sklearn/Named/OrthogonalMatchingPursuitCV.fit", 1), + ("Function/MLModel/Sklearn/Named/OrthogonalMatchingPursuitCV.predict", 2), + ("Function/MLModel/Sklearn/Named/OrthogonalMatchingPursuitCV.score", 1), + ], + "PassiveAggressiveClassifier": [ + ("Function/MLModel/Sklearn/Named/PassiveAggressiveClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/PassiveAggressiveClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/PassiveAggressiveClassifier.score", 1), + ], + "PassiveAggressiveRegressor": [ + ("Function/MLModel/Sklearn/Named/PassiveAggressiveRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/PassiveAggressiveRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/PassiveAggressiveRegressor.score", 1), + ], + "Perceptron": [ + ("Function/MLModel/Sklearn/Named/Perceptron.fit", 1), + ("Function/MLModel/Sklearn/Named/Perceptron.predict", 2), + ("Function/MLModel/Sklearn/Named/Perceptron.score", 1), + ], + "Ridge": [ + ("Function/MLModel/Sklearn/Named/Ridge.fit", 1), + ("Function/MLModel/Sklearn/Named/Ridge.predict", 2), + ("Function/MLModel/Sklearn/Named/Ridge.score", 1), + ], + "RidgeCV": [ + ("Function/MLModel/Sklearn/Named/RidgeCV.fit", 1), + ("Function/MLModel/Sklearn/Named/RidgeCV.predict", 2), + ("Function/MLModel/Sklearn/Named/RidgeCV.score", 1), + ], + "RidgeClassifier": [ + ("Function/MLModel/Sklearn/Named/RidgeClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/RidgeClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/RidgeClassifier.score", 1), + ], + "RidgeClassifierCV": [ + ("Function/MLModel/Sklearn/Named/RidgeClassifierCV.fit", 1), + ("Function/MLModel/Sklearn/Named/RidgeClassifierCV.predict", 2), + ("Function/MLModel/Sklearn/Named/RidgeClassifierCV.score", 1), + ], + "TheilSenRegressor": [ + ("Function/MLModel/Sklearn/Named/TheilSenRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/TheilSenRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/TheilSenRegressor.score", 1), + ], + "RANSACRegressor": [ + ("Function/MLModel/Sklearn/Named/RANSACRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/RANSACRegressor.predict", 1), + ("Function/MLModel/Sklearn/Named/RANSACRegressor.score", 1), + ], + } + expected_transaction_name = "test_linear_models:_test" + if six.PY3: + expected_transaction_name = "test_linear_models:test_model_methods_wrapped_in_function_trace.._test" + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[linear_model_name], + rollup_metrics=expected_scoped_metrics[linear_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_linear_model(linear_model_name) + + _test() + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 1, 0), reason="Requires sklearn >= v1.1") +@pytest.mark.parametrize( + "linear_model_name", + [ + "PoissonRegressor", + "GammaRegressor", + "TweedieRegressor", + "QuantileRegressor", + "SGDClassifier", + "SGDRegressor", + "SGDOneClassSVM", + ], +) +def test_above_v1_1_model_methods_wrapped_in_function_trace(linear_model_name, run_linear_model): + expected_scoped_metrics = { + "PoissonRegressor": [ + ("Function/MLModel/Sklearn/Named/PoissonRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/PoissonRegressor.predict", 1), + ("Function/MLModel/Sklearn/Named/PoissonRegressor.score", 1), + ], + "GammaRegressor": [ + ("Function/MLModel/Sklearn/Named/GammaRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/GammaRegressor.predict", 1), + ("Function/MLModel/Sklearn/Named/GammaRegressor.score", 1), + ], + "TweedieRegressor": [ + ("Function/MLModel/Sklearn/Named/TweedieRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/TweedieRegressor.predict", 1), + ("Function/MLModel/Sklearn/Named/TweedieRegressor.score", 1), + ], + "QuantileRegressor": [ + ("Function/MLModel/Sklearn/Named/QuantileRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/QuantileRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/QuantileRegressor.score", 1), + ], + "SGDClassifier": [ + ("Function/MLModel/Sklearn/Named/SGDClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/SGDClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/SGDClassifier.score", 1), + ], + "SGDRegressor": [ + ("Function/MLModel/Sklearn/Named/SGDRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/SGDRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/SGDRegressor.score", 1), + ], + "SGDOneClassSVM": [ + ("Function/MLModel/Sklearn/Named/SGDOneClassSVM.fit", 1), + ("Function/MLModel/Sklearn/Named/SGDOneClassSVM.predict", 1), + ], + } + expected_transaction_name = "test_linear_models:_test" + if six.PY3: + expected_transaction_name = ( + "test_linear_models:test_above_v1_1_model_methods_wrapped_in_function_trace.._test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[linear_model_name], + rollup_metrics=expected_scoped_metrics[linear_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_linear_model(linear_model_name) + + _test() + + +@pytest.fixture +def run_linear_model(): + def _run(linear_model_name): + import sklearn.linear_model + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + if linear_model_name == "GammaRegressor": + x_train = [[1, 2], [2, 3], [3, 4], [4, 3]] + y_train = [19, 26, 33, 30] + x_test = [[1, 2], [2, 3], [3, 4], [4, 3]] + y_test = [19, 26, 33, 30] + elif linear_model_name in [ + "MultiTaskElasticNet", + "MultiTaskElasticNetCV", + "MultiTaskLasso", + "MultiTaskLassoCV", + ]: + y_train = x_train + y_test = x_test + + clf = getattr(sklearn.linear_model, linear_model_name)() + + model = clf.fit(x_train, y_train) + model.predict(x_test) + + if hasattr(model, "score"): + model.score(x_test, y_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_metric_scorers.py b/tests/mlmodel_sklearn/test_metric_scorers.py new file mode 100644 index 000000000..50557b882 --- /dev/null +++ b/tests/mlmodel_sklearn/test_metric_scorers.py @@ -0,0 +1,150 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from testing_support.fixtures import validate_attributes + +from newrelic.api.background_task import background_task +from newrelic.hooks.mlmodel_sklearn import PredictReturnTypeProxy + + +@pytest.mark.parametrize( + "metric_scorer_name", + ( + "accuracy_score", + "balanced_accuracy_score", + "f1_score", + "precision_score", + "recall_score", + "roc_auc_score", + "r2_score", + ), +) +def test_metric_scorer_attributes(metric_scorer_name, run_metric_scorer): + @validate_attributes("agent", ["DecisionTreeClassifier/TrainingStep/0/%s" % metric_scorer_name]) + @background_task() + def _test(): + run_metric_scorer(metric_scorer_name) + + _test() + + +@pytest.mark.parametrize( + "metric_scorer_name", + ( + "accuracy_score", + "balanced_accuracy_score", + "f1_score", + "precision_score", + "recall_score", + "roc_auc_score", + "r2_score", + ), +) +def test_metric_scorer_training_steps_attributes(metric_scorer_name, run_metric_scorer): + @validate_attributes( + "agent", + [ + "DecisionTreeClassifier/TrainingStep/0/%s" % metric_scorer_name, + "DecisionTreeClassifier/TrainingStep/1/%s" % metric_scorer_name, + ], + ) + @background_task() + def _test(): + run_metric_scorer(metric_scorer_name, training_steps=[0, 1]) + + _test() + + +@pytest.mark.parametrize( + "metric_scorer_name,kwargs", + [ + ("f1_score", {"average": None}), + ("precision_score", {"average": None}), + ("recall_score", {"average": None}), + ], +) +def test_metric_scorer_iterable_score_attributes(metric_scorer_name, kwargs, run_metric_scorer): + @validate_attributes( + "agent", + [ + "DecisionTreeClassifier/TrainingStep/0/%s[0]" % metric_scorer_name, + "DecisionTreeClassifier/TrainingStep/0/%s[1]" % metric_scorer_name, + ], + ) + @background_task() + def _test(): + run_metric_scorer(metric_scorer_name, kwargs) + + _test() + + +@pytest.mark.parametrize( + "metric_scorer_name", + [ + "accuracy_score", + "balanced_accuracy_score", + "f1_score", + "precision_score", + "recall_score", + "roc_auc_score", + "r2_score", + ], +) +def test_metric_scorer_attributes_unknown_model(metric_scorer_name): + @validate_attributes("agent", ["Unknown/TrainingStep/Unknown/%s" % metric_scorer_name]) + @background_task() + def _test(): + from sklearn import metrics + + y_pred = [1, 0] + y_test = [1, 0] + + getattr(metrics, metric_scorer_name)(y_test, y_pred) + + _test() + + +@pytest.mark.parametrize("data", (np.array([0, 1]), "foo", 1, 1.0, True, [0, 1], {"foo": "bar"}, (0, 1), np.str_("F"))) +def test_PredictReturnTypeProxy(data): + wrapped_data = PredictReturnTypeProxy(data, "ModelName", 0) + + assert wrapped_data._nr_model_name == "ModelName" + assert wrapped_data._nr_training_step == 0 + + +@pytest.fixture +def run_metric_scorer(): + def _run(metric_scorer_name, metric_scorer_kwargs=None, training_steps=None): + from sklearn import metrics, tree + + x_train = [[0, 0], [1, 1]] + y_train = [0, 1] + x_test = [[2.0, 2.0], [0, 0.5]] + y_test = [1, 0] + + if not training_steps: + training_steps = [0] + + clf = tree.DecisionTreeClassifier(random_state=0) + for step in training_steps: + model = clf.fit(x_train, y_train) + + labels = model.predict(x_test) + + metric_scorer_kwargs = metric_scorer_kwargs or {} + getattr(metrics, metric_scorer_name)(y_test, labels, **metric_scorer_kwargs) + + return _run diff --git a/tests/mlmodel_sklearn/test_mixture_models.py b/tests/mlmodel_sklearn/test_mixture_models.py new file mode 100644 index 000000000..7ef838126 --- /dev/null +++ b/tests/mlmodel_sklearn/test_mixture_models.py @@ -0,0 +1,85 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +@pytest.mark.parametrize( + "mixture_model_name", + [ + "GaussianMixture", + "BayesianGaussianMixture", + ], +) +def test_model_methods_wrapped_in_function_trace(mixture_model_name, run_mixture_model): + expected_scoped_metrics = { + "GaussianMixture": [ + ("Function/MLModel/Sklearn/Named/GaussianMixture.fit", 1), + ("Function/MLModel/Sklearn/Named/GaussianMixture.predict", 1), + ("Function/MLModel/Sklearn/Named/GaussianMixture.predict_proba", 1), + ("Function/MLModel/Sklearn/Named/GaussianMixture.score", 1), + ], + "BayesianGaussianMixture": [ + ("Function/MLModel/Sklearn/Named/BayesianGaussianMixture.fit", 1), + ("Function/MLModel/Sklearn/Named/BayesianGaussianMixture.predict", 1), + ("Function/MLModel/Sklearn/Named/BayesianGaussianMixture.predict_proba", 1), + ("Function/MLModel/Sklearn/Named/BayesianGaussianMixture.score", 1), + ], + } + + expected_transaction_name = ( + "test_mixture_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_mixture_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[mixture_model_name], + rollup_metrics=expected_scoped_metrics[mixture_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_mixture_model(mixture_model_name) + + _test() + + +@pytest.fixture +def run_mixture_model(): + def _run(mixture_model_name): + import sklearn.mixture + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.mixture, mixture_model_name)() + + model = clf.fit(x_train, y_train) + model.predict(x_test) + model.score(x_test, y_test) + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_ml_model.py b/tests/mlmodel_sklearn/test_ml_model.py new file mode 100644 index 000000000..cfb8e79a6 --- /dev/null +++ b/tests/mlmodel_sklearn/test_ml_model.py @@ -0,0 +1,337 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import pandas +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_ml_event_count import validate_ml_event_count +from testing_support.validators.validate_ml_events import validate_ml_events + +from newrelic.api.background_task import background_task +from newrelic.api.ml_model import wrap_mlmodel + +try: + from sklearn.tree._classes import BaseDecisionTree +except ImportError: + from sklearn.tree.tree import BaseDecisionTree + +_logger = logging.getLogger(__name__) + + +# Create custom model that isn't auto-instrumented to validate ml_model wrapper functionality +class CustomTestModel(BaseDecisionTree): + def __init__( + self, + criterion="poisson", + splitter="random", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=None, + random_state=0, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + class_weight=None, + ccp_alpha=0.0, + ): + super().__init__( + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + class_weight=class_weight, + random_state=random_state, + min_impurity_decrease=min_impurity_decrease, + ccp_alpha=ccp_alpha, + ) + + def fit(self, X, y, sample_weight=None, check_input=True): + if hasattr(super(CustomTestModel, self), "_fit"): + return self._fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + ) + else: + return super(CustomTestModel, self).fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + ) + + def predict(self, X, check_input=True): + return super(CustomTestModel, self).predict(X, check_input=check_input) + + +int_list_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "MyCustomModel", + "model_version": "1.2.3", + "feature.0": 1.0, + "feature.1": 2.0, + "label.0": "0.5", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_custom_model_int_list_no_features_and_labels(): + @validate_ml_event_count(count=1) + @validate_ml_events(int_list_recorded_custom_events) + @background_task() + def _test(): + x_train = [[0, 0], [1, 1]] + y_train = [0, 1] + x_test = [[1.0, 2.0]] + + model = CustomTestModel().fit(x_train, y_train) + wrap_mlmodel(model, name="MyCustomModel", version="1.2.3") + + labels = model.predict(x_test) + + return model + + _test() + + +int_list_recorded_custom_events_with_metadata = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "MyCustomModel", + "model_version": "1.2.3", + "feature.0": 1.0, + "feature.1": 2.0, + "label.0": "0.5", + "new_relic_data_schema_version": 2, + "metadata1": "value1", + "metadata2": "value2", + }, + ), +] + + +@reset_core_stats_engine() +def test_custom_model_int_list_with_metadata(): + @validate_ml_event_count(count=1) + @validate_ml_events(int_list_recorded_custom_events_with_metadata) + @background_task() + def _test(): + x_train = [[0, 0], [1, 1]] + y_train = [0, 1] + x_test = [[1.0, 2.0]] + + model = CustomTestModel().fit(x_train, y_train) + wrap_mlmodel( + model, + name="MyCustomModel", + version="1.2.3", + metadata={"metadata1": "value1", "metadata2": "value2"}, + ) + + labels = model.predict(x_test) + + return model + + _test() + + +pandas_df_recorded_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "PandasTestModel", + "model_version": "1.5.0b1", + "feature.feature1": 0, + "feature.feature2": 0, + "feature.feature3": 1, + "label.label1": "0.5", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_wrapper_attrs_custom_model_pandas_df(): + @validate_ml_event_count(count=1) + @validate_ml_events(pandas_df_recorded_custom_events) + @background_task() + def _test(): + x_train = pandas.DataFrame({"col1": [0, 1], "col2": [0, 1], "col3": [1, 2]}, dtype="category") + y_train = [0, 1] + x_test = pandas.DataFrame({"col1": [0], "col2": [0], "col3": [1]}, dtype="category") + + model = CustomTestModel(random_state=0).fit(x_train, y_train) + wrap_mlmodel( + model, + name="PandasTestModel", + version="1.5.0b1", + feature_names=["feature1", "feature2", "feature3"], + label_names=["label1"], + ) + model.predict(x_test) + return model + + _test() + + +pandas_df_recorded_builtin_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "MyDecisionTreeClassifier", + "model_version": "1.5.0b1", + "feature.feature1": 12, + "feature.feature2": 14, + "label.label1": "0", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_wrapper_attrs_builtin_model(): + @validate_ml_event_count(count=1) + @validate_ml_events(pandas_df_recorded_builtin_events) + @background_task() + def _test(): + import sklearn.tree + + x_train = pandas.DataFrame({"col1": [0, 0], "col2": [1, 1]}, dtype="int") + y_train = pandas.DataFrame({"label": [0, 1]}, dtype="int") + x_test = pandas.DataFrame({"col1": [12], "col2": [14]}, dtype="int") + + clf = getattr(sklearn.tree, "DecisionTreeClassifier")(random_state=0) + + model = clf.fit(x_train, y_train) + wrap_mlmodel( + model, + name="MyDecisionTreeClassifier", + version="1.5.0b1", + feature_names=["feature1", "feature2"], + label_names=["label1"], + ) + model.predict(x_test) + + return model + + _test() + + +pandas_df_mismatched_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "MyDecisionTreeClassifier", + "model_version": "1.5.0b1", + "feature.col1": 12, + "feature.col2": 14, + "feature.col3": 16, + "label.0": "1", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_wrapper_mismatched_features_and_labels_df(): + @validate_ml_event_count(count=1) + @validate_ml_events(pandas_df_mismatched_custom_events) + @background_task() + def _test(): + import sklearn.tree + + x_train = pandas.DataFrame({"col1": [7, 8], "col2": [9, 10], "col3": [24, 25]}, dtype="int") + y_train = pandas.DataFrame({"label": [0, 1]}, dtype="int") + x_test = pandas.DataFrame({"col1": [12], "col2": [14], "col3": [16]}, dtype="int") + + clf = getattr(sklearn.tree, "DecisionTreeClassifier")(random_state=0) + + model = clf.fit(x_train, y_train) + wrap_mlmodel( + model, + name="MyDecisionTreeClassifier", + version="1.5.0b1", + feature_names=["feature1", "feature2"], + label_names=["label1", "label2"], + ) + model.predict(x_test) + return model + + _test() + + +numpy_str_mismatched_custom_events = [ + ( + {"type": "InferenceData"}, + { + "inference_id": None, + "prediction_id": None, + "modelName": "MyDecisionTreeClassifier", + "model_version": "0.0.1", + "feature.0": "20", + "feature.1": "21", + "label.0": "21", + "new_relic_data_schema_version": 2, + }, + ), +] + + +@reset_core_stats_engine() +def test_wrapper_mismatched_features_and_labels_np_array(): + @validate_ml_events(numpy_str_mismatched_custom_events) + @validate_ml_event_count(count=1) + @background_task() + def _test(): + import numpy as np + import sklearn.tree + + x_train = np.array([[20, 20], [21, 21]], dtype="._test" + if six.PY3 + else "test_model_selection_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[model_selection_model_name], + rollup_metrics=expected_scoped_metrics[model_selection_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_model_selection_model(model_selection_model_name) + + _test() + + +@pytest.fixture +def run_model_selection_model(): + def _run(model_selection_model_name): + import sklearn.model_selection + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + if model_selection_model_name == "GridSearchCV": + kwargs = {"estimator": AdaBoostClassifier(), "param_grid": {}} + else: + kwargs = {"estimator": AdaBoostClassifier(), "param_distributions": {}} + clf = getattr(sklearn.model_selection, model_selection_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_multiclass_models.py b/tests/mlmodel_sklearn/test_multiclass_models.py new file mode 100644 index 000000000..dd10d76f1 --- /dev/null +++ b/tests/mlmodel_sklearn/test_multiclass_models.py @@ -0,0 +1,91 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn.ensemble import AdaBoostClassifier +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +@pytest.mark.parametrize( + "multiclass_model_name", + [ + "OneVsRestClassifier", + "OneVsOneClassifier", + "OutputCodeClassifier", + ], +) +def test_model_methods_wrapped_in_function_trace(multiclass_model_name, run_multiclass_model): + expected_scoped_metrics = { + "OneVsRestClassifier": [ + ("Function/MLModel/Sklearn/Named/OneVsRestClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/OneVsRestClassifier.predict", 1), + ("Function/MLModel/Sklearn/Named/OneVsRestClassifier.predict_proba", 1), + ], + "OneVsOneClassifier": [ + ("Function/MLModel/Sklearn/Named/OneVsOneClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/OneVsOneClassifier.predict", 1), + ], + "OutputCodeClassifier": [ + ("Function/MLModel/Sklearn/Named/OutputCodeClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/OutputCodeClassifier.predict", 1), + ], + } + + expected_transaction_name = ( + "test_multiclass_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_multiclass_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[multiclass_model_name], + rollup_metrics=expected_scoped_metrics[multiclass_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_multiclass_model(multiclass_model_name) + + _test() + + +@pytest.fixture +def run_multiclass_model(): + def _run(multiclass_model_name): + import sklearn.multiclass + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + # This is an example of a model that has all the available attributes + # We could have choosen any estimator that has predict, score, + # predict_log_proba, and predict_proba + clf = getattr(sklearn.multiclass, multiclass_model_name)(estimator=AdaBoostClassifier()) + + model = clf.fit(x_train, y_train) + model.predict(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_multioutput_models.py b/tests/mlmodel_sklearn/test_multioutput_models.py new file mode 100644 index 000000000..392328f28 --- /dev/null +++ b/tests/mlmodel_sklearn/test_multioutput_models.py @@ -0,0 +1,129 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn import __init__ # noqa: Needed for get_package_version +from sklearn.ensemble import AdaBoostClassifier +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +# Python 2 will not allow instantiation of abstract class +# (abstract method is __init__ here) +@pytest.mark.skipif(SKLEARN_VERSION >= (1, 0, 0) or six.PY2, reason="Requires sklearn < 1.0 and Python3") +@pytest.mark.parametrize( + "multioutput_model_name", + [ + "MultiOutputEstimator", + ], +) +def test_below_v1_0_model_methods_wrapped_in_function_trace(multioutput_model_name, run_multioutput_model): + expected_scoped_metrics = { + "MultiOutputEstimator": [ + ("Function/MLModel/Sklearn/Named/MultiOutputEstimator.fit", 1), + ("Function/MLModel/Sklearn/Named/MultiOutputEstimator.predict", 2), + ], + } + expected_transaction_name = ( + "test_multioutput_models:test_below_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_multioutput_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[multioutput_model_name], + rollup_metrics=expected_scoped_metrics[multioutput_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_multioutput_model(multioutput_model_name) + + _test() + + +@pytest.mark.parametrize( + "multioutput_model_name", + [ + "MultiOutputClassifier", + "ClassifierChain", + "RegressorChain", + ], +) +def test_above_v1_0_model_methods_wrapped_in_function_trace(multioutput_model_name, run_multioutput_model): + expected_scoped_metrics = { + "MultiOutputClassifier": [ + ("Function/MLModel/Sklearn/Named/MultiOutputClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/MultiOutputClassifier.predict_proba", 1), + ("Function/MLModel/Sklearn/Named/MultiOutputClassifier.score", 1), + ], + "ClassifierChain": [ + ("Function/MLModel/Sklearn/Named/ClassifierChain.fit", 1), + ("Function/MLModel/Sklearn/Named/ClassifierChain.predict_proba", 1), + ], + "RegressorChain": [ + ("Function/MLModel/Sklearn/Named/RegressorChain.fit", 1), + ], + } + expected_transaction_name = ( + "test_multioutput_models:test_above_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_multioutput_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[multioutput_model_name], + rollup_metrics=expected_scoped_metrics[multioutput_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_multioutput_model(multioutput_model_name) + + _test() + + +@pytest.fixture +def run_multioutput_model(): + def _run(multioutput_model_name): + import sklearn.multioutput + from sklearn.datasets import make_multilabel_classification + + X, y = make_multilabel_classification(n_classes=3, random_state=0) + + kwargs = {"estimator": AdaBoostClassifier()} + if multioutput_model_name in ["RegressorChain", "ClassifierChain"]: + kwargs = {"base_estimator": AdaBoostClassifier()} + clf = getattr(sklearn.multioutput, multioutput_model_name)(**kwargs) + + model = clf.fit(X, y) + if hasattr(model, "predict"): + model.predict(X) + if hasattr(model, "score"): + model.score(X, y) + if hasattr(model, "predict_proba"): + model.predict_proba(X) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_naive_bayes_models.py b/tests/mlmodel_sklearn/test_naive_bayes_models.py new file mode 100644 index 000000000..22dc6db1b --- /dev/null +++ b/tests/mlmodel_sklearn/test_naive_bayes_models.py @@ -0,0 +1,141 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn import __init__ # noqa: needed for get_package_version +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 0, 0), reason="Requires sklearn >= 1.0") +@pytest.mark.parametrize( + "naive_bayes_model_name", + [ + "CategoricalNB", + ], +) +def test_above_v1_0_model_methods_wrapped_in_function_trace(naive_bayes_model_name, run_naive_bayes_model): + expected_scoped_metrics = { + "CategoricalNB": [ + ("Function/MLModel/Sklearn/Named/CategoricalNB.fit", 1), + ("Function/MLModel/Sklearn/Named/CategoricalNB.predict", 1), + ("Function/MLModel/Sklearn/Named/CategoricalNB.predict_log_proba", 2), + ("Function/MLModel/Sklearn/Named/CategoricalNB.predict_proba", 1), + ], + } + expected_transaction_name = ( + "test_naive_bayes_models:test_above_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_naive_bayes_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[naive_bayes_model_name], + rollup_metrics=expected_scoped_metrics[naive_bayes_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_naive_bayes_model(naive_bayes_model_name) + + _test() + + +@pytest.mark.parametrize( + "naive_bayes_model_name", + [ + "GaussianNB", + "MultinomialNB", + "ComplementNB", + "BernoulliNB", + ], +) +def test_model_methods_wrapped_in_function_trace(naive_bayes_model_name, run_naive_bayes_model): + expected_scoped_metrics = { + "GaussianNB": [ + ("Function/MLModel/Sklearn/Named/GaussianNB.fit", 1), + ("Function/MLModel/Sklearn/Named/GaussianNB.predict", 1), + ("Function/MLModel/Sklearn/Named/GaussianNB.predict_log_proba", 2), + ("Function/MLModel/Sklearn/Named/GaussianNB.predict_proba", 1), + ], + "MultinomialNB": [ + ("Function/MLModel/Sklearn/Named/MultinomialNB.fit", 1), + ("Function/MLModel/Sklearn/Named/MultinomialNB.predict", 1), + ("Function/MLModel/Sklearn/Named/MultinomialNB.predict_log_proba", 2), + ("Function/MLModel/Sklearn/Named/MultinomialNB.predict_proba", 1), + ], + "ComplementNB": [ + ("Function/MLModel/Sklearn/Named/ComplementNB.fit", 1), + ("Function/MLModel/Sklearn/Named/ComplementNB.predict", 1), + ("Function/MLModel/Sklearn/Named/ComplementNB.predict_log_proba", 2), + ("Function/MLModel/Sklearn/Named/ComplementNB.predict_proba", 1), + ], + "BernoulliNB": [ + ("Function/MLModel/Sklearn/Named/BernoulliNB.fit", 1), + ("Function/MLModel/Sklearn/Named/BernoulliNB.predict", 1), + ("Function/MLModel/Sklearn/Named/BernoulliNB.predict_log_proba", 2), + ("Function/MLModel/Sklearn/Named/BernoulliNB.predict_proba", 1), + ], + } + + expected_transaction_name = ( + "test_naive_bayes_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_naive_bayes_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[naive_bayes_model_name], + rollup_metrics=expected_scoped_metrics[naive_bayes_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_naive_bayes_model(naive_bayes_model_name) + + _test() + + +@pytest.fixture +def run_naive_bayes_model(): + def _run(naive_bayes_model_name): + import sklearn.naive_bayes + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.naive_bayes, naive_bayes_model_name)() + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_neighbors_models.py b/tests/mlmodel_sklearn/test_neighbors_models.py new file mode 100644 index 000000000..53a521157 --- /dev/null +++ b/tests/mlmodel_sklearn/test_neighbors_models.py @@ -0,0 +1,172 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn.neighbors import __init__ # noqa: Needed for get_package_version +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "neighbors_model_name", + [ + "KNeighborsClassifier", + "RadiusNeighborsClassifier", + "KernelDensity", + "LocalOutlierFactor", + "NearestCentroid", + "KNeighborsRegressor", + "RadiusNeighborsRegressor", + "NearestNeighbors", + ], +) +def test_model_methods_wrapped_in_function_trace(neighbors_model_name, run_neighbors_model): + expected_scoped_metrics = { + "KNeighborsClassifier": [ + ("Function/MLModel/Sklearn/Named/KNeighborsClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/KNeighborsClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/KNeighborsClassifier.predict_proba", 1), + ], + "RadiusNeighborsClassifier": [ + ("Function/MLModel/Sklearn/Named/RadiusNeighborsClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/RadiusNeighborsClassifier.predict", 2), + ], + "KernelDensity": [ + ("Function/MLModel/Sklearn/Named/KernelDensity.fit", 1), + ("Function/MLModel/Sklearn/Named/KernelDensity.score", 1), + ], + "LocalOutlierFactor": [ + ("Function/MLModel/Sklearn/Named/LocalOutlierFactor.fit", 1), + ("Function/MLModel/Sklearn/Named/LocalOutlierFactor.predict", 1), + ], + "NearestCentroid": [ + ("Function/MLModel/Sklearn/Named/NearestCentroid.fit", 1), + ("Function/MLModel/Sklearn/Named/NearestCentroid.predict", 2), + ], + "KNeighborsRegressor": [ + ("Function/MLModel/Sklearn/Named/KNeighborsRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/KNeighborsRegressor.predict", 2), + ], + "RadiusNeighborsRegressor": [ + ("Function/MLModel/Sklearn/Named/RadiusNeighborsRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/RadiusNeighborsRegressor.predict", 2), + ], + "NearestNeighbors": [ + ("Function/MLModel/Sklearn/Named/NearestNeighbors.fit", 1), + ], + } + + expected_transaction_name = ( + "test_neighbors_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_neighbors_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[neighbors_model_name], + rollup_metrics=expected_scoped_metrics[neighbors_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_neighbors_model(neighbors_model_name) + + _test() + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 0, 0), reason="Requires sklearn >= 1.0") +@pytest.mark.parametrize( + "neighbors_model_name", + [ + "KNeighborsTransformer", + "RadiusNeighborsTransformer", + "NeighborhoodComponentsAnalysis", + "RadiusNeighborsClassifier", + ], +) +def test_above_v1_0_model_methods_wrapped_in_function_trace(neighbors_model_name, run_neighbors_model): + expected_scoped_metrics = { + "KNeighborsTransformer": [ + ("Function/MLModel/Sklearn/Named/KNeighborsTransformer.fit", 1), + ("Function/MLModel/Sklearn/Named/KNeighborsTransformer.transform", 1), + ], + "RadiusNeighborsTransformer": [ + ("Function/MLModel/Sklearn/Named/RadiusNeighborsTransformer.fit", 1), + ("Function/MLModel/Sklearn/Named/RadiusNeighborsTransformer.transform", 1), + ], + "NeighborhoodComponentsAnalysis": [ + ("Function/MLModel/Sklearn/Named/NeighborhoodComponentsAnalysis.fit", 1), + ("Function/MLModel/Sklearn/Named/NeighborhoodComponentsAnalysis.transform", 1), + ], + "RadiusNeighborsClassifier": [ + ("Function/MLModel/Sklearn/Named/RadiusNeighborsClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/RadiusNeighborsClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/RadiusNeighborsClassifier.predict_proba", 3), # Added in v1.0 + ], + } + expected_transaction_name = ( + "test_neighbors_models:test_above_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_neighbors_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[neighbors_model_name], + rollup_metrics=expected_scoped_metrics[neighbors_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_neighbors_model(neighbors_model_name) + + _test() + + +@pytest.fixture +def run_neighbors_model(): + def _run(neighbors_model_name): + import sklearn.neighbors + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + kwargs = {} + if neighbors_model_name == "LocalOutlierFactor": + kwargs = {"novelty": True} + clf = getattr(sklearn.neighbors, neighbors_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_neural_network_models.py b/tests/mlmodel_sklearn/test_neural_network_models.py new file mode 100644 index 000000000..468bfb4b9 --- /dev/null +++ b/tests/mlmodel_sklearn/test_neural_network_models.py @@ -0,0 +1,96 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "neural_network_model_name", + [ + "MLPClassifier", + "MLPRegressor", + "BernoulliRBM", + ], +) +def test_model_methods_wrapped_in_function_trace(neural_network_model_name, run_neural_network_model): + expected_scoped_metrics = { + "MLPClassifier": [ + ("Function/MLModel/Sklearn/Named/MLPClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/MLPClassifier.predict", 1), + ("Function/MLModel/Sklearn/Named/MLPClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/MLPClassifier.predict_proba", 2), + ], + "MLPRegressor": [ + ("Function/MLModel/Sklearn/Named/MLPRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/MLPRegressor.predict", 1), + ], + "BernoulliRBM": [ + ("Function/MLModel/Sklearn/Named/BernoulliRBM.fit", 1), + ("Function/MLModel/Sklearn/Named/BernoulliRBM.transform", 1), + ], + } + + expected_transaction_name = ( + "test_neural_network_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_neural_network_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[neural_network_model_name], + rollup_metrics=expected_scoped_metrics[neural_network_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_neural_network_model(neural_network_model_name) + + _test() + + +@pytest.fixture +def run_neural_network_model(): + def _run(neural_network_model_name): + import sklearn.neural_network + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + clf = getattr(sklearn.neural_network, neural_network_model_name)() + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_pipeline_models.py b/tests/mlmodel_sklearn/test_pipeline_models.py new file mode 100644 index 000000000..ac9b918f4 --- /dev/null +++ b/tests/mlmodel_sklearn/test_pipeline_models.py @@ -0,0 +1,95 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from sklearn.decomposition import TruncatedSVD +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "pipeline_model_name", + [ + "Pipeline", + "FeatureUnion", + ], +) +def test_model_methods_wrapped_in_function_trace(pipeline_model_name, run_pipeline_model): + expected_scoped_metrics = { + "Pipeline": [ + ("Function/MLModel/Sklearn/Named/Pipeline.fit", 1), + ("Function/MLModel/Sklearn/Named/Pipeline.predict", 1), + ("Function/MLModel/Sklearn/Named/Pipeline.score", 1), + ], + "FeatureUnion": [ + ("Function/MLModel/Sklearn/Named/FeatureUnion.fit", 1), + ("Function/MLModel/Sklearn/Named/FeatureUnion.transform", 1), + ], + } + + expected_transaction_name = ( + "test_pipeline_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_pipeline_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[pipeline_model_name], + rollup_metrics=expected_scoped_metrics[pipeline_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_pipeline_model(pipeline_model_name) + + _test() + + +@pytest.fixture +def run_pipeline_model(): + def _run(pipeline_model_name): + import sklearn.pipeline + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + if pipeline_model_name == "Pipeline": + kwargs = {"steps": [("scaler", StandardScaler()), ("svc", SVC())]} + else: + kwargs = {"transformer_list": [("scaler", StandardScaler()), ("svd", TruncatedSVD(n_components=2))]} + clf = getattr(sklearn.pipeline, pipeline_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "transform"): + model.transform(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_prediction_stats.py b/tests/mlmodel_sklearn/test_prediction_stats.py new file mode 100644 index 000000000..5538119e7 --- /dev/null +++ b/tests/mlmodel_sklearn/test_prediction_stats.py @@ -0,0 +1,519 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +import numpy as np +import pandas as pd +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task + +# This will act as the UUID for `prediction_id` +ML_METRIC_FORCED_UUID = "0b59992f-2349-4a46-8de1-696d3fe1088b" + + +@pytest.fixture(scope="function") +def force_uuid(monkeypatch): + monkeypatch.setattr(uuid, "uuid4", lambda *a, **k: ML_METRIC_FORCED_UUID) + + +_test_prediction_stats_tags = frozenset( + {("modelName", "DummyClassifier"), ("prediction_id", ML_METRIC_FORCED_UUID), ("model_version", "0.0.0")} +) + + +@pytest.mark.parametrize( + "x_train,y_train,x_test,metrics", + [ + ( + [[0, 0], [1, 1]], + [0, 1], + [[2.0, 2.0], [0, 0.5]], + [ + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Mean", _test_prediction_stats_tags, 1), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Percentile25", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Percentile50", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Percentile75", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/StandardDeviation", + _test_prediction_stats_tags, + 1, + ), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Min", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Max", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Count", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Mean", _test_prediction_stats_tags, 1), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Percentile25", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Percentile50", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Percentile75", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/StandardDeviation", + _test_prediction_stats_tags, + 1, + ), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Min", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Max", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Count", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Mean", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Percentile25", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Percentile50", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Percentile75", _test_prediction_stats_tags, 1), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/StandardDeviation", + _test_prediction_stats_tags, + 1, + ), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Min", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Max", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Count", _test_prediction_stats_tags, 1), + ], + ), + ( + np.array([[0, 0], [1, 1]]), + [0, 1], + np.array([[2.0, 2.0], [0, 0.5]]), + [ + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Mean", _test_prediction_stats_tags, 1), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Percentile25", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Percentile50", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Percentile75", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/StandardDeviation", + _test_prediction_stats_tags, + 1, + ), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Min", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Max", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/0/Count", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Mean", _test_prediction_stats_tags, 1), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Percentile25", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Percentile50", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Percentile75", + _test_prediction_stats_tags, + 1, + ), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/StandardDeviation", + _test_prediction_stats_tags, + 1, + ), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Min", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Max", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Feature/1/Count", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Mean", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Percentile25", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Percentile50", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Percentile75", _test_prediction_stats_tags, 1), + ( + "MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/StandardDeviation", + _test_prediction_stats_tags, + 1, + ), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Min", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Max", _test_prediction_stats_tags, 1), + ("MLModel/Sklearn/Named/DummyClassifier/Predict/Label/0/Count", _test_prediction_stats_tags, 1), + ], + ), + ( + np.array([["a", 0, 4], ["b", 1, 3]], dtype="._test" + if six.PY3 + else "test_semi_supervised_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[semi_supervised_model_name], + rollup_metrics=expected_scoped_metrics[semi_supervised_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_semi_supervised_model(semi_supervised_model_name) + + _test() + + +@pytest.mark.skipif(SKLEARN_VERSION < (1, 0, 0), reason="Requires sklearn <= 1.0") +@pytest.mark.parametrize( + "semi_supervised_model_name", + [ + "SelfTrainingClassifier", + ], +) +def test_above_v1_0_model_methods_wrapped_in_function_trace(semi_supervised_model_name, run_semi_supervised_model): + expected_scoped_metrics = { + "SelfTrainingClassifier": [ + ("Function/MLModel/Sklearn/Named/SelfTrainingClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/SelfTrainingClassifier.predict", 1), + ("Function/MLModel/Sklearn/Named/SelfTrainingClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/SelfTrainingClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/SelfTrainingClassifier.predict_proba", 1), + ], + } + expected_transaction_name = ( + "test_semi_supervised_models:test_above_v1_0_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_semi_supervised_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[semi_supervised_model_name], + rollup_metrics=expected_scoped_metrics[semi_supervised_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_semi_supervised_model(semi_supervised_model_name) + + _test() + + +@pytest.fixture +def run_semi_supervised_model(): + def _run(semi_supervised_model_name): + import sklearn.semi_supervised + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + if semi_supervised_model_name == "SelfTrainingClassifier": + kwargs = {"base_estimator": AdaBoostClassifier()} + else: + kwargs = {} + clf = getattr(sklearn.semi_supervised, semi_supervised_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + if hasattr(model, "predict"): + model.predict(x_test) + if hasattr(model, "score"): + model.score(x_test, y_test) + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_svm_models.py b/tests/mlmodel_sklearn/test_svm_models.py new file mode 100644 index 000000000..fe95f2f46 --- /dev/null +++ b/tests/mlmodel_sklearn/test_svm_models.py @@ -0,0 +1,110 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.common.package_version_utils import get_package_version +from newrelic.packages import six + +SKLEARN_VERSION = tuple(map(int, get_package_version("sklearn").split("."))) + + +@pytest.mark.parametrize( + "svm_model_name", + [ + "LinearSVC", + "LinearSVR", + "SVC", + "NuSVC", + "SVR", + "NuSVR", + "OneClassSVM", + ], +) +def test_model_methods_wrapped_in_function_trace(svm_model_name, run_svm_model): + expected_scoped_metrics = { + "LinearSVC": [ + ("Function/MLModel/Sklearn/Named/LinearSVC.fit", 1), + ("Function/MLModel/Sklearn/Named/LinearSVC.predict", 1), + ], + "LinearSVR": [ + ("Function/MLModel/Sklearn/Named/LinearSVR.fit", 1), + ("Function/MLModel/Sklearn/Named/LinearSVR.predict", 1), + ], + "SVC": [ + ("Function/MLModel/Sklearn/Named/SVC.fit", 1), + ("Function/MLModel/Sklearn/Named/SVC.predict", 1), + ], + "NuSVC": [ + ("Function/MLModel/Sklearn/Named/NuSVC.fit", 1), + ("Function/MLModel/Sklearn/Named/NuSVC.predict", 1), + ], + "SVR": [ + ("Function/MLModel/Sklearn/Named/SVR.fit", 1), + ("Function/MLModel/Sklearn/Named/SVR.predict", 1), + ], + "NuSVR": [ + ("Function/MLModel/Sklearn/Named/NuSVR.fit", 1), + ("Function/MLModel/Sklearn/Named/NuSVR.predict", 1), + ], + "OneClassSVM": [ + ("Function/MLModel/Sklearn/Named/OneClassSVM.fit", 1), + ("Function/MLModel/Sklearn/Named/OneClassSVM.predict", 1), + ], + } + + expected_transaction_name = ( + "test_svm_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_svm_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[svm_model_name], + rollup_metrics=expected_scoped_metrics[svm_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_svm_model(svm_model_name) + + _test() + + +@pytest.fixture +def run_svm_model(): + def _run(svm_model_name): + import sklearn.svm + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + X, y = load_iris(return_X_y=True) + x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + + kwargs = {"random_state": 0} + if svm_model_name in ["SVR", "NuSVR", "OneClassSVM"]: + kwargs = {} + clf = getattr(sklearn.svm, svm_model_name)(**kwargs) + + model = clf.fit(x_train, y_train) + model.predict(x_test) + + return model + + return _run diff --git a/tests/mlmodel_sklearn/test_tree_models.py b/tests/mlmodel_sklearn/test_tree_models.py new file mode 100644 index 000000000..b30b7e2ea --- /dev/null +++ b/tests/mlmodel_sklearn/test_tree_models.py @@ -0,0 +1,158 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.packages import six + + +def test_model_methods_wrapped_in_function_trace(tree_model_name, run_tree_model): + # Note: in the following expected metrics, predict and predict_proba are called by + # score and predict_log_proba so they are expected to be called twice instead of + # once like the rest of the methods. + expected_scoped_metrics = { + "ExtraTreeRegressor": [ + ("Function/MLModel/Sklearn/Named/ExtraTreeRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreeRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/ExtraTreeRegressor.score", 1), + ], + "DecisionTreeClassifier": [ + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.predict_proba", 2), + ], + "ExtraTreeClassifier": [ + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.predict", 2), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.score", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.predict_log_proba", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.predict_proba", 2), + ], + "DecisionTreeRegressor": [ + ("Function/MLModel/Sklearn/Named/DecisionTreeRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/DecisionTreeRegressor.predict", 2), + ("Function/MLModel/Sklearn/Named/DecisionTreeRegressor.score", 1), + ], + } + expected_transaction_name = ( + "test_tree_models:test_model_methods_wrapped_in_function_trace.._test" + if six.PY3 + else "test_tree_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[tree_model_name], + rollup_metrics=expected_scoped_metrics[tree_model_name], + background_task=True, + ) + @background_task() + def _test(): + run_tree_model() + + _test() + + +def test_multiple_calls_to_model_methods(tree_model_name, run_tree_model): + # Note: in the following expected metrics, predict and predict_proba are called by + # score and predict_log_proba so they are expected to be called twice as often as + # the other methods. + expected_scoped_metrics = { + "ExtraTreeRegressor": [ + ("Function/MLModel/Sklearn/Named/ExtraTreeRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreeRegressor.predict", 4), + ("Function/MLModel/Sklearn/Named/ExtraTreeRegressor.score", 2), + ], + "DecisionTreeClassifier": [ + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.predict", 4), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.score", 2), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.predict_log_proba", 2), + ("Function/MLModel/Sklearn/Named/DecisionTreeClassifier.predict_proba", 4), + ], + "ExtraTreeClassifier": [ + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.fit", 1), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.predict", 4), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.score", 2), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.predict_log_proba", 2), + ("Function/MLModel/Sklearn/Named/ExtraTreeClassifier.predict_proba", 4), + ], + "DecisionTreeRegressor": [ + ("Function/MLModel/Sklearn/Named/DecisionTreeRegressor.fit", 1), + ("Function/MLModel/Sklearn/Named/DecisionTreeRegressor.predict", 4), + ("Function/MLModel/Sklearn/Named/DecisionTreeRegressor.score", 2), + ], + } + expected_transaction_name = ( + "test_tree_models:test_multiple_calls_to_model_methods.._test" if six.PY3 else "test_tree_models:_test" + ) + + @validate_transaction_metrics( + expected_transaction_name, + scoped_metrics=expected_scoped_metrics[tree_model_name], + rollup_metrics=expected_scoped_metrics[tree_model_name], + background_task=True, + ) + @background_task() + def _test(): + x_test = [[2.0, 2.0], [2.0, 1.0]] + y_test = [1, 1] + + model = run_tree_model() + + model.predict(x_test) + model.score(x_test, y_test) + # Some models don't have these methods. + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + + _test() + + +@pytest.fixture(params=["ExtraTreeRegressor", "DecisionTreeClassifier", "ExtraTreeClassifier", "DecisionTreeRegressor"]) +def tree_model_name(request): + return request.param + + +@pytest.fixture +def run_tree_model(tree_model_name): + def _run(): + import sklearn.tree + + x_train = [[0, 0], [1, 1]] + y_train = [0, 1] + x_test = [[2.0, 2.0], [2.0, 1.0]] + y_test = [1, 1] + + clf = getattr(sklearn.tree, tree_model_name)(random_state=0) + model = clf.fit(x_train, y_train) + + labels = model.predict(x_test) + model.score(x_test, y_test) + # Some models don't have these methods. + if hasattr(model, "predict_log_proba"): + model.predict_log_proba(x_test) + if hasattr(model, "predict_proba"): + model.predict_proba(x_test) + return model + + return _run diff --git a/tests/external_boto3/conftest.py b/tests/template_jinja2/conftest.py similarity index 58% rename from tests/external_boto3/conftest.py rename to tests/template_jinja2/conftest.py index 90d82f007..a6922078d 100644 --- a/tests/external_boto3/conftest.py +++ b/tests/template_jinja2/conftest.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { - 'transaction_tracer.explain_threshold': 0.0, - 'transaction_tracer.transaction_threshold': 0.0, - 'transaction_tracer.stack_trace_threshold': 0.0, - 'debug.log_data_collector_payloads': True, - 'debug.record_transaction_failure': True, + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, } collector_agent_registration = collector_agent_registration_fixture( - app_name='Python Agent Test (external_boto3)', - default_settings=_default_settings) + app_name="Python Agent Test (template_jinja2)", default_settings=_default_settings +) diff --git a/tests/template_jinja2/test_jinja2.py b/tests/template_jinja2/test_jinja2.py new file mode 100644 index 000000000..c64dac923 --- /dev/null +++ b/tests/template_jinja2/test_jinja2.py @@ -0,0 +1,41 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jinja2 import Template +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task + + +@validate_transaction_metrics( + "test_render", + background_task=True, + scoped_metrics=( + ("Template/Render/