From 73e64be45b088e16b83891fc0d0ea8c6f18e8c2d Mon Sep 17 00:00:00 2001 From: Prabhu Subramanian Date: Thu, 2 Jan 2025 18:02:13 +0000 Subject: [PATCH 1/2] Ruby frontend Signed-off-by: Prabhu Subramanian Install ruby 3.4.0 Signed-off-by: Prabhu Subramanian --- .github/workflows/containers.yml | 1 + .github/workflows/master.yml | 1 + .github/workflows/release.yml | 1 + build.sbt | 8 +- ci/Dockerfile | 40 +- codemeta.json | 2 +- console/build.sbt | 6 +- dataflowengineoss/build.sbt | 2 +- .../semanticsloader/Parser.scala | 2 +- meta.yaml | 2 +- platform/frontends/c2cpg/build.sbt | 8 +- .../c2cpg/passes/AstCreationPass.scala | 4 +- platform/frontends/javasrc2cpg/build.sbt | 4 +- .../javasrc2cpg/util/SourceParser.scala | 9 +- platform/frontends/jimple2cpg/AUTHORS | 2 - platform/frontends/jimple2cpg/build.sbt | 5 +- .../test/resources/project/build.properties | 2 +- .../appthreat/php2atom/parser/PhpParser.scala | 12 - platform/frontends/pysrc2cpg/build.sbt | 2 +- .../pysrc2cpg/Py2CpgOnFileSystem.scala | 81 +- platform/frontends/ruby2atom/.gitignore | 5 + platform/frontends/ruby2atom/build.sbt | 23 + .../ruby2atom/src/main/resources/log4j2.xml | 13 + .../scala/io/appthreat/ruby2atom/Main.scala | 54 + .../io/appthreat/ruby2atom/Ruby2Atom.scala | 75 ++ .../ruby2atom/astcreation/AstCreator.scala | 145 +++ .../astcreation/AstCreatorHelper.scala | 262 ++++ .../AstForControlStructuresCreator.scala | 365 ++++++ .../AstForExpressionsCreator.scala | 1137 +++++++++++++++++ .../astcreation/AstForFunctionsCreator.scala | 635 +++++++++ .../astcreation/AstForStatementsCreator.scala | 444 +++++++ .../astcreation/AstForTypesCreator.scala | 316 +++++ .../astcreation/RubyIntermediateAst.scala | 646 ++++++++++ .../datastructures/RubyProgramSummary.scala | 199 +++ .../ruby2atom/datastructures/RubyScope.scala | 417 ++++++ .../datastructures/ScopeElement.scala | 73 ++ .../ruby2atom/parser/RubyAstGenRunner.scala | 88 ++ .../ruby2atom/parser/RubyJsonAst.scala | 207 +++ .../ruby2atom/parser/RubyJsonHelpers.scala | 415 ++++++ .../ruby2atom/parser/RubyJsonParser.scala | 17 + .../parser/RubyJsonToNodeCreator.scala | 1090 ++++++++++++++++ .../ruby2atom/passes/AstCreationPass.scala | 41 + .../passes/ConfigFileCreationPass.scala | 31 + .../appthreat/ruby2atom/passes/Defines.scala | 279 ++++ .../ruby2atom/utils/FreshNameGenerator.scala | 11 + platform/frontends/x2cpg/build.sbt | 4 +- .../main/scala/io/appthreat/x2cpg/Ast.scala | 17 +- .../io/appthreat/x2cpg/AstNodeBuilder.scala | 2 +- .../io/appthreat/x2cpg/SourceFiles.scala | 330 +++-- .../appthreat/x2cpg/astgen/AstGenConfig.scala | 30 + .../x2cpg/astgen/AstGenNodeBuilder.scala | 23 + .../appthreat/x2cpg/astgen/AstGenRunner.scala | 155 +++ .../io/appthreat/x2cpg/astgen/package.scala | 41 + .../x2cpg/datastructures/ProgramSummary.scala | 456 +++++++ .../x2cpg/frontendspecific/package.scala | 16 + .../ruby2atom/Constants.scala | 92 ++ .../ruby2atom/ImplicitRequirePass.scala | 195 +++ .../ruby2atom/ImportsPass.scala | 24 + .../x2cpg/typestub/TypeStubConfig.scala | 46 + .../x2cpg/typestub/TypeStubUtil.scala | 27 + .../x2cpg/utils/ConcurrentTaskUtil.scala | 65 + .../appthreat/x2cpg/utils/Environment.scala | 9 +- .../x2cpg/utils/ExternalCommand.scala | 76 ++ .../{joern => appthreat}/x2cpg/AstTests.scala | 0 .../x2cpg/SourceFilesTests.scala | 0 .../x2cpg/X2CpgTests.scala | 0 .../x2cpg/layers/DumpAstTests.scala | 0 .../x2cpg/layers/DumpCdgTests.scala | 0 .../x2cpg/layers/DumpCfgTests.scala | 0 .../passes/CfgDominatorFrontierTests.scala | 0 .../x2cpg/passes/CfgDominatorPassTests.scala | 0 .../x2cpg/passes/ContainsEdgePassTest.scala | 0 .../passes/MemberAccessLinkerTests.scala | 0 .../passes/MethodDecoratorPassTests.scala | 0 .../x2cpg/passes/NamespaceCreatorTests.scala | 0 .../x2cpg/testfixtures/CfgTestFixture.scala | 0 .../x2cpg/testfixtures/Code2CpgFixture.scala | 0 .../x2cpg/testfixtures/DefaultTestCpg.scala | 0 .../testfixtures/EmptyGraphFixture.scala | 0 .../x2cpg/testfixtures/LanguageFrontend.scala | 0 .../x2cpg/testfixtures/TestCpg.scala | 0 .../x2cpg/utils/ExternalCommandTest.scala | 0 .../x2cpg/utils/HashUtilsTest.scala | 0 .../x2cpg/utils/IgnoreInWindows.scala | 0 .../dependency/DependencyResolverTests.scala | 0 .../dependency/MavenCoordinatesTests.scala | 0 .../src/universal/schema-extender/build.sbt | 2 +- .../schema-extender/project/build.properties | 2 +- project/DownloadHelper.scala | 48 + project/Projects.scala | 3 +- project/Versions.scala | 6 +- project/build.properties | 2 +- pyproject.toml | 2 +- 93 files changed, 8662 insertions(+), 193 deletions(-) delete mode 100644 platform/frontends/jimple2cpg/AUTHORS create mode 100644 platform/frontends/ruby2atom/.gitignore create mode 100644 platform/frontends/ruby2atom/build.sbt create mode 100755 platform/frontends/ruby2atom/src/main/resources/log4j2.xml create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Main.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Ruby2Atom.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreator.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreatorHelper.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForControlStructuresCreator.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForExpressionsCreator.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForFunctionsCreator.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForStatementsCreator.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForTypesCreator.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/RubyIntermediateAst.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyProgramSummary.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyScope.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/ScopeElement.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyAstGenRunner.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonAst.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonHelpers.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonParser.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonToNodeCreator.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/AstCreationPass.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/ConfigFileCreationPass.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/Defines.scala create mode 100644 platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/utils/FreshNameGenerator.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenConfig.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenNodeBuilder.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenRunner.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/package.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ProgramSummary.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/package.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/Constants.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImplicitRequirePass.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImportsPass.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubConfig.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubUtil.scala create mode 100644 platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ConcurrentTaskUtil.scala rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/AstTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/SourceFilesTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/X2CpgTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/layers/DumpAstTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/layers/DumpCdgTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/layers/DumpCfgTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/passes/CfgDominatorFrontierTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/passes/CfgDominatorPassTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/passes/ContainsEdgePassTest.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/passes/MemberAccessLinkerTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/passes/MethodDecoratorPassTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/passes/NamespaceCreatorTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/testfixtures/CfgTestFixture.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/testfixtures/Code2CpgFixture.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/testfixtures/DefaultTestCpg.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/testfixtures/EmptyGraphFixture.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/testfixtures/LanguageFrontend.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/testfixtures/TestCpg.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/utils/ExternalCommandTest.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/utils/HashUtilsTest.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/utils/IgnoreInWindows.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/utils/dependency/DependencyResolverTests.scala (100%) rename platform/frontends/x2cpg/src/test/scala/io/{joern => appthreat}/x2cpg/utils/dependency/MavenCoordinatesTests.scala (100%) create mode 100644 project/DownloadHelper.scala diff --git a/.github/workflows/containers.yml b/.github/workflows/containers.yml index 98af7e17..4ecca8a0 100644 --- a/.github/workflows/containers.yml +++ b/.github/workflows/containers.yml @@ -48,6 +48,7 @@ jobs: uses: actions/setup-node@v4 with: node-version: '22.x' + - uses: oras-project/setup-oras@v1 - name: Trim CI agent run: | chmod +x ci/free_disk_space.sh diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index 130c8881..7c0a56d7 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -34,6 +34,7 @@ jobs: uses: actions/setup-node@v4 with: node-version: '22.x' + - uses: oras-project/setup-oras@v1 - name: Delete `.rustup` directory run: rm -rf /home/runner/.rustup # to save disk space if: runner.os == 'Linux' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 27456a48..dd283a3c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -40,6 +40,7 @@ jobs: uses: actions/setup-node@v4 with: node-version: '22.x' + - uses: oras-project/setup-oras@v1 - name: Delete `.rustup` directory run: rm -rf /home/runner/.rustup # to save disk space if: runner.os == 'Linux' diff --git a/build.sbt b/build.sbt index ab051b2b..25153b9c 100644 --- a/build.sbt +++ b/build.sbt @@ -1,7 +1,7 @@ name := "chen" ThisBuild / organization := "io.appthreat" -ThisBuild / version := "2.2.3" -ThisBuild / scalaVersion := "3.5.2" +ThisBuild / version := "2.3.0" +ThisBuild / scalaVersion := "3.6.2" val cpgVersion = "1.0.1" @@ -17,6 +17,7 @@ lazy val jssrc2cpg = Projects.jssrc2cpg lazy val javasrc2cpg = Projects.javasrc2cpg lazy val jimple2cpg = Projects.jimple2cpg lazy val php2atom = Projects.php2atom +lazy val ruby2atom = Projects.ruby2atom lazy val aggregatedProjects: Seq[ProjectReference] = Seq( platform, @@ -30,7 +31,8 @@ lazy val aggregatedProjects: Seq[ProjectReference] = Seq( jssrc2cpg, javasrc2cpg, jimple2cpg, - php2atom + php2atom, + ruby2atom, ) ThisBuild / libraryDependencies ++= Seq( diff --git a/ci/Dockerfile b/ci/Dockerfile index 36e49d7f..e814ae67 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -1,10 +1,10 @@ -FROM almalinux:9.4-minimal +FROM ghcr.io/appthreat/base:main LABEL maintainer="appthreat" \ org.opencontainers.image.authors="Team AppThreat " \ org.opencontainers.image.source="https://github.com/appthreat/chen" \ org.opencontainers.image.url="https://github.com/appthreat/chen" \ - org.opencontainers.image.version="2.2.x" \ + org.opencontainers.image.version="2.3.x" \ org.opencontainers.image.vendor="appthreat" \ org.opencontainers.image.licenses="Apache-2.0" \ org.opencontainers.image.title="chen" \ @@ -25,15 +25,20 @@ ENV JAVA_VERSION=$JAVA_VERSION \ PYTHON_CMD=python3 \ PYTHONUNBUFFERED=1 \ PYTHONIOENCODING="utf-8" \ - JAVA_OPTS="-XX:+UseG1GC -XX:+ExplicitGCInvokesConcurrent -XX:+ParallelRefProcEnabled -XX:+UseStringDeduplication -XX:+UnlockExperimentalVMOptions -XX:G1NewSizePercent=20 -XX:+UnlockDiagnosticVMOptions -XX:G1SummarizeRSetStatsPeriod=1" \ + JAVA_OPTS="-XX:+UseG1GC -XX:+ExplicitGCInvokesConcurrent -XX:+ParallelRefProcEnabled -XX:+UseStringDeduplication -XX:+UnlockExperimentalVMOptions -XX:G1NewSizePercent=20 -XX:+UnlockDiagnosticVMOptions -XX:G1SummarizeRSetStatsPeriod=1 -Dorg.jline.terminal.disableDeprecatedProviderWarning=true" \ CHEN_DATAFLOW_TRACKED_WIDTH=128 \ SCALAPY_PYTHON_LIBRARY=python3.12 \ ANDROID_HOME=/opt/android-sdk-linux \ CHEN_INSTALL_DIR=/opt/workspace \ PHP_PARSER_BIN=/opt/vendor/bin/php-parse \ CDXGEN_NO_BANNER=true \ - COMPOSER_ALLOW_SUPERUSER=1 -ENV PATH=/opt/miniconda3/bin:${PATH}:/opt/platform:${JAVA_HOME}/bin:${MAVEN_HOME}/bin:${GRADLE_HOME}/bin:/usr/local/bin/:/root/.local/bin:${ANDROID_HOME}/cmdline-tools/latest/bin:${ANDROID_HOME}/tools:${ANDROID_HOME}/tools/bin:${ANDROID_HOME}/platform-tools: + COMPOSER_ALLOW_SUPERUSER=1 \ + MALLOC_CONF="dirty_decay_ms:2000,narenas:2,background_thread:true" \ + RUBY_CONFIGURE_OPTS="--with-jemalloc --enable-yjit" \ + RUBYOPT="--yjit" \ + RUBY_BUILD_BUILD_PATH="/tmp/rbenv" \ + RUBY_BUILD_HTTP_CLIENT=curl +ENV PATH=/opt/miniconda3/bin:${PATH}:/opt/platform:${JAVA_HOME}/bin:${MAVEN_HOME}/bin:${GRADLE_HOME}/bin:/usr/local/bin/:/root/.local/bin:${ANDROID_HOME}/cmdline-tools/latest/bin:${ANDROID_HOME}/tools:${ANDROID_HOME}/tools/bin:${ANDROID_HOME}/platform-tools:/root/.rbenv/bin: WORKDIR /opt COPY ./ci/conda-install.sh /opt/ @@ -56,8 +61,22 @@ RUN set -e; \ *) echo >&2 "error: unsupported architecture: '$ARCH_NAME'"; exit 1 ;; \ esac; \ echo -e "[nodejs]\nname=nodejs\nstream=20\nprofiles=\nstate=enabled\n" > /etc/dnf/modules.d/nodejs.module \ - && microdnf install -y gcc git-core php php-cli php-curl php-zip php-bcmath php-json php-pear php-mbstring php-devel make wget bash graphviz graphviz-gd \ - pcre2 findutils which tar gzip zip unzip sudo nodejs ncurses sqlite-devel glibc-common glibc-all-langpacks \ + && microdnf install --nodocs -y gcc git-core php php-cli php-curl php-zip php-bcmath php-json php-pear php-mbstring php-devel make wget bash graphviz graphviz-gd \ + openssl-devel libffi-devel readline-devel libyaml zlib-devel ncurses ncurses-devel rust \ + pcre2 findutils which tar gzip zip unzip sudo nodejs sqlite-devel glibc-common glibc-all-langpacks \ + && microdnf install --nodocs -y epel-release \ + && microdnf install --nodocs --enablerepo=crb -y libyaml-devel jemalloc-devel \ + && git clone https://github.com/rbenv/rbenv.git --depth=1 ~/.rbenv \ + && echo 'export PATH="/root/.rbenv/bin:$PATH"' >> ~/.bashrc \ + && echo 'eval "$(/root/.rbenv/bin/rbenv init - bash)"' >> ~/.bashrc \ + && source ~/.bashrc \ + && mkdir -p "$(rbenv root)/plugins" \ + && git clone https://github.com/rbenv/ruby-build.git --depth=1 "$(rbenv root)/plugins/ruby-build" \ + && MAKE_OPTS=-j2 rbenv install 3.4.0 \ + && rbenv global 3.4.0 \ + && ruby --version \ + && which ruby \ + && rm -rf /root/.rbenv/cache $RUBY_BUILD_BUILD_PATH \ && mkdir -p /opt/miniconda3 /opt/workspace \ && wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${ARCH_NAME}.sh -O /opt/miniconda3/miniconda.sh \ && bash /opt/miniconda3/miniconda.sh -b -u -p /opt/miniconda3 \ @@ -74,15 +93,14 @@ RUN set -e; \ && rpm -ivh graphviz-devel-2.44.0-26.el9.${ARCH_NAME}.rpm \ && rm graphviz-devel-2.44.0-26.el9.${ARCH_NAME}.rpm \ && curl -s "https://get.sdkman.io" | bash \ - && source "$HOME/.sdkman/bin/sdkman-init.sh" \ - && echo -e "sdkman_auto_answer=true\nsdkman_selfupdate_feature=false\nsdkman_auto_env=true\nsdkman_curl_connect_timeout=20\nsdkman_curl_max_time=0" >> $HOME/.sdkman/etc/config \ + && source "/root/.sdkman/bin/sdkman-init.sh" \ + && echo -e "sdkman_auto_answer=true\nsdkman_selfupdate_feature=false\nsdkman_auto_env=true\nsdkman_curl_connect_timeout=20\nsdkman_curl_max_time=0" >> /root/.sdkman/etc/config \ && sdk install java $JAVA_VERSION \ && sdk install maven $MAVEN_VERSION \ && sdk install gradle $GRADLE_VERSION \ && sdk offline enable \ && mv /root/.sdkman/candidates/* /opt/ \ && rm -rf /root/.sdkman \ - && microdnf install -y epel-release \ && mkdir -p ${ANDROID_HOME}/cmdline-tools \ && curl -L https://dl.google.com/android/repository/commandlinetools-linux-11076708_latest.zip -o ${ANDROID_HOME}/cmdline-tools/android_tools.zip \ && unzip ${ANDROID_HOME}/cmdline-tools/android_tools.zip -d ${ANDROID_HOME}/cmdline-tools/ \ @@ -92,7 +110,7 @@ RUN set -e; \ && /opt/android-sdk-linux/cmdline-tools/latest/bin/sdkmanager 'platform-tools' --sdk_root=/opt/android-sdk-linux \ && /opt/android-sdk-linux/cmdline-tools/latest/bin/sdkmanager 'platforms;android-34' --sdk_root=/opt/android-sdk-linux \ && /opt/android-sdk-linux/cmdline-tools/latest/bin/sdkmanager 'build-tools;34.0.0' --sdk_root=/opt/android-sdk-linux \ - && sudo npm install -g @appthreat/atom @cyclonedx/cdxgen --omit=optional \ + && npm install -g @appthreat/atom @cyclonedx/cdxgen --omit=optional \ && php -r "copy('https://getcomposer.org/installer', 'composer-setup.php');" && php composer-setup.php \ && mv composer.phar /usr/local/bin/composer ENV LC_ALL=en_US.UTF-8 \ diff --git a/codemeta.json b/codemeta.json index 9a56dbd6..271dacb3 100644 --- a/codemeta.json +++ b/codemeta.json @@ -7,7 +7,7 @@ "downloadUrl": "https://github.com/AppThreat/chen", "issueTracker": "https://github.com/AppThreat/chen/issues", "name": "chen", - "version": "2.2.3", + "version": "2.3.0", "description": "Code Hierarchy Exploration Net (chen) is an advanced exploration toolkit for your application source code and its dependency hierarchy.", "applicationCategory": "code-analysis", "keywords": [ diff --git a/console/build.sbt b/console/build.sbt index a399eb06..b4a4ca4a 100644 --- a/console/build.sbt +++ b/console/build.sbt @@ -3,7 +3,7 @@ name := "console" enablePlugins(JavaAppPackaging) val ScoptVersion = "4.1.0" -val CaskVersion = "0.10.1" +val CaskVersion = "0.10.2" val CirceVersion = "0.14.10" val ZeroturnaroundVersion = "1.17" @@ -28,9 +28,9 @@ libraryDependencies ++= Seq( "com.lihaoyi" %% "pprint" % "0.9.0", "com.lihaoyi" %% "cask" % CaskVersion, "dev.scalapy" %% "scalapy-core" % "0.5.3", - "org.scala-lang.modules" % "scala-asm" % "9.7.0-scala-2", + "org.scala-lang.modules" % "scala-asm" % "9.7.1-scala-1", "org.scalatest" %% "scalatest" % Versions.scalatest % Test, - "org.scala-lang" %% "scala3-compiler" % "3.5.2" + "org.scala-lang" %% "scala3-compiler" % "3.6.2" ) diff --git a/dataflowengineoss/build.sbt b/dataflowengineoss/build.sbt index 12d54133..623c4a63 100644 --- a/dataflowengineoss/build.sbt +++ b/dataflowengineoss/build.sbt @@ -8,7 +8,7 @@ libraryDependencies ++= Seq( "io.circe" %% "circe-generic" % Versions.circe, "io.circe" %% "circe-parser" % Versions.circe, "org.scalatest" %% "scalatest" % Versions.scalatest % Test, - "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4" + "org.scala-lang.modules" %% "scala-parallel-collections" % "1.1.0" ) enablePlugins(Antlr4Plugin) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala index 820da78b..57cf670c 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala @@ -22,7 +22,7 @@ object Semantics: def empty: Semantics = fromList(List()) -class Semantics private (methodToSemantic: mutable.Map[String, FlowSemantic]): +class Semantics(methodToSemantic: mutable.Map[String, FlowSemantic]): /** The map below keeps a mapping between results of a regex and the regex string it matches. e.g. * diff --git a/meta.yaml b/meta.yaml index c0a4f3db..9f8b2b36 100644 --- a/meta.yaml +++ b/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "2.2.3" %} +{% set version = "2.3.0" %} package: name: chen diff --git a/platform/frontends/c2cpg/build.sbt b/platform/frontends/c2cpg/build.sbt index 1ecc1e77..cb69de8e 100644 --- a/platform/frontends/c2cpg/build.sbt +++ b/platform/frontends/c2cpg/build.sbt @@ -3,14 +3,14 @@ name := "c2cpg" dependsOn(Projects.semanticcpg, Projects.dataflowengineoss % Test, Projects.x2cpg % "compile->compile;test->test") libraryDependencies ++= Seq( - "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4", - "org.eclipse.platform" % "org.eclipse.equinox.common" % "3.19.100", - "org.eclipse.platform" % "org.eclipse.core.resources" % "3.21.0" excludeAll( + "org.scala-lang.modules" %% "scala-parallel-collections" % "1.1.0", + "org.eclipse.platform" % "org.eclipse.equinox.common" % "3.19.200", + "org.eclipse.platform" % "org.eclipse.core.resources" % "3.22.0" excludeAll( ExclusionRule(organization = "com.ibm.icu", name = "icu4j"), ExclusionRule(organization = "org.eclipse.platform", name = "org.eclipse.jface"), ExclusionRule(organization = "org.eclipse.platform", name = "org.eclipse.jface.text") ), - "org.jline" % "jline" % "3.27.1", + "org.jline" % "jline" % "3.28.0", "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala index d3bc1e28..b79e0c88 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala @@ -31,7 +31,9 @@ class AstCreationPass(cpg: Cpg, config: Config, report: Report = new Report()) .determine( config.inputPath, FileDefaults.SOURCE_FILE_EXTENSIONS ++ FileDefaults.HEADER_FILE_EXTENSIONS, - config.withDefaultIgnoredFilesRegex(DefaultIgnoredFolders) + ignoredDefaultRegex = Option(DefaultIgnoredFolders), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) ) .sortWith(_.compareToIgnoreCase(_) > 0) .toArray diff --git a/platform/frontends/javasrc2cpg/build.sbt b/platform/frontends/javasrc2cpg/build.sbt index de1c235d..66bb34cf 100644 --- a/platform/frontends/javasrc2cpg/build.sbt +++ b/platform/frontends/javasrc2cpg/build.sbt @@ -4,10 +4,10 @@ dependsOn(Projects.dataflowengineoss, Projects.x2cpg % "compile->compile;test->t libraryDependencies ++= Seq( "io.appthreat" %% "cpg2" % Versions.cpg, - "com.github.javaparser" % "javaparser-symbol-solver-core" % "3.26.2", + "com.github.javaparser" % "javaparser-symbol-solver-core" % "3.26.3", "org.scalatest" %% "scalatest" % Versions.scalatest % Test, "org.projectlombok" % "lombok" % "1.18.36", - "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4", + "org.scala-lang.modules" %% "scala-parallel-collections" % "1.1.0", "org.scala-lang.modules" %% "scala-parser-combinators" % "2.4.0", "net.lingala.zip4j" % "zip4j" % "2.11.5" ) diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala index 04c539a9..224fe526 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala @@ -90,8 +90,13 @@ object SourceParser: config: Config, sourcesOverride: Option[List[String]] = None ): Array[String] = - val inputPaths = sourcesOverride.getOrElse(config.inputPath :: Nil).toSet - SourceFiles.determine(inputPaths, JavaSrc2Cpg.sourceFileExtensions, config).toArray + SourceFiles.determine( + config.inputPath, + JavaSrc2Cpg.sourceFileExtensions, + ignoredDefaultRegex = Option(JavaSrc2Cpg.DefaultIgnoredFilesRegex), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ).toArray /** Implements the logic described in the option description for the "delombok-mode" option: * - no-delombok: do not run delombok. diff --git a/platform/frontends/jimple2cpg/AUTHORS b/platform/frontends/jimple2cpg/AUTHORS deleted file mode 100644 index b6d17d7d..00000000 --- a/platform/frontends/jimple2cpg/AUTHORS +++ /dev/null @@ -1,2 +0,0 @@ -David Baker Effendi -Fabian Yamaguchi diff --git a/platform/frontends/jimple2cpg/build.sbt b/platform/frontends/jimple2cpg/build.sbt index faaf57b6..00d53375 100644 --- a/platform/frontends/jimple2cpg/build.sbt +++ b/platform/frontends/jimple2cpg/build.sbt @@ -4,13 +4,14 @@ dependsOn(Projects.dataflowengineoss, Projects.x2cpg % "compile->compile;test->t libraryDependencies ++= Seq( "io.appthreat" %% "cpg2" % Versions.cpg, - "commons-io" % "commons-io" % "2.17.0", + "commons-io" % "commons-io" % "2.18.0", "org.soot-oss" % "soot" % "4.6.0", - "org.scala-lang.modules" % "scala-asm" % "9.7.0-scala-2", + "org.scala-lang.modules" % "scala-asm" % "9.7.1-scala-1", "org.ow2.asm" % "asm" % "9.7.1", "org.ow2.asm" % "asm-analysis" % "9.7.1", "org.ow2.asm" % "asm-util" % "9.7.1", "org.ow2.asm" % "asm-tree" % "9.7.1", + "io.circe" %% "circe-core" % Versions.circe, "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) diff --git a/platform/frontends/jimple2cpg/src/test/resources/project/build.properties b/platform/frontends/jimple2cpg/src/test/resources/project/build.properties index 19479ba4..73df629a 100644 --- a/platform/frontends/jimple2cpg/src/test/resources/project/build.properties +++ b/platform/frontends/jimple2cpg/src/test/resources/project/build.properties @@ -1 +1 @@ -sbt.version=1.5.2 +sbt.version=1.10.7 diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala index adad75a3..c0a4ff7e 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala @@ -4,15 +4,12 @@ import better.files.File import io.appthreat.php2atom.Config import io.appthreat.php2atom.parser.Domain.PhpFile import io.appthreat.x2cpg.utils.ExternalCommand -import org.slf4j.LoggerFactory import java.nio.file.Paths import scala.util.{Failure, Success, Try} class PhpParser private (phpParserPath: String, phpIniPath: String): - private val logger = LoggerFactory.getLogger(this.getClass) - private def phpParseCommand(filename: String): String = val phpParserCommands = "--with-recovery --resolve-names -P --json-dump" phpParserPath match @@ -33,7 +30,6 @@ class PhpParser private (phpParserPath: String, phpIniPath: String): processParserOutput(output, inputFilePath) case Failure(exception) => - logger.debug(s"Failure running php-parser with $command", exception.getMessage) None private def processParserOutput(output: Seq[String], filename: String): Option[PhpFile] = @@ -48,17 +44,11 @@ class PhpParser private (phpParserPath: String, phpIniPath: String): case Success(Some(value)) => Some(value) case Success(None) => - logger.debug(s"Parsing json string for $filename resulted in null return value") None case Failure(exception) => - logger.debug( - s"Parsing json string for $filename failed with exception", - exception - ) None else - logger.debug(s"No JSON output for $filename") None private def jsonValueToPhpFile(json: ujson.Value, filename: String): Option[PhpFile] = @@ -66,12 +56,10 @@ class PhpParser private (phpParserPath: String, phpIniPath: String): case Success(phpFile) => Some(phpFile) case Failure(e) => - logger.debug(s"Failed to generate intermediate AST for $filename", e) None end PhpParser object PhpParser: - private val logger = LoggerFactory.getLogger(this.getClass()) val PhpParserBinEnvVar = "PHP_PARSER_BIN" diff --git a/platform/frontends/pysrc2cpg/build.sbt b/platform/frontends/pysrc2cpg/build.sbt index fc5000e6..076023bf 100644 --- a/platform/frontends/pysrc2cpg/build.sbt +++ b/platform/frontends/pysrc2cpg/build.sbt @@ -4,7 +4,7 @@ dependsOn(Projects.dataflowengineoss, Projects.x2cpg % "compile->compile;test->t libraryDependencies ++= Seq( "io.appthreat" %% "cpg2" % Versions.cpg, - "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4", + "org.scala-lang.modules" %% "scala-parallel-collections" % "1.1.0", "org.scalatest" %% "scalatest" % Versions.scalatest % Test ) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala index a5238cbe..c5666e67 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala @@ -4,7 +4,6 @@ import io.appthreat.x2cpg.passes.frontend.TypeRecoveryParserConfig import io.appthreat.x2cpg.{SourceFiles, X2Cpg, X2CpgConfig, X2CpgFrontend} import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.utils.IOUtils -import org.slf4j.LoggerFactory import java.nio.file.* import scala.util.Try @@ -35,48 +34,50 @@ case class Py2CpgOnFileSystemConfig( end Py2CpgOnFileSystemConfig class Py2CpgOnFileSystem extends X2CpgFrontend[Py2CpgOnFileSystemConfig]: - private val logger = LoggerFactory.getLogger(getClass) /** Entry point for files system based cpg generation from python code. * @param config * Configuration for cpg generation. */ override def createCpg(config: Py2CpgOnFileSystemConfig): Try[Cpg] = - logConfiguration(config) - - X2Cpg.withNewEmptyCpg(config.outputPath, config) { (cpg, _) => - val venvIgnorePath = - if config.ignoreVenvDir then - config.venvDir :: Nil - else - Nil - val inputPath = Path.of(config.inputPath) - val ignoreDirNamesSet = config.ignoreDirNames.toSet - val absoluteIgnorePaths = (config.ignorePaths ++ venvIgnorePath).map { path => - inputPath.resolve(path) - } - - val inputFiles = SourceFiles - .determine(config.inputPath, Set(".py"), config) - .map(x => Path.of(x)) - .filter { file => filterIgnoreDirNames(file, inputPath, ignoreDirNamesSet) } - .filter { file => - !absoluteIgnorePaths.exists(ignorePath => file.startsWith(ignorePath)) - } - - val inputProviders = inputFiles.map { inputFile => () => - val content = IOUtils.readLinesInFile(inputFile).mkString("\n") - Py2Cpg.InputPair(content, inputPath.relativize(inputFile).toString) + X2Cpg.withNewEmptyCpg(config.outputPath, config) { (cpg, _) => + val venvIgnorePath = + if config.ignoreVenvDir then + config.venvDir :: Nil + else + Nil + val inputPath = Path.of(config.inputPath) + val ignoreDirNamesSet = config.ignoreDirNames.toSet + val absoluteIgnorePaths = (config.ignorePaths ++ venvIgnorePath).map { path => + inputPath.resolve(path) + } + + val inputFiles = SourceFiles + .determine( + config.inputPath, + Set(".py"), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ) + .map(x => Path.of(x)) + .filter { file => filterIgnoreDirNames(file, inputPath, ignoreDirNamesSet) } + .filter { file => + !absoluteIgnorePaths.exists(ignorePath => file.startsWith(ignorePath)) + } + + val inputProviders = inputFiles.map { inputFile => () => + val content = IOUtils.readLinesInFile(inputFile).mkString("\n") + Py2Cpg.InputPair(content, inputPath.relativize(inputFile).toString) + } + val py2Cpg = new Py2Cpg( + inputProviders, + cpg, + config.inputPath, + config.requirementsTxt, + config.schemaValidation + ) + py2Cpg.buildCpg() } - val py2Cpg = new Py2Cpg( - inputProviders, - cpg, - config.inputPath, - config.requirementsTxt, - config.schemaValidation - ) - py2Cpg.buildCpg() - } end createCpg private def filterIgnoreDirNames( @@ -93,12 +94,4 @@ class Py2CpgOnFileSystem extends X2CpgFrontend[Py2CpgOnFileSystemConfig]: val aPartIsInIgnoreSet = parts.exists(part => ignoreDirNamesSet.contains(part.toString)) !aPartIsInIgnoreSet - private def logConfiguration(config: Py2CpgOnFileSystemConfig): Unit = - logger.debug(s"Output file: ${config.outputPath}") - logger.debug(s"Input directory: ${config.inputPath}") - logger.debug(s"Venv directory: ${config.venvDir}") - logger.debug(s"IgnoreVenvDir: ${config.ignoreVenvDir}") - logger.debug(s"IgnorePaths: ${config.ignorePaths.mkString(", ")}") - logger.debug(s"IgnoreDirNames: ${config.ignoreDirNames.mkString(", ")}") - logger.debug(s"No dummy types: ${config.disableDummyTypes}") end Py2CpgOnFileSystem diff --git a/platform/frontends/ruby2atom/.gitignore b/platform/frontends/ruby2atom/.gitignore new file mode 100644 index 00000000..a086a83c --- /dev/null +++ b/platform/frontends/ruby2atom/.gitignore @@ -0,0 +1,5 @@ +# Created by IntelliJ's ANTLR plugin +gen/ +*.tokens +type_stubs +src/main/resources/ruby_ast_gen diff --git a/platform/frontends/ruby2atom/build.sbt b/platform/frontends/ruby2atom/build.sbt new file mode 100644 index 00000000..eb3e589d --- /dev/null +++ b/platform/frontends/ruby2atom/build.sbt @@ -0,0 +1,23 @@ +name := "ruby2atom" + +dependsOn(Projects.dataflowengineoss % "compile->compile;test->test", Projects.x2cpg % "compile->compile;test->test") + +libraryDependencies ++= Seq( + "io.appthreat" %% "cpg2" % Versions.cpg, + "com.lihaoyi" %% "upickle" % Versions.upickle, + "org.scalatest" %% "scalatest" % Versions.scalatest % Test +) + +enablePlugins(JavaAppPackaging, LauncherJarPlugin) +Global / onChangedBuildSource := ReloadOnSourceChanges +Universal / packageName := name.value +Universal / topLevelDirectory := None +githubOwner := "appthreat" +githubRepository := "chen" +credentials += + Credentials( + "GitHub Package Registry", + "maven.pkg.github.com", + "appthreat", + sys.env.getOrElse("GITHUB_TOKEN", "N/A") + ) diff --git a/platform/frontends/ruby2atom/src/main/resources/log4j2.xml b/platform/frontends/ruby2atom/src/main/resources/log4j2.xml new file mode 100755 index 00000000..ab3b2d88 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/resources/log4j2.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Main.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Main.scala new file mode 100644 index 00000000..6a766d03 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Main.scala @@ -0,0 +1,54 @@ +package io.appthreat.ruby2atom + +import io.appthreat.ruby2atom.Frontend.* +import io.appthreat.x2cpg.astgen.AstGenConfig +import io.appthreat.x2cpg.passes.frontend.{ + TypeRecoveryParserConfig, + XTypeRecovery, + XTypeRecoveryConfig +} +import io.appthreat.x2cpg.typestub.TypeStubConfig +import io.appthreat.x2cpg.{X2CpgConfig, X2CpgMain} +import scopt.OParser + +import java.nio.file.Paths + +final case class Config(downloadDependencies: Boolean = false, useTypeStubs: Boolean = true) + extends X2CpgConfig[Config] + with TypeRecoveryParserConfig[Config] + with TypeStubConfig[Config] + with AstGenConfig[Config]: + + override val astGenProgramName: String = "ruby_ast_gen" + override val astGenConfigPrefix: String = "ruby2atom" + override val multiArchitectureBuilds: Boolean = true + + this.defaultIgnoredFilesRegex = + List("spec", "tests?", "vendor", "db(\\\\|/)([\\w_]*)migrate([_\\w]*)").flatMap { + directory => + List( + s"(^|\\\\)$directory($$|\\\\)".r.unanchored, + s"(^|/)$directory($$|/)".r.unanchored + ) + } + + override def withTypeStubs(value: Boolean): Config = + copy(useTypeStubs = value).withInheritedFields(this) +end Config + +private object Frontend: + + implicit val defaultConfig: Config = Config() + + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("ruby2atom"), + TypeStubConfig.parserOptions + ) + +object Main extends X2CpgMain(cmdLineParser, new Ruby2Atom()): + + def run(config: Config, rubySrc2Cpg: Ruby2Atom): Unit = + rubySrc2Cpg.run(config) diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Ruby2Atom.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Ruby2Atom.scala new file mode 100644 index 00000000..6fa160f0 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/Ruby2Atom.scala @@ -0,0 +1,75 @@ +package io.appthreat.ruby2atom + +import better.files.File +import io.appthreat.ruby2atom.astcreation.AstCreator +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.StatementList +import io.appthreat.ruby2atom.datastructures.RubyProgramSummary +import io.appthreat.ruby2atom.parser.* +import io.appthreat.ruby2atom.passes.{AstCreationPass, ConfigFileCreationPass} +import io.appthreat.x2cpg.X2Cpg.withNewEmptyCpg +import io.appthreat.x2cpg.frontendspecific.ruby2atom.* +import io.appthreat.x2cpg.passes.base.AstLinkerPass +import io.appthreat.x2cpg.passes.callgraph.NaiveCallLinker +import io.appthreat.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass, XTypeRecoveryConfig} +import io.appthreat.x2cpg.utils.{ConcurrentTaskUtil, ExternalCommand} +import io.appthreat.x2cpg.{SourceFiles, X2CpgFrontend} +import io.shiftleft.codepropertygraph.generated.{Cpg, Languages} +import io.shiftleft.passes.CpgPassBase +import io.shiftleft.semanticcpg.language.* +import upickle.default.* + +import java.nio.file.{Files, Paths} +import scala.util.matching.Regex +import scala.util.{Failure, Success, Try, Using} + +class Ruby2Atom extends X2CpgFrontend[Config]: + + override def createCpg(config: Config): Try[Cpg] = + withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => + new MetaDataPass(cpg, Languages.RUBYSRC, config.inputPath).createAndApply() + new ConfigFileCreationPass(cpg).createAndApply() + createCpgAction(cpg, config) + } + + private def createCpgAction(cpg: Cpg, config: Config): Unit = + File.usingTemporaryDirectory("ruby2atomOut") { tmpDir => + val astGenResult = RubyAstGenRunner(config).execute(tmpDir) + + val astCreators = ConcurrentTaskUtil + .runUsingThreadPool( + Ruby2Atom.processAstGenRunnerResults( + astGenResult.parsedFiles, + config, + cpg.metaData.root.headOption + ) + ) + .flatMap { + case Failure(exception) => None + case Success(astCreator) => Option(astCreator) + } + AstCreationPass(cpg, astCreators).createAndApply() + TypeNodePass.withTypesFromCpg(cpg).createAndApply() + } +end Ruby2Atom + +object Ruby2Atom: + + /** Parses the generated AST Gen files in parallel and produces AstCreators from each. + */ + def processAstGenRunnerResults( + astFiles: List[String], + config: Config, + projectRoot: Option[String] + ): Iterator[() => AstCreator] = + astFiles.map { fileName => () => + val parserResult = RubyJsonParser.readFile(Paths.get(fileName)) + val rubyProgram = new RubyJsonToNodeCreator().visitProgram(parserResult.json) + val sourceFileName = parserResult.fullPath + new AstCreator( + sourceFileName, + projectRoot, + enableFileContents = false, + rootNode = rubyProgram + )(config.schemaValidation) + }.iterator +end Ruby2Atom diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreator.scala new file mode 100644 index 00000000..46304bec --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreator.scala @@ -0,0 +1,145 @@ +package io.appthreat.ruby2atom.astcreation + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.* +import io.appthreat.ruby2atom.datastructures.{ + BlockScope, + NamespaceScope, + RubyProgramSummary, + RubyScope +} +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.ruby2atom.utils.FreshNameGenerator +import io.appthreat.x2cpg.utils.NodeBuilders.{newModifierNode, newThisParameterNode} +import io.appthreat.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, ValidationMode} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal +import overflowdb.BatchedUpdate.DiffGraphBuilder + +import java.util.regex.Matcher + +class AstCreator( + val fileName: String, + protected val projectRoot: Option[String] = None, + protected val programSummary: RubyProgramSummary = RubyProgramSummary(), + val enableFileContents: Boolean = false, + val rootNode: StatementList +)(implicit withSchemaValidation: ValidationMode) + extends AstCreatorBase(fileName) + with AstCreatorHelper + with AstForStatementsCreator + with AstForExpressionsCreator + with AstForControlStructuresCreator + with AstForFunctionsCreator + with AstForTypesCreator + with AstNodeBuilder[RubyExpression, AstCreator]: + + val tmpGen: FreshNameGenerator[String] = FreshNameGenerator(i => s"") + val procParamGen: FreshNameGenerator[Left[String, Nothing]] = + FreshNameGenerator(i => Left(s"")) + + /* Used to track variable names and their LOCAL nodes. + */ + protected val scope: RubyScope = new RubyScope(programSummary, projectRoot) + + protected var fileNode: Option[NewFile] = None + + protected var parseLevel: AstParseLevel = AstParseLevel.FULL_AST + + override protected def offset(node: RubyExpression): Option[(Integer, Integer)] = node.offset + + protected val relativeFileName: String = + projectRoot + .map(fileName.stripPrefix) + .map(_.stripPrefix(java.io.File.separator)) + .getOrElse(fileName) + + private def internalLineAndColNum: Option[Integer] = Option(1) + + /** The relative file name, in a unix path delimited format. + */ + private def relativeUnixStyleFileName = + relativeFileName.replaceAll(Matcher.quoteReplacement(java.io.File.separator), "/") + + override def createAst(): DiffGraphBuilder = + val ast = astForRubyFile(rootNode) + Ast.storeInDiffGraph(ast, diffGraph) + diffGraph + + /* A Ruby file has the following AST hierarchy: FILE -> NAMESPACE_BLOCK -> METHOD. + * The (parsed) contents of the file are put under that fictitious METHOD node, thus + * allowing for a straightforward representation of out-of-method statements. + */ + protected def astForRubyFile(rootStatements: StatementList): Ast = + fileNode = Option(NewFile().name(relativeFileName)) + val fullName = + s"$relativeUnixStyleFileName:${NamespaceTraversal.globalNamespaceName}".stripPrefix("/") + + val namespaceBlock = NewNamespaceBlock() + .filename(relativeFileName) + .name(NamespaceTraversal.globalNamespaceName) + .fullName(fullName) + + scope.pushNewScope(NamespaceScope(fullName)) + val rubyFakeMethodAst = astInFakeMethod(rootStatements) + scope.popScope() + + Ast(fileNode.get).withChild(Ast(namespaceBlock).withChild(rubyFakeMethodAst)) + + private def astInFakeMethod(rootNode: StatementList): Ast = + val name = Defines.Main + // From the
method onwards, we do not embed the namespace name in the full names + val fullName = + s"${scope.surroundingScopeFullName.head.stripSuffix(NamespaceTraversal.globalNamespaceName)}$name" + val code = rootNode.text + val methodNode_ = methodNode( + node = rootNode, + name = name, + code = code, + fullName = fullName, + signature = None, + fileName = relativeFileName + ) + val thisParameterNode = newThisParameterNode( + name = Defines.Self, + code = Defines.Self, + typeFullName = Defines.Any, + line = methodNode_.lineNumber, + column = methodNode_.columnNumber + ) + val thisParameterAst = Ast(thisParameterNode) + scope.addToScope(Defines.Self, thisParameterNode) + val methodReturn = methodReturnNode(rootNode, Defines.Any) + + scope.newProgramScope + .map { moduleScope => + scope.pushNewScope(moduleScope) + val block = blockNode(rootNode) + scope.pushNewScope(BlockScope(block)) + val statementAsts = rootNode.statements.flatMap(astsForStatement) + scope.popScope() + val bodyAst = blockAst(block, statementAsts) + scope.popScope() + methodAst( + methodNode_, + thisParameterAst :: Nil, + bodyAst, + methodReturn, + newModifierNode("MODULE") :: newModifierNode(ModifierTypes.VIRTUAL) :: Nil + ) + } + .getOrElse(Ast()) + end astInFakeMethod +end AstCreator + +/** Determines till what depth the AST creator will parse until. + */ +enum AstParseLevel: + + /** This level will parse all types and methods signatures, but exclude method bodies. + */ + case SIGNATURES + + /** This level will parse the full AST. + */ + case FULL_AST diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreatorHelper.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreatorHelper.scala new file mode 100644 index 00000000..a629beef --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstCreatorHelper.scala @@ -0,0 +1,262 @@ +package io.appthreat.ruby2atom.astcreation +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{ + ClassFieldIdentifier, + ControlFlowStatement, + DummyNode, + IfExpression, + InstanceFieldIdentifier, + MemberAccess, + RubyExpression, + RubyFieldIdentifier, + SingleAssignment, + StatementList, + TextSpan, + UnaryExpression +} +import io.appthreat.ruby2atom.datastructures.{BlockScope, FieldDecl} +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.ruby2atom.passes.GlobalTypes +import io.appthreat.ruby2atom.passes.GlobalTypes.{kernelFunctions, kernelPrefix} +import io.appthreat.x2cpg.{Ast, ValidationMode} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Operators} + +import scala.collection.mutable + +trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + private val usedFullNames = mutable.Set.empty[String] + + /** Ensures a unique full name is assigned based on the current scope. + * @param name + * the name of the entity. + * @param counter + * an optional counter, used to create unique instances in the case of redefinitions. + * @param useSurroundingTypeFullName + * flag for whether the fullName is for accessor-like method lowering + * @return + * a unique full name. + */ + protected def computeFullName( + name: String, + counter: Option[Int] = None, + useSurroundingTypeFullName: Boolean = false + ): String = + val surroundingName = + if useSurroundingTypeFullName then scope.surroundingTypeFullName.head + else scope.surroundingScopeFullName.head + val candidate = counter match + case Some(cnt) => s"$surroundingName.$name$cnt" + case None => s"$surroundingName.$name" + if usedFullNames.contains(candidate) then + computeFullName(name, counter.map(_ + 1).orElse(Option(0)), useSurroundingTypeFullName) + else + usedFullNames.add(candidate) + candidate + + override def column(node: RubyExpression): Option[Integer] = node.column + override def columnEnd(node: RubyExpression): Option[Integer] = node.columnEnd + override def line(node: RubyExpression): Option[Integer] = node.line + override def lineEnd(node: RubyExpression): Option[Integer] = node.lineEnd + + override def code(node: RubyExpression): String = shortenCode(node.text) + + protected def isBuiltin(x: String): Boolean = kernelFunctions.contains(x) + protected def prefixAsKernelDefined(x: String): String = s"$kernelPrefix$pathSep$x" + protected def prefixAsBundledType(x: String): String = s"${GlobalTypes.builtinPrefix}.$x" + protected def isBundledClass(x: String): Boolean = GlobalTypes.bundledClasses.contains(x) + protected def pathSep = "." + + private def astForFieldInstance(name: String, node: RubyExpression & RubyFieldIdentifier): Ast = + val identName = node match + case _: InstanceFieldIdentifier => Defines.Self + case _: ClassFieldIdentifier => + scope.surroundingTypeFullName.map(_.split("[.]").last).getOrElse(Defines.Any) + + astForFieldAccess( + MemberAccess( + DummyNode(identifierNode(node, identName, identName, Defines.Any))( + node.span.spanStart(identName) + ), + ".", + name + )(node.span) + ) + + protected def handleVariableOccurrence(node: RubyExpression): Ast = + val name = code(node) + val identifier = identifierNode(node, name, name, Defines.Any) + val typeRef = scope.tryResolveTypeReference(name) + + node match + case fieldVariable: RubyFieldIdentifier => + scope.findFieldInScope(name) match + case None => + scope.pushField(FieldDecl(name, Defines.Any, false, false, fieldVariable)) + astForFieldInstance(name, fieldVariable) + case Some(field) => + astForFieldInstance(name, field.node) + case _ => + scope.lookupVariable(name) match + case None if typeRef.isDefined => + Ast(identifier.typeFullName(typeRef.get.name)) + case None => + val local = localNode(node, name, name, Defines.Any) + scope.addToScope(name, local) match + case BlockScope(block) => diffGraph.addEdge(block, local, EdgeTypes.AST) + case _ => + Ast(identifier).withRefEdge(identifier, local) + case Some(local) => + local match + case x: NewLocal => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName) + case x: NewMethodParameterIn => + identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName) + Ast(identifier).withRefEdge(identifier, local) + end match + end handleVariableOccurrence + + protected def astForAssignment( + lhs: NewNode, + rhs: NewNode, + lineNumber: Option[Integer], + columnNumber: Option[Integer] + ): Ast = + astForAssignment(Ast(lhs), Ast(rhs), lineNumber, columnNumber) + + protected def astForAssignment( + lhs: Ast, + rhs: Ast, + lineNumber: Option[Integer], + columnNumber: Option[Integer], + code: Option[String] = None + ): Ast = + val _code = + code.getOrElse( + Seq(lhs, rhs).flatMap(_.root).collect { case x: ExpressionNew => x.code }.mkString(" = ") + ) + val assignment = NewCall() + .name(Operators.assignment) + .methodFullName(Operators.assignment) + .code(_code) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + + callAst(assignment, Seq(lhs, rhs)) + end astForAssignment + + protected def memberForMethod( + method: NewMethod, + astParentType: Option[String] = None, + astParentFullName: Option[String] = None + ): NewMember = + NewMember().name(method.name).code(method.name).dynamicTypeHintFullName(Seq(method.fullName)) + + /** Lowers the `||=` and `&&=` assignment operators to the respective `.nil?` checks + */ + def lowerAssignmentOperator( + lhs: RubyExpression, + rhs: RubyExpression, + op: String, + span: TextSpan + ): RubyExpression & + ControlFlowStatement = + val condition = nilCheckCondition(lhs, op, "nil?", span) + val thenClause = nilCheckThenClause(lhs, rhs, span) + nilCheckIfStatement(condition, thenClause, span) + + /** Generates the required `.nil?` check condition used in the lowering of `||=` and `&&=` + */ + private def nilCheckCondition( + lhs: RubyExpression, + op: String, + memberName: String, + span: TextSpan + ): RubyExpression = + val memberAccess = + MemberAccess(lhs, op = ".", memberName = "nil?")(span.spanStart(s"${lhs.span.text}.nil?")) + if op == "||=" then memberAccess + else + UnaryExpression(op = "!", expression = memberAccess)( + span.spanStart(s"!${memberAccess.span.text}") + ) + + /** Generates the assignment and the `thenClause` used in the lowering of `||=` and `&&=` + */ + private def nilCheckThenClause( + lhs: RubyExpression, + rhs: RubyExpression, + span: TextSpan + ): RubyExpression = + StatementList(List( + SingleAssignment(lhs, "=", rhs)(span.spanStart(s"${lhs.span.text} = ${rhs.span.text}")) + ))( + span.spanStart(s"${lhs.span.text} = ${rhs.span.text}") + ) + + /** Generates the if statement for the lowering of `||=` and `&&=` + */ + private def nilCheckIfStatement( + condition: RubyExpression, + thenClause: RubyExpression, + span: TextSpan + ): RubyExpression & ControlFlowStatement = + IfExpression( + condition = condition, + thenClause = thenClause, + elsifClauses = List.empty, + elseClause = None + )( + span.spanStart(s"if ${condition.span.text} then ${thenClause.span.text} end") + ) + + protected val UnaryOperatorNames: Map[String, String] = Map( + "!" -> Operators.logicalNot, + "not" -> Operators.logicalNot, + "~" -> Operators.not, + "+" -> Operators.plus, + "-" -> Operators.minus + ) + + protected val BinaryOperatorNames: Map[String, String] = + Map( + "+" -> Operators.addition, + "-" -> Operators.subtraction, + "*" -> Operators.multiplication, + "/" -> Operators.division, + "%" -> Operators.modulo, + "**" -> Operators.exponentiation, + "==" -> Operators.equals, + "===" -> Operators.equals, + "!=" -> Operators.notEquals, + "<" -> Operators.lessThan, + "<=" -> Operators.lessEqualsThan, + ">" -> Operators.greaterThan, + ">=" -> Operators.greaterEqualsThan, + "<=>" -> Operators.compare, + "&&" -> Operators.logicalAnd, + "and" -> Operators.logicalAnd, + "or" -> Operators.logicalOr, + "||" -> Operators.logicalOr, + "&" -> Operators.and, + "|" -> Operators.or, + "^" -> Operators.xor, +// "<<" -> Operators.shiftLeft, Note: Generally Ruby abstracts this as an append operator based on the LHS + ">>" -> Operators.arithmeticShiftRight + ) + + protected val AssignmentOperatorNames: Map[String, String] = Map( + "=" -> Operators.assignment, + "+=" -> Operators.assignmentPlus, + "-=" -> Operators.assignmentMinus, + "*=" -> Operators.assignmentMultiplication, + "/=" -> Operators.assignmentDivision, + "%=" -> Operators.assignmentModulo, + "**=" -> Operators.assignmentExponentiation, + "|=" -> Operators.assignmentOr, + "&=" -> Operators.assignmentAnd, + "<<=" -> Operators.assignmentShiftLeft, + ">>=" -> Operators.assignmentArithmeticShiftRight + ) +end AstCreatorHelper diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForControlStructuresCreator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForControlStructuresCreator.scala new file mode 100644 index 00000000..4c5a60f3 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForControlStructuresCreator.scala @@ -0,0 +1,365 @@ +package io.appthreat.ruby2atom.astcreation + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{ + ArrayPattern, + BinaryExpression, + BreakExpression, + CaseExpression, + ControlFlowStatement, + DoWhileExpression, + ElseClause, + ForExpression, + IfExpression, + InClause, + MatchVariable, + MemberCall, + NextExpression, + OperatorAssignment, + RescueExpression, + ReturnExpression, + RubyExpression, + SimpleCall, + SimpleIdentifier, + SingleAssignment, + SplattingRubyNode, + StatementList, + UnaryExpression, + Unknown, + UnlessExpression, + UntilExpression, + WhenClause, + WhileExpression +} +import io.appthreat.ruby2atom.parser.RubyJsonHelpers +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.ruby2atom.passes.Defines.RubyOperators +import io.appthreat.x2cpg.{Ast, ValidationMode} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} +import io.shiftleft.codepropertygraph.generated.nodes.{ + NewBlock, + NewFieldIdentifier, + NewIdentifier, + NewLiteral, + NewLocal +} + +trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + protected def astForControlStructureExpression(node: ControlFlowStatement): Ast = node match + case node: WhileExpression => astForWhileStatement(node) + case node: DoWhileExpression => astForDoWhileStatement(node) + case node: UntilExpression => astForUntilStatement(node) + case node: CaseExpression => blockAst(NewBlock(), astsForCaseExpression(node).toList) + case node: IfExpression => astForIfExpression(node) + case node: UnlessExpression => astForUnlessStatement(node) + case node: ForExpression => astForForExpression(node) + case node: RescueExpression => astForRescueExpression(node) + case node: NextExpression => astForNextExpression(node) + case node: BreakExpression => astForBreakExpression(node) + case node: OperatorAssignment => astForOperatorAssignmentExpression(node) + + private def astForWhileStatement(node: WhileExpression): Ast = + val conditionAst = astForExpression(node.condition) + val bodyAsts = astsForStatement(node.body) + whileAst(Some(conditionAst), bodyAsts, Option(code(node)), line(node), column(node)) + + private def astForDoWhileStatement(node: DoWhileExpression): Ast = + val conditionAst = astForExpression(node.condition) + val bodyAsts = astsForStatement(node.body) + doWhileAst(Some(conditionAst), bodyAsts, Option(code(node)), line(node), column(node)) + + // `until T do B` is lowered as `while !T do B` + private def astForUntilStatement(node: UntilExpression): Ast = + val notCondition = astForExpression(UnaryExpression("!", node.condition)(node.condition.span)) + val bodyAsts = astsForStatement(node.body) + whileAst(Some(notCondition), bodyAsts, Option(code(node)), line(node), column(node)) + + // Recursively lowers into a ternary conditional call + private def astForIfExpression(node: IfExpression): Ast = + def builder(node: IfExpression, conditionAst: Ast, thenAst: Ast, elseAsts: List[Ast]): Ast = + // We want to make sure there's always an «else» clause in a ternary operator. + // The default value is a `nil` literal. + val elseAsts_ = if elseAsts.isEmpty then + List(astForNilBlock) + else + elseAsts + + val call = callNode( + node, + code(node), + Operators.conditional, + Operators.conditional, + DispatchTypes.STATIC_DISPATCH + ) + callAst(call, conditionAst :: thenAst :: elseAsts_) + + // TODO: Remove or modify the builder pattern when we are no longer using ANTLR + node.elseClause match + case Some(elseClause) => + elseClause match + case _: IfExpression => astForJsonIfStatement(node) + case _ => foldIfExpression(builder)(node) + case None => + foldIfExpression(builder)(node) + end astForIfExpression + + private def astForJsonIfStatement(node: IfExpression): Ast = + val conditionAst = astForExpression(node.condition) + val thenAst = astForThenClause(node.thenClause) + val elseAsts = node.elseClause + .map { + case x: IfExpression => + val wrappedBlock = blockNode(x) + Ast(wrappedBlock).withChild(astForJsonIfStatement(x)) + case x => + astForElseClause(x) + } + .getOrElse(Ast()) + + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(conditionAst), thenAst :: elseAsts :: Nil) + + // `unless T do B` is lowered as `if !T then B` + private def astForUnlessStatement(node: UnlessExpression): Ast = + val notConditionAst = + astForExpression(UnaryExpression("!", node.condition)(node.condition.span)) + val thenAst = node.trueBranch match + case stmtList: StatementList => astForStatementList(stmtList) + case _ => astForStatementList(StatementList(List(node.trueBranch))(node.trueBranch.span)) + val elseAsts = node.falseBranch.map(astForElseClause).toList + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(notConditionAst), thenAst :: elseAsts) + + protected def astForElseClause(node: RubyExpression): Ast = + node match + case elseNode: ElseClause => + elseNode.thenClause match + case stmtList: StatementList => astForStatementList(stmtList) + case node => + astForUnknown(node) + case elseNode => + astForUnknown(elseNode) + + private def astForForExpression(node: ForExpression): Ast = + val forEachNode = controlStructureNode(node, ControlStructureTypes.FOR, code(node)) + + def collectionAst = astForExpression(node.iterableVariable) + val collectionNode = node.iterableVariable + + val iterIdentifier = + identifierNode( + node = node.forVariable, + name = node.forVariable.span.text, + code = node.forVariable.span.text, + typeFullName = Defines.Any + ) + val iterVarLocal = NewLocal().name(node.forVariable.span.text).code(node.forVariable.span.text) + scope.addToScope(node.forVariable.span.text, iterVarLocal) + + val idxName = "_idx_" + val idxLocal = + NewLocal().name(idxName).code(idxName).typeFullName(Defines.getBuiltInType(Defines.Integer)) + val idxIdenAtAssign = identifierNode( + node = collectionNode, + name = idxName, + code = idxName, + typeFullName = Defines.getBuiltInType(Defines.Integer) + ) + + val idxAssignment = + callNode( + node, + s"$idxName = 0", + Operators.assignment, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val idxAssignmentArgs = + List( + Ast(idxIdenAtAssign), + Ast(NewLiteral().code("0").typeFullName(Defines.getBuiltInType(Defines.Integer))) + ) + val idxAssignmentAst = callAst(idxAssignment, idxAssignmentArgs) + + val idxIdAtCond = idxIdenAtAssign.copy + val collectionCountAccess = callNode( + node, + s"${node.iterableVariable.span.text}.length", + Operators.fieldAccess, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH + ) + val fieldAccessAst = callAst( + collectionCountAccess, + collectionAst :: Ast(NewFieldIdentifier().canonicalName("length").code("length")) :: Nil + ) + + val idxLt = callNode( + node, + s"$idxName < ${node.iterableVariable.span.text}.length", + Operators.lessThan, + Operators.lessThan, + DispatchTypes.STATIC_DISPATCH + ) + val idxLtArgs = List(Ast(idxIdAtCond), fieldAccessAst) + val ltCallCond = callAst(idxLt, idxLtArgs) + + val idxIdAtCollAccess = idxIdenAtAssign.copy + val collectionIdxAccess = callNode( + node, + s"${node.iterableVariable.span.text}[$idxName++]", + Operators.indexAccess, + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH + ) + val postIncrAst = callAst( + callNode( + node, + s"$idxName++", + Operators.postIncrement, + Operators.postIncrement, + DispatchTypes.STATIC_DISPATCH + ), + Ast(idxIdAtCollAccess) :: Nil + ) + + val indexAccessAst = callAst(collectionIdxAccess, collectionAst :: postIncrAst :: Nil) + val iteratorAssignmentNode = callNode( + node, + s"${node.forVariable.span.text} = ${node.iterableVariable.span.text}[$idxName++]", + Operators.assignment, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val iteratorAssignmentArgs = List(Ast(iterIdentifier), indexAccessAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + val doBodyAst = astsForStatement(node.doBlock) + + val locals = Ast(idxLocal) + .withRefEdge(idxIdenAtAssign, idxLocal) + .withRefEdge(idxIdAtCond, idxLocal) + .withRefEdge(idxIdAtCollAccess, idxLocal) :: Ast(iterVarLocal).withRefEdge( + iterIdentifier, + iterVarLocal + ) :: Nil + + val conditionAsts = ltCallCond :: Nil + val initAsts = idxAssignmentAst :: Nil + val updateAsts = iteratorAssignmentAst :: Nil + + forAst( + forNode = forEachNode, + locals = locals, + initAsts = initAsts, + conditionAsts = conditionAsts, + updateAsts = updateAsts, + bodyAsts = doBodyAst + ) + end astForForExpression + + protected def astsForCaseExpression(node: CaseExpression): Seq[Ast] = + // TODO: Clean up the below + def goCase(expr: Option[SimpleIdentifier]): List[RubyExpression] = + val elseThenClause: Option[RubyExpression] = + node.elseClause.map(_.asInstanceOf[ElseClause].thenClause) + val whenClauses = node.matchClauses.collect { case x: WhenClause => x } + val inClauses = node.matchClauses.collect { case x: InClause => x } + + val ifElseChain = if whenClauses.nonEmpty then + whenClauses.foldRight[Option[RubyExpression]](elseThenClause) { + (whenClause: WhenClause, restClause: Option[RubyExpression]) => + // We translate multiple match expressions into an or expression. + // + // A single match expression is compared using `.===` to the case target expression if it is present + // otherwise it is treated as a conditional. + // + // There may be a splat as the last match expression, + // `case y when *x then c end` or + // `case when *x then c end` + // which is translated to `x.include? y` and `x.any?` conditions respectively + + val conditions = whenClause.matchExpressions.map { mExpr => + expr.map(e => BinaryExpression(mExpr, "===", e)(mExpr.span)).getOrElse(mExpr) + } ++ whenClause.matchSplatExpression.iterator.flatMap { + case splat @ SplattingRubyNode(exprList) => + expr + .map { e => + List(MemberCall(exprList, ".", "include?", List(e))(splat.span)) + } + .getOrElse { + List(MemberCall(exprList, ".", "any?", List())(splat.span)) + } + case e => + List(Unknown()(e.span)) + } + // There is always at least one match expression or a splat + // will become an unknown in condition at the end + val condition = conditions.init.foldRight(conditions.last) { (cond, condAcc) => + BinaryExpression(cond, "||", condAcc)(whenClause.span) + } + val conditional = IfExpression( + condition, + whenClause.thenClause.asStatementList, + List(), + restClause.map { els => ElseClause(els.asStatementList)(els.span) } + )(node.span) + Some(conditional) + } + else + inClauses.foldRight[Option[RubyExpression]](elseThenClause) { + (inClause: InClause, restClause: Option[RubyExpression]) => + val (condition, body) = inClause.pattern match + case x: ArrayPattern => + val condition = expr.map(e => BinaryExpression(x, "===", e)(x.span)).getOrElse( + inClause.pattern + ) + val body = inClause.body + + val variables = x.children.collect { case x: MatchVariable => + x + } + + val conditionBody = if variables.nonEmpty then + StatementList(variables.map { x => + val lhs = SimpleIdentifier()(x.span) + SingleAssignment(lhs, "=", x)( + inClause.span + .spanStart( + s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${lhs.span.text})" + ) + ) + } :+ body)(body.span) + else + body + + (condition, conditionBody) + case x => (x, inClause.body) + + val conditional = IfExpression( + condition, + body, + List.empty, + restClause.map { els => ElseClause(els.asStatementList)(els.span) } + )(node.span) + Some(conditional) + } + ifElseChain.iterator.toList + end goCase + + def generatedNode: StatementList = node.expression + .map { e => + val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh)) + StatementList( + List(SingleAssignment(tmp, "=", e)(e.span)) ++ + goCase(Some(tmp)) + )(node.span) + } + .getOrElse(StatementList(goCase(None))(node.span)) + astsForStatement(generatedNode) + end astsForCaseExpression + + private def astForOperatorAssignmentExpression(node: OperatorAssignment): Ast = + val loweredAssignment = lowerAssignmentOperator(node.lhs, node.rhs, node.op, node.span) + astForControlStructureExpression(loweredAssignment) +end AstForControlStructuresCreator diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForExpressionsCreator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForExpressionsCreator.scala new file mode 100644 index 00000000..d62ba463 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForExpressionsCreator.scala @@ -0,0 +1,1137 @@ +package io.appthreat.ruby2atom.astcreation + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{Unknown, Block as RubyBlock, *} +import io.appthreat.ruby2atom.datastructures.BlockScope +import io.appthreat.ruby2atom.parser.RubyJsonHelpers +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.ruby2atom.passes.GlobalTypes +import io.appthreat.ruby2atom.passes.Defines.{RubyOperators, getBuiltInType} +import io.appthreat.x2cpg.{Ast, ValidationMode, Defines as XDefines} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{ + ControlStructureTypes, + DispatchTypes, + EdgeTypes, + NodeTypes, + Operators, + PropertyNames +} + +import scala.collection.mutable + +trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + /** For tracking aliased calls that occur on the LHS of a member access or call. + */ + protected val baseAstCache = mutable.Map.empty[RubyExpression, String] + + protected def astForExpression(node: RubyExpression): Ast = node match + case node: ControlFlowStatement => astForControlStructureExpression(node) + case node: StaticLiteral => astForStaticLiteral(node) + case node: HereDocNode => astForHereDoc(node) + case node: DynamicLiteral => astForDynamicLiteral(node) + case node: UnaryExpression => astForUnary(node) + case node: BinaryExpression => astForBinary(node) + case node: MemberAccess => astForMemberAccess(node) + case node: MemberCall => astForMemberCall(node) + case node: ObjectInstantiation => astForObjectInstantiation(node) + case node: IndexAccess => astForIndexAccess(node) + case node: SingleAssignment => astForSingleAssignment(node) + case node: AttributeAssignment => astForAttributeAssignment(node) + case node: TypeIdentifier => astForTypeIdentifier(node) + case node: RubyIdentifier => astForSimpleIdentifier(node) + case node: SimpleCall => astForSimpleCall(node) + case node: RequireCall => astForRequireCall(node) + case node: IncludeCall => astForIncludeCall(node) + case node: RaiseCall => astForRaiseCall(node) + case node: YieldExpr => astForYield(node) + case node: RangeExpression => astForRange(node) + case node: ArrayLiteral => astForArrayLiteral(node) + case node: HashLike => astForHashLiteral(node) + case node: Association => astForAssociation(node) + case node: MandatoryParameter => astForMandatoryParameter(node) + case node: SplattingRubyNode => astForSplattingRubyNode(node) + case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) + case node: ProcOrLambdaExpr => astForProcOrLambdaExpr(node) + case node: SingletonObjectMethodDeclaration => astForSingletonObjectMethodDeclaration(node) + case node: RubyCallWithBlock[?] => astForCallWithBlock(node) + case node: SelfIdentifier => astForSelfIdentifier(node) + case node: StatementList => astForStatementList(node) + case node: MultipleAssignment => blockAst(blockNode(node), astsForStatement(node).toList) + case node: ReturnExpression => astForReturnExpression(node) + case node: AccessModifier => astForSimpleIdentifier(node.toSimpleIdentifier) + case node: ArrayPattern => astForArrayPattern(node) + case node: MatchVariable => astForMatchVariable(node) + case node: DummyNode => Ast(node.node) + case node: Unknown => astForUnknown(node) + case x => + astForUnknown(node) + + protected def astForStaticLiteral(node: StaticLiteral): Ast = + Ast(literalNode(node, code(node), node.typeFullName)) + + protected def astForHereDoc(node: HereDocNode): Ast = + Ast(literalNode(node, code(node), getBuiltInType("String"))) + + // Helper for nil literals to put in empty clauses + protected def astForNilLiteral: Ast = + Ast(NewLiteral().code("nil").typeFullName(getBuiltInType(Defines.NilClass))) + + protected def astForNilBlock: Ast = blockAst(NewBlock(), List(astForNilLiteral)) + + protected def astForDynamicLiteral(node: DynamicLiteral): Ast = + val fmtValueAsts = node.expressions.map { + case stmtList: StatementList if stmtList.size == 1 => + val expressionAst = astForExpression(stmtList.statements.head) + val call = callNode( + node = stmtList, + code = stmtList.text, + name = Operators.formattedValue, + methodFullName = Operators.formattedValue, + dispatchType = DispatchTypes.STATIC_DISPATCH, + signature = None, + typeFullName = Some(node.typeFullName) + ) + callAst(call, Seq(expressionAst)) + case stmtList: StatementList if stmtList.size > 1 => + astForUnknown(stmtList) + case node => + val call = callNode( + node = node, + code = node.text, + name = Operators.formattedValue, + methodFullName = Operators.formattedValue, + dispatchType = DispatchTypes.STATIC_DISPATCH, + signature = None, + typeFullName = Option(Defines.Any) + ) + callAst(call, Seq(astForExpression(node))) + } + callAst( + callNode( + node = node, + code = code(node), + name = Operators.formatString, + methodFullName = Operators.formatString, + dispatchType = DispatchTypes.STATIC_DISPATCH, + signature = None, + typeFullName = Some(node.typeFullName) + ), + fmtValueAsts + ) + end astForDynamicLiteral + + protected def astForUnary(node: UnaryExpression): Ast = + getUnaryOperatorName(node.op) match + case None => + astForUnknown(node) + case Some(op) => + val expressionAst = astForExpression(node.expression) + val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + callAst(call, Seq(expressionAst)) + + protected def astForBinary(node: BinaryExpression): Ast = + getBinaryOperatorName(node.op) match + case None => + astForMemberCall(MemberCall(node.lhs, ".", node.op, List(node.rhs))(node.span)) + case Some(op) => + val lhsAst = astForExpression(node.lhs) + val rhsAst = astForExpression(node.rhs) + val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + callAst(call, Seq(lhsAst, rhsAst)) + + // Member accesses are checked in RubyNodeCreator, i.e. `x.y` is the call of `y` of `x` without any arguments. + // where x.Y is considered a constant access as Y is capitalized. + protected def astForMemberAccess(node: MemberAccess): Ast = + node.target match + case x: SimpleIdentifier => + val newTarget = scope.getSurroundingType(x.text).map(_.fullName) match + case Some(surroundingType) => + val typeName = surroundingType.split('.').last + TypeIdentifier(s"$surroundingType")(x.span.spanStart(typeName)) + case None => x + astForFieldAccess(node.copy(target = newTarget)(node.span)) + case _ => astForFieldAccess(node) + + /** Attempts to extract a type from the base of a member call. + */ + protected def typeFromCallTarget(baseNode: RubyExpression): Option[String] = + baseNode match + case literal: LiteralExpr => Option(literal.typeFullName) + case _ => + scope.lookupVariable(baseNode.text) match + // fixme: This should be under type recovery logic + case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => + Option(decl.typeFullName) + case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => + Option(decl.typeFullName) + case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => + decl.dynamicTypeHintFullName.headOption + case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty => + decl.dynamicTypeHintFullName.headOption + case _ => None + + private def astForTypeIdentifier(node: TypeIdentifier): Ast = + Ast(typeRefNode(node, code(node), node.typeFullName)) + + protected def astForMemberCall(node: MemberCall, isStatic: Boolean = false): Ast = + + def createMemberCall(n: MemberCall): Ast = + val receiverAst = astForFieldAccess( + MemberAccess(n.target, ".", n.methodName)(n.span), + stripLeadingAt = true + ) + val (baseAst, baseCode) = astForMemberAccessTarget(n.target) + val builtinType = n.target match + case MemberAccess(_: SelfIdentifier, _, memberName) if isBundledClass(memberName) => + Option(prefixAsBundledType(memberName)) + case x: TypeIdentifier if x.isBuiltin => Option(x.typeFullName) + case _ => None + val methodFullName = receiverAst.nodes + .collectFirst { + case _ if builtinType.isDefined => s"${builtinType.get}.${n.methodName}" + case x: NewMethodRef => x.methodFullName + case _ => + (n.target match + case ma: MemberAccess => + scope.tryResolveTypeReference(ma.memberName).map(_.name) + case _ => typeFromCallTarget(n.target) + ).map(x => s"$x.${n.methodName}") + .getOrElse(XDefines.DynamicCallUnknownFullName) + } + .getOrElse(XDefines.DynamicCallUnknownFullName) + val argumentAsts = n.arguments.map(astForMethodCallArgument) + val dispatchType = + if isStatic then DispatchTypes.STATIC_DISPATCH else DispatchTypes.DYNAMIC_DISPATCH + + val callCode = if baseCode.contains(" target + case x: SimpleIdentifier => + scope.getSurroundingType(x.text).map(_.fullName) match + case Some(surroundingType) => + val typeName = surroundingType.split('.').last + TypeIdentifier(s"$surroundingType")(x.span.spanStart(typeName)) + case None if scope.lookupVariable(x.text).isDefined => x + case None if x.text.charAt(0).isUpper => // calls have lower-case first character + MemberAccess(SelfIdentifier()(x.span.spanStart(Defines.Self)), ".", x.text)(x.span) + case None => MemberCall( + SelfIdentifier()(x.span.spanStart(Defines.Self)), + ".", + x.text, + Nil + )(x.span) + case x @ MemberAccess(ma, _, _) => x.copy(target = determineMemberAccessBase(ma))(x.span) + case _ => target + + node.target match + case _: LiteralExpr => + createMemberCall(node) + case x: SimpleIdentifier if isBundledClass(x.text) => + createMemberCall( + node.copy(target = TypeIdentifier(prefixAsBundledType(x.text))(x.span))(node.span) + ) + case x: SimpleIdentifier => + createMemberCall(node.copy(target = determineMemberAccessBase(x))(node.span)) + case memAccess: MemberAccess => + createMemberCall(node.copy(target = determineMemberAccessBase(memAccess))(node.span)) + case _ => createMemberCall(node) + end astForMemberCall + + protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = + val (memberName, memberCode) = node.target match + case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize + case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@") + case _: TypeIdentifier => node.memberName -> node.memberName + case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) => + s"@${node.memberName}" -> node.memberName + case _ => node.memberName -> node.memberName + + val fieldIdentifierAst = Ast(fieldIdentifierNode(node, memberName, memberCode)) + val (targetAst, _code) = astForMemberAccessTarget(node.target) + val code = s"$_code${node.op}$memberCode" + val memberType = typeFromCallTarget(node.target) + .flatMap(scope.tryResolveTypeReference) + .map(_.fields) + .getOrElse(List.empty) + .collectFirst { + case x if x.name == memberName => + scope.tryResolveTypeReference(x.typeName).map(_.name).getOrElse(Defines.Any) + } + .orElse(Option(Defines.Any)) + val fieldAccess = callNode( + node, + code, + Operators.fieldAccess, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + signature = None, + typeFullName = Option(Defines.Any) + ).possibleTypes(IndexedSeq(memberType.get)) + callAst(fieldAccess, Seq(targetAst, fieldIdentifierAst)) + end astForFieldAccess + + private def astForMemberAccessTarget(target: RubyExpression): (Ast, String) = + target match + case simpleLhs: (LiteralExpr | SimpleIdentifier | SelfIdentifier | TypeIdentifier) => + astForExpression(simpleLhs) -> code(target) + case target: MemberAccess => + handleTmpGen(target, astForFieldAccess(target, stripLeadingAt = true)) + case target => handleTmpGen(target, astForExpression(target)) + + private def handleTmpGen(target: RubyExpression, rhs: Ast): (Ast, String) = + // Check cache + val createAssignmentToTmp = !baseAstCache.contains(target) + val tmpName = baseAstCache + .updateWith(target) { + case Some(tmpName) => + // TODO: Type ref nodes are automatically committed on creation, so if we have found a suitable cached AST, + // we want to clean this creation up. + Option(tmpName) + case None => + val tmpName = this.tmpGen.fresh + val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any) + scope.addToScope(tmpName, tmpGenLocal) match + case BlockScope(block) => diffGraph.addEdge(block, tmpGenLocal, EdgeTypes.AST) + case _ => + Option(tmpName) + } + .get + val tmpIden = NewIdentifier().name(tmpName).code(tmpName).typeFullName(Defines.Any) + val tmpIdenAst = + scope.lookupVariable(tmpName).map(x => Ast(tmpIden).withRefEdge(tmpIden, x)).getOrElse(Ast( + tmpIden + )) + val code = s"$tmpName = ${target.text}" + if createAssignmentToTmp then + astForAssignment(tmpIdenAst, rhs, target.line, target.column, Option(code)) -> s"($code)" + else + tmpIdenAst -> s"($code)" + end handleTmpGen + + protected def astForIndexAccess(node: IndexAccess): Ast = + // Array::[] and Hash::[] looks like an index access to the parser, some other methods may have this name too + lazy val defaultBehaviour = + val indexAsts = node.indices.map(astForExpression) + val targetAst = astForExpression(node.target) + val call = + callNode( + node, + code(node), + Operators.indexAccess, + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH + ) + callAst(call, targetAst +: indexAsts) + scope.tryResolveTypeReference(node.target.text).map(_.name) match + case Some(typeReference) => + scope + .tryResolveMethodInvocation("[]", typeFullName = Option(typeReference)) + .map { m => + val expr = + astForExpression(MemberCall(node.target, ".", "[]", node.indices)(node.span)) + expr.root.collect { case x: NewCall => + x.methodFullName(s"$typeReference.${m.name}") + scope.tryResolveTypeReference(m.returnType).map(_.name).foreach( + x.typeFullName(_) + ) + } + expr + } + .getOrElse(defaultBehaviour) + case None if node.indices.isEmpty => + astForExpression(MemberCall(node.target, ".", "[]", node.indices)(node.span)) + case None => defaultBehaviour + end astForIndexAccess + + /* `foo() do end` is lowered as a METHOD node shaped like so: + * ``` + * = def 0() + * + * end + * foo(, ) + * ``` + */ + protected def astForCallWithBlock[C <: RubyCall](node: RubyExpression & RubyCallWithBlock[C]) + : Ast = + val Seq(typeRef, _) = astForDoBlock(node.block): @unchecked + val typeRefDummyNode = typeRef.root.map(DummyNode(_)(node.span)).toList + + // Create call with argument referencing the MethodRef + val callWithLambdaArg = node.withoutBlock match + case x: SimpleCall => + astForSimpleCall(x.copy(arguments = x.arguments ++ typeRefDummyNode)(x.span)) + case x: MemberCall => + astForMemberCall(x.copy(arguments = x.arguments ++ typeRefDummyNode)(x.span)) + case x => + Ast() + + callWithLambdaArg + + protected def astForObjectInstantiation(node: RubyExpression & ObjectInstantiation): Ast = + /* + We short-cut the call edge from `new` call to `initialize` method, however we keep the modelling of the receiver + as referring to the singleton class. + */ + val (receiverTypeFullName, fullName) = node.target match + case x: (SimpleIdentifier | MemberAccess) => + scope.tryResolveTypeReference(x.text) match + case Some(typeMetaData) => + s"${typeMetaData.name}" -> s"${typeMetaData.name}.${Defines.Initialize}" + case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName + case _ => XDefines.Any -> XDefines.DynamicCallUnknownFullName + /* + Similarly to some other frontends, we lower the constructor into two operations, e.g., + `return Bar.new`, lowered to + `return {Bar tmp = Bar.(); tmp.(); tmp}` + */ + val block = blockNode(node) + scope.pushNewScope(BlockScope(block)) + + val tmpName = this.tmpGen.fresh + val tmpTypeHint = receiverTypeFullName.stripSuffix("") + val tmp = SimpleIdentifier(None)(node.span.spanStart(tmpName)) + val tmpLocal = NewLocal().name(tmpName).code(tmpName).dynamicTypeHintFullName(Seq(tmpTypeHint)) + scope.addToScope(tmpName, tmpLocal) + + def tmpIdentifier = + val tmpAst = astForSimpleIdentifier(tmp) + tmpAst.root.collect { case x: NewIdentifier => x.typeFullName(tmpTypeHint) } + tmpAst + + // Assign tmp to + val allocCall = + callNode(node, code(node), Operators.alloc, Operators.alloc, DispatchTypes.STATIC_DISPATCH) + val allocAst = callAst(allocCall, Seq.empty) + val assignmentCall = callNode( + node, + s"${tmp.text} = ${code(node.target)}.${Defines.Initialize}", + Operators.assignment, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val tmpAssignment = callAst(assignmentCall, Seq(tmpIdentifier, allocAst)) + + // Call constructor + val argumentAsts = node match + case x: SimpleObjectInstantiation => x.arguments.map(astForMethodCallArgument) + case x: ObjectInstantiationWithBlock => + val Seq(typeRef, _) = astForDoBlock(x.block): @unchecked + x.arguments.map(astForMethodCallArgument) :+ typeRef + + val constructorCall = + callNode( + node, + code(node), + Defines.Initialize, + XDefines.DynamicCallUnknownFullName, + DispatchTypes.DYNAMIC_DISPATCH + ) + if fullName != XDefines.DynamicCallUnknownFullName then + constructorCall.dynamicTypeHintFullName(Seq(fullName)) + val constructorRecv = + astForExpression(MemberAccess(node.target, ".", Defines.Initialize)(node.span)) + val constructorCallAst = + callAst(constructorCall, argumentAsts, Option(tmpIdentifier), Option(constructorRecv)) + val retIdentifierAst = tmpIdentifier + scope.popScope() + + // Assemble statements + blockAst(block, Ast(tmpLocal) :: tmpAssignment :: constructorCallAst :: retIdentifierAst :: Nil) + end astForObjectInstantiation + + protected def astForSingleAssignment(node: SingleAssignment): Ast = + node.rhs match + case x: Unknown if x.span.text == Defines.Undefined => + // If the RHS is undefined, then this variable is not defined/placed in the variable table/registry + Ast() + case _ => + getAssignmentOperatorName(node.op) match + case None => + astForUnknown(node) + case Some(op) => + node.rhs match + case cfNode: ControlFlowStatement => + def elseAssignNil(span: TextSpan) = Option { + ElseClause( + StatementList( + SingleAssignment( + node.lhs, + node.op, + StaticLiteral(getBuiltInType(Defines.NilClass))( + span.spanStart("nil") + ) + )(span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) :: Nil + )(span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) + )(span.spanStart(s"else\n\t${node.lhs.span.text} ${node.op} nil\nend")) + } + + def transform(e: RubyExpression & ControlFlowStatement): RubyExpression = + transformLastRubyNodeInControlFlowExpressionBody( + e, + x => reassign(node.lhs, node.op, x, transform), + elseAssignNil + ) + + cfNode match + case x @ OperatorAssignment(lhs, op, rhs) => + val loweredNode = lowerAssignmentOperator(lhs, rhs, op, x.span) + astForExpression(transform(loweredNode)) + case x => + astForExpression(transform(cfNode)) + + case _ => + val rhsAst = astForExpression(node.rhs) + // The if the LHS defines a new variable, put the local variable into scope + val lhsAst = node.lhs match + case x: SimpleIdentifier if scope.lookupVariable(code(x)).isEmpty => + val name = code(x) + val local = localNode(x, name, name, Defines.Any) + scope.addToScope(name, local) match + case BlockScope(block) => + diffGraph.addEdge(block, local, EdgeTypes.AST) + case _ => + astForExpression(node.lhs) + case SplattingRubyNode(nameNode: SimpleIdentifier) + if scope.lookupVariable(code(nameNode)).isEmpty => + val name = code(nameNode) + val local = localNode(nameNode, name, name, Defines.Any) + scope.addToScope(name, local) match + case BlockScope(block) => + diffGraph.addEdge(block, local, EdgeTypes.AST) + case _ => + astForExpression(node.lhs) + case x: GroupedParameter => + val asts = astsForStatement(x.multipleAssignment) + val call = + callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + return callAst(call, asts :+ rhsAst) + case _ => astForExpression(node.lhs) + + // If this is a simple object instantiation assignment, we can give the LHS variable a type hint + if node.rhs.isInstanceOf[ObjectInstantiation] && lhsAst.root.exists( + _.isInstanceOf[NewIdentifier] + ) + then + rhsAst.nodes.collectFirst { + case tmp: NewIdentifier + if tmp.name.startsWith(" + lhsAst.root.collectFirst { case i: NewIdentifier => + scope.lookupVariable(i.name).foreach { + case x: NewLocal => + x.dynamicTypeHintFullName( + x.dynamicTypeHintFullName :+ tmp.typeFullName + ) + case x: NewMethodParameterIn => + x.dynamicTypeHintFullName( + x.dynamicTypeHintFullName :+ tmp.typeFullName + ) + } + i.dynamicTypeHintFullName( + i.dynamicTypeHintFullName :+ tmp.typeFullName + ) + } + } + end if + + val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH) + callAst(call, Seq(lhsAst, rhsAst)) + + private def reassign( + lhs: RubyExpression, + op: String, + rhs: RubyExpression, + transform: (RubyExpression & ControlFlowStatement) => RubyExpression + ): RubyExpression = + def stmtListAssigningLastExpression(stmts: List[RubyExpression]): List[RubyExpression] = + stmts match + case (head: ControlFlowClause) :: Nil => clauseAssigningLastExpression(head) :: Nil + case (head: ControlFlowStatement) :: Nil => transform(head) :: Nil + case head :: Nil => + SingleAssignment(lhs, op, head)( + rhs.span.spanStart(s"${lhs.span.text} $op ${head.span.text}") + ) :: Nil + case Nil => List.empty + case head :: tail => head :: stmtListAssigningLastExpression(tail) + + def clauseAssigningLastExpression(x: RubyExpression & ControlFlowClause): RubyExpression = + x match + case RescueClause(exceptionClassList, assignment, thenClause) => + RescueClause( + exceptionClassList, + assignment, + reassign(lhs, op, thenClause, transform) + )(x.span) + case EnsureClause(thenClause) => + EnsureClause(reassign(lhs, op, thenClause, transform))(x.span) + case ElsIfClause(condition, thenClause) => + ElsIfClause(condition, reassign(lhs, op, thenClause, transform))(x.span) + case ElseClause(thenClause) => + ElseClause(reassign(lhs, op, thenClause, transform))(x.span) + case WhenClause(matchExpressions, matchSplatExpression, thenClause) => + WhenClause( + matchExpressions, + matchSplatExpression, + reassign(lhs, op, thenClause, transform) + )(x.span) + + rhs match + case StatementList(statements) => + StatementList(stmtListAssigningLastExpression(statements))(rhs.span) + case clause: ControlFlowClause => clauseAssigningLastExpression(clause) + case expr: ControlFlowStatement => transform(expr) + case _ => + SingleAssignment(lhs, op, rhs)( + rhs.span.spanStart(s"${lhs.span.text} $op ${rhs.span.text}") + ) + end reassign + + // `x.y = 1` is approximated as `x.y = 1`, i.e. as calling `x.y =` assignment with argument `1` + // This has the benefit of avoiding unnecessary call resolution + protected def astForAttributeAssignment(node: AttributeAssignment): Ast = + val memberAccess = MemberAccess(node.target, ".", s"@${node.attributeName}")( + node.span.spanStart(s"${node.target.text}.${node.attributeName}") + ) + + val assignmentOp = AssignmentOperatorNames(node.assignmentOperator) + + val lhsAst = astForFieldAccess(memberAccess, stripLeadingAt = true) + val rhsAst = astForExpression(node.rhs) + val call = callNode(node, code(node), assignmentOp, assignmentOp, DispatchTypes.STATIC_DISPATCH) + callAst(call, Seq(lhsAst, rhsAst)) + + protected def astForSimpleIdentifier(node: RubyExpression & RubyIdentifier): Ast = + val name = code(node) + if isBundledClass(name) then + val typeFullName = prefixAsBundledType(name) + Ast(typeRefNode(node, typeFullName, typeFullName)) + else + scope.lookupVariable(name) match + case Some(_) => handleVariableOccurrence(node) + case None if scope.tryResolveMethodInvocation(node.text).isDefined => + astForSimpleCall(SimpleCall(node, List())(node.span)) + case None => + astForMemberAccess( + MemberAccess(SelfIdentifier()(node.span.spanStart(Defines.Self)), ".", node.text)( + node.span + ) + ) + + protected def astForArrayPattern(node: ArrayPattern): Ast = + val callNode_ = + callNode( + node, + code(node), + Operators.arrayInitializer, + Operators.arrayInitializer, + DispatchTypes.STATIC_DISPATCH + ) + val childrenAst = node.children.map(astForExpression) + + callAst(callNode_, childrenAst) + + protected def astForMatchVariable(node: MatchVariable): Ast = + val nodeCode = shortenCode(s"${RubyOperators.arrayPatternMatch}(${node.span.text})") + val callNode_ = callNode( + node, + nodeCode, + RubyOperators.arrayPatternMatch, + RubyOperators.arrayPatternMatch, + DispatchTypes.STATIC_DISPATCH + ) + val identAst = astForExpression(SimpleIdentifier()(node.span)) + + callAst(callNode_, identAst :: Nil) + + protected def astForMandatoryParameter(node: RubyExpression): Ast = handleVariableOccurrence(node) + + protected def astForSimpleCall(node: SimpleCall): Ast = + node.target match + case targetNode: SimpleIdentifier => astForMethodCallWithoutBlock(node, targetNode) + case targetNode: RubyFieldIdentifier => + astForMemberCallWithoutBlock(node, targetNode.toMemberAccess) + case targetNode: MemberAccess => astForMemberCallWithoutBlock(node, targetNode) + case targetNode => + astForUnknown(targetNode) + + protected def astForRequireCall(node: RequireCall): Ast = + val pathOpt = node.argument match + case arg: StaticLiteral if arg.isString => Option(arg.innerText) + case _ => None + pathOpt.foreach(path => + scope.addRequire(projectRoot.get, fileName, path, node.isRelative, node.isWildCard) + ) + + val callName = node.target.text + val requireCallNode = NewCall() + .name(node.target.text) + .code(code(node)) + .methodFullName(getBuiltInType(callName)) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(Defines.Any) + val arguments = astForExpression(node.argument) :: Nil + callAst(requireCallNode, arguments) + + protected def astForIncludeCall(node: IncludeCall): Ast = + scope.addInclude( + node.argument.text.replaceAll("::", ".") + ) // Maybe generate ast and get name in a more structured approach instead + astForSimpleCall(node.asSimpleCall) + + protected def astForRaiseCall(node: RaiseCall): Ast = + val throwControlStruct = controlStructureNode(node, ControlStructureTypes.THROW, code(node)) + val args = node.arguments.map(astForExpression) + Ast(throwControlStruct).withChildren(args) + + /** A yield in Ruby calls an explicit (or implicit) proc parameter and returns its value. This can + * be lowered as block.call(), which is effectively how one invokes a proc parameter in any case. + */ + protected def astForYield(node: YieldExpr): Ast = + scope.useProcParam match + case Some(param) => + // We do not know if we necessarily have an explicit proc param here, or if we need to create a new one + if scope.lookupVariable(param).isEmpty then + scope.anonProcParam.map { param => + val paramNode = ProcParameter(param)(node.span.spanStart(s"&$param")) + astForParameter(paramNode, -1) + } + val loweredCall = + MemberCall( + SimpleIdentifier()(node.span.spanStart(param)), + ".", + "call", + node.arguments + )(node.span) + astForExpression(loweredCall) + case None => + astForUnknown(node) + + protected def astForRange(node: RangeExpression): Ast = + val lbAst = astForExpression(node.lowerBound) + val ubAst = astForExpression(node.upperBound) + val call = + callNode(node, code(node), Operators.range, Operators.range, DispatchTypes.STATIC_DISPATCH) + callAst(call, Seq(lbAst, ubAst)) + + protected def astForArrayLiteral(node: ArrayLiteral): Ast = + val arguments = if node.text.startsWith("%") then + val argumentsType = + if node.isStringArray then getBuiltInType(Defines.String) + else getBuiltInType(Defines.Symbol) + node.elements.map { + case element @ StaticLiteral(_) => StaticLiteral(argumentsType)(element.span) + case element @ DynamicLiteral(_, expressions) => + DynamicLiteral(argumentsType, expressions)(element.span) + case element => element + } + else + node.elements + val argumentAsts = arguments.map(astForExpression) + + val call = + callNode( + node, + code(node), + Operators.arrayInitializer, + Operators.arrayInitializer, + DispatchTypes.STATIC_DISPATCH + ) + callAst(call, argumentAsts) + end astForArrayLiteral + + protected def astForHashLiteral(node: HashLike): Ast = + val tmp = this.tmpGen.fresh + + def tmpAst(tmpNode: Option[RubyExpression] = None) = astForSimpleIdentifier( + SimpleIdentifier()(tmpNode.map(_.span).getOrElse(node.span).spanStart(tmp)) + ) + + val block = blockNode(node) + scope.pushNewScope(BlockScope(block)) + val tmpLocal = NewLocal().name(tmp).code(tmp) + scope.addToScope(tmp, tmpLocal) + + val argumentAsts = node.elements.flatMap(elem => + elem match + case associationNode: Association => astForAssociationHash(associationNode, tmp) + case splattingRubyNode: SplattingRubyNode => + astForSplattingRubyNode(splattingRubyNode) :: Nil + case node => + astForUnknown(node) :: Nil + ) + + val hashInitCall = callNode( + node, + code(node), + RubyOperators.hashInitializer, + RubyOperators.hashInitializer, + DispatchTypes.STATIC_DISPATCH + ) + + val assignment = + callNode( + node, + code(node), + Operators.assignment, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val tmpAssignment = callAst(assignment, tmpAst() :: Ast(hashInitCall) :: Nil) + val tmpRetAst = tmpAst(node.elements.lastOption) + + scope.popScope() + blockAst(block, tmpAssignment +: argumentAsts :+ tmpRetAst) + end astForHashLiteral + + protected def astForAssociationHash(node: Association, tmp: String): List[Ast] = + node.key match + case mod: AccessModifier => + // Modifiers aren't allowed here, will be shadowed by a simple identifier + astForAssociationHash(node.copy(key = mod.toSimpleIdentifier)(node.span), tmp) + case iden: SimpleIdentifier => + // An identifier here will always be interpreted as a symbol + val sym = + StaticLiteral(getBuiltInType(Defines.Symbol))(iden.span.spanStart(s":${iden.text}")) + astForAssociationHash(node.copy(key = sym)(node.span), tmp) + case rangeExpr: RangeExpression => + val expandedList = generateStaticLiteralsForRange(rangeExpr).map { x => + astForSingleKeyValue(x, node.value, tmp) + } + + if expandedList.nonEmpty then + expandedList + else + astForSingleKeyValue(node.key, node.value, tmp) :: Nil + case _ => astForSingleKeyValue(node.key, node.value, tmp) :: Nil + + protected def generateStaticLiteralsForRange(node: RangeExpression): List[StaticLiteral] = + (node.lowerBound, node.upperBound) match + case (lb: StaticLiteral, ub: StaticLiteral) => + (lb.typeFullName, ub.typeFullName) match + case ( + s"${GlobalTypes.`kernelPrefix`}.Integer", + s"${GlobalTypes.`kernelPrefix`}.Integer" + ) => + generateRange( + lb.span.text.toInt, + ub.span.text.toInt, + node.rangeOperator.exclusive + ) + .map(x => + StaticLiteral(lb.typeFullName)(TextSpan( + lb.line, + lb.column, + lb.lineEnd, + lb.columnEnd, + None, + x.toString + )) + ) + .toList + case ( + s"${GlobalTypes.`kernelPrefix`}.String", + s"${GlobalTypes.`kernelPrefix`}.String" + ) => + val lbVal = lb.span.text.replaceAll("['\"]", "") + val ubVal = ub.span.text.replaceAll("['\"]", "") + + // TODO: Also might need to check if one is upper case and other is lower, since in Ruby this would not + // create any range but it might with this impl of using ASCII values. + if lbVal.length > 1 || ubVal.length > 1 then + // Not simulating the case where we have something like "ab"..."ad" + return List.empty + + generateRange(lbVal(0).toInt, ubVal(0).toInt, node.rangeOperator.exclusive) + .map(x => + StaticLiteral(lb.typeFullName)( + TextSpan( + lb.line, + lb.column, + lb.lineEnd, + lb.columnEnd, + None, + s"\'${x.toChar.toString}\'" + ) + ) + ) + .toList + case _ => + List.empty + case _ => + List.empty + + private def generateRange(lhs: Int, rhs: Int, exclusive: Boolean): Range = + if exclusive then lhs until rhs + else lhs to rhs + + protected def astForAssociation(node: Association): Ast = + val key = astForExpression(node.key) + val value = astForExpression(node.value) + val call = + callNode( + node, + code(node), + RubyOperators.association, + RubyOperators.association, + DispatchTypes.STATIC_DISPATCH + ) + callAst(call, Seq(key, value)) + + protected def astForSingleKeyValue( + keyNode: RubyExpression, + valueNode: RubyExpression, + tmp: String + ): Ast = + astForExpression( + SingleAssignment( + IndexAccess( + SimpleIdentifier()(TextSpan( + keyNode.line, + keyNode.column, + keyNode.lineEnd, + keyNode.columnEnd, + None, + tmp + )), + List(keyNode) + )( + TextSpan( + keyNode.line, + keyNode.column, + keyNode.lineEnd, + keyNode.columnEnd, + None, + s"$tmp[${keyNode.span.text}]" + ) + ), + "=", + valueNode + )( + TextSpan( + keyNode.line, + keyNode.column, + keyNode.lineEnd, + keyNode.columnEnd, + None, + s"$tmp[${keyNode.span.text}] = ${valueNode.span.text}" + ) + ) + ) + + protected def astForRescueExpression(node: RescueExpression): Ast = + val tryAst = astForStatementList(node.body.asStatementList) + val rescueAsts = node.rescueClauses + .map { x => + val classes = + x.exceptionClassList.map(e => + scope.tryResolveTypeReference(e.text).map(_.name).getOrElse(e.text) + ).toSeq + val variables = x.variables + .flatMap { v => + handleVariableOccurrence(v) + scope.lookupVariable(v.text) + } + .collect { + case x: NewLocal => Ast(x.dynamicTypeHintFullName(classes)) + case x: NewMethodParameterIn => Ast(x.dynamicTypeHintFullName(classes)) + } + .toList + val rescueNode = controlStructureNode(x.thenClause.asStatementList, "CATCH", "catch") + Ast(rescueNode).withChild( + astForStatementList(x.thenClause.asStatementList).withChildren(variables) + ) + } + val elseAst = node.elseClause.map { x => + val astForClause = + controlStructureNode(x.thenClause.asStatementList, ControlStructureTypes.ELSE, "else") + Ast(astForClause).withChild(astForStatementList(x.thenClause.asStatementList)) + } + + val ensureAst = node.ensureClause.map { x => + val astForEnsureClause = + controlStructureNode(x.thenClause.asStatementList, "FINALLY", "finally") + Ast(astForEnsureClause).withChild(astForStatementList(x.thenClause.asStatementList)) + } + + val tryNode = controlStructureNode(node.body.asStatementList, ControlStructureTypes.TRY, "try") + tryCatchAst(tryNode, tryAst, rescueAsts ++ elseAst, ensureAst) + end astForRescueExpression + + private def astForSelfIdentifier(node: SelfIdentifier): Ast = + val thisIdentifier = + identifierNode( + node, + Defines.Self, + code(node), + scope.surroundingTypeFullName.getOrElse(Defines.Any) + ) + + scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(thisIdentifier).withRefEdge(thisIdentifier, selfParam)) + .getOrElse(Ast(thisIdentifier)) + + protected def astForUnknown(node: RubyExpression): Ast = + val className = node.getClass.getSimpleName + val text = code(node) + Ast(unknownNode(node, text)) + + private def astForMemberCallWithoutBlock(node: SimpleCall, memberAccess: MemberAccess): Ast = + val receiverAst = astForFieldAccess(memberAccess) + val methodName = memberAccess.memberName + val methodFullName = XDefines.DynamicCallUnknownFullName + val argumentAsts = node.arguments.map(astForMethodCallArgument) + val call = + callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH) + + callAst(call, argumentAsts, Some(receiverAst)) + + private def astForMethodCallWithoutBlock( + node: SimpleCall, + methodIdentifier: SimpleIdentifier + ): Ast = + val methodName = methodIdentifier.text + lazy val defaultResult = Defines.Any -> XDefines.DynamicCallUnknownFullName + + val (receiverType, methodFullNameHint) = + scope + .tryResolveMethodInvocation( + methodName, + typeFullName = scope.surroundingTypeFullName + ) // Check if this is a method invocation of a method define within this scope + .orElse( + scope.tryResolveMethodInvocation(methodName) + ) // Check if this is a method invocation of a member imported into scope + match + case Some(m) => + scope.typeForMethod(m).map(t => t.name -> s"${t.name}.${m.name}").getOrElse( + defaultResult + ) + case None => defaultResult + + val argumentAst = node.arguments.map(astForMethodCallArgument) + val (dispatchType, methodFullName) = + if receiverType.startsWith(GlobalTypes.builtinPrefix) then + (DispatchTypes.STATIC_DISPATCH, methodFullNameHint) + else (DispatchTypes.DYNAMIC_DISPATCH, XDefines.DynamicCallUnknownFullName) + + val call = callNode(node, code(node), methodName, methodFullName, dispatchType) + + if methodFullName != methodFullNameHint then call.possibleTypes(IndexedSeq(methodFullNameHint)) + + val receiverAst = astForFieldAccess( + MemberAccess(SelfIdentifier()(node.span.spanStart(Defines.Self)), ".", call.name)(node.span), + stripLeadingAt = true + ) + val selfIdentifier = identifierNode(node, Defines.Self, Defines.Self, receiverType) + val baseAst = scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(selfIdentifier).withRefEdge(selfIdentifier, selfParam)) + .getOrElse(Ast(selfIdentifier)) + callAst(call, argumentAst, Option(baseAst), Option(receiverAst)) + end astForMethodCallWithoutBlock + + private def astForProcOrLambdaExpr(node: ProcOrLambdaExpr): Ast = + val Seq(typeRef, _) = astForDoBlock(node.block): @unchecked + typeRef + + private def astForSingletonObjectMethodDeclaration(node: SingletonObjectMethodDeclaration): Ast = + val methodAstsWithRefs = astForMethodDeclaration(node, isSingletonObjectMethod = true) + + // Set span contents + methodAstsWithRefs.flatMap(_.nodes).foreach { + case m: NewMethodRef => DummyNode(m.copy)(node.body.span.spanStart(m.code)) + case _ => + } + + val Seq(typeRef, _) = methodAstsWithRefs + + typeRef + + private def astForMethodCallArgument(node: RubyExpression): Ast = + node match + // Associations in method calls are keyword arguments + case assoc: Association => astForKeywordArgument(assoc) + case block: RubyBlock => + val Seq(methodDecl, typeDecl, typeRef, _) = astForDoBlock(block) + Ast.storeInDiffGraph(methodDecl, diffGraph) + Ast.storeInDiffGraph(typeDecl, diffGraph) + + typeRef + case selfMethod: SingletonMethodDeclaration => + // Last element is the method declaration, the prefix methods would be `foo = def foo (...)` pointers in other + // contexts, but this would be empty as a method call argument + val Seq(_, methodDeclAst) = astForSingletonMethodDeclaration(selfMethod) + scope.surroundingTypeFullName.foreach { tfn => + methodDeclAst.root.collect { case m: NewMethod => + m.astParentType(NodeTypes.TYPE_DECL).astParentFullName(s"$tfn") + } + } + Ast.storeInDiffGraph(methodDeclAst, diffGraph) + scope.surroundingScopeFullName + .map(s => + Ast(methodRefNode( + node, + selfMethod.span.text, + s"$s.${selfMethod.methodName}", + Defines.Any + )) + ) + .getOrElse(Ast()) + case _ => astForExpression(node) + + private def astForKeywordArgument(assoc: Association): Ast = + + def setArgumentName(argumentAst: Ast, name: String): Ast = + argumentAst.root.collectFirst { case x: ExpressionNew => + x.argumentName_=(Option(name)) + x.argumentIndex_=(-1) + } + argumentAst + + val value = astForExpression(assoc.value) + assoc.key match + case keyIdentifier: SimpleIdentifier => setArgumentName(value, keyIdentifier.text) + case symbol @ StaticLiteral(typ) if typ == getBuiltInType(Defines.Symbol) => + setArgumentName(value, symbol.text.stripPrefix(":")) + case _: (LiteralExpr | RubyCall | ProcOrLambdaExpr | MemberAccess | IndexAccess) => + astForExpression(assoc) + case x => + astForExpression(assoc) + + protected def astForSplattingRubyNode(node: SplattingRubyNode): Ast = + val splattingCall = + callNode( + node, + code(node), + RubyOperators.splat, + RubyOperators.splat, + DispatchTypes.STATIC_DISPATCH + ) + val argumentAst = astsForStatement(node.target) + callAst(splattingCall, argumentAst) + + private def getBinaryOperatorName(op: String): Option[String] = BinaryOperatorNames.get(op) + + private def getUnaryOperatorName(op: String): Option[String] = UnaryOperatorNames.get(op) + + private def getAssignmentOperatorName(op: String): Option[String] = + AssignmentOperatorNames.get(op) +end AstForExpressionsCreator diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForFunctionsCreator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForFunctionsCreator.scala new file mode 100644 index 00000000..56887aed --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForFunctionsCreator.scala @@ -0,0 +1,635 @@ +package io.appthreat.ruby2atom.astcreation + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.* +import io.appthreat.ruby2atom.datastructures.{ConstructorScope, MethodScope} +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.x2cpg.utils.NodeBuilders.{ + newBindingNode, + newClosureBindingNode, + newLocalNode, + newModifierNode, + newThisParameterNode +} +import io.appthreat.x2cpg.{Ast, AstEdge, ValidationMode, Defines as XDefines} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{ + DispatchTypes, + EdgeTypes, + EvaluationStrategies, + ModifierTypes, + NodeTypes, + Operators +} +import io.appthreat.ruby2atom.utils.FreshNameGenerator + +import scala.collection.mutable + +trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + /** As expressions may be discarded, we cannot store closure ASTs in the diffgraph at the point of + * creation. So we assume every reference to this map means that the closure AST was successfully + * propagated. + */ + protected val closureToRefs = mutable.Map.empty[RubyExpression, Seq[NewNode]] + + /** Creates method declaration related structures. + * @param node + * the node to create the AST structure from. + * @param isClosure + * if true, will generate a type decl, type ref, and method ref, as well as add the `c` + * modifier. + * @return + * a method declaration with additional refs and types if specified. + */ + protected def astForMethodDeclaration( + node: RubyExpression & ProcedureDeclaration, + isClosure: Boolean = false, + isSingletonObjectMethod: Boolean = false, + useSurroundingTypeFullName: Boolean = false + ): Seq[Ast] = + val isInTypeDecl = scope.surroundingAstLabel.contains(NodeTypes.TYPE_DECL) + val isConstructor = (node.methodName == Defines.Initialize) && isInTypeDecl + val methodName = node.methodName + + val fullName = + node match + case x: SingletonObjectMethodDeclaration => + computeFullName( + s"class<<${x.baseClass.span.text}.$methodName", + useSurroundingTypeFullName = useSurroundingTypeFullName + ) + case _ => + computeFullName(methodName, useSurroundingTypeFullName = useSurroundingTypeFullName) + + val astParentType = + if useSurroundingTypeFullName || shouldUseSurroundingTypeFullName then + Some(NodeTypes.TYPE_DECL) + else scope.surroundingAstLabel + + val astParentFullName = + if useSurroundingTypeFullName || shouldUseSurroundingTypeFullName then + scope.surroundingTypeFullName + else scope.surroundingScopeFullName + + val method = methodNode( + node = node, + name = methodName, + fullName = fullName, + code = code(node), + signature = None, + fileName = relativeFileName, + astParentType = astParentType, + astParentFullName = astParentFullName + ) + + val isSurroundedByProgramScope = scope.isSurroundedByProgramScope + if isConstructor then scope.pushNewScope(ConstructorScope(fullName, this.procParamGen.fresh)) + else scope.pushNewScope(MethodScope(fullName, this.procParamGen.fresh)) + + val thisParameterNode = newThisParameterNode( + name = Defines.Self, + code = Defines.Self, + typeFullName = scope.surroundingTypeFullName.getOrElse(Defines.Any), + line = method.lineNumber, + column = method.columnNumber + ) + val thisParameterAst = Ast(thisParameterNode) + scope.addToScope(Defines.Self, thisParameterNode) + val parameterAsts = thisParameterAst :: astForParameters(node.parameters) + + val optionalStatementList = statementListForOptionalParams(node.parameters) + + val methodReturn = methodReturnNode(node, Defines.Any) + + val refs = + val typeRef = + if isClosure then typeRefNode(node, s"$methodName&Proc", s"$fullName&Proc") + else typeRefNode(node, methodName, fullName) + List(typeRef, methodRefNode(node, methodName, fullName, fullName)).map(Ast.apply) + + // Consider which variables are captured from the outer scope + val stmtBlockAst = if isClosure || isSingletonObjectMethod then + val baseStmtBlockAst = astForMethodBody(node.body, optionalStatementList) + transformAsClosureBody(refs, baseStmtBlockAst) + else if methodName == Defines.TypeDeclBody then + val stmtList = node.body.asInstanceOf[StatementList] + astForStatementList( + StatementList(stmtList.statements ++ optionalStatementList.statements)(stmtList.span) + ) + else if methodName != Defines.Initialize then + astForMethodBody(node.body, optionalStatementList) + else + astForConstructorMethodBody(node.body, optionalStatementList) + + // For yield statements where there isn't an explicit proc parameter + val anonProcParam = scope.procParamName.map { p => + val nextIndex = + parameterAsts.flatMap(_.root).lastOption.map { case m: NewMethodParameterIn => + m.index + 1 + }.getOrElse(0) + + Ast(p.index(nextIndex)) + } + + scope.popScope() + + val methodTypeDeclAst = + val typeDeclNode_ = typeDeclNode(node, methodName, fullName, relativeFileName, code(node)) + astParentType.foreach(typeDeclNode_.astParentType(_)) + astParentFullName.foreach(typeDeclNode_.astParentFullName(_)) + createMethodTypeBindings(method, typeDeclNode_) + if isClosure then Ast(typeDeclNode_).withChild(Ast(newModifierNode("LAMBDA"))) + else Ast(typeDeclNode_) + + // Due to lambdas being invoked by `call()`, this additional type ref holding that member is created. + val lambdaTypeDeclAst = if isClosure then + val typeDeclNode_ = + typeDeclNode(node, s"$methodName&Proc", s"$fullName&Proc", relativeFileName, code(node)) + astParentType.foreach(typeDeclNode_.astParentType(_)) + astParentFullName.foreach(typeDeclNode_.astParentFullName(_)) + Ast(typeDeclNode_) + .withChild( + // This member refers back to itself, as itself is the type decl bound to the respective method + Ast(NewMember().name("call").code("call").dynamicTypeHintFullName(Seq(fullName)) + .typeFullName(Defines.Any)) + ) + else Ast() + + val accessModifier = + // Initialize is guaranteed `private` by the Ruby interpreter (we include our method here) + if methodName == Defines.Initialize || methodName == Defines.TypeDeclBody then + ModifierTypes.PRIVATE + //
functions are private functions on the Object class + else if isSurroundedByProgramScope then ModifierTypes.PRIVATE + // Else, use whatever modifier has been user-defined (or is default for current scope) + else currentAccessModifier + val modifiers = mutable.Buffer(ModifierTypes.VIRTUAL, accessModifier) + if isClosure then modifiers.addOne("LAMBDA") + if isConstructor then modifiers.addOne(ModifierTypes.CONSTRUCTOR) + + val prefixMemberAst = + if isClosure || isSingletonObjectMethod || isSurroundedByProgramScope then + Ast() // program scope members are set elsewhere + else + // Singleton constructors that initialize @@ fields should have their members linked under the singleton class + val methodMember = scope.surroundingTypeFullName match + case Some(astParentTfn) => + memberForMethod(method, Option(NodeTypes.TYPE_DECL), Option(astParentTfn)) + case None => + memberForMethod(method, scope.surroundingAstLabel, scope.surroundingScopeFullName) + Ast(memberForMethod(method, Option(NodeTypes.TYPE_DECL), astParentFullName)) + // For closures, we also want the method/type refs for upstream use + val methodAst_ = + val mAst = methodAst( + method, + parameterAsts ++ anonProcParam, + stmtBlockAst, + methodReturn, + modifiers.map(newModifierNode).toSeq + ) + mAst + + // Each of these ASTs are linked via AstLinker as per the astParent* properties + (prefixMemberAst :: methodAst_ :: methodTypeDeclAst :: lambdaTypeDeclAst :: Nil) + .foreach(Ast.storeInDiffGraph(_, diffGraph)) + // In the case of a closure, we expect this method to return a method ref, otherwise, we bind a pointer to a + // method ref, e.g. self.foo = def foo(...) + if isClosure || isSingletonObjectMethod then refs else createMethodRefPointer(method) :: Nil + end astForMethodDeclaration + + protected def astForMethodAccessModifier(node: MethodAccessModifier): Seq[Ast] = + val originalAccessModifier = currentAccessModifier + popAccessModifier() + + node match + case _: PrivateMethodModifier => + pushAccessModifier(ModifierTypes.PRIVATE) + case _: PublicMethodModifier => + pushAccessModifier(ModifierTypes.PUBLIC) + + val methodAst = node.method match + case m: ProcedureDeclaration => astsForStatement(m) + case x => + // Not sure how we should represent dynamically setting access modifiers based on method refs + Nil + + popAccessModifier() + pushAccessModifier(originalAccessModifier) + + methodAst + end astForMethodAccessModifier + + private def transformAsClosureBody(refs: List[Ast], baseStmtBlockAst: Ast) = + // Determine which locals are captured + val capturedLocalNodes = baseStmtBlockAst.nodes + .collect { + case x: NewIdentifier if x.name != Defines.Self => x + } // Self identifiers are handled separately + .distinctBy(_.name) + .map(i => scope.lookupVariableInOuterScope(i.name)) + .filter(_.nonEmpty) + .flatten + .toSet + + val capturedIdentifiers = baseStmtBlockAst.nodes.collect { + case i: NewIdentifier if capturedLocalNodes.map(_.name).contains(i.name) => i + } + // Copy AST block detaching the REF nodes between parent locals/params and identifiers, with the closures' one + val capturedBlockAst = baseStmtBlockAst.copy(refEdges = baseStmtBlockAst.refEdges.filterNot { + case AstEdge(_: NewIdentifier, dst: DeclarationNew) => capturedLocalNodes.contains(dst) + case _ => false + }) + + val typeRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x } + + val astChildren = mutable.Buffer.empty[NewNode] + val refEdges = mutable.Buffer.empty[(NewNode, NewNode)] + val captureEdges = mutable.Buffer.empty[(NewNode, NewNode)] + capturedLocalNodes + .collect { + case local: NewLocal => + val closureBindingId = + scope.variableScopeFullName(local.name).map(x => s"$x.${local.name}") + (local, local.name, local.code, closureBindingId) + case param: NewMethodParameterIn => + val closureBindingId = + scope.variableScopeFullName(param.name).map(x => s"$x.${param.name}") + (param, param.name, param.code, closureBindingId) + } + .collect { case (capturedLocal, name, code, Some(closureBindingId)) => + val capturingLocal = + newLocalNode( + name = name, + typeFullName = Defines.Any, + closureBindingId = Option(closureBindingId) + ) + + val closureBinding = newClosureBindingNode( + closureBindingId = closureBindingId, + originalName = name, + evaluationStrategy = EvaluationStrategies.BY_REFERENCE + ) + + // Create new local node for lambda, with corresponding REF edges to identifiers and closure binding + val _refEdges = + capturedIdentifiers.filter(_.name == name).map(i => i -> capturingLocal) :+ ( + closureBinding, + capturedLocal + ) + + astChildren.addOne(capturingLocal) + refEdges.addAll(_refEdges.toList) + captureEdges.addAll(typeRefOption.map(typeRef => typeRef -> closureBinding).toList) + } + + val astWithAstChildren = astChildren.foldLeft(capturedBlockAst) { case (ast, child) => + ast.withChild(Ast(child)) + } + val astWithRefEdges = refEdges.foldLeft(astWithAstChildren) { case (ast, (src, dst)) => + ast.withRefEdge(src, dst) + } + captureEdges.foldLeft(astWithRefEdges) { case (ast, (src, dst)) => + ast.withCaptureEdge(src, dst) + } + end transformAsClosureBody + + /** Creates the bindings between the method and its types. This is useful for resolving function + * pointers and imports. + */ + protected def createMethodTypeBindings(method: NewMethod, typeDecl: NewTypeDecl): Unit = + val bindingNode = newBindingNode("", "", method.fullName) + diffGraph.addEdge(typeDecl, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, method, EdgeTypes.REF) + + // TODO: remaining cases + protected def astForParameter(node: RubyExpression, index: Int): Ast = + node match + case node: (MandatoryParameter | OptionalParameter) => + val parameterIn = parameterInNode( + node = node, + name = node.name, + code = code(node), + index = index, + isVariadic = false, + evaluationStrategy = EvaluationStrategies.BY_REFERENCE, + typeFullName = None + ) + scope.addToScope(node.name, parameterIn) + Ast(parameterIn) + case node: ProcParameter => + val parameterIn = parameterInNode( + node = node, + name = node.name, + code = code(node), + index = index, + isVariadic = false, + evaluationStrategy = EvaluationStrategies.BY_REFERENCE, + typeFullName = None + ) + scope.setProcParam(node.name, parameterIn) + Ast() // The proc parameter is retrieved later under method AST creation + case node: CollectionParameter => + val typeFullName = node match + case ArrayParameter(_) => prefixAsKernelDefined("Array") + case HashParameter(_) => prefixAsKernelDefined("Hash") + val parameterIn = parameterInNode( + node = node, + name = node.name, + code = code(node), + index = index, + isVariadic = true, + evaluationStrategy = EvaluationStrategies.BY_REFERENCE, + typeFullName = Option(typeFullName) + ) + scope.addToScope(node.name, parameterIn) + Ast(parameterIn) + case node: GroupedParameter => + val parameterIn = parameterInNode( + node = node.tmpParam, + name = node.name, + code = code(node.tmpParam), + index = index, + isVariadic = false, + evaluationStrategy = EvaluationStrategies.BY_REFERENCE, + typeFullName = None + ) + scope.addToScope(node.name, parameterIn) + Ast(parameterIn) + case node => + astForUnknown(node) + + private def generateTextSpan(node: RubyExpression, text: String): TextSpan = + TextSpan( + node.span.line, + node.span.column, + node.span.lineEnd, + node.span.columnEnd, + node.span.offset, + text + ) + + protected def statementForOptionalParam(node: OptionalParameter): RubyExpression = + val defaultExprNode = node.defaultExpression + + IfExpression( + UnaryExpression( + "!", + SimpleCall( + SimpleIdentifier(None)(generateTextSpan(defaultExprNode, "defined?")), + List(SimpleIdentifier(None)(generateTextSpan(defaultExprNode, node.name))) + )(generateTextSpan(defaultExprNode, s"defined?(${node.name})")) + )(generateTextSpan(defaultExprNode, s"!defined?(${node.name})")), + StatementList( + List( + SingleAssignment( + SimpleIdentifier(None)(generateTextSpan(defaultExprNode, node.name)), + "=", + node.defaultExpression + )(generateTextSpan(defaultExprNode, s"${node.name}=${node.defaultExpression.span.text}")) + ) + )(generateTextSpan(defaultExprNode, "")), + List.empty, + None + )( + generateTextSpan( + defaultExprNode, + s"if !defined?(${node.name}) \t${node.name}=${node.defaultExpression.span.text}\n end" + ) + ) + end statementForOptionalParam + + protected def astForAnonymousTypeDeclaration(node: AnonymousTypeDeclaration): Ast = + + // This will link the type decl to the surrounding context via base overlays + val Seq(typeRefAst) = astForClassDeclaration(node).take(1) + + typeRefAst.nodes + .collectFirst { case typRef: NewTypeRef => + val typeIdentifier = SimpleIdentifier()(node.span.spanStart(typRef.code)) + // Takes the `Class.new` before the block starts or any other keyword + val newSpanText = typRef.code + astForMemberCall( + MemberCall(typeIdentifier, ".", "new", List.empty)(node.span.spanStart(newSpanText)) + ) + } + .getOrElse(Ast()) + + protected def astForSingletonMethodDeclaration(node: SingletonMethodDeclaration): Seq[Ast] = + node.target match + case targetNode: SingletonMethodIdentifier => + val fullName = computeFullName(node.methodName) + + val (astParentType, astParentFullName, thisParamCode, addEdge) = targetNode match + case _: SelfIdentifier => + (scope.surroundingAstLabel, scope.surroundingScopeFullName, Defines.Self, false) + case _: SimpleIdentifier => + val baseType = node.target.span.text + scope.surroundingTypeFullName.map(_.split("[.]").last) match + case Some(typ) if typ == baseType => + (scope.surroundingAstLabel, scope.surroundingScopeFullName, baseType, false) + case Some(typ) => + scope.tryResolveTypeReference(baseType) match + case Some(typ) => + (Option(NodeTypes.TYPE_DECL), Option(typ.name), baseType, true) + case None => (None, None, Defines.Self, false) + case None => (None, None, Defines.Self, false) + + scope.pushNewScope(MethodScope(fullName, this.procParamGen.fresh)) + val method = methodNode( + node = node, + name = node.methodName, + fullName = fullName, + code = code(node), + signature = None, + fileName = relativeFileName + ) + val methodTypeDecl_ = + typeDeclNode(node, node.methodName, fullName, relativeFileName, code(node)) + val methodTypeDeclAst = Ast(methodTypeDecl_) + + createMethodTypeBindings(method, methodTypeDecl_) + + val thisNodeTypeFullName = astParentFullName match + case Some(fn) => s"$fn" + case None => Defines.Any + + val thisNode = newThisParameterNode( + name = Defines.Self, + code = thisParamCode, + typeFullName = thisNodeTypeFullName, + line = method.lineNumber, + column = method.columnNumber + ) + val thisParameterAst = Ast(thisNode) + scope.addToScope(Defines.Self, thisNode) + + val parameterAsts = thisParameterAst :: astForParameters(node.parameters) + val optionalStatementList = statementListForOptionalParams(node.parameters) + val stmtBlockAst = astForMethodBody(node.body, optionalStatementList) + + val anonProcParam = scope.procParamName.map { p => + val nextIndex = + parameterAsts.flatMap(_.root).lastOption.map { case m: NewMethodParameterIn => + m.index + 1 + }.getOrElse(0) + + Ast(p.index(nextIndex)) + } + + scope.popScope() + + astParentType.orElse(scope.surroundingAstLabel).foreach { t => + methodTypeDecl_.astParentType(t) + method.astParentType(t) + } + astParentFullName.orElse(scope.surroundingScopeFullName).foreach { fn => + methodTypeDecl_.astParentFullName(fn) + method.astParentFullName(fn) + } + + // The member for these types refers to the singleton class + val member = memberForMethod( + method, + Option(NodeTypes.TYPE_DECL), + astParentFullName.map(x => s"$x") + ) + diffGraph.addNode(member) + + val _methodAst = + methodAst( + method, + parameterAsts ++ anonProcParam, + stmtBlockAst, + methodReturnNode(node, Defines.Any), + newModifierNode(ModifierTypes.VIRTUAL) :: newModifierNode( + currentAccessModifier + ) :: Nil + ) + + (_methodAst :: methodTypeDeclAst :: Nil).foreach(Ast.storeInDiffGraph(_, diffGraph)) + if addEdge then + Nil + else + createMethodRefPointer(method) :: Nil + case targetNode => + astForUnknown(node) :: Nil + + private def createMethodRefPointer(method: NewMethod): Ast = + if scope.isSurroundedByProgramScope then + val methodRefNode = Ast( + NewMethodRef() + .code(s"def ${method.name} (...)") + .methodFullName(method.fullName) + .typeFullName(method.fullName) + .lineNumber(method.lineNumber) + .columnNumber(method.columnNumber) + ) + + val methodRefIdent = + val self = NewIdentifier().name(Defines.Self).code(Defines.Self).typeFullName(Defines.Any) + val fi = NewFieldIdentifier() + .code(method.name) + .canonicalName(method.name) + .lineNumber(method.lineNumber) + .columnNumber(method.columnNumber) + val fieldAccess = NewCall() + .name(Operators.fieldAccess) + .code(s"${Defines.Self}.${method.name}") + .methodFullName(Operators.fieldAccess) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(Defines.Any) + val selfAst = scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(self).withRefEdge(self, selfParam)) + .getOrElse(Ast(self)) + callAst(fieldAccess, Seq(selfAst, Ast(fi))) + + astForAssignment(methodRefIdent, methodRefNode, method.lineNumber, method.columnNumber) + else + Ast() + + private def astForParameters(parameters: List[RubyExpression]): List[Ast] = + parameters.zipWithIndex.map { case (parameterNode, index) => + astForParameter(parameterNode, index + 1) + } + + private def statementListForOptionalParams(params: List[RubyExpression]): StatementList = + StatementList( + params + .collect { case x: OptionalParameter => + x + } + .map(statementForOptionalParam) + )(TextSpan(None, None, None, None, None, "")) + + private def astForMethodBody( + body: RubyExpression, + optionalStatementList: StatementList, + returnLastExpression: Boolean = true + ): Ast = + if this.parseLevel == AstParseLevel.SIGNATURES then + Ast() + else + body match + case stmtList: StatementList => + val combinedStmtList = + StatementList(optionalStatementList.statements ++ stmtList.statements)( + stmtList.span + ) + if returnLastExpression then + astForStatementListReturningLastExpression(combinedStmtList) + else astForStatementList(combinedStmtList) + case rescueExpr: RescueExpression => + astForRescueExpression(rescueExpr) + case _: (StaticLiteral | BinaryExpression | SingleAssignment | SimpleIdentifier | ArrayLiteral | HashLiteral | + SimpleCall | MemberAccess | MemberCall) => + val combinedStmtList = + StatementList(optionalStatementList.statements ++ List(body))(body.span) + if returnLastExpression then + astForStatementListReturningLastExpression(combinedStmtList) + else astForStatementList(combinedStmtList) + case body => + astForUnknown(body) + + private def astForConstructorMethodBody( + body: RubyExpression, + optionalStatementList: StatementList + ): Ast = + if this.parseLevel == AstParseLevel.SIGNATURES then + Ast() + else + body match + case stmtList: StatementList => + astForStatementList(StatementList( + optionalStatementList.statements ++ stmtList.statements + )(stmtList.span)) + case _: (StaticLiteral | BinaryExpression | SingleAssignment | SimpleIdentifier | ArrayLiteral | HashLiteral | + SimpleCall | MemberAccess | MemberCall) => + astForStatementList( + StatementList(optionalStatementList.statements ++ List(body))(body.span) + ) + case body => + astForUnknown(body) + + private val accessModifierStack: mutable.Stack[String] = mutable.Stack.empty + + protected def currentAccessModifier: String = + accessModifierStack.headOption.getOrElse(ModifierTypes.PUBLIC) + + protected def pushAccessModifier(name: String): Unit = + accessModifierStack.push(name) + + protected def popAccessModifier(): Unit = + if accessModifierStack.nonEmpty then accessModifierStack.pop() + + private def shouldUseSurroundingTypeFullName: Boolean = + val inBodyMethodScope = + scope.surroundingScopeFullName.exists(x => + x.split("[.]").takeRight(1).contains(Defines.TypeDeclBody) + ) + + scope.surroundingAstLabel match + case Some(NodeTypes.METHOD) => inBodyMethodScope + case _ => false +end AstForFunctionsCreator diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForStatementsCreator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForStatementsCreator.scala new file mode 100644 index 00000000..8f0c6664 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForStatementsCreator.scala @@ -0,0 +1,444 @@ +package io.appthreat.ruby2atom.astcreation + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{RubyStatement, *} +import io.appthreat.ruby2atom.datastructures.BlockScope +import io.appthreat.ruby2atom.parser.RubyJsonHelpers +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.ruby2atom.passes.Defines.getBuiltInType +import io.appthreat.x2cpg.datastructures.MethodLike +import io.appthreat.x2cpg.{Ast, ValidationMode} +import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewControlStructure} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, ModifierTypes, NodeTypes} + +trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + protected def astsForStatement(node: RubyExpression): Seq[Ast] = + baseAstCache.clear() // A safe approximation on where to reset the cache + node match + case node: IfExpression => astForIfStatement(node) + case node: OperatorAssignment => astForOperatorAssignment(node) + case node: CaseExpression => astsForCaseExpression(node) + case node: StatementList => astForStatementList(node) :: Nil + case node: ReturnExpression => astForReturnExpression(node) :: Nil + case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) :: Nil + case node: TypeDeclaration => astForClassDeclaration(node) + case node: FieldsDeclaration => astsForFieldDeclarations(node) + case node: AccessModifier => astForAccessModifier(node) + case node: MethodDeclaration => astForMethodDeclaration(node) + case node: MethodAccessModifier => astForMethodAccessModifier(node) + case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node) + case node: MultipleAssignment => node.assignments.map(astForExpression) + case node: BreakExpression => astForBreakExpression(node) :: Nil + case node: SingletonStatementList => astForSingletonStatementList(node) + case node: AliasStatement => astForAliasStatement(node) + case _ => astForExpression(node) :: Nil + end astsForStatement + + private def astForIfStatement(node: IfExpression): Seq[Ast] = + def builder(node: IfExpression, conditionAst: Ast, thenAst: Ast, elseAsts: List[Ast]): Ast = + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(conditionAst), thenAst :: elseAsts) + + // TODO: Remove or modify the builder pattern when we are no longer using ANTLR + node.elseClause match + case Some(elseClause) => + elseClause match + case _: IfExpression => astForJsonIfStatement(node) + case _ => foldIfExpression(builder)(node) :: Nil + case None => + foldIfExpression(builder)(node) :: Nil + + private def astForOperatorAssignment(node: OperatorAssignment): Seq[Ast] = + val loweredAssignment = lowerAssignmentOperator(node.lhs, node.rhs, node.op, node.span) + astsForStatement(loweredAssignment) + + private def astForJsonIfStatement(node: IfExpression): Seq[Ast] = + val conditionAst = astForExpression(node.condition) + val thenAst = astForThenClause(node.thenClause) + val elseAsts = node.elseClause + .map { + case x: IfExpression => + val wrappedBlock = blockNode(x) + Ast(wrappedBlock).withChildren(astForJsonIfStatement(x)) :: Nil + case x => + astForElseClause(x) :: Nil + } + .getOrElse(Ast() :: Nil) + + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code(node)) + controlStructureAst(ifNode, Some(conditionAst), thenAst +: elseAsts) :: Nil + + private def astForAccessModifier(node: AccessModifier): Seq[Ast] = + scope.surroundingAstLabel match + case Some(x) if x == NodeTypes.METHOD => + val simpleIdent = node.toSimpleIdentifier + astForSimpleCall(SimpleCall(simpleIdent, List.empty)(simpleIdent.span)) :: Nil + case _ => + registerAccessModifier(node) + + /** Registers the currently set access modifier for the current type (until it is reset later). + */ + private def registerAccessModifier(node: AccessModifier): Seq[Ast] = + val modifier = node match + case PrivateModifier() => ModifierTypes.PRIVATE + case ProtectedModifier() => ModifierTypes.PROTECTED + case PublicModifier() => ModifierTypes.PUBLIC + popAccessModifier() // pop off the current modifier in scope + pushAccessModifier(modifier) // push new one on + Nil + + // Rewrites a nested `if T_1 then E_1 elsif T_2 then E_2 elsif ... elsif T_n then E_n else E_{n+1}` + // as `B(T_1, E_1, B(T_2, E_2, ..., B(T_n, E_n, E_{n+1})..)` + protected def foldIfExpression(builder: (IfExpression, Ast, Ast, List[Ast]) => Ast)( + node: IfExpression + ): Ast = + val conditionAst = astForExpression(node.condition) + val thenAst = astForThenClause(node.thenClause) + val elseAsts = astsForElseClauses(node.elsifClauses, node.elseClause, foldIfExpression(builder)) + builder(node, conditionAst, thenAst, elseAsts) + + protected def astForThenClause(node: RubyExpression): Ast = + astForStatementList(node.asStatementList) + + private def astsForElseClauses( + elsIfClauses: List[RubyExpression], + elseClause: Option[RubyExpression], + astForIf: IfExpression => Ast + ): List[Ast] = + elsIfClauses match + case Nil => elseClause.map(astForElseClause).toList + case elsIfNode :: rest => + elsIfNode match + case elsIfNode: ElsIfClause => + val newIf = IfExpression( + elsIfNode.condition, + elsIfNode.thenClause, + rest, + elseClause + )(elsIfNode.span) + val wrappingBlock = blockNode(elsIfNode) + val wrappedAst = Ast(wrappingBlock).withChild(astForIf(newIf)) + wrappedAst :: Nil + case elsIfNode => + Nil + + protected def astForStatementList(node: StatementList): Ast = + val block = blockNode(node) + scope.pushNewScope(BlockScope(block)) + val statementAsts = node.statements.flatMap(astsForStatement) + scope.popScope() + blockAst(block, statementAsts) + + protected def astForDoBlock(block: Block & RubyExpression): Seq[Ast] = + if closureToRefs.contains(block) then + closureToRefs(block).map(x => Ast(x.copy)) + else + val methodName = nextClosureName() + // Create closure structures: [TypeRef, MethodRef] + val methodRefAsts = block.body match + case x: Block => + astForMethodDeclaration( + x.toMethodDeclaration(methodName, Option(block.parameters)), + isClosure = true + ) + case _ => + astForMethodDeclaration( + block.toMethodDeclaration(methodName, Option(block.parameters)), + isClosure = true + ) + closureToRefs.put(block, methodRefAsts.flatMap(_.root)) + methodRefAsts + + protected def astForReturnExpression(node: ReturnExpression): Ast = + val argumentAsts = node.expressions.map(astForExpression) + val returnNode_ = returnNode(node, code(node)) + returnAst(returnNode_, argumentAsts) + + protected def astForNextExpression(node: NextExpression): Ast = + val nextNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.CONTINUE) + .lineNumber(line(node)) + .columnNumber(column(node)) + .code(code(node)) + Ast(nextNode) + + protected def astForStatementListReturningLastExpression(node: StatementList): Ast = + val block = blockNode(node) + scope.pushNewScope(BlockScope(block)) + + val stmtAsts = node.statements.size match + case 0 => List() + case n => + val (headStmts, lastStmt) = node.statements.splitAt(n - 1) + headStmts.flatMap(astsForStatement) ++ lastStmt.flatMap(astsForImplicitReturnStatement) + + scope.popScope() + blockAst(block, stmtAsts) + + private def astsForImplicitReturnStatement(node: RubyExpression): Seq[Ast] = + def elseReturnNil(span: TextSpan) = Option { + ElseClause( + StatementList( + ReturnExpression( + StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil")) :: Nil + )( + span.spanStart("return nil") + ) :: Nil + )(span.spanStart("return nil")) + )(span.spanStart("else\n\treturn nil\nend")) + } + + node match + case expr: ControlFlowStatement => + def transform(e: RubyExpression & ControlFlowStatement): RubyExpression = + transformLastRubyNodeInControlFlowExpressionBody( + e, + returnLastNode(_, transform), + elseReturnNil + ) + + expr match + case x @ OperatorAssignment(lhs, op, rhs) => + val loweredAssignment = lowerAssignmentOperator(lhs, rhs, op, x.span) + astsForStatement(transform(loweredAssignment)) + case x => + astsForStatement(transform(expr)) + case node: MemberCallWithBlock => returnAstForRubyCall(node) + case node: SimpleCallWithBlock => returnAstForRubyCall(node) + case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | SelfIdentifier | IndexAccess | + Association | YieldExpr | RubyCall | RubyFieldIdentifier | HereDocNode | Unknown) => + astForReturnExpression(ReturnExpression(List(node))(node.span)) :: Nil + case node: SingleAssignment => + astForSingleAssignment(node) :: List( + astForReturnExpression(ReturnExpression(List(node.lhs))(node.span)) + ) + case node: DefaultMultipleAssignment => + astsForStatement(node) ++ astsForImplicitReturnStatement( + ArrayLiteral(node.assignments.map(_.lhs))(node.span) + ) + case node: GroupedParameterDesugaring => + // If the desugaring is the last expression, then we should return nil + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.NilClass)(nilReturnSpan) + astsForStatement(node) ++ astsForImplicitReturnStatement(nilReturnLiteral) + case node: AttributeAssignment => + List( + astForAttributeAssignment(node), + astForReturnFieldAccess( + MemberAccess(node.target, node.op, node.attributeName)(node.span) + ) + ) + case node: MemberAccess => astForReturnMemberCall(node) :: Nil + case ret: ReturnExpression => astForReturnExpression(ret) :: Nil + case node: (MethodDeclaration | SingletonMethodDeclaration) => + (astsForStatement(node) :+ astForReturnMethodDeclarationSymbolName(node)).toList + case stmtList: StatementList + if stmtList.statements.lastOption.exists(_.isInstanceOf[ReturnExpression]) => + stmtList.statements.map(astForExpression) + case StatementList(stmts) => + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.NilClass)(nilReturnSpan) + stmts.map(astForExpression) ++ astsForImplicitReturnStatement(nilReturnLiteral) + case x: RangeExpression => + astForReturnRangeExpression(x) :: Nil + case node: AccessModifier => + val simpleIdent = node.toSimpleIdentifier + val simpleCall = SimpleCall(simpleIdent, List.empty)(simpleIdent.span) + astForReturnExpression(ReturnExpression(List(simpleCall))(node.span)) :: Nil + case node: MethodAccessModifier => + val simpleIdent = node.toSimpleIdentifier + + val methodIdentName = node.method match + case x: StaticLiteral => x.span.text + case x: MethodDeclaration => x.methodName + case x => + x.span.text + + val methodIdent = SimpleIdentifier(None)(simpleIdent.span.spanStart(methodIdentName)) + + val simpleCall = SimpleCall(simpleIdent, List(methodIdent))( + simpleIdent.span.spanStart(s"${simpleIdent.span.text} ${methodIdent.span.text}") + ) + astForReturnExpression(ReturnExpression(List(simpleCall))(node.span)) :: Nil + case node: FieldsDeclaration => + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.NilClass)(nilReturnSpan) + astsForFieldDeclarations(node) ++ astsForImplicitReturnStatement(nilReturnLiteral) + case node: SingletonClassDeclaration => + astForAnonymousTypeDeclaration(node) + val nilReturnSpan = node.span.spanStart("return nil") + val nilReturnLiteral = StaticLiteral(Defines.NilClass)(nilReturnSpan) + astsForImplicitReturnStatement(nilReturnLiteral) + case node => + astsForStatement(node).toList + end match + end astsForImplicitReturnStatement + + private def returnAstForRubyCall[C <: RubyCall](node: RubyExpression & RubyCallWithBlock[C]) + : Seq[Ast] = + val callAst = astForCallWithBlock(node) + returnAst(returnNode(node, code(node)), List(callAst)) :: Nil + + private def astForReturnFieldAccess(node: MemberAccess): Ast = + returnAst(returnNode(node, code(node)), List(astForFieldAccess(node))) + + // The evaluation of a MethodDeclaration returns its name in symbol form. + // E.g. `def f = 0` ===> `:f` + private def astForReturnMethodDeclarationSymbolName(node: RubyExpression & ProcedureDeclaration) + : Ast = + val literalNode_ = literalNode(node, s":${node.methodName}", getBuiltInType(Defines.Symbol)) + val returnNode_ = returnNode(node, literalNode_.code) + returnAst(returnNode_, Seq(Ast(literalNode_))) + + private def astForReturnRangeExpression(node: RangeExpression): Ast = + returnAst(returnNode(node, code(node)), List(astForRange(node))) + + private def astForReturnMemberCall(node: MemberAccess): Ast = + returnAst(returnNode(node, code(node)), List(astForMemberAccess(node))) + + protected def astForBreakExpression(node: BreakExpression): Ast = + val _node = NewControlStructure() + .controlStructureType(ControlStructureTypes.BREAK) + .lineNumber(line(node)) + .columnNumber(column(node)) + .code(code(node)) + Ast(_node) + + protected def astForSingletonStatementList(list: SingletonStatementList): Seq[Ast] = + list.statements.map(astForExpression) + + /** Wraps the last RubyNode with a ReturnExpression. + * @param x + * the node to wrap a return around. If a StatementList is given, then the ReturnExpression + * will wrap around the final element. + * @return + * the RubyNode with an explicit expression + */ + private def returnLastNode( + x: RubyExpression, + transform: (RubyExpression & ControlFlowStatement) => RubyExpression + ): RubyExpression = + def statementListReturningLastExpression(stmts: List[RubyExpression]): List[RubyExpression] = + stmts match + case (head: ControlFlowClause) :: Nil => clauseReturningLastExpression(head) :: Nil + case (head: ControlFlowStatement) :: Nil => transform(head) :: Nil + case (head: ReturnExpression) :: Nil => head :: Nil + case head :: Nil => ReturnExpression(head :: Nil)(head.span) :: Nil + case Nil => List.empty + case head :: tail => head :: statementListReturningLastExpression(tail) + + def clauseReturningLastExpression(x: RubyExpression & ControlFlowClause): RubyExpression = + x match + case RescueClause(exceptionClassList, assignment, thenClause) => + RescueClause(exceptionClassList, assignment, returnLastNode(thenClause, transform))( + x.span + ) + case EnsureClause(thenClause) => + EnsureClause(returnLastNode(thenClause, transform))(x.span) + case ElsIfClause(condition, thenClause) => + ElsIfClause(condition, returnLastNode(thenClause, transform))(x.span) + case ElseClause(thenClause) => ElseClause(returnLastNode(thenClause, transform))(x.span) + case WhenClause(matchExpressions, matchSplatExpression, thenClause) => + WhenClause( + matchExpressions, + matchSplatExpression, + returnLastNode(thenClause, transform) + )(x.span) + case InClause(pattern, body) => InClause(pattern, returnLastNode(body, transform))(x.span) + + x match + case StatementList(statements) => + StatementList(statementListReturningLastExpression(statements))(x.span) + case clause: ControlFlowClause => clauseReturningLastExpression(clause) + case node: ControlFlowStatement => transform(node) + case node: ReturnExpression => node + case _ => ReturnExpression(x :: Nil)(x.span) + end returnLastNode + + /** @param node + * \- Control Flow Expression RubyNode + * @param transform + * \- RubyNode => RubyNode function for transformation on the clauses of the + * ControlFlowExpression + * @return + * RubyNode with transform function applied + */ + protected def transformLastRubyNodeInControlFlowExpressionBody( + node: RubyExpression & ControlFlowStatement, + transform: RubyExpression => RubyExpression, + defaultElseBranch: TextSpan => Option[ElseClause] + ): RubyExpression = + node match + case RescueExpression(body, rescueClauses, elseClause, ensureClause) => + // Ensure never returns a value, only the main body, rescue & else clauses + RescueExpression( + transform(body), + rescueClauses.map(transform).collect { case x: RescueClause => x }, + elseClause.map(transform).orElse(defaultElseBranch(node.span)).collect { + case x: ElseClause => x + }, + ensureClause + )(node.span) + case WhileExpression(condition, body) => + WhileExpression(condition, transform(body))(node.span) + case DoWhileExpression(condition, body) => + DoWhileExpression(condition, transform(body))(node.span) + case UntilExpression(condition, body) => + UntilExpression(condition, transform(body))(node.span) + case OperatorAssignment(lhs, op, rhs) => + val loweredNode = lowerAssignmentOperator(lhs, rhs, op, node.span) + transformLastRubyNodeInControlFlowExpressionBody( + loweredNode, + transform, + defaultElseBranch + ) + case IfExpression(condition, thenClause, elsifClauses, elseClause) => + IfExpression( + condition, + transform(thenClause), + elsifClauses.map(transform), + elseClause.map(transform).orElse(defaultElseBranch(node.span)) + )(node.span) + case UnlessExpression(condition, trueBranch, falseBranch) => + UnlessExpression( + condition, + transform(trueBranch), + falseBranch.map(transform).orElse(defaultElseBranch(node.span)) + )(node.span) + case ForExpression(forVariable, iterableVariable, doBlock) => + ForExpression(forVariable, iterableVariable, transform(doBlock))(node.span) + case CaseExpression(expression, whenClauses, elseClause) => + CaseExpression( + expression, + whenClauses.map(transform), + elseClause.map(transform).orElse(defaultElseBranch(node.span)) + )(node.span) + case next: NextExpression => next + case break: BreakExpression => break + + protected def astForAliasStatement(statement: AliasStatement): Seq[Ast] = + val aliasMethodDecl = generateAliasMethodDecl(statement) + // alias should always be lifted to the class decl + astForMethodDeclaration(aliasMethodDecl, useSurroundingTypeFullName = true) + + private def generateAliasMethodDecl(alias: AliasStatement): MethodDeclaration = + val span = alias.span + val forwardingCallTarget = SimpleIdentifier(None)(span.spanStart(alias.oldName)) + val forwardedArgs = + SplattingRubyNode(SimpleIdentifier()(span.spanStart("args")))(span.spanStart("*args")) + val forwardedBlock = SimpleIdentifier()(span.spanStart("&block")) + val forwardingCall = SimpleCall(forwardingCallTarget, forwardedArgs :: forwardedBlock :: Nil)( + span.spanStart(s"${alias.oldName}(*args, &block)") + ) + + val aliasMethodBody = StatementList(forwardingCall :: Nil)(forwardingCall.span) + val aliasingMethodParams = + ArrayParameter("*args")(span.spanStart("*args")) :: ProcParameter("&block")( + span.spanStart("&block") + ) :: Nil + + MethodDeclaration(alias.newName, aliasingMethodParams, aliasMethodBody)( + alias.span.spanStart(s"def ${alias.newName}(*args, &block)") + ) +end AstForStatementsCreator diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForTypesCreator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForTypesCreator.scala new file mode 100644 index 00000000..e20d553b --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/AstForTypesCreator.scala @@ -0,0 +1,316 @@ +package io.appthreat.ruby2atom.astcreation + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{TypeDeclaration, *} +import io.appthreat.ruby2atom.datastructures.{ + BlockScope, + MethodScope, + ModuleScope, + NamespaceScope, + TypeScope +} +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.x2cpg.utils.NodeBuilders.newModifierNode +import io.appthreat.x2cpg.{Ast, ValidationMode} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{ + DispatchTypes, + EdgeTypes, + EvaluationStrategies, + ModifierTypes, + NodeTypes, + Operators +} + +import scala.collection.immutable.List +import scala.collection.mutable + +trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + protected def astForClassDeclaration(node: RubyExpression & TypeDeclaration): Seq[Ast] = + node.name match + case name: SimpleIdentifier => astForSimpleNamedClassDeclaration(node, name) + case name => + astForUnknown(node) :: Nil + + private def getBaseClassName(node: RubyExpression): String = + node match + case simpleIdentifier: SimpleIdentifier => + simpleIdentifier.text + case _: SelfIdentifier => + Defines.Self + case qualifiedBaseClass: MemberAccess => + qualifiedBaseClass.text.replace("::", ".") + case qualifiedBaseClass: MemberCall => + qualifiedBaseClass.text.replace("::", ".") + case x => + x.text + + private def astForSimpleNamedClassDeclaration( + node: RubyExpression & TypeDeclaration, + nameIdentifier: SimpleIdentifier + ): Seq[Ast] = + val className = nameIdentifier.text + val inheritsFrom = node.baseClass.map(getBaseClassName).toList + pushAccessModifier(ModifierTypes.PUBLIC) + + /** Pushes new NamespaceScope onto scope stack and populates AST_PARENT_FULL_NAME and + * AST_PARENT_TYPE for TypeDecls that are declared in a namespace + * @param typeDecl + * \- TypeDecl node + * @param astParentFullName + * \- Fullname of AstParent + * @return + * typeDecl node with updated fields + */ + def populateAstParentValues(typeDecl: NewTypeDecl, astParentFullName: String): NewTypeDecl = + val namespaceBlockFullName = + s"${scope.surroundingScopeFullName.getOrElse("")}.$astParentFullName" + scope.pushNewScope(NamespaceScope(namespaceBlockFullName)) + + val namespaceBlock = + NewNamespaceBlock().name(astParentFullName).fullName(astParentFullName).filename( + relativeFileName + ) + + diffGraph.addNode(namespaceBlock) + + fileNode.foreach(diffGraph.addEdge(_, namespaceBlock, EdgeTypes.AST)) + + typeDecl.astParentFullName(astParentFullName) + typeDecl.astParentType(NodeTypes.NAMESPACE_BLOCK) + + typeDecl.fullName(computeFullName(className)) + typeDecl + + val (typeDecl, classFullName, shouldPopAdditionalScope) = node match + case x: NamespaceDeclaration if x.namespaceParts.isDefined => + val className = nameIdentifier.text + val typeDeclTemp = typeDeclNode( + node = node, + name = className, + fullName = Defines.Any, + filename = relativeFileName, + code = code(node), + inherits = inheritsFrom, + alias = None + ) + populateAstParentValues(typeDeclTemp, x.namespaceParts.get.mkString(".")) + val classFullName = typeDeclTemp.fullName + + (typeDeclTemp, classFullName, true) + case _ => + val classFullName = computeFullName(className) + val typeDeclTemp = typeDeclNode( + node = node, + name = className, + fullName = classFullName, + filename = relativeFileName, + code = code(node), + inherits = inheritsFrom, + alias = None + ) + scope.surroundingAstLabel.foreach(typeDeclTemp.astParentType(_)) + scope.surroundingScopeFullName.foreach(typeDeclTemp.astParentFullName(_)) + (typeDeclTemp, classFullName, false) + + /* + In Ruby, there are semantic differences between the ordinary class and singleton class (think "meta" class in + Python). Similar to how Java allows both static and dynamic methods/fields/etc. within the same type declaration, + Ruby allows `self` methods and @@ fields to be defined alongside ordinary methods and @ fields. However, both + classes are more dynamic and have separate behaviours in Ruby and we model it as such. + + To signify the singleton type, we add the tag. + */ + val singletonTypeDecl = typeDecl.copy + .name(s"$className") + .fullName(s"$classFullName") + .inheritsFromTypeFullName(inheritsFrom.map(x => s"$x")) + + val (classModifiers, singletonModifiers) = node match + case _: ModuleDeclaration => + scope.pushNewScope(ModuleScope(classFullName)) + ( + ModifierTypes.VIRTUAL :: Nil map newModifierNode map Ast.apply, + ModifierTypes.VIRTUAL :: ModifierTypes.FINAL :: Nil map newModifierNode map Ast.apply + ) + case _: TypeDeclaration => + scope.pushNewScope(TypeScope(classFullName, List.empty)) + ( + ModifierTypes.VIRTUAL :: Nil map newModifierNode map Ast.apply, + ModifierTypes.VIRTUAL :: Nil map newModifierNode map Ast.apply + ) + + val classBody = + node.body.asInstanceOf[StatementList] // for now (bodyStatement is a superset of stmtList) + + val statementsToForwardUpTheAst = mutable.ArrayBuffer.empty[Ast] + def separateStatementsFromBody(ss: List[RubyExpression]) = + // There may be additional expression nodes introduced from nodes such as type decls, so we must + // re-distribute these back into the method + ss.flatMap { + case t: TypeDeclaration => + val (typeDeclAsts, other) = + astsForStatement(t).partition(_.root.exists(_.isInstanceOf[NewTypeDecl])) + statementsToForwardUpTheAst.addAll(other) + typeDeclAsts + case n => astsForStatement(n) + } + + val classBodyAsts = + val bodyAsts = separateStatementsFromBody(classBody.statements) + if scope.shouldGenerateDefaultConstructor && this.parseLevel == AstParseLevel.FULL_AST then + val bodyStart = classBody.span.spanStart() + val initBody = StatementList(List())(bodyStart) + val methodDecl = astForMethodDeclaration(MethodDeclaration( + Defines.Initialize, + List(), + initBody + )(bodyStart)) + methodDecl ++ bodyAsts + else + bodyAsts + + val fields = node match + case classDecl: ClassDeclaration => classDecl.fields + case moduleDecl: ModuleDeclaration => moduleDecl.fields + case _ => Seq.empty + val (fieldTypeMemberNodes, fieldSingletonMemberNodes) = fields + .map { x => + val name = code(x) + x.isInstanceOf[InstanceFieldIdentifier] -> Ast(memberNode(x, name, name, Defines.Any)) + } + .partition(_._1) + + scope.popScope() + + if scope.surroundingAstLabel.contains(NodeTypes.TYPE_DECL) then + val typeDeclMember = NewMember() + .name(className) + .code(className) + .dynamicTypeHintFullName(Seq(s"$classFullName")) + diffGraph.addNode(typeDeclMember) + + val prefixAst = createTypeRefPointer(typeDecl) + val typeDeclAst = Ast(typeDecl) + .withChildren(classModifiers) + .withChildren(fieldTypeMemberNodes.map(_._2)) + .withChildren(classBodyAsts) + val singletonTypeDeclAst = + Ast(singletonTypeDecl) + .withChildren(singletonModifiers) + .withChildren(fieldSingletonMemberNodes.map(_._2)) + val bodyMemberCallAst = + node.bodyMemberCall match + case Some(bodyMemberCall) => astForTypeDeclBodyCall(bodyMemberCall, classFullName) + case None => Ast() + + (typeDeclAst :: singletonTypeDeclAst :: Nil).foreach(Ast.storeInDiffGraph(_, diffGraph)) + + if shouldPopAdditionalScope then scope.popScope() + popAccessModifier() + prefixAst :: bodyMemberCallAst :: statementsToForwardUpTheAst.toList + end astForSimpleNamedClassDeclaration + + private def astForTypeDeclBodyCall(node: TypeDeclBodyCall, typeFullName: String): Ast = + val callAst = astForMemberCall(node.toMemberCall, isStatic = true) + callAst.nodes.collectFirst { + case c: NewCall if c.name == Defines.TypeDeclBody => + c.methodFullName(s"$typeFullName.${Defines.TypeDeclBody}") + } + callAst + + private def createTypeRefPointer(typeDecl: NewTypeDecl): Ast = + if scope.isSurroundedByProgramScope then + // We aim to preserve whether it's a `class` or `module` in the `code` property + val typeRefCode = s"${typeDecl.code.strip().takeWhile(_ != ' ')} ${typeDecl.name} (...)" + val typeRefNode = Ast( + NewTypeRef() + .code(typeRefCode) + .typeFullName( + s"${typeDecl.fullName}" + ) // Everything will be dispatched on the singleton + .lineNumber(typeDecl.lineNumber) + .columnNumber(typeDecl.columnNumber) + ) + + val typeRefIdent = + val self = NewIdentifier().name(Defines.Self).code(Defines.Self).typeFullName(Defines.Any) + val fi = NewFieldIdentifier() + .code(typeDecl.name) + .canonicalName(typeDecl.name) + .lineNumber(typeDecl.lineNumber) + .columnNumber(typeDecl.columnNumber) + val fieldAccess = NewCall() + .name(Operators.fieldAccess) + .code(s"${Defines.Self}.${typeDecl.name}") + .methodFullName(Operators.fieldAccess) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(Defines.Any) + val selfAst = scope + .lookupVariable(Defines.Self) + .map(selfParam => Ast(self).withRefEdge(self, selfParam)) + .getOrElse(Ast(self)) + callAst(fieldAccess, Seq(selfAst, Ast(fi))) + astForAssignment(typeRefIdent, typeRefNode, typeDecl.lineNumber, typeDecl.columnNumber) + else + Ast() + + protected def astsForFieldDeclarations(node: FieldsDeclaration): Seq[Ast] = + node.fieldNames.flatMap(astsForSingleFieldDeclaration(node, _)) + + private def astsForSingleFieldDeclaration( + node: FieldsDeclaration, + nameNode: RubyExpression + ): Seq[Ast] = + nameNode match + case nameAsSymbol: StaticLiteral if nameAsSymbol.isSymbol => + val fieldName = nameAsSymbol.innerText.prepended('@') + val memberNode_ = memberNode(nameAsSymbol, fieldName, code(node), Defines.Any) + val memberAst = Ast(memberNode_) + val getterAst = + Option.when(node.hasGetter)(astForGetterMethod(node, fieldName)).getOrElse(Nil) + val setterAst = + Option.when(node.hasSetter)(astForSetterMethod(node, fieldName)).getOrElse(Nil) + Seq(memberAst) ++ getterAst ++ setterAst + case nameAsIdent: SimpleIdentifier => + val fieldName = nameAsIdent.span.text.prepended('@') + val memberNode_ = memberNode(nameAsIdent, fieldName, code(node), Defines.Any) + val memberAst = Ast(memberNode_) + val getterAst = + Option.when(node.hasGetter)(astForGetterMethod(node, fieldName)).getOrElse(Nil) + val setterAst = + Option.when(node.hasSetter)(astForSetterMethod(node, fieldName)).getOrElse(Nil) + Seq(memberAst) ++ getterAst ++ setterAst + case _ => + Seq() + + // creates a `def () { return }` METHOD, for = @. + private def astForGetterMethod(node: FieldsDeclaration, fieldName: String): Seq[Ast] = + val name = fieldName.drop(1) + val code = s"def $name (...)" + val methodDecl = MethodDeclaration( + name, + Nil, + StatementList(InstanceFieldIdentifier()(node.span.spanStart(fieldName)) :: Nil)( + node.span.spanStart(s"return $fieldName") + ) + )(node.span.spanStart(code)) + astForMethodDeclaration(methodDecl, useSurroundingTypeFullName = true) + + // creates a `def =(x) { = x }` METHOD, for = @ + private def astForSetterMethod(node: FieldsDeclaration, fieldName: String): Seq[Ast] = + val name = fieldName.drop(1) + "=" + val code = s"def $name (...)" + val assignment = SingleAssignment( + InstanceFieldIdentifier()(node.span.spanStart(fieldName)), + "=", + SimpleIdentifier()(node.span.spanStart("x")) + )(node.span.spanStart(s"$fieldName = x")) + val methodDecl = MethodDeclaration( + name, + MandatoryParameter("x")(node.span.spanStart("x")) :: Nil, + StatementList(assignment :: Nil)(node.span.spanStart(s"return $fieldName")) + )(node.span.spanStart(code)) + astForMethodDeclaration(methodDecl, useSurroundingTypeFullName = true) +end AstForTypesCreator diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/RubyIntermediateAst.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/RubyIntermediateAst.scala new file mode 100644 index 00000000..55f3c7fa --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/astcreation/RubyIntermediateAst.scala @@ -0,0 +1,646 @@ +package io.appthreat.ruby2atom.astcreation + +import io.appthreat.ruby2atom.passes.{Defines, GlobalTypes} +import io.shiftleft.codepropertygraph.generated.nodes.NewNode + +import java.util.Objects + +object RubyIntermediateAst: + + case class TextSpan( + line: Option[Integer], + column: Option[Integer], + lineEnd: Option[Integer], + columnEnd: Option[Integer], + offset: Option[(Integer, Integer)], + text: String + ): + def spanStart(newText: String = ""): TextSpan = + TextSpan(line, column, line, column, offset, newText) + + /** Most-if-not-all constructs in Ruby evaluate to some value, so we name the base class + * `RubyExpression`. + */ + sealed class RubyExpression(val span: TextSpan): + def line: Option[Integer] = span.line + + def column: Option[Integer] = span.column + + def lineEnd: Option[Integer] = span.lineEnd + + def columnEnd: Option[Integer] = span.columnEnd + + def offset: Option[(Integer, Integer)] = span.offset + + def text: String = span.text + + override def hashCode(): Int = Objects.hash(span) + + override def equals(obj: Any): Boolean = + obj match + case o: RubyExpression => o.span == span + case _ => false + + /** Ruby statements evaluate to some value (and thus are expressions), but also perform some + * operation, e.g., assignments, method definitions, etc. + */ + sealed trait RubyStatement extends RubyExpression + + implicit class RubyExpressionHelper(node: RubyExpression): + def asStatementList: StatementList = node match + case stmtList: StatementList => stmtList + case _ => StatementList(List(node))(node.span) + + final case class Unknown()(span: TextSpan) extends RubyExpression(span) + + final case class StatementList(statements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with RubyStatement: + override def text: String = statements.size match + case 0 | 1 => span.text + case _ => "(...)" + + def size: Int = statements.size + + final case class SingletonStatementList(statements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with RubyStatement: + override def text: String = statements.size match + case 0 | 1 => span.text + case _ => "(...)" + + def size: Int = statements.size + + sealed trait AllowedTypeDeclarationChild + + sealed trait TypeDeclaration extends AllowedTypeDeclarationChild with RubyStatement: + def name: RubyExpression + def baseClass: Option[RubyExpression] + def body: RubyExpression + def bodyMemberCall: Option[TypeDeclBodyCall] + + sealed trait NamespaceDeclaration extends RubyStatement: + def namespaceParts: Option[List[String]] + + final case class ModuleDeclaration( + name: RubyExpression, + body: RubyExpression, + fields: List[RubyExpression & RubyFieldIdentifier], + bodyMemberCall: Option[TypeDeclBodyCall], + namespaceParts: Option[List[String]] + )(span: TextSpan) + extends RubyExpression(span) + with TypeDeclaration + with NamespaceDeclaration: + def baseClass: Option[RubyExpression] = None + + final case class ClassDeclaration( + name: RubyExpression, + baseClass: Option[RubyExpression], + body: RubyExpression, + fields: List[RubyExpression & RubyFieldIdentifier], + bodyMemberCall: Option[TypeDeclBodyCall], + namespaceParts: Option[List[String]] + )(span: TextSpan) + extends RubyExpression(span) + with TypeDeclaration + with NamespaceDeclaration + + sealed trait AnonymousTypeDeclaration extends RubyExpression with TypeDeclaration + + final case class AnonymousClassDeclaration( + name: RubyExpression, + baseClass: Option[RubyExpression], + body: RubyExpression, + bodyMemberCall: Option[TypeDeclBodyCall] = None + )(span: TextSpan) + extends RubyExpression(span) + with AnonymousTypeDeclaration + + final case class SingletonClassDeclaration( + name: RubyExpression, + baseClass: Option[RubyExpression], + body: RubyExpression, + bodyMemberCall: Option[TypeDeclBodyCall] = None + )(span: TextSpan) + extends RubyExpression(span) + with AnonymousTypeDeclaration + + final case class FieldsDeclaration(fieldNames: List[RubyExpression], accessType: String)( + span: TextSpan + ) extends RubyExpression(span) + with AllowedTypeDeclarationChild: + def hasGetter: Boolean = text.startsWith("attr_reader") || text.startsWith("attr_accessor") + def hasSetter: Boolean = text.startsWith("attr_writer") || text.startsWith("attr_accessor") + + def isSplattingFieldDecl: Boolean = + fieldNames.length == 1 && fieldNames.head.isInstanceOf[SplattingRubyNode] + + sealed trait ProcedureDeclaration extends RubyStatement: + def methodName: String + def parameters: List[RubyExpression] + def body: RubyExpression + + final case class MethodDeclaration( + methodName: String, + parameters: List[RubyExpression], + body: RubyExpression + )( + span: TextSpan + ) extends RubyExpression(span) + with ProcedureDeclaration + with AllowedTypeDeclarationChild + + final case class SingletonMethodDeclaration( + target: RubyExpression, + methodName: String, + parameters: List[RubyExpression], + body: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + with ProcedureDeclaration + with AllowedTypeDeclarationChild + + final case class SingletonObjectMethodDeclaration( + methodName: String, + parameters: List[RubyExpression], + body: RubyExpression, + baseClass: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + with ProcedureDeclaration + + sealed trait MethodParameter: + def name: String + + final case class MandatoryParameter(name: String)(span: TextSpan) extends RubyExpression(span) + with MethodParameter: + def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier()(span) + + final case class OptionalParameter(name: String, defaultExpression: RubyExpression)( + span: TextSpan + ) extends RubyExpression(span) + with MethodParameter + + final case class GroupedParameter( + name: String, + tmpParam: RubyExpression, + multipleAssignment: GroupedParameterDesugaring + )(span: TextSpan) + extends RubyExpression(span) + with MethodParameter + + sealed trait CollectionParameter extends MethodParameter + + final case class ArrayParameter(name: String)(span: TextSpan) extends RubyExpression(span) + with CollectionParameter + + final case class HashParameter(name: String)(span: TextSpan) extends RubyExpression(span) + with CollectionParameter + + final case class ProcParameter(name: String)(span: TextSpan) extends RubyExpression(span) + with MethodParameter + + final case class SingleAssignment(lhs: RubyExpression, op: String, rhs: RubyExpression)( + span: TextSpan + ) extends RubyExpression(span) + with RubyStatement + + trait MultipleAssignment extends RubyStatement: + def assignments: List[SingleAssignment] + + final case class OperatorAssignment(lhs: RubyExpression, op: String, rhs: RubyExpression)( + span: TextSpan + ) extends RubyExpression(span) + with RubyStatement + with ControlFlowStatement + + final case class DefaultMultipleAssignment(assignments: List[SingleAssignment])(span: TextSpan) + extends RubyExpression(span) + with MultipleAssignment + + final case class GroupedParameterDesugaring(assignments: List[SingleAssignment])(span: TextSpan) + extends RubyExpression(span) + with MultipleAssignment + + final case class SplattingRubyNode(target: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + + final case class AttributeAssignment( + target: RubyExpression, + op: String, + attributeName: String, + assignmentOperator: String, + rhs: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + + /** Any structure that conditionally modifies the control flow of the program. These also behave + * as statements. + */ + sealed trait ControlFlowStatement extends RubyStatement + + /** A control structure's clause, which may contain an additional control structures. + */ + sealed trait ControlFlowClause + + /** Any structure that is an Identifier, except self. e.g. `a`, `@a`, `@@a` + */ + sealed trait RubyIdentifier extends RubyExpression: + override def toString: String = span.text + + /** Ruby Instance or Class Variable Identifiers: `@a`, `@@a` + */ + sealed trait RubyFieldIdentifier extends RubyIdentifier: + def toMemberAccess: MemberAccess = + MemberAccess(SelfIdentifier()(span), ".", span.text)(span) + + sealed trait SingletonMethodIdentifier + + final case class RescueExpression( + body: RubyExpression, + rescueClauses: List[RescueClause], + elseClause: Option[ElseClause], + ensureClause: Option[EnsureClause] + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement + + final case class RescueClause( + exceptionClassList: Option[RubyExpression], + variables: Option[RubyExpression], + thenClause: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowClause + + final case class EnsureClause(thenClause: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowClause + + final case class WhileExpression(condition: RubyExpression, body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement + + final case class DoWhileExpression(condition: RubyExpression, body: RubyExpression)( + span: TextSpan + ) extends RubyExpression(span) + with ControlFlowStatement + + final case class UntilExpression(condition: RubyExpression, body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement + + final case class IfExpression( + condition: RubyExpression, + thenClause: RubyExpression, + elsifClauses: List[RubyExpression], + elseClause: Option[RubyExpression] + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement + with RubyStatement + + final case class ElsIfClause(condition: RubyExpression, thenClause: RubyExpression)( + span: TextSpan + ) extends RubyExpression(span) + with ControlFlowClause + + final case class ElseClause(thenClause: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowClause + + final case class UnlessExpression( + condition: RubyExpression, + trueBranch: RubyExpression, + falseBranch: Option[RubyExpression] + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement + + final case class ForExpression( + forVariable: RubyExpression, + iterableVariable: RubyExpression, + doBlock: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement + + final case class CaseExpression( + expression: Option[RubyExpression], + matchClauses: List[RubyExpression], + elseClause: Option[RubyExpression] + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowStatement + + final case class WhenClause( + matchExpressions: List[RubyExpression], + matchSplatExpression: Option[RubyExpression], + thenClause: RubyExpression + )(span: TextSpan) + extends RubyExpression(span) + with ControlFlowClause + + final case class InClause(pattern: RubyExpression, body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with ControlFlowClause + + final case class ArrayPattern(children: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + + final case class MatchVariable()(span: TextSpan) extends RubyExpression(span) + + final case class NextExpression()(span: TextSpan) extends RubyExpression(span) + with ControlFlowStatement + + final case class BreakExpression()(span: TextSpan) extends RubyExpression(span) + with ControlFlowStatement + + final case class ReturnExpression(expressions: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with RubyStatement + + /** Represents an unqualified identifier e.g. `X`, `x`, `@@x`, `$x`, `$<`, etc. */ + final case class SimpleIdentifier(typeFullName: Option[String] = None)(span: TextSpan) + extends RubyExpression(span) + with RubyIdentifier + with SingletonMethodIdentifier: + override def toString: String = s"SimpleIdentifier(${span.text}, $typeFullName)" + + /** Represents a type reference successfully determined, e.g. module A; end; A + */ + final case class TypeIdentifier(typeFullName: String)(span: TextSpan) + extends RubyExpression(span) + with RubyIdentifier: + def isBuiltin: Boolean = typeFullName.startsWith(GlobalTypes.builtinPrefix) + override def toString: String = s"TypeIdentifier(${span.text}, $typeFullName)" + + /** Represents a InstanceFieldIdentifier e.g `@x` */ + final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyExpression(span) + with RubyFieldIdentifier + + /** Represents a ClassFieldIdentifier e.g `@@x` */ + final case class ClassFieldIdentifier()(span: TextSpan) extends RubyExpression(span) + with RubyFieldIdentifier + + final case class SelfIdentifier()(span: TextSpan) extends RubyExpression(span) + with SingletonMethodIdentifier + + /** Represents some kind of literal expression. + */ + sealed trait LiteralExpr: + def typeFullName: String + + /** Represents a non-interpolated literal. */ + final case class StaticLiteral(typeFullName: String)(span: TextSpan) extends RubyExpression(span) + with LiteralExpr: + def isSymbol: Boolean = text.startsWith(":") + + def isString: Boolean = text.startsWith("\"") || text.startsWith("'") + + def innerText: String = + val strRegex = "['\"]([./:]{0,3}[\\w\\d_-]+)(?:\\.rb)?['\"]".r + text match + case s":'$content'" => content + case s":$symbol" => symbol + case strRegex(content) if content != null => content + case s => s + + final case class DynamicLiteral(typeFullName: String, expressions: List[RubyExpression])( + span: TextSpan + ) extends RubyExpression(span) + with LiteralExpr + + final case class RangeExpression( + lowerBound: RubyExpression, + upperBound: RubyExpression, + rangeOperator: RangeOperator + )(span: TextSpan) + extends RubyExpression(span) + + final case class RangeOperator(exclusive: Boolean)(span: TextSpan) extends RubyExpression(span) + + final case class ArrayLiteral(elements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with LiteralExpr: + def isSymbolArray: Boolean = text.take(2).toLowerCase.startsWith("%i") + + def isStringArray: Boolean = text.take(2).toLowerCase.startsWith("%w") + + def isDynamic: Boolean = text.take(2).startsWith("%I") || text.take(2).startsWith("%W") + + def isStatic: Boolean = !isDynamic + + def typeFullName: String = Defines.getBuiltInType(Defines.Array) + + sealed trait HashLike extends RubyExpression with LiteralExpr: + def elements: List[RubyExpression] + def typeFullName: String = Defines.getBuiltInType(Defines.Hash) + + final case class HashLiteral(elements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with HashLike + + final case class Association(key: RubyExpression, value: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + + final case class AssociationList(elements: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + with HashLike + + /** Represents a call. + */ + sealed trait RubyCall extends RubyExpression: + def target: RubyExpression + def arguments: List[RubyExpression] + def withBlock(block: Block): RubyCallWithBlock[?] = + SimpleCallWithBlock(target, arguments, block)(span) + + /** Represents traditional calls, e.g. `foo`, `foo x, y`, `foo(x,y)` */ + final case class SimpleCall(target: RubyExpression, arguments: List[RubyExpression])( + span: TextSpan + ) extends RubyExpression(span) + with RubyCall + + final case class RequireCall( + target: RubyExpression, + argument: RubyExpression, + isRelative: Boolean = false, + isWildCard: Boolean = false + )(span: TextSpan) + extends RubyExpression(span) + with RubyCall: + def arguments: List[RubyExpression] = List(argument) + def asSimpleCall: SimpleCall = SimpleCall(target, arguments)(span) + + final case class IncludeCall(target: RubyExpression, argument: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with RubyCall: + def arguments: List[RubyExpression] = List(argument) + def asSimpleCall: SimpleCall = SimpleCall(target, arguments)(span) + + final case class RaiseCall(target: RubyExpression, arguments: List[RubyExpression])( + span: TextSpan + ) extends RubyExpression(span) + with RubyCall + + sealed trait AccessModifier extends AllowedTypeDeclarationChild: + def toSimpleIdentifier: SimpleIdentifier + + sealed trait MethodAccessModifier extends AllowedTypeDeclarationChild: + def toSimpleIdentifier: SimpleIdentifier + def method: RubyExpression + + final case class PublicModifier()(span: TextSpan) extends RubyExpression(span) + with AccessModifier: + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span) + + final case class PrivateModifier()(span: TextSpan) extends RubyExpression(span) + with AccessModifier: + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span) + + final case class ProtectedModifier()(span: TextSpan) extends RubyExpression(span) + with AccessModifier: + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span) + + final case class PrivateMethodModifier(method: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with MethodAccessModifier: + override def toSimpleIdentifier: SimpleIdentifier = + SimpleIdentifier(None)(span.spanStart("private_class_method")) + + final case class PublicMethodModifier(method: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with MethodAccessModifier: + override def toSimpleIdentifier: SimpleIdentifier = + SimpleIdentifier(None)(span.spanStart("public_class_method")) + + /** Represents standalone `proc { ... }` or `lambda { ... }` expressions + */ + final case class ProcOrLambdaExpr(block: Block)(span: TextSpan) extends RubyExpression(span) + + final case class YieldExpr(arguments: List[RubyExpression])(span: TextSpan) + extends RubyExpression(span) + + /** Represents a call with a block argument. + */ + sealed trait RubyCallWithBlock[C <: RubyCall] extends RubyCall: + + def block: Block + + def withoutBlock: RubyExpression & C + + final case class SimpleCallWithBlock( + target: RubyExpression, + arguments: List[RubyExpression], + block: Block + )( + span: TextSpan + ) extends RubyExpression(span) + with RubyCallWithBlock[SimpleCall]: + def withoutBlock: SimpleCall = SimpleCall(target, arguments)(span) + + /** Represents member calls, e.g. `x.y(z,w)` */ + final case class MemberCall( + target: RubyExpression, + op: String, + methodName: String, + arguments: List[RubyExpression] + )( + span: TextSpan + ) extends RubyExpression(span) + with RubyCall: + override def withBlock(block: Block): RubyCallWithBlock[?] = + MemberCallWithBlock(target, op, methodName, arguments, block)(span) + + /** Special class for `` calls of type decls. + */ + final case class TypeDeclBodyCall(target: RubyExpression, typeName: String)(span: TextSpan) + extends RubyExpression(span) + with RubyCall: + + def toMemberCall: MemberCall = MemberCall(target, op, Defines.TypeDeclBody, arguments)(span) + + def arguments: List[RubyExpression] = Nil + + def op: String = "::" + + final case class MemberCallWithBlock( + target: RubyExpression, + op: String, + methodName: String, + arguments: List[RubyExpression], + block: Block + )(span: TextSpan) + extends RubyExpression(span) + with RubyCallWithBlock[MemberCall]: + def withoutBlock: MemberCall = MemberCall(target, op, methodName, arguments)(span) + + /** Represents index accesses, e.g. `x[0]`, `self.x.y[1, 2]` */ + final case class IndexAccess(target: RubyExpression, indices: List[RubyExpression])( + span: TextSpan + ) extends RubyExpression(span) + + final case class MemberAccess(target: RubyExpression, op: String, memberName: String)( + span: TextSpan + ) extends RubyExpression(span): + override def toString: String = s"${target.text}${op}$memberName" + + /** A Ruby node that instantiates objects. + */ + sealed trait ObjectInstantiation extends RubyCall + + final case class SimpleObjectInstantiation( + target: RubyExpression, + arguments: List[RubyExpression] + )(span: TextSpan) + extends RubyExpression(span) + with ObjectInstantiation: + override def withBlock(block: Block): RubyCallWithBlock[SimpleObjectInstantiation] = + ObjectInstantiationWithBlock(target, arguments, block)(span) + + final case class ObjectInstantiationWithBlock( + target: RubyExpression, + arguments: List[RubyExpression], + block: Block + )( + span: TextSpan + ) extends RubyExpression(span) + with ObjectInstantiation + with RubyCallWithBlock[SimpleObjectInstantiation]: + def withoutBlock: SimpleObjectInstantiation = SimpleObjectInstantiation(target, arguments)(span) + + /** Represents a `do` or `{ .. }` (braces) block. */ + final case class Block(parameters: List[RubyExpression], body: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + with RubyStatement: + + def toStatementList: StatementList = StatementList(body :: Nil)(span) + + def toMethodDeclaration( + name: String, + parameters: Option[List[RubyExpression]] + ): MethodDeclaration = + parameters match + case Some(givenParameters) => MethodDeclaration(name, givenParameters, body)(span) + case None => MethodDeclaration(name, this.parameters, body)(span) + + /** A dummy class for wrapping around `NewNode` and allowing it to integrate with RubyNode + * classes. + */ + final case class DummyNode(node: NewNode)(span: TextSpan) extends RubyExpression(span) + + final case class UnaryExpression(op: String, expression: RubyExpression)(span: TextSpan) + extends RubyExpression(span) + + final case class BinaryExpression(lhs: RubyExpression, op: String, rhs: RubyExpression)( + span: TextSpan + ) extends RubyExpression(span) + + final case class HereDocNode(content: String)(span: TextSpan) extends RubyExpression(span) + + final case class AliasStatement(oldName: String, newName: String)(span: TextSpan) + extends RubyExpression(span) + with AllowedTypeDeclarationChild +end RubyIntermediateAst diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyProgramSummary.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyProgramSummary.scala new file mode 100644 index 00000000..785aac73 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyProgramSummary.scala @@ -0,0 +1,199 @@ +package io.appthreat.ruby2atom.datastructures + +import better.files.File +import io.appthreat.x2cpg.Defines as XDefines +import io.appthreat.x2cpg.datastructures.{ + FieldLike, + MethodLike, + ProgramSummary, + StubbedType, + TypeLike +} +import io.appthreat.x2cpg.typestub.{TypeStubMetaData, TypeStubUtil} +import io.appthreat.ruby2atom.passes.Defines +import upickle.default.* + +import java.io.{ByteArrayInputStream, InputStream} +import java.util.zip.ZipInputStream +import scala.annotation.targetName +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.util.{Failure, Success, Try} + +type NamespaceToTypeMap = mutable.Map[String, mutable.Set[RubyType]] + +class RubyProgramSummary( + initialNamespaceMap: NamespaceToTypeMap = mutable.Map.empty, + initialPathMap: NamespaceToTypeMap = mutable.Map.empty +) extends ProgramSummary[RubyType, RubyMethod, RubyField]: + + override val namespaceToType: NamespaceToTypeMap = initialNamespaceMap + val pathToType: NamespaceToTypeMap = initialPathMap + + @targetName("appendAll") + def ++=(other: RubyProgramSummary): RubyProgramSummary = + RubyProgramSummary( + ProgramSummary.merge(this.namespaceToType, other.namespaceToType), + ProgramSummary.merge(this.pathToType, other.pathToType) + ) + +object RubyProgramSummary: + + def BuiltinTypes(implicit typeStubMetaData: TypeStubMetaData): NamespaceToTypeMap = + val typeStubDir = File(typeStubMetaData.packagePath) + if !typeStubDir.exists || !typeStubDir.isDirectory then + mutable.Map.empty + else if typeStubMetaData.useTypeStubs then + mpkZipToInitialMapping(mergeBuiltinMpkZip) match + case Failure(exception) => mutable.Map.empty + case Success(mapping) => mapping + else + mutable.Map.empty + + private def mpkZipToInitialMapping(inputStream: InputStream): Try[NamespaceToTypeMap] = + Try(readBinary[NamespaceToTypeMap](inputStream.readAllBytes())) + + private def mergeBuiltinMpkZip(implicit typeStubMetaData: TypeStubMetaData): InputStream = + val classLoader = getClass.getClassLoader + val typeStubDir = TypeStubUtil.typeStubDir + + val typeStubFiles: Seq[File] = + typeStubDir + .walk() + .filter(f => + f.isRegularFile && f.name.startsWith("rubysrc") && f.`extension`.contains(".zip") + ) + .toSeq + + if typeStubFiles.isEmpty then + InputStream.nullInputStream() + else + val mergedMpksObj = ListBuffer[collection.mutable.Map[String, Set[RubyStubbedType]]]() + typeStubFiles.foreach { f => + f.fileInputStream { fis => + val zis = new ZipInputStream(fis) + + LazyList.continually(zis.getNextEntry).takeWhile(_ != null).foreach { file => + val mpkObj = + upickle.default.readBinary[collection.mutable.Map[ + String, + Set[RubyStubbedType] + ]](zis.readAllBytes()) + mergedMpksObj.addOne(mpkObj) + } + } + } + + val mergedMpks = mergedMpksObj + .reduceOption((prev, curr) => + curr.keys.foreach(key => + prev.updateWith(key) { + case Some(x) => + Option(x ++ curr(key)) + case None => + Option(curr(key)) + } + ) + prev + ) + .getOrElse(collection.mutable.Map[String, Set[RubyStubbedType]]()) + + new ByteArrayInputStream(upickle.default.writeBinary(mergedMpks)) + end if + end mergeBuiltinMpkZip +end RubyProgramSummary + +case class RubyMethod( + name: String, + parameterTypes: List[(String, String)], + returnType: String, + baseTypeFullName: Option[String] +) extends MethodLike + +object RubyMethod: + implicit val rubyMethodRwJson: ReadWriter[RubyMethod] = readwriter[ujson.Value].bimap[RubyMethod]( + x => ujson.Obj("name" -> x.name), + json => + RubyMethod( + name = json("name").str, + parameterTypes = List.empty, + returnType = XDefines.Any, + baseTypeFullName = Option(json("name").str.split("\\.").dropRight(1).mkString(".")) + ) + ) + +case class RubyField(name: String, typeName: String) extends FieldLike derives ReadWriter + +class RubyStubbedType(name: String, methods: List[RubyMethod], fields: List[RubyField]) + extends RubyType(name, methods, fields) + with StubbedType[RubyMethod, RubyField] + +object RubyStubbedType: + implicit val rubyTypeRw: ReadWriter[RubyStubbedType] = + readwriter[ujson.Value].bimap[RubyStubbedType]( + x => + ujson.Obj( + "name" -> x.name, + "methods" -> x.methods.map { method => + ujson.Obj("name" -> method.name) + }, + "fields" -> x.fields.map { field => write[RubyField](field) } + ), + json => + RubyStubbedType( + name = json("name").str, + methods = json.obj.get("methods") match + case Some(jsonMethods) => + val methodsList = read[List[RubyMethod]](jsonMethods) + + methodsList.map { func => + val baseTypeFullName = json("name").str + + func.copy(name = func.name, baseTypeFullName = Option(baseTypeFullName)) + } + case None => Nil + , + fields = json.obj.get("fields").map(read[List[RubyField]](_)).getOrElse(Nil) + ) + ) +end RubyStubbedType + +case class RubyType(name: String, methods: List[RubyMethod], fields: List[RubyField]) + extends TypeLike[RubyMethod, RubyField]: + + @targetName("add") + override def +(o: TypeLike[RubyMethod, RubyField]): TypeLike[RubyMethod, RubyField] = + this.copy(methods = mergeMethods(o), fields = mergeFields(o)) + + def hasConstructor: Boolean = + methods.exists(_.name == Defines.Initialize) + +object RubyType: + implicit val rubyTypeRw: ReadWriter[RubyType] = readwriter[ujson.Value].bimap[RubyType]( + x => + ujson.Obj( + "name" -> x.name, + "methods" -> x.methods.map { method => + ujson.Obj("name" -> method.name) + }, + "fields" -> x.fields.map { field => write[RubyField](field) } + ), + json => + RubyType( + name = json("name").str, + methods = json.obj.get("methods") match + case Some(jsonMethods) => + val methodsList = read[List[RubyMethod]](jsonMethods) + + methodsList.map { func => + val splitName = func.name.split("\\.") + val baseTypeFullName = splitName.dropRight(1).mkString(".") + + func.copy(name = func.name, baseTypeFullName = Option(baseTypeFullName)) + } + case None => Nil + , + fields = json.obj.get("fields").map(read[List[RubyField]](_)).getOrElse(Nil) + ) + ) +end RubyType diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyScope.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyScope.scala new file mode 100644 index 00000000..26b4a87a --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/RubyScope.scala @@ -0,0 +1,417 @@ +package io.appthreat.ruby2atom.datastructures + +import better.files.File +import io.appthreat.ruby2atom.passes.GlobalTypes +import io.appthreat.ruby2atom.passes.GlobalTypes.builtinPrefix +import io.appthreat.x2cpg.Defines +import io.appthreat.x2cpg.datastructures.{TypedScopeElement, *} +import io.shiftleft.codepropertygraph.generated.NodeTypes +import io.shiftleft.codepropertygraph.generated.nodes.{ + DeclarationNew, + NewLocal, + NewMethodParameterIn +} +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal + +import java.io.File as JFile +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.Try + +class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) + extends Scope[String, DeclarationNew, TypedScopeElement] + with TypedScope[RubyMethod, RubyField, RubyType](summary): + + private val builtinMethods = GlobalTypes.kernelFunctions + .map(m => RubyMethod(m, List.empty, Defines.Any, Some(GlobalTypes.kernelPrefix))) + .toList + + override val typesInScope: mutable.Set[RubyType] = + mutable.Set(RubyType(GlobalTypes.kernelPrefix, builtinMethods, List.empty)) + + // Add some built-in methods that are significant + // TODO: Perhaps create an offline pre-built list of methods + typesInScope.addAll( + Seq( + RubyType( + s"$builtinPrefix.Array", + List(RubyMethod( + "[]", + List.empty, + s"$builtinPrefix.Array", + Option(s"$builtinPrefix.Array") + )), + List.empty + ), + RubyType( + s"$builtinPrefix.Hash", + List(RubyMethod("[]", List.empty, s"$builtinPrefix.Hash", Option(s"$builtinPrefix.Hash"))), + List.empty + ) + ) + ) + + override val membersInScope: mutable.Set[MemberLike] = mutable.Set(builtinMethods*) + + /** @return + * using the stack, will initialize a new module scope object. + */ + def newProgramScope: Option[ProgramScope] = + surroundingScopeFullName.map(_.stripSuffix(NamespaceTraversal.globalNamespaceName)).map( + ProgramScope.apply + ) + + /** @return + * true if the top of the stack is the program/module. + */ + def isSurroundedByProgramScope: Boolean = + stack + .take(2) + .filterNot { + case ScopeElement(BlockScope(_), _) => true + case _ => false + } + .headOption match + case Some(ScopeElement(ProgramScope(_), _)) => true + case _ => false + + def pushField(field: FieldDecl): Unit = + popScope().foreach { + case TypeScope(fullName, fields) => + pushNewScope(TypeScope(fullName, fields :+ field)) + case x => + pushField(field) + pushNewScope(x) + } + + def getFieldsInScope: List[FieldDecl] = + stack.collect { case ScopeElement(TypeScope(_, fields), _) => fields }.flatten + + def findFieldInScope(fieldName: String): Option[FieldDecl] = + getFieldsInScope.find(_.name == fieldName) + + override def pushNewScope(scopeNode: TypedScopeElement): Unit = + // Use the summary to determine if there is a constructor present + val mappedScopeNode = scopeNode match + case n: NamespaceLikeScope => + typesInScope.addAll(summary.typesUnderNamespace(n.fullName)) + n + case n: ProgramScope => + typesInScope.addAll(summary.typesUnderNamespace(n.fullName)) + n + case TypeScope(name, _) => + typesInScope.addAll(summary.matchingTypes(name)) + scopeNode + case _ => scopeNode + + super.pushNewScope(mappedScopeNode) + + /** Variables entering children scope persist into parent scopes, so the variables should be + * transferred to the top-level method, returning the next block so that locals can be attached. + */ + override def addToScope(identifier: String, variable: DeclarationNew): TypedScopeElement = + variable match + case _: NewMethodParameterIn => super.addToScope(identifier, variable) + case _ => + stack.collectFirst { + case x @ ScopeElement(_: MethodLikeScope, _) => x + case x @ ScopeElement(_: ProgramScope, _) => x + } match + case Some(target) => + val newTarget = target.addVariable(identifier, variable) + + val targetIdx = stack.indexOf(target) + val prefix = stack.take(targetIdx) + val suffix = stack.takeRight(stack.size - targetIdx - 1) + stack = prefix ++ List(newTarget) ++ suffix + prefix.lastOption.map(_.scopeNode).getOrElse(newTarget.scopeNode) + case None => super.addToScope(identifier, variable) + + def lookupVariableInOuterScope(identifier: String): List[DeclarationNew] = + stack.drop(1).collect { + case scopeElement if scopeElement.variables.contains(identifier) => + scopeElement.variables(identifier) + } + + def addRequire( + projectRoot: String, + currentFilePath: String, + requiredPath: String, + isRelative: Boolean, + isWildCard: Boolean = false + ): Unit = + val path = + requiredPath.stripSuffix( + ":" + ) // Sometimes the require call provides a processed path + // We assume the project root is the sole LOAD_PATH of the project sources + // NB: Tracking whatever has been added to $LOADER is dynamic and requires post-processing step! + val resolvedPath = + if isRelative then + Try((File(currentFilePath).parent / path).pathAsString).toOption + .map(_.stripPrefix(s"$projectRoot${JFile.separator}")) + .getOrElse(path) + else + path + + val pathsToImport = + if isWildCard then + val dir = File(projectRoot) / resolvedPath + if dir.isDirectory then + dir.list + .map( + _.pathAsString.stripPrefix(s"$projectRoot${JFile.separator}").stripSuffix(".rb").replaceAll( + "\\\\", + "/" + ) + ) + .toList + else Nil + else + resolvedPath :: Nil + + pathsToImport.foreach { pathName => + // Pull in type / module defs + summary.pathToType.getOrElse(pathName, Set()) match + case x if x.nonEmpty => + x.foreach { ty => addImportedTypeOrModule(ty.name) } + addImportedFunctions(pathName) + case _ => + addRequireGem(path) + } + end addRequire + + def addImportedFunctions(importName: String): Unit = + val matchingTypes = summary.namespaceToType.values.flatten.filter { x => + x.name.startsWith(importName) + } + + typesInScope.addAll(matchingTypes) + + def addInclude(typeOrModule: String): Unit = + addImportedMember(typeOrModule) + + def addRequireGem(gemName: String): Unit = + val matchingTypes = summary.namespaceToType.values.flatten.filter(_.name.startsWith(gemName)) + typesInScope.addAll(matchingTypes) + + /** @return + * the full name of the surrounding scope. + */ + def surroundingScopeFullName: Option[String] = stack.collectFirst { + case ScopeElement(x: NamespaceLikeScope, _) => x.fullName + case ScopeElement(x: TypeLikeScope, _) => x.fullName + case ScopeElement(x: MethodLikeScope, _) => x.fullName + } + + /** Locates a position in the stack matching a partial function, modifies it and emits a result + * @param pf + * Tests ScopeElements of the stack. If they match, return the new value and the result to emi + * @return + * the emitted result if the position was found and modifies + */ + def updateSurrounding[T]( + pf: PartialFunction[ + ScopeElement[String, DeclarationNew, TypedScopeElement], + (ScopeElement[String, DeclarationNew, TypedScopeElement], T) + ] + ): Option[T] = + stack.zipWithIndex + .collectFirst { case (pf(elem, res), i) => + (elem, res, i) + } + .map { case (elem, res, i) => + stack = stack.updated(i, elem) + res + } + + /** Get the name of the implicit or explict proc param and mark the method scope as using the proc + * param + */ + def useProcParam: Option[String] = updateSurrounding { + case ScopeElement(MethodScope(fullName, param, _), variables) => + (ScopeElement(MethodScope(fullName, param, true), variables), param.fold(x => x, x => x)) + case ScopeElement(ConstructorScope(fullName, param, _), variables) => + ( + ScopeElement(ConstructorScope(fullName, param, true), variables), + param.fold(x => x, x => x) + ) + } + + /** Get the name of the implicit or explicit proc param */ + def anonProcParam: Option[String] = stack.collectFirst { + case ScopeElement(x: MethodLikeScope, _) if x.procParam.isLeft => + x.procParam match + case Left(param) => param + case Right(param) => + param // this is just so that we don't get a pattern match warning, but should never be triggered + } + + /** Set the name of explicit proc param */ + def setProcParam(param: String, paramNode: NewMethodParameterIn): Unit = updateSurrounding { + case ScopeElement(MethodScope(fullName, _, _), variables) => + ( + ScopeElement( + MethodScope(fullName, Right(param), true), + variables ++ Map(paramNode.name -> paramNode) + ), + () + ) + case ScopeElement(ConstructorScope(fullName, _, _), variables) => + ( + ScopeElement( + ConstructorScope(fullName, Right(param), true), + variables ++ Map(paramNode.name -> paramNode) + ), + () + ) + } + + /** If a proc param is used, provides the node to add to the AST. + */ + def procParamName: Option[NewMethodParameterIn] = + stack + .collectFirst { + case ScopeElement(x: MethodLikeScope, _) if x.hasYield => + x.procParam match + case Left(param) => param + case Right(param) => param + } + .flatMap(lookupVariable(_).collect { case p: NewMethodParameterIn => p }) + + def surroundingTypeFullName: Option[String] = + stack.collectFirst { case ScopeElement(x: TypeLikeScope, _) => + x.fullName + } + + /** Searches the surrounding classes for a class that matches the given value. Returns it if + * found. + */ + def getSurroundingType(value: String): Option[TypeLikeScope] = + stack + .collect { case ScopeElement(x: TypeLikeScope, _) => x } + .collectFirst { + case x: TypeLikeScope if x.fullName.split('.').toSeq.endsWith(value.split('.')) => + x + } + + /** @return + * the corresponding node label according to the scope element. + */ + def surroundingAstLabel: Option[String] = stack.collectFirst { + case ScopeElement(_: NamespaceLikeScope, _) => NodeTypes.NAMESPACE_BLOCK + case ScopeElement(_: ProgramScope, _) => NodeTypes.METHOD + case ScopeElement(_: TypeLikeScope, _) => NodeTypes.TYPE_DECL + case ScopeElement(_: MethodLikeScope, _) => NodeTypes.METHOD + } + + def surrounding[T <: TypedScopeElement](implicit tag: ClassTag[T]): Option[T] = + stack.collectFirst { + case ScopeElement(elem: T, _) => elem + } + + /** @return + * true if one should still generate a default constructor for the enclosing type decl. + */ + def shouldGenerateDefaultConstructor: Boolean = stack + .collectFirst { + case ScopeElement(_: ModuleScope, _) => false + case ScopeElement(x: TypeLikeScope, _) => + !typesInScope.find(_.name == x.fullName).exists(_.hasConstructor) + case _ => false + } + .getOrElse(false) + + /** When a singleton class is introduced into the scope, the base variable will now have the + * singleton's functionality mixed in. This method finds base variable and appends the singleton + * type. + * + * @param singletonClassName + * the singleton type full name. + * @param variableName + * the base variable + */ + def pushSingletonClassDeclaration(singletonClassName: String, variableName: String): Unit = + lookupVariable(variableName).foreach { + case local: NewLocal => + local.possibleTypes(local.possibleTypes :+ singletonClassName) + case param: NewMethodParameterIn => + param.possibleTypes(param.possibleTypes :+ singletonClassName) + case _ => + } + + override def typeForMethod(m: RubyMethod): Option[RubyType] = + typesInScope.find(t => Option(t.name) == m.baseTypeFullName).orElse { super.typeForMethod(m) } + + override def tryResolveTypeReference(typeName: String): Option[RubyType] = + val normalizedTypeName = typeName.replaceAll("::", ".") + + /** Given a typeName, attempts to resolve full name using internal types currently in scope + * @param typeName + * the shorthand name + * @return + * the type meta-data if found + */ + def tryResolveInternalTypeReference(typeName: String): Option[RubyType] = + typesInScope.collectFirst { + case typ + if !typ.isInstanceOf[RubyStubbedType] && typ.name.split("[.]").endsWith( + typeName.split("[.]") + ) => typ + } + + /** Given a typeName, attempts to resolve full name using stubbed types currently in scope + * @param typeName + * the shorthand name + * @return + * the type meta-data if found + */ + def tryResolveStubbedTypeReference(typeName: String): Option[RubyType] = + typesInScope.collectFirst { + case typ + if typ.isInstanceOf[RubyStubbedType] && typ.name.split("[.]").endsWith( + typeName.split("[.]") + ) => typ + } + + // TODO: While we find better ways to understand how the implicit class loading works, + // we can approximate that all types are in scope in the mean time. + tryResolveInternalTypeReference(typeName) + .orElse(tryResolveStubbedTypeReference(typeName)) + .orElse { + super.tryResolveTypeReference(normalizedTypeName) match + case None if GlobalTypes.kernelFunctions.contains(normalizedTypeName) => + Option(RubyType( + s"${GlobalTypes.kernelPrefix}.$normalizedTypeName", + List.empty, + List.empty + )) + case None if GlobalTypes.bundledClasses.contains(normalizedTypeName) => + Option(RubyType( + s"${GlobalTypes.builtinPrefix}.$normalizedTypeName", + List.empty, + List.empty + )) + case None => + None + case x => x + } + end tryResolveTypeReference + + /** @param identifier + * the name of the variable. + * @return + * the full name of the variable's scope, if available. + */ + def variableScopeFullName(identifier: String): Option[String] = + stack + .collectFirst { + case scopeElement if scopeElement.variables.contains(identifier) => + scopeElement + } + .map { + case ScopeElement(x: NamespaceLikeScope, _) => x.fullName + case ScopeElement(x: TypeLikeScope, _) => x.fullName + case ScopeElement(x: MethodLikeScope, _) => x.fullName + } +end RubyScope diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/ScopeElement.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/ScopeElement.scala new file mode 100644 index 00000000..2769b5ed --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/datastructures/ScopeElement.scala @@ -0,0 +1,73 @@ +package io.appthreat.ruby2atom.datastructures + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{RubyFieldIdentifier, RubyExpression} +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.x2cpg.datastructures.{NamespaceLikeScope, TypedScopeElement} +import io.shiftleft.codepropertygraph.generated.nodes.NewBlock + +/** The namespace. + * @param fullName + * the namespace path. + */ +case class NamespaceScope(fullName: String) extends NamespaceLikeScope + +case class FieldDecl( + name: String, + typeFullName: String, + isStatic: Boolean, + isInitialized: Boolean, + node: RubyExpression & RubyFieldIdentifier +) extends TypedScopeElement + +/** A type-like scope with a full name. + */ +trait TypeLikeScope extends TypedScopeElement: + + /** @return + * the full name of the type-like. + */ + def fullName: String + +/** A file-level module. + * + * @param fileName + * the relative file name. + */ +case class ProgramScope(fileName: String) extends TypeLikeScope: + override def fullName: String = s"$fileName${Defines.Main}" + +/** A Ruby module/abstract class. + * @param fullName + * the type full name. + */ +case class ModuleScope(fullName: String) extends TypeLikeScope + +/** A class or interface. + * + * @param fullName + * the type full name. + */ +case class TypeScope(fullName: String, fields: List[FieldDecl]) extends TypeLikeScope + +/** Represents scope objects that map to a method node. + */ +trait MethodLikeScope extends TypedScopeElement: + def fullName: String + def procParam: Either[String, String] + def hasYield: Boolean + +case class MethodScope( + fullName: String, + procParam: Either[String, String], + hasYield: Boolean = false +) extends MethodLikeScope + +case class ConstructorScope( + fullName: String, + procParam: Either[String, String], + hasYield: Boolean = false +) extends MethodLikeScope + +/** Represents scope objects that map to a block node. + */ +case class BlockScope(block: NewBlock) extends TypedScopeElement diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyAstGenRunner.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyAstGenRunner.scala new file mode 100644 index 00000000..ef7e9c6e --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyAstGenRunner.scala @@ -0,0 +1,88 @@ +package io.appthreat.ruby2atom.parser + +import better.files.File +import io.appthreat.ruby2atom.Config +import io.appthreat.x2cpg.SourceFiles +import io.appthreat.x2cpg.astgen.AstGenRunner.{ + AstGenProgramMetaData, + AstGenRunnerResult, + DefaultAstGenRunnerResult +} +import io.appthreat.x2cpg.astgen.AstGenRunnerBase +import io.appthreat.x2cpg.utils.{Environment, ExternalCommand} + +import java.io.File.separator +import java.io.{ByteArrayOutputStream, InputStream, PrintStream} +import java.nio.file.{Files, Path, Paths, StandardCopyOption} +import java.util +import java.util.jar.JarFile +import scala.collection.mutable +import scala.jdk.CollectionConverters.* +import scala.util.{Failure, Success, Try, Using} + +class RubyAstGenRunner(config: Config) extends AstGenRunnerBase(config): + + override def fileFilter(file: String, out: File): Boolean = + file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match + case filePath if isIgnoredByUserConfig(filePath) => false + case filePath if isIgnoredByDefaultRegex(filePath) => false + case _ => true + + private def isIgnoredByDefaultRegex(filePath: String): Boolean = + config.defaultIgnoredFilesRegex.exists(_.matches(filePath)) + + override def runAstGenNative(in: String, out: File, exclude: String, include: String)(implicit + metaData: AstGenProgramMetaData + ): AstGenRunnerResult = + val command = s"rbastgen -i $in -o ${out.pathAsString}" + val excludeArgs = if exclude.isEmpty then "" else s" -e '$exclude'" + ExternalCommand.run(s"$command$excludeArgs", in, true) match + case Success(result) => + val srcFiles = SourceFiles.determine( + out.pathAsString, + Set(".json"), + ignoredDefaultRegex = Option(config.defaultIgnoredFilesRegex), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) + ) + val parsed = filterFiles(srcFiles, out) + DefaultAstGenRunnerResult(parsed, List.empty) + case Failure(f) => + DefaultAstGenRunnerResult() + + override def execute(out: File): AstGenRunnerResult = + implicit val metaData: AstGenProgramMetaData = config.astGenMetaData + val combineIgnoreRegex = + if config.ignoredFilesRegex.toString().isEmpty && config.defaultIgnoredFilesRegex.toString.nonEmpty + then + config.defaultIgnoredFilesRegex.mkString("|") + else if config.ignoredFilesRegex.toString().nonEmpty && config.defaultIgnoredFilesRegex + .toString.isEmpty + then + config.ignoredFilesRegex.toString() + else if config.ignoredFilesRegex.toString().nonEmpty && config.defaultIgnoredFilesRegex + .toString().nonEmpty + then + s"((${config.ignoredFilesRegex.toString()})|(${config.defaultIgnoredFilesRegex.mkString("|")}))" + else + "" + + runAstGenNative(config.inputPath, out, combineIgnoreRegex, "") + + private sealed trait ExecutionEnvironment extends AutoCloseable: + def path: Path + + def close(): Unit = {} + + private case class TempDir(path: Path) extends ExecutionEnvironment: + + override def close(): Unit = + def cleanUpDir(f: Path): Unit = + if Files.isDirectory(f) then + Files.list(f).iterator.asScala.foreach(cleanUpDir) + Files.deleteIfExists(f) + + cleanUpDir(path) + + private case class LocalDir(path: Path) extends ExecutionEnvironment +end RubyAstGenRunner diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonAst.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonAst.scala new file mode 100644 index 00000000..0bbed7d6 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonAst.scala @@ -0,0 +1,207 @@ +package io.appthreat.ruby2atom.parser + +import io.shiftleft.codepropertygraph.generated.Operators +import upickle.default.* + +/** The JSON key values, in alphabetical order. + */ +object ParserKeys: + val Alias = "alias" + val Arguments = "arguments" + val As = "as" + val Base = "base" + val Body = "body" + val Bodies = "bodies" + val Call = "call" + val CallName = "call_name" + val CaseExpression = "case_expression" + val Children = "children" + val Code = "code" + val Collection = "collection" + val Condition = "condition" + val Conditions = "conditions" + val Def = "def" + val ElseClause = "else_clause" + val ElseBranch = "else_branch" + val End = "end" + val ExecList = "exec_list" + val ExecVar = "exec_var" + val FilePath = "file_path" + val Guard = "guard" + val Key = "key" + val Left = "left" + val Lhs = "lhs" + val MetaData = "meta_data" + val Name = "name" + val Op = "op" + val ParamIdx = "param_idx" + val Pattern = "pattern" + val RelFilePath = "rel_file_path" + val Receiver = "receiver" + val Right = "right" + val Rhs = "rhs" + val Statement = "statement" + val Start = "start" + val SuperClass = "superclass" + val ThenBranch = "then_branch" + val Type = "type" + val Value = "value" + val Values = "values" + val Variable = "variable" + val WhenClauses = "when_clauses" +end ParserKeys + +enum AstType(val name: String): + case Alias extends AstType("alias") + case And extends AstType("and") + case AndAssign extends AstType("and_asgn") + case Arg extends AstType("arg") + case Args extends AstType("args") + case Array extends AstType("array") + case ArrayPattern extends AstType("array_pattern") + case ArrayPatternWithTail extends AstType("array_pattern_with_tail") + case BackRef extends AstType("back_ref") + case Begin extends AstType("begin") + case Block extends AstType("block") + case BlockArg extends AstType("blockarg") + case BlockPass extends AstType("block_pass") + case BlockWithNumberedParams extends AstType("numblock") + case Break extends AstType("break") + case CaseExpression extends AstType("case") + case CaseMatchStatement extends AstType("case_match") + case ClassDefinition extends AstType("class") + case ClassVariable extends AstType("cvar") + case ClassVariableAssign extends AstType("cvasgn") + case ConstVariableAssign extends AstType("casgn") + case ConditionalSend extends AstType("csend") + case Defined extends AstType("defined?") + case DynamicString extends AstType("dstr") + case DynamicSymbol extends AstType("dsym") + case Ensure extends AstType("ensure") + case ExclusiveFlipFlop extends AstType("eflipflop") + case ExclusiveRange extends AstType("erange") + case ExecutableString extends AstType("xstr") + case False extends AstType("false") + case FindPattern extends AstType("find_pattern") + case Float extends AstType("float") + case ForStatement extends AstType("for") + case ForPostStatement extends AstType("for_post") + case ForwardArg extends AstType("forward_arg") + case ForwardArgs extends AstType("forward_args") + case ForwardedArgs extends AstType("forwarded_args") + case GlobalVariable extends AstType("gvar") + case GlobalVariableAssign extends AstType("gvasgn") + case Hash extends AstType("hash") + case HashPattern extends AstType("hash_pattern") + case Identifier extends AstType("ident") + case IfGuard extends AstType("if_guard") + case IfStatement extends AstType("if") + case InclusiveFlipFlop extends AstType("iflipflop") + case InclusiveRange extends AstType("irange") + case InPattern extends AstType("in_pattern") + case Int extends AstType("int") + case InstanceVariable extends AstType("ivar") + case InstanceVariableAssign extends AstType("ivasgn") + case KwArg extends AstType("kwarg") + case KwBegin extends AstType("kwbegin") + case KwNilArg extends AstType("kwnilarg") + case KwOptArg extends AstType("kwoptarg") + case KwRestArg extends AstType("kwrestarg") + case KwSplat extends AstType("kwsplat") + case LocalVariable extends AstType("lvar") + case LocalVariableAssign extends AstType("lvasgn") + case MatchAlt extends AstType("match_alt") + case MatchAs extends AstType("match_as") + case MatchNilPattern extends AstType("match_nil_pattern") + case MatchPattern extends AstType("match_pattern") + case MatchPatternP extends AstType("match_pattern_p") + case MatchRest extends AstType("match_rest") + case MatchVariable extends AstType("match_var") + case MatchWithLocalVariableAssign extends AstType("match_with_lvasgn") + case MethodDefinition extends AstType("def") + case ModuleDefinition extends AstType("module") + case MultipleAssignment extends AstType("masgn") + case MultipleLeftHandSide extends AstType("mlhs") + case Next extends AstType("next") + case Nil extends AstType("nil") + case NthRef extends AstType("nth_ref") + case OperatorAssign extends AstType("op_asgn") + case OptionalArgument extends AstType("optarg") + case Or extends AstType("or") + case OrAssign extends AstType("or_asgn") + case Pair extends AstType("pair") + case PostExpression extends AstType("postexe") + case PreExpression extends AstType("preexe") + case ProcArgument extends AstType("procarg0") + case Rational extends AstType("rational") + case Redo extends AstType("redo") + case Retry extends AstType("retry") + case Return extends AstType("return") + case RegexExpression extends AstType("regexp") + case RegexOption extends AstType("regopt") + case ResBody extends AstType("resbody") + case RestArg extends AstType("restarg") + case RescueStatement extends AstType("rescue") + case ScopedConstant extends AstType("const") + case Self extends AstType("self") + case Send extends AstType("send") + case ShadowArg extends AstType("shadowarg") + case SingletonMethodDefinition extends AstType("defs") + case SingletonClassDefinition extends AstType("sclass") + case Splat extends AstType("splat") + case StaticString extends AstType("str") + case StaticSymbol extends AstType("sym") + case Super extends AstType("super") + case SuperNoArgs extends AstType("zsuper") + case TopLevelConstant extends AstType("cbase") + case True extends AstType("true") + case UnDefine extends AstType("undef") + case UnlessExpression extends AstType("unless") + case UnlessGuard extends AstType("unless_guard") + case UntilExpression extends AstType("until") + case UntilPostExpression extends AstType("until_post") + case WhenStatement extends AstType("when") + case WhileStatement extends AstType("while") + case WhilePostStatement extends AstType("while_post") + case Yield extends AstType("yield") +end AstType + +object AstType: + def fromString(input: String): Option[AstType] = AstType.values.find(_.name == input) + +object BinaryOperators: + private val BinaryOperators: Set[String] = + Set( + "+", + "-", + "*", + "/", + "%", + "**", + "==", + "===", + "!=", + "<", + "<=", + ">", + ">=", + "<=>", + "&&", + "and", + "or", + "||", + "&", + "|", + "^", + // "<<" -> Operators.shiftLeft, Note: Generally Ruby abstracts this as an append operator based on the LHS + ">>" + ) + + def isBinaryOperatorName(op: String): Boolean = BinaryOperators.contains(op) +end BinaryOperators + +object UnaryOperators: + private val UnaryOperators: Set[String] = + Set("!", "not", "~", "+", "-") + + def isUnaryOperatorName(op: String): Boolean = UnaryOperators.contains(op) diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonHelpers.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonHelpers.scala new file mode 100644 index 00000000..f0ec4608 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonHelpers.scala @@ -0,0 +1,415 @@ +package io.appthreat.ruby2atom.parser + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{ + AllowedTypeDeclarationChild, + ArrayLiteral, + ClassFieldIdentifier, + DefaultMultipleAssignment, + FieldsDeclaration, + MemberAccess, + MethodDeclaration, + ProcedureDeclaration, + RubyExpression, + RubyFieldIdentifier, + SelfIdentifier, + SimpleCall, + SimpleIdentifier, + SingleAssignment, + SingletonClassDeclaration, + SingletonMethodDeclaration, + SplattingRubyNode, + StatementList, + StaticLiteral, + TextSpan, + TypeDeclBodyCall, + UnaryExpression +} +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.ruby2atom.passes.Defines.getBuiltInType +import upickle.core.* +import upickle.default.* + +object RubyJsonHelpers: + + implicit class JsonObjHelper(o: ujson.Obj): + + def toTextSpan: TextSpan = + val metaData = + if o.obj.contains(ParserKeys.MetaData) then read[MetaData](o(ParserKeys.MetaData)) + else read[MetaData](o) + + val offset = Option(metaData.offsetStart) -> Option(metaData.offsetEnd) match + case (Some(start), Some(end)) => Option(start -> end) + case _ => None + + TextSpan( + line = Option(metaData.lineNumber).filterNot(_ == -1), + column = Option(metaData.columnNumber).filterNot(_ == -1), + lineEnd = Option(metaData.lineNumberEnd).filterNot(_ == -1), + columnEnd = Option(metaData.columnNumberEnd).filterNot(_ == -1), + offset = offset, + text = metaData.code + ) + + def visitOption(key: String)(implicit + visit: ujson.Value => RubyExpression + ): Option[RubyExpression] = + if contains(key) then Option(visit(o(key))) else None + + def visitArray(key: String)(implicit + visit: ujson.Value => RubyExpression + ): List[RubyExpression] = + o(key).arr.map(visit).toList + + def contains(key: String): Boolean = o.obj.get(key).exists(x => x != null && x != ujson.Null) + end JsonObjHelper + + protected def nilLiteral(span: TextSpan): StaticLiteral = + StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil")) + + def createClassBodyAndFields( + obj: ujson.Obj + )(implicit + visit: ujson.Value => RubyExpression + ): (StatementList, List[RubyExpression & RubyFieldIdentifier]) = + + def bodyMethod(fieldStatements: List[RubyExpression]): MethodDeclaration = + val body = fieldStatements + .map { + case field: SimpleIdentifier => + val assignmentSpan = field.span.spanStart(s"${field.span.text} = nil") + SingleAssignment(ClassFieldIdentifier()(field.span), "=", nilLiteral(field.span))( + assignmentSpan + ) + case field: RubyFieldIdentifier => + val assignmentSpan = field.span.spanStart(s"${field.span.text} = nil") + SingleAssignment(field, "=", nilLiteral(field.span))(assignmentSpan) + case assignment @ SingleAssignment(_: RubyFieldIdentifier, _, _) => assignment + case assignment @ SingleAssignment(lhs: SimpleIdentifier, _, _) => + assignment.copy(lhs = ClassFieldIdentifier()(lhs.span))(assignment.span) + case otherExpr => otherExpr + } + .distinctBy { + case _ @SingleAssignment(lhs: RubyFieldIdentifier, _, _) => lhs.text + case x => x + } + + MethodDeclaration( + Defines.TypeDeclBody, + Nil, + StatementList(body)(obj.toTextSpan.spanStart(s"(...)")) + )( + obj.toTextSpan.spanStart(s"def ; (...); end") + ) + end bodyMethod + + /** @param expr + * An expression that is a direct child to a class or module. + * @return + * true if the expression constitutes field-related behaviour, false if otherwise. + */ + def isFieldStmt(expr: RubyExpression): Boolean = + expr match + case _: SingleAssignment => true + case _: SimpleIdentifier => true + case _: RubyFieldIdentifier => true + case _ => false + + /** @param expr + * An expression that is a direct child to a class or module. + * @return + * true if the expression is a Splatting Field Declaration (`attr_x(*foo)`), false otherwise. + */ + def isSplattingField(expr: RubyExpression): Boolean = + expr match + case x: FieldsDeclaration if x.isSplattingFieldDecl => true + case _: AllowedTypeDeclarationChild => false + case _ => false + + /** Extracts a field from the expression. + * @param expr + * An expression that is a direct child to a class or module. + */ + def getFields( + expr: RubyExpression, + typeDeclChildStatements: Boolean = true + ): List[RubyExpression & RubyFieldIdentifier] = + expr match + case field: SimpleIdentifier if typeDeclChildStatements => + ClassFieldIdentifier()(field.span) :: Nil + case field: RubyFieldIdentifier if typeDeclChildStatements => field :: Nil + case _ @SingleAssignment(lhs: RubyFieldIdentifier, _, _) => lhs :: Nil + case _ @SingleAssignment(lhs: SimpleIdentifier, _, _) if typeDeclChildStatements => + ClassFieldIdentifier()(lhs.span) :: Nil + case proc: ProcedureDeclaration => getFields(proc.body, false) + case _ @StatementList(stmts) => + stmts.flatMap(x => getFields(x, typeDeclChildStatements)).distinctBy(_.text) + case _ => Nil + + /** Attempts to evaluate and parse the collection associated with the splattingField, generating + * FieldDeclarations for each of the elements. + * @param fieldStmts + * List of all the field statements + * @param splattingFields + * List of splatting fields + * @return + * List of: + * - Some(_) => if splattingField either is evaluated to a list of FieldDeclarations, + * otherwise a SimpleCall + * - None => if splattingField cannot be evaluated to either FieldsDeclaration or SimpleCall + */ + def lowerSplattingFieldDecl( + fieldStmts: List[RubyExpression], + splattingFields: List[RubyExpression] + ): List[Option[RubyExpression]] = + splattingFields.flatMap { + case x @ FieldsDeclaration(fieldName :: Nil, accessType) + if fieldName.isInstanceOf[SplattingRubyNode] => + fieldStmts.map { + case _ @SingleAssignment(lhs: SimpleIdentifier, _, rhs: MemberAccess) + if rhs.memberName == "freeze" && lhs.span.text == fieldName.span.text + .stripPrefix("*") => + rhs.target match + case y: ArrayLiteral => + Some(FieldsDeclaration(y.elements, accessType)(x.span)) + case _ => None + case _ @SingleAssignment(_: SimpleIdentifier, _, rhs: ArrayLiteral) => + Some(FieldsDeclaration(rhs.elements, accessType)(x.span)) + case _ => + Some( + SimpleCall( + SimpleIdentifier()(x.span.spanStart(accessType)), + List(fieldName) + )( + x.span.spanStart(s"$accessType(${fieldName.span.text})") + ) + ) + } + case _ => None + } + + obj.visitOption(ParserKeys.Body).map(lowerSingletonClassDecls) match + case Some(stmtList @ StatementList(expression :: Nil)) + if expression.isInstanceOf[AllowedTypeDeclarationChild] => + if isSplattingField(expression) then + val splattingField = expression.asInstanceOf[FieldsDeclaration] + splattingField.fieldNames.headOption match + case Some(splattingFieldName) => + val nonExpandedSplattingFieldCall = + SimpleCall( + SimpleIdentifier()(expression.span.spanStart(splattingField.accessType)), + List(splattingFieldName) + )(expression.span.spanStart( + s"${splattingField.accessType}(${splattingFieldName.span.text})" + )) + ( + StatementList(bodyMethod(List(nonExpandedSplattingFieldCall)) :: Nil)( + stmtList.span + ), + getFields(expression) + ) + case None => + ( + StatementList(bodyMethod(Nil) :: expression :: Nil)(stmtList.span), + getFields(expression) + ) + end match + else + ( + StatementList(bodyMethod(Nil) :: expression :: Nil)(stmtList.span), + getFields(expression) + ) + case Some(stmtList @ StatementList(expression :: Nil)) if isFieldStmt(expression) => + ( + StatementList(bodyMethod(expression :: Nil) :: Nil)(stmtList.span), + getFields(expression) + ) + case Some(stmtList: StatementList) => + val (fieldStmts, otherStmts) = stmtList.statements.partition(isFieldStmt) + val (typeDeclStmts, bodyStmts) = + otherStmts.partition(_.isInstanceOf[AllowedTypeDeclarationChild]) + val (splattingFields, otherTypeDeclStmts) = typeDeclStmts.partition(isSplattingField) + val (expandedSplattingFields, nonExpandedSplattingFieldsCalls) = + lowerSplattingFieldDecl(fieldStmts, splattingFields) + .filter(_.isDefined) + .map(_.get) + .partition(_.isInstanceOf[FieldsDeclaration]) + + val fields = + (fieldStmts.flatMap(x => getFields(x)) ++ otherTypeDeclStmts.flatMap(x => + getFields(x) + )) + .distinctBy(_.text) + val body = stmtList.copy(statements = + bodyMethod( + fieldStmts ++ otherTypeDeclStmts.flatMap(x => + getFields(x) + ) ++ bodyStmts ++ nonExpandedSplattingFieldsCalls + ) +: (otherTypeDeclStmts ++ expandedSplattingFields) + )(stmtList.span) + + (body, fields) + case None => (StatementList(bodyMethod(Nil) :: Nil)(obj.toTextSpan.spanStart("")), Nil) + end match + end createClassBodyAndFields + + def createBodyMemberCall(name: String, textSpan: TextSpan): TypeDeclBodyCall = + TypeDeclBodyCall( + MemberAccess(SelfIdentifier()(textSpan.spanStart(Defines.Self)), "::", name)( + textSpan.spanStart(s"${Defines.Self}::$name") + ), + name + )(textSpan.spanStart(s"${Defines.Self}::$name::${Defines.TypeDeclBody}")) + + def getParts(memberAccess: MemberAccess): List[String] = + memberAccess.target match + case targetMemberAccess: MemberAccess => + getParts(targetMemberAccess) :+ memberAccess.memberName + case expr => expr.text :: memberAccess.memberName :: Nil + + def lowerMultipleAssignment( + obj: ujson.Obj, + lhsNodes: List[RubyExpression], + rhsNodes: List[RubyExpression], + defaultResult: () => RubyExpression, + nilResult: () => RubyExpression + ): RubyExpression = + + /** Recursively expand and duplicate splatting nodes so that they line up with what they + * consume. + * + * @param nodes + * the splat nodes. + * @param expandSize + * how many more duplicates to create. + */ + def slurp(nodes: List[RubyExpression], expandSize: Int): List[RubyExpression] = nodes match + case (head: SplattingRubyNode) :: tail if expandSize > 0 => + head :: slurp(head :: tail, expandSize - 1) + case head :: tail => head :: slurp(tail, expandSize) + case Nil => List.empty + val op = "=" + lazy val defaultAssignments = lhsNodes + .zipAll(rhsNodes, defaultResult(), nilResult()) + .map { case (lhs, rhs) => SingleAssignment(lhs, op, rhs)(obj.toTextSpan) } + + val assignments = if (lhsNodes ++ rhsNodes).exists(_.isInstanceOf[SplattingRubyNode]) then + rhsNodes.size - lhsNodes.size match + // Handle slurping the RHS values + case x if x > 0 => + val slurpedLhs = slurp(lhsNodes, x) + + slurpedLhs + .zip(rhsNodes) + .groupBy(_._1) + .toSeq + .map { case (lhsNode, xs) => lhsNode -> xs.map(_._2) } + .sortBy { x => + slurpedLhs.indexOf(x._1) + } // groupBy produces a map which discards insertion order + .map { + case (SplattingRubyNode(lhs), rhss) => + SingleAssignment(lhs, op, ArrayLiteral(rhss)(obj.toTextSpan))( + obj.toTextSpan + ) + case (lhs, rhs :: Nil) => SingleAssignment(lhs, op, rhs)(obj.toTextSpan) + case (lhs, rhss) => SingleAssignment( + lhs, + op, + ArrayLiteral(rhss)(obj.toTextSpan) + )(obj.toTextSpan) + } + .toList + // Handle splitting the RHS values + case x if x < 0 => + val slurpedRhs = slurp(rhsNodes, Math.abs(x)) + + lhsNodes + .zip(slurpedRhs) + .groupBy(_._2) + .toSeq + .map { case (rhsNode, xs) => rhsNode -> xs.map(_._1) } + .sortBy { x => + slurpedRhs.indexOf(x._1) + } // groupBy produces a map which discards insertion order + .flatMap { + case (SplattingRubyNode(rhs), lhss) => + lhss.map( + SingleAssignment(_, op, SplattingRubyNode(rhs)(rhs.span))(obj.toTextSpan) + ) + case (rhs, lhs :: Nil) => Seq(SingleAssignment(lhs, op, rhs)(obj.toTextSpan)) + case (rhs, lhss) => lhss.map(SingleAssignment( + _, + op, + SplattingRubyNode(rhs)(rhs.span) + )(obj.toTextSpan)) + } + .toList + case _ => defaultAssignments + else + val diff = rhsNodes.size - lhsNodes.size + if diff < 0 then defaultAssignments.dropRight(Math.abs(diff)) else defaultAssignments + DefaultMultipleAssignment(assignments)(obj.toTextSpan) + end lowerMultipleAssignment + + def infinityUpperBound(obj: ujson.Obj): MemberAccess = + MemberAccess( + SimpleIdentifier(Option(getBuiltInType(Defines.Float)))(obj.toTextSpan.spanStart("Float")), + "::", + "INFINITY" + )(obj.toTextSpan.spanStart("Float::INFINITY")) + + def infinityLowerBound(obj: ujson.Obj): UnaryExpression = + UnaryExpression( + "-", + MemberAccess( + SimpleIdentifier(Option(getBuiltInType(Defines.Float)))( + obj.toTextSpan.spanStart("Float") + ), + "::", + "INFINITY" + )(obj.toTextSpan.spanStart("Float::INFINITY")) + )(obj.toTextSpan.spanStart("-Float::INFINITY")) + + def lowerSingletonClassDecls(classBody: RubyExpression): StatementList = + val loweredStmts = classBody match + case x: StatementList => lowerSingletonClassDeclarations(x) + case x => lowerSingletonClassDeclarations(StatementList(List(x))(x.span)) + + val stmts = loweredStmts match + case StatementList(stmts) => stmts + case x => List(x) + + StatementList(stmts)(classBody.span) + + private def lowerSingletonClassDeclarations(classBody: RubyExpression): RubyExpression = + classBody match + case stmtList: StatementList => + StatementList(stmtList.statements.flatMap { + case _ @SingletonClassDeclaration( + _, + baseClass: Some[RubyExpression], + body: StatementList, + _ + ) => + body.statements.map { + case method @ MethodDeclaration(methodName, parameters, body) => + SingletonMethodDeclaration(baseClass.get, methodName, parameters, body)( + method.span + ) + case nonMethodStatement => nonMethodStatement + } + case nonStmtListBody => nonStmtListBody :: Nil + })(stmtList.span) + case nonStmtList => nonStmtList + + private case class MetaData( + code: String, + @upickle.implicits.key("start_line") lineNumber: Integer, + @upickle.implicits.key("start_column") columnNumber: Integer, + @upickle.implicits.key("end_line") lineNumberEnd: Integer, + @upickle.implicits.key("end_column") columnNumberEnd: Integer, + @upickle.implicits.key("offset_start") offsetStart: Integer, + @upickle.implicits.key("offset_end") offsetEnd: Integer + ) derives ReadWriter +end RubyJsonHelpers diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonParser.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonParser.scala new file mode 100644 index 00000000..a9d5c3d9 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonParser.scala @@ -0,0 +1,17 @@ +package io.appthreat.ruby2atom.parser + +import io.appthreat.x2cpg.astgen.ParserResult +import io.shiftleft.utils.IOUtils + +import java.nio.file.{Path, Paths} + +object RubyJsonParser: + + def readFile(file: Path): ParserResult = + val jsonContent = IOUtils.readLinesInFile(file).mkString + val json = ujson.read(jsonContent) + val fullFilePath = json(ParserKeys.FilePath).str + val filePath = Paths.get(fullFilePath) + val relFilePath = json(ParserKeys.RelFilePath).str + val sourceFileContent = IOUtils.readEntireFile(filePath) + ParserResult(relFilePath, filePath.toString, json, sourceFileContent) diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonToNodeCreator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonToNodeCreator.scala new file mode 100644 index 00000000..ce0e63fe --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/parser/RubyJsonToNodeCreator.scala @@ -0,0 +1,1090 @@ +package io.appthreat.ruby2atom.parser + +import io.appthreat.ruby2atom.astcreation.RubyIntermediateAst.{RubyExpression, *} +import io.appthreat.ruby2atom.parser.RubyJsonHelpers.* +import io.appthreat.ruby2atom.passes.Defines +import io.appthreat.ruby2atom.passes.Defines.{NilClass, RubyOperators, getBuiltInType} +import io.appthreat.ruby2atom.passes.GlobalTypes.builtinPrefix +import io.appthreat.ruby2atom.utils.FreshNameGenerator +import io.appthreat.x2cpg.frontendspecific.ruby2atom.ImportsPass +import io.appthreat.x2cpg.frontendspecific.ruby2atom.ImportsPass.ImportCallNames +import org.slf4j.LoggerFactory +import ujson.* + +class RubyJsonToNodeCreator( + variableNameGen: FreshNameGenerator[String] = FreshNameGenerator(id => s""), + procParamGen: FreshNameGenerator[Left[String, Nothing]] = + FreshNameGenerator(id => Left(s"")), + fileName: String = "" +): + + private val logger = LoggerFactory.getLogger(getClass) + private val classNameGen = FreshNameGenerator(id => s"") + + private implicit val implVisit: ujson.Value => RubyExpression = (x: ujson.Value) => visit(x) + + protected def freshClassName(span: TextSpan): SimpleIdentifier = + SimpleIdentifier(None)(span.spanStart(classNameGen.fresh)) + + private def defaultTextSpan(code: String = ""): TextSpan = + TextSpan(None, None, None, None, None, code) + + private def defaultResult(span: Option[TextSpan] = None): RubyExpression = + Unknown()(span.getOrElse(defaultTextSpan())) + + private def visit(v: ujson.Value): RubyExpression = + v match + case obj: ujson.Obj => visit(obj) + case ujson.Null => StatementList(Nil)(defaultTextSpan()) + case ujson.Str(x) => StaticLiteral(getBuiltInType(Defines.String))(defaultTextSpan(x)) + case x => + logger.warn(s"Unhandled ujson type ${x.getClass}") + defaultResult() + + /** Main entrypoint of JSON deserialization. + */ + def visitProgram(obj: ujson.Value): StatementList = + visit(obj.obj) match + case x: StatementList => x + case x => StatementList(x :: Nil)(x.span) + + private def visit(obj: ujson.Obj): RubyExpression = + + def visitAstType(typ: AstType): RubyExpression = + typ match + case AstType.Alias => visitAlias(obj) + case AstType.And => visitAnd(obj) + case AstType.AndAssign => visitAndAssign(obj) + case AstType.Arg => visitArg(obj) + case AstType.Args => visitArgs(obj) + case AstType.Array => visitArray(obj) + case AstType.ArrayPattern => visitArrayPattern(obj) + case AstType.ArrayPatternWithTail => visitArrayPatternWithTail(obj) + case AstType.BackRef => visitBackRef(obj) + case AstType.Begin => visitBegin(obj) + case AstType.Block => visitBlock(obj) + case AstType.BlockArg => visitBlockArg(obj) + case AstType.BlockPass => visitBlockPass(obj) + case AstType.BlockWithNumberedParams => visitBlockWithNumberedParams(obj) + case AstType.Break => visitBreak(obj) + case AstType.CaseExpression => visitCaseExpression(obj) + case AstType.CaseMatchStatement => visitCaseMatchStatement(obj) + case AstType.ClassDefinition => visitClassDefinition(obj) + case AstType.ClassVariable => visitClassVariable(obj) + case AstType.ClassVariableAssign => visitSingleAssignment(obj) + case AstType.ConstVariableAssign => visitSingleAssignment(obj) + case AstType.ConditionalSend => visitSend(obj, isConditional = true) + case AstType.Defined => visitDefined(obj) + case AstType.DynamicString => visitDynamicString(obj) + case AstType.DynamicSymbol => visitDynamicSymbol(obj) + case AstType.Ensure => visitEnsure(obj) + case AstType.ExclusiveFlipFlop => visitExclusiveFlipFlop(obj) + case AstType.ExclusiveRange => visitExclusiveRange(obj) + case AstType.ExecutableString => visitExecutableString(obj) + case AstType.False => visitFalse(obj) + case AstType.FindPattern => visitFindPattern(obj) + case AstType.Float => visitFloat(obj) + case AstType.ForStatement => visitForStatement(obj) + case AstType.ForPostStatement => visitForStatement(obj) + case AstType.ForwardArg => visitForwardArg(obj) + case AstType.ForwardArgs => visitForwardArgs(obj) + case AstType.ForwardedArgs => visitForwardedArgs(obj) + case AstType.GlobalVariable => visitGlobalVariable(obj) + case AstType.GlobalVariableAssign => visitGlobalVariableAssign(obj) + case AstType.Hash => visitHash(obj) + case AstType.HashPattern => visitHashPattern(obj) + case AstType.Identifier => visitIdentifier(obj) + case AstType.IfGuard => visitIfGuard(obj) + case AstType.IfStatement => visitIfStatement(obj) + case AstType.InclusiveFlipFlop => visitInclusiveFlipFlop(obj) + case AstType.InclusiveRange => visitInclusiveRange(obj) + case AstType.InPattern => visitInPattern(obj) + case AstType.Int => visitInt(obj) + case AstType.InstanceVariable => visitInstanceVariable(obj) + case AstType.InstanceVariableAssign => visitSingleAssignment(obj) + case AstType.KwArg => visitKwArg(obj) + case AstType.KwBegin => visitKwBegin(obj) + case AstType.KwNilArg => visitKwNilArg(obj) + case AstType.KwOptArg => visitKwOptArg(obj) + case AstType.KwRestArg => visitKwRestArg(obj) + case AstType.KwSplat => visitKwSplat(obj) + case AstType.LocalVariable => visitLocalVariable(obj) + case AstType.LocalVariableAssign => visitSingleAssignment(obj) + case AstType.MatchAlt => visitMatchAlt(obj) + case AstType.MatchAs => visitMatchAs(obj) + case AstType.MatchNilPattern => visitMatchNilPattern(obj) + case AstType.MatchPattern => visitMatchPattern(obj) + case AstType.MatchPatternP => visitMatchPatternP(obj) + case AstType.MatchRest => visitMatchRest(obj) + case AstType.MatchVariable => visitMatchVariable(obj) + case AstType.MatchWithLocalVariableAssign => visitMatchWithLocalVariableAssign(obj) + case AstType.MethodDefinition => visitMethodDefinition(obj) + case AstType.ModuleDefinition => visitModuleDefinition(obj) + case AstType.MultipleAssignment => visitMultipleAssignment(obj) + case AstType.MultipleLeftHandSide => visitMultipleLeftHandSide(obj) + case AstType.Next => visitNext(obj) + case AstType.Nil => visitNil(obj) + case AstType.NthRef => visitNthRef(obj) + case AstType.OperatorAssign => visitOperatorAssign(obj) + case AstType.OptionalArgument => visitOptionalArgument(obj) + case AstType.Or => visitOr(obj) + case AstType.OrAssign => visitOrAssign(obj) + case AstType.Pair => visitPair(obj) + case AstType.PostExpression => visitPostExpression(obj) + case AstType.PreExpression => visitPreExpression(obj) + case AstType.ProcArgument => visitProcArgument(obj) + case AstType.Rational => visitRational(obj) + case AstType.Redo => visitRedo(obj) + case AstType.Retry => visitRetry(obj) + case AstType.Return => visitReturn(obj) + case AstType.RegexExpression => visitRegexExpression(obj) + case AstType.RegexOption => visitRegexOption(obj) + case AstType.ResBody => visitResBody(obj) + case AstType.RestArg => visitRestArg(obj) + case AstType.RescueStatement => visitRescueStatement(obj) + case AstType.ScopedConstant => visitScopedConstant(obj) + case AstType.Self => visitSelf(obj) + case AstType.Send => visitSend(obj) + case AstType.ShadowArg => visitShadowArg(obj) + case AstType.SingletonMethodDefinition => visitSingletonMethodDefinition(obj) + case AstType.SingletonClassDefinition => visitSingletonClassDefinition(obj) + case AstType.Splat => visitSplat(obj) + case AstType.StaticString => visitStaticString(obj) + case AstType.StaticSymbol => visitStaticSymbol(obj) + case AstType.Super => visitSuper(obj) + case AstType.SuperNoArgs => visitSuperNoArgs(obj) + case AstType.TopLevelConstant => visitTopLevelConstant(obj) + case AstType.True => visitTrue(obj) + case AstType.UnDefine => visitUnDefine(obj) + case AstType.UnlessExpression => visitUnlessExpression(obj) + case AstType.UnlessGuard => visitUnlessGuard(obj) + case AstType.UntilExpression => visitUntilExpression(obj) + case AstType.UntilPostExpression => visitUntilPostExpression(obj) + case AstType.WhenStatement => visitWhenStatement(obj) + case AstType.WhileStatement => visitWhileStatement(obj) + case AstType.WhilePostStatement => visitWhileStatement(obj) + case AstType.Yield => visitYield(obj) + + val astTypeStr = obj(ParserKeys.Type).str + AstType.fromString(astTypeStr) match + case Some(typ) => visitAstType(typ) + case _ => + logger.warn(s"Unhandled `parser` type '$astTypeStr'") + defaultResult() + end visit + + private def visitAccessModifier(obj: Obj): RubyExpression = + obj(ParserKeys.Name).str match + case "public" => PublicModifier()(obj.toTextSpan) + case "private" => PrivateModifier()(obj.toTextSpan) + case "protected" => ProtectedModifier()(obj.toTextSpan) + case modifierName => + logger.warn(s"Unknown modifier type $modifierName") + defaultResult(Option(obj.toTextSpan)) + + private def visitAlias(obj: Obj): RubyExpression = + val name = visit(obj(ParserKeys.Name)).text.stripPrefix(":") + val alias = visit(obj(ParserKeys.Alias)).text.stripPrefix(":") + AliasStatement(alias, name)(obj.toTextSpan) + + private def visitAnd(obj: Obj): RubyExpression = + val op = "&&" + val lhs = visit(obj(ParserKeys.Lhs)) + val rhs = visit(obj(ParserKeys.Rhs)) + BinaryExpression(lhs, op, rhs)(obj.toTextSpan) + + private def visitAndAssign(obj: Obj): RubyExpression = + val lhs = visit(obj(ParserKeys.Lhs)) match + case param: MandatoryParameter => param.toSimpleIdentifier + case x => x + val rhs = visit(obj(ParserKeys.Rhs)) + OperatorAssignment(lhs, "&&=", rhs)(obj.toTextSpan) + + private def visitArg(obj: Obj): RubyExpression = + MandatoryParameter(obj(ParserKeys.Value).str)(obj.toTextSpan) + + private def visitArgs(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitArray(obj: Obj): RubyExpression = + val children = obj.visitArray(ParserKeys.Children).flatMap { + case x: AssociationList => x.elements + case x => x :: Nil + } + + ArrayLiteral(children)(obj.toTextSpan) + + private def visitArrayPattern(obj: Obj): RubyExpression = + val children = obj.visitArray(ParserKeys.Children) + ArrayPattern(children)(obj.toTextSpan) + + private def visitArrayPatternWithTail(obj: Obj): RubyExpression = + defaultResult(Option(obj.toTextSpan)) + + private def visitBackRef(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitBegin(obj: Obj): RubyExpression = + StatementList(obj.visitArray(ParserKeys.Body))(obj.toTextSpan) + + private def visitGroupedParameter(arrayParam: ArrayLiteral): RubyExpression = + val freshTmpVar = variableNameGen.fresh + val tmpMandatoryParam = MandatoryParameter(freshTmpVar)(arrayParam.span.spanStart(freshTmpVar)) + + val singleAssignments = arrayParam.elements.map { param => + val rhsSplattingNode = + SplattingRubyNode(tmpMandatoryParam)(arrayParam.span.spanStart(s"*$freshTmpVar")) + val lhs = param match + case x: SimpleIdentifier => SimpleIdentifier()(x.span) + case x: ArrayParameter => + SplattingRubyNode( + SimpleIdentifier()(arrayParam.span.spanStart(x.span.text.stripPrefix("*"))) + )( + arrayParam.span.spanStart(x.span.text) + ) + case x: ArrayLiteral => + visitGroupedParameter(x) + case x => + logger.warn( + s"Invalid parameter type in grouped parameter list: ${x.getClass} (code: ${arrayParam.span.text})" + ) + defaultResult(Option(arrayParam.span)) + SingleAssignment(lhs, "=", rhsSplattingNode)( + arrayParam.span.spanStart(s"${lhs.span.text} = ${rhsSplattingNode.span.text}") + ) + } + + GroupedParameter( + tmpMandatoryParam.span.text, + tmpMandatoryParam, + GroupedParameterDesugaring(singleAssignments)(arrayParam.span) + )(arrayParam.span) + end visitGroupedParameter + + private def visitBlock(obj: Obj): RubyExpression = + val parameters = + obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj].visitArray(ParserKeys.Children).map { + case x: ArrayLiteral => visitGroupedParameter(x) + case x => x + } + + val assignments = parameters.collect { case x: GroupedParameter => + x.multipleAssignment + } + + val body = obj.visitOption(ParserKeys.Body) match + case Some(stmt: StatementList) => stmt.copy(stmt.statements ++ assignments)(stmt.span) + case Some(expr) => StatementList(expr +: assignments)(expr.span) + case None => StatementList(Nil)(obj.toTextSpan) + + val block = Block(parameters, body)(body.span.spanStart(obj.toTextSpan.text)) + visit(obj(ParserKeys.CallName)) match + case classNew: ObjectInstantiation if classNew.span.text == "Class.new" => + AnonymousClassDeclaration(freshClassName(obj.toTextSpan), None, block.toStatementList)( + obj.toTextSpan + ) + case objNew: ObjectInstantiation => objNew.withBlock(block) + case lambda: SimpleIdentifier if lambda.text == "lambda" => + ProcOrLambdaExpr(block)(obj.toTextSpan) + case ident: SimpleIdentifier if ident.span.text == "loop" => + val trueLiteral = + StaticLiteral(Defines.getBuiltInType(Defines.TrueClass))(ident.span.spanStart("true")) + DoWhileExpression(trueLiteral, body)(ident.span) + case simpleIdentifier: SimpleIdentifier => + SimpleCall(simpleIdentifier, Nil)(obj.toTextSpan).withBlock(block) + case simpleCall: RubyCall => simpleCall.withBlock(block) + case memberAccess @ MemberAccess(target, op, memberName) => + val memberCall = MemberCall(target, op, memberName, List.empty)(memberAccess.span) + memberCall.withBlock(block) + case x: ProtectedModifier => + SimpleCall(x.toSimpleIdentifier, Nil)(obj.toTextSpan).withBlock(block) + case x => + logger.warn(s"Unexpected call type used for block ${x.getClass}, ignoring block") + x + end match + end visitBlock + + private def visitBlockArg(obj: Obj): RubyExpression = + val span = obj.toTextSpan + val name = obj(ParserKeys.Value).strOpt.filterNot(_ == "&").getOrElse(procParamGen.fresh.value) + ProcParameter(name)(span) + + private def visitBlockPass(obj: Obj): RubyExpression = + lazy val default = SimpleIdentifier()(obj.toTextSpan.spanStart(procParamGen.current.value)) + obj.visitOption(ParserKeys.Value).getOrElse(default) + + private def visitBlockWithNumberedParams(obj: Obj): RubyExpression = + SimpleIdentifier()(obj.toTextSpan) + + private def visitBracketAssignmentAsSend(obj: Obj): RubyExpression = + val lhsBase = visit(obj(ParserKeys.Receiver)) + val args = obj.visitArray(ParserKeys.Arguments) + + val lhs = + IndexAccess(lhsBase, List(args.head))( + obj.toTextSpan.spanStart(s"${lhsBase.span.text}[${args.head.span.text}]") + ) + + val rhs = + if args.size == 2 then args(1) + else SimpleIdentifier()(obj.toTextSpan.spanStart("*")) + + SingleAssignment(lhs, "=", rhs)(obj.toTextSpan) + + private def visitBreak(obj: Obj): RubyExpression = BreakExpression()(obj.toTextSpan) + + private def visitCaseExpression(obj: Obj): RubyExpression = + val expression = obj.visitOption(ParserKeys.CaseExpression) + val whenClauses = obj.visitArray(ParserKeys.WhenClauses) + + val elseClause = obj.visitOption(ParserKeys.ElseClause) match + case Some(elseClause) => Some(ElseClause(elseClause)(elseClause.span)) + case None => None + + CaseExpression(expression, whenClauses, elseClause)(obj.toTextSpan) + + private def visitCaseMatchStatement(obj: Obj): RubyExpression = + val expression = visit(obj(ParserKeys.Statement)) + val inClauses = obj.visitArray(ParserKeys.Bodies) + val elseClause = obj.visitOption(ParserKeys.ElseClause).map(x => ElseClause(x)(x.span)) + + CaseExpression(Some(expression), inClauses, elseClause)(obj.toTextSpan) + + private def visitClassDefinition(obj: Obj): RubyExpression = + val (name, namespaceParts) = visit(obj(ParserKeys.Name)) match + case memberAccess: MemberAccess => + val memberIdentifier = + SimpleIdentifier()(memberAccess.span.spanStart(memberAccess.memberName)) + (memberIdentifier, Option(getParts(memberAccess).dropRight(1))) + case identifier => (identifier, None) + val baseClass = obj.visitOption(ParserKeys.SuperClass) + val (body, fields) = createClassBodyAndFields(obj) + val bodyMemberCall = createBodyMemberCall(name.text, obj.toTextSpan) + ClassDeclaration( + name = name, + baseClass = baseClass, + body = body, + fields = fields, + bodyMemberCall = Option(bodyMemberCall), + namespaceParts = namespaceParts + )(obj.toTextSpan) + + private def visitClassVariable(obj: Obj): RubyExpression = ClassFieldIdentifier()(obj.toTextSpan) + + private def visitCollectionAliasSend(obj: Obj): RubyExpression = + // Modify this `obj` to conform to what the AstCreator would expect i.e, Array [1,2,3] would be an Array::[] call + val collectionName = obj(ParserKeys.Name).str + val metaData = obj(ParserKeys.MetaData) + metaData.obj.put(ParserKeys.Code, collectionName) + val receiver = ujson.Obj( + ParserKeys.Type -> ujson.Str(AstType.ScopedConstant.name), + ParserKeys.MetaData -> metaData, + ParserKeys.Base -> ujson.Null, + ParserKeys.Name -> ujson.Str(collectionName) + ) + val arguments = obj(ParserKeys.Arguments).arr.headOption + .flatMap { + case x: ujson.Obj => AstType.fromString(x(ParserKeys.Type).str).map(t => t -> x) + case _ => None + } + .map { + case (AstType.Array, o) => + o.visitArray(ParserKeys.Children).flatMap { + case x: AssociationList => x.elements + case x => x :: Nil + } + case (_, o) => + visit(o) :: Nil + } + .getOrElse(Nil) + + val textSpan = + obj.toTextSpan.spanStart(s"$collectionName [${arguments.map(_.span.text).mkString(", ")}]") + + IndexAccess(visit(receiver), arguments)(textSpan) + end visitCollectionAliasSend + + private def visitDefined(obj: Obj): RubyExpression = + val name = SimpleIdentifier(Option(getBuiltInType(Defines.Defined)))( + obj.toTextSpan.spanStart(Defines.Defined) + ) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleCall(name, arguments)(obj.toTextSpan) + + private def visitDynamicString(obj: Obj): RubyExpression = + val typeFullName = getBuiltInType(Defines.String) + val expressions = obj.visitArray(ParserKeys.Children) + DynamicLiteral(typeFullName, expressions)(obj.toTextSpan) + + private def visitDynamicSymbol(obj: Obj): RubyExpression = + val typeFullName = getBuiltInType(Defines.Symbol) + val expressions = obj.visitArray(ParserKeys.Children) + DynamicLiteral(typeFullName, expressions)(obj.toTextSpan) + + private def visitEnsure(obj: Obj): RubyExpression = + val ensureClause = EnsureClause(visit(obj(ParserKeys.Body)))(obj.toTextSpan) + visit(obj(ParserKeys.Statement)) match + case rescueExpression: RescueExpression => + rescueExpression.copy( + rescueExpression.body, + rescueExpression.rescueClauses, + rescueExpression.elseClause, + Some(ensureClause) + )(obj.toTextSpan) + case x => + RescueExpression(x, List.empty, Option.empty, Some(ensureClause))(obj.toTextSpan) + + private def visitExclusiveFlipFlop(obj: Obj): RubyExpression = + defaultResult(Option(obj.toTextSpan)) + + private def visitExclusiveRange(obj: Obj): RubyExpression = + val start = visit(obj(ParserKeys.Start)) + val end = visit(obj(ParserKeys.End)) + val op = RangeOperator(true)(obj.toTextSpan.spanStart("...")) + RangeExpression(start, end, op)(obj.toTextSpan) + + private def visitExecutableString(obj: Obj): RubyExpression = + val operatorName = RubyOperators.backticks + val callName = SimpleIdentifier(Option(getBuiltInType(operatorName)))( + obj.toTextSpan.spanStart(operatorName) + ) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleCall(callName, arguments)(obj.toTextSpan) + + private def visitFalse(obj: Obj): RubyExpression = + StaticLiteral(getBuiltInType(Defines.FalseClass))(obj.toTextSpan) + + private def visitFieldDeclaration(obj: Obj): RubyExpression = + val arguments = obj.visitArray(ParserKeys.Arguments) + val accessType = obj(ParserKeys.Name).str + FieldsDeclaration(arguments, accessType)(obj.toTextSpan) + + private def visitFindPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitFieldAssignmentSend(obj: Obj, fieldName: String): RubyExpression = + val span = obj.toTextSpan + val receiver = visit(obj(ParserKeys.Receiver)) + val memberAccess = MemberAccess(receiver, ".", fieldName)( + receiver.span.spanStart(s"${receiver.text}.@$fieldName") + ) + val argument = obj + .visitArray(ParserKeys.Arguments) + .headOption + .getOrElse(StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil"))) + SingleAssignment(memberAccess, "=", argument)(span) + + private def visitFloat(obj: Obj): RubyExpression = + StaticLiteral(getBuiltInType(Defines.Float))(obj.toTextSpan) + + private def visitForStatement(obj: Obj): RubyExpression = + val forVariable = visit(obj(ParserKeys.Variable)) + val iterableVariable = visit(obj(ParserKeys.Collection)) + val doBlock = visit(obj(ParserKeys.Body)) match + case stmtList: StatementList => stmtList + case other => StatementList(List(other))(other.span) + + ForExpression(forVariable, iterableVariable, doBlock)(obj.toTextSpan) + + private def visitForwardArg(obj: Obj): RubyExpression = + logger.warn("Forward arg unhandled") + defaultResult(Option(obj.toTextSpan)) + + // Note: Forward args should probably be handled more explicitly, but this should preserve flows if the same + // identifier is used in latter forwarding + private def visitForwardArgs(obj: Obj): RubyExpression = MandatoryParameter("...")(obj.toTextSpan) + + private def visitForwardedArgs(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitGlobalVariable(obj: Obj): RubyExpression = + val span = obj.toTextSpan + val name = obj(ParserKeys.Value).str + val selfBase = SelfIdentifier()(span.spanStart("self")) + MemberAccess(selfBase, ".", name)(span) + + private def visitGlobalVariableAssign(obj: Obj): RubyExpression = + val span = obj.toTextSpan + + val selfBase = SelfIdentifier()(span.spanStart("self")) + val lhsName = obj(ParserKeys.Lhs).str + val lhs = + MemberAccess(selfBase, ".", lhsName)(span.spanStart(s"${selfBase.span.text}.$lhsName")) + + val rhs = visit(obj(ParserKeys.Rhs)) + val op = "=" + + SingleAssignment(lhs, op, rhs)(obj.toTextSpan) + + private def visitHash(obj: Obj): RubyExpression = + val isHashLiteral = obj.toTextSpan.text.stripMargin.startsWith("{") + + obj.visitArray(ParserKeys.Children) match + case (assoc: Association) :: Nil => + if isHashLiteral then HashLiteral(List(assoc))(obj.toTextSpan) + else assoc // 2 => 1 is interpreted as {2: 1}, so we lower this for now + case children => + if isHashLiteral then HashLiteral(children)(obj.toTextSpan) + else AssociationList(children)(obj.toTextSpan) + + private def visitHashPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitIdentifier(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitIfGuard(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitIfStatement(obj: Obj): RubyExpression = + val condition = visit(obj(ParserKeys.Condition)) + + val elseClause = obj.visitOption(ParserKeys.ElseBranch).map { + case x: IfExpression => x + case x => ElseClause(StatementList(List(x))(x.span))(x.span) + } + + obj.visitOption(ParserKeys.ThenBranch) match + case Some(thenBranch) => + IfExpression(condition, thenBranch, elsifClauses = List.empty, elseClause)(obj.toTextSpan) + case None => + val nilBlock = ReturnExpression( + List(StaticLiteral(Defines.getBuiltInType(Defines.NilClass))( + obj.toTextSpan.spanStart("nil") + )) + )(obj.toTextSpan.spanStart("return nil")) + IfExpression(condition, nilBlock, elsifClauses = List.empty, elseClause)(obj.toTextSpan) + + private def visitInclude(obj: Obj): RubyExpression = + val callName = obj(ParserKeys.Name).str + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + val argument = obj.visitArray(ParserKeys.Arguments).head + + IncludeCall(target, argument)(obj.toTextSpan) + + private def visitInclusiveFlipFlop(obj: Obj): RubyExpression = + defaultResult(Option(obj.toTextSpan)) + + private def visitInclusiveRange(obj: Obj): RubyExpression = + val start = obj.visitOption(ParserKeys.Start) match + case Some(expr) => expr + case None => infinityLowerBound(obj) + val end = obj.visitOption(ParserKeys.End) match + case Some(expr) => expr + case None => infinityUpperBound(obj) + val op = RangeOperator(false)(obj.toTextSpan.spanStart("..")) + RangeExpression(start, end, op)(obj.toTextSpan) + + private def visitIndexAccessAsSend(obj: Obj): RubyExpression = + val target = visit(obj(ParserKeys.Receiver)) + val indices = obj.visitArray(ParserKeys.Arguments) + IndexAccess(target, indices)(obj.toTextSpan) + + private def visitInPattern(obj: Obj): RubyExpression = + val patternType = visit(obj(ParserKeys.Pattern)) + val patternBody = visit(obj(ParserKeys.Body)) + + InClause(patternType, patternBody)(obj.toTextSpan) + + private def visitInt(obj: Obj): RubyExpression = + val typeFullName = getBuiltInType(Defines.Integer) + StaticLiteral(typeFullName)(obj.toTextSpan) + + private def visitInstanceVariable(obj: Obj): RubyExpression = + InstanceFieldIdentifier()(obj.toTextSpan) + + private def visitKwArg(obj: Obj): RubyExpression = + val name = obj(ParserKeys.Key).str + val default = obj + .visitOption(ParserKeys.Value) + .getOrElse(StaticLiteral(getBuiltInType(Defines.NilClass))(obj.toTextSpan.spanStart("nil"))) + OptionalParameter(name, default)(obj.toTextSpan) + + private def visitKwBegin(obj: Obj): RubyExpression = + val stmts = obj(ParserKeys.Body) match + case o: Obj => visit(o) :: Nil + case _: Arr => obj.visitArray(ParserKeys.Body) + case _ => + val span = obj.toTextSpan + logger.warn(s"Unhandled JSON body type for `KwBegin`: ${span.text}") + defaultResult(Option(span)) :: Nil + StatementList(stmts)(obj.toTextSpan) + + private def visitKwNilArg(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitKwOptArg(obj: Obj): RubyExpression = visitKwArg(obj) + + private def visitKwRestArg(obj: Obj): RubyExpression = + val name = + if obj.contains(ParserKeys.Value) then obj(ParserKeys.Value).str else obj.toTextSpan.text + HashParameter(name)(obj.toTextSpan) + + private def visitKwSplat(obj: Obj): RubyExpression = + val values = visit(obj(ParserKeys.Value)) match + case x: StatementList => x.statements.head + case x => x + SplattingRubyNode(values)(obj.toTextSpan) + + private def visitLocalVariable(obj: Obj): RubyExpression = SimpleIdentifier()(obj.toTextSpan) + + private def visitMatchAlt(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchAs(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchNilPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchPatternP(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchRest(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitMatchVariable(obj: Obj): RubyExpression = MatchVariable()(obj.toTextSpan) + + private def visitMatchWithLocalVariableAssign(obj: Obj): RubyExpression = + val lhs = visit(obj(ParserKeys.Lhs)) + val rhs = visit(obj(ParserKeys.Rhs)) + MemberCall(lhs, ".", RubyOperators.regexpMatch, rhs :: Nil)(obj.toTextSpan) + + private def visitMethodAccessModifier(obj: Obj): RubyExpression = + val body = obj.visitArray(ParserKeys.Arguments) match + case head :: Nil => head + case xs => xs.head + + obj(ParserKeys.Name).str match + case "public_class_method" => + PublicMethodModifier(body)(obj.toTextSpan) + case "private_class_method" => + PrivateMethodModifier(body)(obj.toTextSpan) + case modifierName => + logger.warn(s"Unknown modifier type $modifierName") + defaultResult(Option(obj.toTextSpan)) + + private def visitMethodDefinition(obj: Obj): RubyExpression = + val name = obj(ParserKeys.Name).str + val parameters = visitMethodParameters(obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj]) + val body = obj + .visitOption(ParserKeys.Body) + .map { + case x: StatementList => x + case x => StatementList(List(x))(x.span) + } + .getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart(""))) + MethodDeclaration(name, parameters, body)(obj.toTextSpan) + + private def visitModuleDefinition(obj: Obj): RubyExpression = + val (name, namespaceParts) = visit(obj(ParserKeys.Name)) match + case memberAccess: MemberAccess => + val memberIdentifier = + SimpleIdentifier()(memberAccess.span.spanStart(memberAccess.memberName)) + (memberIdentifier, Option(getParts(memberAccess).dropRight(1))) + case identifier => (identifier, None) + val (body, fields) = createClassBodyAndFields(obj) + val bodyMemberCall = createBodyMemberCall(name.text, obj.toTextSpan) + ModuleDeclaration( + name = name, + body = body, + fields = fields, + bodyMemberCall = Option(bodyMemberCall), + namespaceParts = namespaceParts + )(obj.toTextSpan) + + private def visitMultipleAssignment(obj: Obj): RubyExpression = + val lhs = visit(obj(ParserKeys.Lhs)) match + case _ @ArrayLiteral(elements) => elements + case expr => expr :: Nil + val rhs = visit(obj(ParserKeys.Rhs)) match + case _ @ArrayLiteral(elements) => elements + case expr => expr :: Nil + lowerMultipleAssignment( + obj, + lhs, + rhs, + () => defaultResult(), + () => StaticLiteral(getBuiltInType(Defines.NilClass))(obj.toTextSpan) + ) + + private def visitMultipleLeftHandSide(obj: Obj): RubyExpression = + val arr = visitArray(obj).asInstanceOf[ArrayLiteral] + arr.copy(elements = arr.elements.map { + case param: MandatoryParameter => param.toSimpleIdentifier + case expr => expr + })(arr.span) + + private def visitNext(obj: Obj): RubyExpression = NextExpression()(obj.toTextSpan) + + private def visitNil(obj: Obj): RubyExpression = + StaticLiteral(getBuiltInType(Defines.NilClass))(obj.toTextSpan) + + private def visitNthRef(obj: Obj): RubyExpression = + val span = obj.toTextSpan + val name = obj(ParserKeys.Value).num.toInt + val selfBase = SelfIdentifier()(span.spanStart("self")) + MemberAccess(selfBase, ".", s"$$$name")(span) + + private def visitObjectInstantiation(obj: Obj): RubyExpression = + // The receiver is the target with the JSON parser + val receiver = visit(obj(ParserKeys.Receiver)) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleObjectInstantiation(receiver, arguments)(obj.toTextSpan) + + private def visitOperatorAssign(obj: Obj): RubyExpression = + val lhs = visit(obj(ParserKeys.Lhs)) match + case param: MandatoryParameter => param.toSimpleIdentifier + case x => x + val op = s"${obj(ParserKeys.Op).str}=" + val rhs = visit(obj(ParserKeys.Rhs)) + SingleAssignment(lhs, op, rhs)(obj.toTextSpan) + + private def visitOptionalArgument(obj: Obj): RubyExpression = + val name = obj(ParserKeys.Key).str + val default = visit(obj(ParserKeys.Value)) + OptionalParameter(name, default)(obj.toTextSpan) + + private def visitOr(obj: Obj): RubyExpression = + val op = "||" + val lhs = visit(obj(ParserKeys.Lhs)) + val rhs = visit(obj(ParserKeys.Rhs)) + BinaryExpression(lhs, op, rhs)(obj.toTextSpan) + + private def visitOrAssign(obj: Obj): RubyExpression = + val lhs = visit(obj(ParserKeys.Lhs)) match + case param: MandatoryParameter => param.toSimpleIdentifier + case x => x + val rhs = visit(obj(ParserKeys.Rhs)) + OperatorAssignment(lhs, "||=", rhs)(obj.toTextSpan) + + private def visitPair(obj: Obj): RubyExpression = + val key = visit(obj(ParserKeys.Key)) + val value = visit(obj(ParserKeys.Value)) + Association(key, value)(obj.toTextSpan) + + private def visitMethodParameters(paramsNode: Obj): List[RubyExpression] = + AstType.fromString(paramsNode(ParserKeys.Type).str) match + case Some(AstType.Args) => paramsNode.visitArray(ParserKeys.Children) + case Some(AstType.ForwardArgs) => visit(paramsNode) :: Nil + case Some(x) => + logger.warn(s"Not explicitly handled parameter type '$x', no special handling applied") + visit(paramsNode) :: Nil + case _ => + logger.error( + s"Unknown JSON type used as method parameter ${paramsNode(ParserKeys.Type).str}" + ) + defaultResult(Option(paramsNode.toTextSpan)) :: Nil + + private def visitPostExpression(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitPreExpression(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitProcArgument(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitRaise(obj: Obj): RubyExpression = + val callName = obj(ParserKeys.Name).str + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + + obj.visitArray(ParserKeys.Arguments) match + case Nil => RaiseCall(target, List.empty)(obj.toTextSpan) + case (argument: StaticLiteral) :: Nil => + val simpleErrorId = + SimpleIdentifier(Option(s"$builtinPrefix.StandardError"))( + argument.span.spanStart("StandardError") + ) + val implicitSimpleErrInst = SimpleObjectInstantiation(simpleErrorId, argument :: Nil)( + argument.span.spanStart(s"StandardError.new(${argument.text})") + ) + RaiseCall(target, implicitSimpleErrInst :: Nil)(obj.toTextSpan) + case argument :: Nil => + RaiseCall(target, List(argument))(obj.toTextSpan) + case arguments => + RaiseCall(target, arguments)(obj.toTextSpan) + + private def visitRational(obj: Obj): RubyExpression = + StaticLiteral(getBuiltInType(Defines.Rational))(obj.toTextSpan) + + private def visitRedo(obj: Obj): RubyExpression = + val callTarget = SimpleIdentifier()(obj.toTextSpan.spanStart("redo")) + SimpleCall(callTarget, Nil)(obj.toTextSpan) + + private def visitRetry(obj: Obj): RubyExpression = + val callTarget = SimpleIdentifier()(obj.toTextSpan.spanStart("retry")) + SimpleCall(callTarget, Nil)(obj.toTextSpan) + + private def visitReturn(obj: Obj): RubyExpression = + if obj.contains(ParserKeys.Values) then + val returnExpressions = obj.visitArray(ParserKeys.Values) + ReturnExpression(returnExpressions)(obj.toTextSpan) + else if obj.contains(ParserKeys.Value) then + ReturnExpression(visit(obj(ParserKeys.Value)) :: Nil)(obj.toTextSpan) + else + ReturnExpression(List.empty)(obj.toTextSpan) + + private def visitRegexExpression(obj: Obj): RubyExpression = + obj.visitOption(ParserKeys.Value) match + case Some(_ @StatementList(stmts)) => + DynamicLiteral(Defines.getBuiltInType(Defines.Regexp), stmts)(obj.toTextSpan) + case _ => StaticLiteral(Defines.getBuiltInType(Defines.Regexp))(obj.toTextSpan) + + private def visitRegexOption(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitResBody(obj: Obj): RubyExpression = + val exceptionClassList = obj.visitOption(ParserKeys.ExecList) + val variables = obj.visitOption(ParserKeys.ExecVar) + val body = obj.visitOption(ParserKeys.Body) match + case Some(stmt: StatementList) => stmt + case Some(expr) => StatementList(expr :: Nil)(expr.span) + case None => StatementList(Nil)(obj.toTextSpan) + RescueClause(exceptionClassList, variables, body)(obj.toTextSpan) + + private def visitRestArg(obj: Obj): RubyExpression = + obj(ParserKeys.Value) match + case ujson.Null => ArrayParameter("*")(obj.toTextSpan) + case ujson.Str(name) => ArrayParameter(name)(obj.toTextSpan) + case x => + logger.warn(s"Unhandled `restarg` JSON type '$x'") + defaultResult(Option(obj.toTextSpan)) + + private def visitRescueStatement(obj: Obj): RubyExpression = + val stmt = visit(obj(ParserKeys.Statement)) + val rescueClauses = obj.visitArray(ParserKeys.Bodies).asInstanceOf[List[RescueClause]] + val elseClause = obj.visitOption(ParserKeys.ElseClause) match + case Some(body) => Option(ElseClause(body)(body.span)) + case None => Option.empty + + RescueExpression(stmt, rescueClauses, elseClause, Option.empty)(obj.toTextSpan) + + private def visitRequireLike(obj: Obj): RubyExpression = + val callName = obj(ParserKeys.Name).str + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + val argument = obj + .visitArray(ParserKeys.Arguments) + .headOption + .getOrElse(StaticLiteral(getBuiltInType(Defines.NilClass))(obj.toTextSpan.spanStart("nil"))) + val isRelative = callName == "require_relative" || callName == "require_all" + val isWildcard = callName == "require_all" + RequireCall(target, argument, isRelative, isWildcard)(obj.toTextSpan) + + private def visitScopedConstant(obj: Obj): RubyExpression = + val identifier = obj(ParserKeys.Name).str + if obj.contains(ParserKeys.Base) then + val target = visit(obj(ParserKeys.Base)) + val op = if obj.toTextSpan.text.contains("::") then "::" else "." + MemberAccess(target, op, identifier)(obj.toTextSpan) + else + SimpleIdentifier()(obj.toTextSpan) + + private def visitSelf(obj: Obj): RubyExpression = SelfIdentifier()(obj.toTextSpan) + + private def visitSend(obj: Obj, isConditional: Boolean = false): RubyExpression = + val callName = obj(ParserKeys.Name).str + val hasReceiver = obj.contains(ParserKeys.Receiver) + callName match + case "new" => visitObjectInstantiation(obj) + case "Array" | "Hash" => visitCollectionAliasSend(obj) + case "[]" => visitIndexAccessAsSend(obj) + case "[]=" => visitBracketAssignmentAsSend(obj) + case "raise" => visitRaise(obj) + case "include" => visitInclude(obj) + case "attr_reader" | "attr_writer" | "attr_accessor" => visitFieldDeclaration(obj) + case "private" | "public" | "protected" => visitAccessModifier(obj) + case "private_class_method" | "public_class_method" => visitMethodAccessModifier(obj) + case requireLike if ImportCallNames.contains(requireLike) && !hasReceiver => + visitRequireLike(obj) + case _ if BinaryOperators.isBinaryOperatorName(callName) => + val lhs = visit(obj(ParserKeys.Receiver)) + val rhs = obj.visitArray(ParserKeys.Arguments).head + BinaryExpression(lhs, callName, rhs)(obj.toTextSpan) + case _ if UnaryOperators.isUnaryOperatorName(callName) => + UnaryExpression(callName, visit(obj(ParserKeys.Receiver)))(obj.toTextSpan) + case s"$name=" if hasReceiver => visitFieldAssignmentSend(obj, name) + case _ => + val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + val argumentArr = obj.visitArray(ParserKeys.Arguments) + val arguments = argumentArr.flatMap { + case hashLiteral: HashLiteral => + hashLiteral.elements // a hash is likely named arguments + case assocList: AssociationList => assocList.elements // same as above + case x => x :: Nil + } + val objSpan = obj.toTextSpan + val hasArguments = arguments.nonEmpty + val usesParenthesis = objSpan.text.endsWith(")") + if obj.contains(ParserKeys.Receiver) then + val base = visit(obj(ParserKeys.Receiver)) + val isMemberCall = usesParenthesis || callName == "<<" || hasArguments + val op = + val dot = if objSpan.text.stripPrefix(base.text).startsWith("::") then "::" else "." + if isConditional then s"&$dot" else dot + if isMemberCall then MemberCall(base, op, callName, arguments)(obj.toTextSpan) + else MemberAccess(base, op, callName)(obj.toTextSpan) + else if hasArguments || usesParenthesis then + SimpleCall(target, arguments)(obj.toTextSpan) + else + // The following allows the AstCreator to approximate when an identifier could be a call or not - puts less + // strain on data-flow tracking for externally inherited accessor calls such as `params` in RubyOnRails + SimpleIdentifier()(obj.toTextSpan.spanStart(callName)) + end match + end visitSend + + private def visitShadowArg(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitSingletonMethodDefinition(obj: Obj): RubyExpression = + val base = visit(obj(ParserKeys.Base)) + val name = obj(ParserKeys.Name).str + val parameters = visitMethodParameters(obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj]) + val body = + obj.visitOption(ParserKeys.Body).getOrElse( + StatementList(Nil)(obj.toTextSpan.spanStart("")) + ) match + case stmtList: StatementList => stmtList + case expr => StatementList(expr :: Nil)(expr.span) + SingletonMethodDeclaration(base, name, parameters, body)(obj.toTextSpan) + + private def visitSingletonClassDefinition(obj: Obj): RubyExpression = + val name = visit(obj(ParserKeys.Name)) + val baseClass = obj.visitOption(ParserKeys.SuperClass) + val body = obj.visitOption(ParserKeys.Body).getOrElse( + StatementList(Nil)(obj.toTextSpan.spanStart("")) + ) + + obj.visitOption(ParserKeys.Def) match + case Some(body) => + name match + case _: SelfIdentifier => + val bodyList = body match + case stmtList: StatementList => stmtList + case expr => StatementList(expr :: Nil)(expr.span) + + val base = baseClass match + case Some(baseClass) => baseClass + case None => SelfIdentifier()(obj.toTextSpan.spanStart("self")) + + SingletonClassDeclaration(freshClassName(obj.toTextSpan), Some(base), bodyList)( + obj.toTextSpan + ) + case _ => + def mapDefBody(defBody: RubyExpression): RubyExpression = defBody match + case method @ MethodDeclaration(methodName, parameters, body) => + val memberAccess = + MemberAccess(name, ".", methodName)( + method.span.spanStart(s"${name.span.text}.${methodName}") + ) + val singletonBlockMethod = + SingletonObjectMethodDeclaration(methodName, parameters, body, name)( + method.span + ) + SingleAssignment(memberAccess, "=", singletonBlockMethod)( + method.span.spanStart(s"${memberAccess.span.text} = ${method.span.text}") + ) + case expr => expr + + val stmts = body match + case _ @StatementList(stmts) => stmts.map(mapDefBody) + case expr => mapDefBody(expr) :: Nil + SingletonStatementList(stmts)(obj.toTextSpan) + + case None => + val anonName = freshClassName(obj.toTextSpan) + SingletonClassDeclaration(name = anonName, baseClass = baseClass, body = body)( + obj.toTextSpan + ) + end match + end visitSingletonClassDefinition + + private def visitSingleAssignment(obj: Obj): RubyExpression = + val lhsSpan = obj.toTextSpan.spanStart(obj(ParserKeys.Lhs).str) + val lhs = obj(ParserKeys.Lhs).str match + case s"@@$_" => ClassFieldIdentifier()(lhsSpan) + case s"@$_" => InstanceFieldIdentifier()(lhsSpan) + case _ => SimpleIdentifier()(lhsSpan) + obj.visitOption(ParserKeys.Rhs) match + case Some(rhs) => + SingleAssignment(lhs, "=", rhs)(obj.toTextSpan) + case None => + if AstType.fromString(obj(ParserKeys.Type).str) == AstType.LocalVariableAssign then + // `lvasgn` is used in exec_var for rescueExpr, which only has LHS + MandatoryParameter(lhs.span.text)(lhs.span) + else + lhs + + private def visitSplat(obj: Obj): RubyExpression = + obj.visitOption(ParserKeys.Value) match + case Some(x) => SplattingRubyNode(x)(obj.toTextSpan) + case None => + val emptyStar = SimpleIdentifier()(obj.toTextSpan.spanStart("_")) + SplattingRubyNode(emptyStar)(obj.toTextSpan) + + private def visitStaticString(obj: Obj): RubyExpression = + val typeFullName = getBuiltInType(Defines.String) + val originalSpan = obj.toTextSpan + val value = obj(ParserKeys.Value).str + // In general, we want the quotations, unless it is a HEREDOC string, then we'd prefer the value + val span = + if !originalSpan.text.contains(value) then originalSpan.spanStart(value) else originalSpan + StaticLiteral(typeFullName)(span) + + private def visitStaticSymbol(obj: Obj): RubyExpression = + val typeFullName = getBuiltInType(Defines.Symbol) + val objTextSpan = obj.toTextSpan + + if objTextSpan.text.startsWith(":") then StaticLiteral(typeFullName)(obj.toTextSpan) + else StaticLiteral(typeFullName)(objTextSpan.spanStart(s":${objTextSpan.text}")) + + private def visitSuper(obj: Obj): RubyExpression = + val name = SimpleIdentifier(Option(getBuiltInType(Defines.Super)))( + obj.toTextSpan.spanStart(Defines.Super) + ) + val arguments = obj.visitArray(ParserKeys.Arguments) + SimpleCall(name, arguments)(obj.toTextSpan) + + private def visitSuperNoArgs(obj: Obj): RubyExpression = + val name = SimpleIdentifier(Option(getBuiltInType(Defines.Super)))( + obj.toTextSpan.spanStart(Defines.Super) + ) + SimpleCall(name, Nil)(obj.toTextSpan) + + private def visitTopLevelConstant(obj: Obj): RubyExpression = + if obj.contains(ParserKeys.Name) then + val identifier = obj(ParserKeys.Name).str + SimpleIdentifier()(obj.toTextSpan.spanStart(identifier)) + else + SelfIdentifier()(obj.toTextSpan.spanStart("self")) + + private def visitTrue(obj: Obj): RubyExpression = + StaticLiteral(getBuiltInType(Defines.TrueClass))(obj.toTextSpan) + + private def visitUnDefine(obj: Obj): RubyExpression = + defaultResult(Option(obj.toTextSpan)) + + private def visitUnlessExpression(obj: Obj): RubyExpression = + defaultResult(Option(obj.toTextSpan)) + + private def visitUnlessGuard(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan)) + + private def visitUntilExpression(obj: Obj): RubyExpression = + val condition = visit(obj(ParserKeys.Condition)) + val body = visit(obj(ParserKeys.Body)) + + UntilExpression(condition, body)(obj.toTextSpan) + + private def visitUntilPostExpression(obj: Obj): RubyExpression = + val condition = visit(obj(ParserKeys.Condition)) + val body = visit(obj(ParserKeys.Body)) + + DoWhileExpression(condition, body)(obj.toTextSpan) + + private def visitWhenStatement(obj: Obj): RubyExpression = + val (matchCondition, matchSplatCondition) = obj.visitArray(ParserKeys.Conditions).partition { + case x: SplattingRubyNode => false + case x => true + } + + val thenClause = visit(obj(ParserKeys.ThenBranch)) + + WhenClause(matchCondition, matchSplatCondition.headOption, thenClause)(obj.toTextSpan) + + private def visitWhileStatement(obj: Obj): RubyExpression = + val condition = visit(obj(ParserKeys.Condition)) match + case x: StatementList => x.statements.head + case x => x + + val body = visit(obj(ParserKeys.Body)) + + WhileExpression(condition, body)(obj.toTextSpan) + + private def visitYield(obj: Obj): RubyExpression = + val arguments = obj.visitArray(ParserKeys.Arguments) + YieldExpr(arguments)(obj.toTextSpan) +end RubyJsonToNodeCreator diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/AstCreationPass.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/AstCreationPass.scala new file mode 100644 index 00000000..24a11a7c --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/AstCreationPass.scala @@ -0,0 +1,41 @@ +package io.appthreat.ruby2atom.passes + +import io.appthreat.ruby2atom.astcreation.AstCreator +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.NodeTypes +import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl +import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal +import overflowdb.BatchedUpdate + +class AstCreationPass(cpg: Cpg, astCreators: List[AstCreator]) + extends ForkJoinParallelCpgPass[AstCreator](cpg): + + override def generateParts(): Array[AstCreator] = astCreators.toArray + + override def init(): Unit = + // The first entry will be the type, which is often found on fieldAccess nodes + // (which may be receivers to calls) + val diffGraph = new DiffGraphBuilder + val emptyType = + NewTypeDecl() + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(NamespaceTraversal.globalNamespaceName) + .isExternal(true) + val anyType = + NewTypeDecl() + .name(Defines.Any) + .fullName(Defines.Any) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(NamespaceTraversal.globalNamespaceName) + .isExternal(true) + diffGraph.addNode(emptyType).addNode(anyType) + BatchedUpdate.applyDiff(cpg.graph, diffGraph) + + override def runOnPart(diffGraph: DiffGraphBuilder, astCreator: AstCreator): Unit = + try + val ast = astCreator.createAst() + diffGraph.absorb(ast) + catch + case ex: Exception => +end AstCreationPass diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/ConfigFileCreationPass.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/ConfigFileCreationPass.scala new file mode 100644 index 00000000..333238a5 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/ConfigFileCreationPass.scala @@ -0,0 +1,31 @@ +package io.appthreat.ruby2atom.passes + +import better.files.File +import io.appthreat.x2cpg.passes.frontend.XConfigFileCreationPass +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.semanticcpg.language.* + +import scala.util.Try + +/** Creates the CONFIGURATION layer from any existing `Gemfile` or `Gemfile.lock` files found at + * root level. + */ +class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): + + private val validGemfilePaths = + Try(File(cpg.metaData.root.headOption.getOrElse(""))).toOption match + case Some(rootPath) => Seq("Gemfile", "Gemfile.lock").map(rootPath / _) + case None => Seq() + + override protected val configFileFilters: List[File => Boolean] = List( + // Gemfiles + validGemfilePaths.contains, + extensionFilter(".ini"), + // YAML files + extensionFilter(".yaml"), + extensionFilter(".yml"), + // XML files + extensionFilter(".xml"), + // ERB files + extensionFilter(".erb") + ) diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/Defines.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/Defines.scala new file mode 100644 index 00000000..aca5a157 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/passes/Defines.scala @@ -0,0 +1,279 @@ +package io.appthreat.ruby2atom.passes + +object Defines: + + val Any: String = "ANY" + val Defined: String = "defined" + val Undefined: String = "Undefined" + val Object: String = "Object" + val NilClass: String = "NilClass" + val TrueClass: String = "TrueClass" + val FalseClass: String = "FalseClass" + val Numeric: String = "Numeric" + val New: String = "new" + val Integer: String = "Integer" + val Float: String = "Float" + val String: String = "String" + val Symbol: String = "Symbol" + val Array: String = "Array" + val Hash: String = "Hash" + val Encoding: String = "Encoding" + val Regexp: String = "Regexp" + val Lambda: String = "lambda" + val Proc: String = "proc" + val Loop: String = "loop" + val Self: String = "self" + val Super: String = "super" + val Rational: String = "Rational" + val Initialize: String = "initialize" + val TypeDeclBody: String = "" + + val Main: String = "
" + + val Resolver: String = "" + + def getBuiltInType(typeInString: String) = s"${GlobalTypes.kernelPrefix}.$typeInString" + + object RubyOperators: + val backticks: String = ".backticks" + val hashInitializer = ".hashInitializer" + val association = ".association" + val splat = ".splat" + val arrayPatternMatch = ".arrayPatternMatch" + val regexpMatch = "=~" + val regexpNotMatch = "!~" +end Defines + +object GlobalTypes: + val Kernel = "Kernel" + val builtinPrefix = "__core" + val kernelPrefix = s"$builtinPrefix.$Kernel" + + /** Source: https://ruby-doc.org/docs/ruby-doc-bundle/Manual/man-1.4/function.html + */ + val bundledClasses: Set[String] = Set( + "ARGF", + "ArgumentError", + "Array", + "BasicObject", + "Binding", + "Class", + "ClosedQueueError", + "Comparable", + "Complex", + "ConditionVariable", + "Continuation", + "Dir", + "ENV", + "EOFError", + "Encoding", + "Encoding.CompatibilityError", + "Encoding.Converter", + "Encoding.ConverterNotFoundError", + "Encoding.InvalidByteSequenceError", + "Encoding.UndefinedConversionError", + "EncodingError", + "Enumerable", + "Enumerator", + "Enumerator.ArithmeticSequence", + "Enumerator.Chain", + "Enumerator.Generator", + "Enumerator.Lazy", + "Enumerator.Producer", + "Enumerator.Yielder", + "Errno", + "Exception", + "FalseClass", + "Fiber", + "Fiber.SchedulerInterface", + "FiberError", + "File", + "File.Constants", + "File.Stat", + "FileTest", + "Float", + "FloatDomainError", + "FrozenError", + "GC", + "GC.Profiler", + "Hash", + "IO", + "IO.EAGAINWaitReadable", + "IO.EAGAINWaitWritable", + "IO.EINPROGRESSWaitReadable", + "IO.EINPROGRESSWaitWritable", + "IO.EWOULDBLOCKWaitReadable", + "IO.EWOULDBLOCKWaitWritable", + "IO.WaitReadable", + "IO.WaitWritable", + "IOError", + "IndexError", + "Integer", + "Interrupt", + Kernel, + "KeyError", + "LoadError", + "LocalJumpError", + "Marshal", + "MatchData", + "Math", + "Math.DomainError", + "Method", + "Module", + "Mutex", + "NameError", + "NilClass", + "NoMatchingPatternError", + "NoMemoryError", + "NoMethodError", + "NotImplementedError", + "Numeric", + "Object", + "ObjectSpace", + "ObjectSpace.WeakMap", + "Pool", + "Proc", + "Process", + "Process.GID", + "Process.Status", + "Process.Sys", + "Process.UID", + "Queue", + "Ractor", + "Ractor.ClosedError", + "Ractor.Error", + "Ractor.IsolationError", + "Ractor.MovedError", + "Ractor.MovedObject", + "Ractor.RemoteError", + "Ractor.UnsafeError", + "Random", + "Random.Formatter", + "Range", + "RangeError", + "Rational", + "Regexp", + "RegexpError", + "Ripper", + "RubyVM", + "RubyVM.AbstractSyntaxTree", + "RubyVM.AbstractSyntaxTree.Node", + "RubyVM.InstructionSequence", + "RubyVM.MJIT", + "RuntimeError", + "ScriptError", + "SecurityError", + "Signal", + "SignalException", + "SizedQueue", + "StandardError", + "StopIteration", + "String", + "Struct", + "Symbol", + "SyntaxError", + "SystemCallError", + "SystemExit", + "SystemStackError", + "Thread", + "Thread.Backtrace", + "Thread.Backtrace.Location", + "ThreadError", + "ThreadGroup", + "Time", + "TracePoint", + "TrueClass", + "TypeError", + "UnboundMethod", + "UncaughtThrowError", + "UnicodeNormalize", + "Warning", + "ZeroDivisionError", + "fatal", + "unknown" + ) + + /* Source: https://ruby-doc.org/3.2.2/Kernel.html + * + * We comment-out methods that require an explicit "receiver" (target of member access) and those that may be commonly + * shadowed. + */ + val kernelFunctions: Set[String] = Set( + "Array", + "Complex", + "Float", + "Hash", + "Integer", + "Rational", + "String", + "__callee__", + "__dir__", + "__method__", + "abort", + "at_exit", + "autoload", + "autoload?", + "binding", + "block_given?", + "callcc", + "caller", + "caller_locations", + "catch", + "chomp", + "chomp!", + "chop", + "chop!", + // "class", + // "clone", + "eval", + "exec", + "exit", + "exit!", + "fail", + "fork", + "format", + // "frozen?", + "gets", + "global_variables", + "gsub", + "gsub!", + "iterator?", + "lambda", + "load", + "local_variables", + "loop", + "open", + "p", + "print", + "printf", + "proc", + "putc", + "puts", + "raise", + "rand", + "readline", + "readlines", + "require", + "require_all", + "require_relative", +// "select", + "set_trace_func", + "sleep", + "spawn", + "sprintf", + "srand", + "sub", + "sub!", + "syscall", + "system", + "tap", + "test", + // "then", + "throw", + "trace_var", + // "trap", + "untrace_var", + "warn" + // "yield_self", + ) +end GlobalTypes diff --git a/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/utils/FreshNameGenerator.scala b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/utils/FreshNameGenerator.scala new file mode 100644 index 00000000..e3feb963 --- /dev/null +++ b/platform/frontends/ruby2atom/src/main/scala/io/appthreat/ruby2atom/utils/FreshNameGenerator.scala @@ -0,0 +1,11 @@ +package io.appthreat.ruby2atom.utils + +class FreshNameGenerator[T](template: Int => T): + private var counter: Int = 0 + def fresh: T = + val name = template(counter) + counter += 1 + name + + def current: T = + template(counter - 1) diff --git a/platform/frontends/x2cpg/build.sbt b/platform/frontends/x2cpg/build.sbt index e9587457..11eb162b 100644 --- a/platform/frontends/x2cpg/build.sbt +++ b/platform/frontends/x2cpg/build.sbt @@ -3,7 +3,9 @@ name := "x2cpg" dependsOn(Projects.semanticcpg) libraryDependencies ++= Seq( - "io.circe" %% "circe-core" % Versions.circe, + "com.lihaoyi" %% "upickle" % Versions.upickle, + "com.typesafe" % "config" % Versions.typeSafeConfig, + "com.michaelpollmeier" % "versionsort" % Versions.versionSort, "io.circe" %% "circe-generic" % Versions.circe, "io.circe" %% "circe-parser" % Versions.circe, "org.scalatest" %% "scalatest" % Versions.scalatest % Test diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala index af39267f..876a7b15 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala @@ -3,7 +3,6 @@ package io.appthreat.x2cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.nodes.AstNode.PropertyDefaults -import org.slf4j.LoggerFactory import overflowdb.BatchedUpdate.DiffGraphBuilder import overflowdb.SchemaViolationException @@ -14,8 +13,6 @@ enum ValidationMode: object Ast: - private val logger = LoggerFactory.getLogger(getClass) - def apply(node: NewNode)(implicit withSchemaValidation: ValidationMode): Ast = Ast(Vector.empty :+ node) def apply()(implicit withSchemaValidation: ValidationMode): Ast = new Ast(Vector.empty) @@ -49,6 +46,9 @@ object Ast: ast.bindsEdges.foreach { edge => diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) } + ast.captureEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CAPTURE) + } end storeInDiffGraph def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit @@ -90,7 +90,8 @@ case class Ast( refEdges: collection.Seq[AstEdge] = Vector.empty, bindsEdges: collection.Seq[AstEdge] = Vector.empty, receiverEdges: collection.Seq[AstEdge] = Vector.empty, - argEdges: collection.Seq[AstEdge] = Vector.empty + argEdges: collection.Seq[AstEdge] = Vector.empty, + captureEdges: collection.Seq[AstEdge] = Vector.empty )(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled): def root: Option[NewNode] = nodes.headOption @@ -200,6 +201,14 @@ case class Ast( dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER)) this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) + def withCaptureEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE) + this.copy(captureEdges = captureEdges ++ List(AstEdge(src, dst))) + + def withCaptureEdges(src: NewNode, dsts: Seq[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE)) + this.copy(captureEdges = captureEdges ++ dsts.map(AstEdge(src, _))) + /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and * `argumentIndex` fields of the new root node are set to `order`. If `replacementNode` is set, * then this replaces `node` in the new copy. diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala index a1050657..f5c73635 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala @@ -42,7 +42,7 @@ trait AstNodeBuilder[Node, NodeProcessor]: protected def shortenCode(code: String): String = StringUtils.abbreviate(code, math.max(MinCodeLength, MaxCodeLength)) - protected def offset(node: Node): Option[(Int, Int)] = None + protected def offset(node: Node): Option[(Integer, Integer)] = None protected def unknownNode(node: Node, code: String): NewUnknown = NewUnknown() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala index 3785b60b..1c3910ea 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala @@ -1,20 +1,102 @@ package io.appthreat.x2cpg -import better.files.File.VisitOptions import better.files.* +import better.files.File.VisitOptions import org.slf4j.LoggerFactory import java.io.FileNotFoundException +import java.nio.file.FileVisitor +import java.nio.file.FileVisitResult +import java.nio.file.Path import java.nio.file.Paths +import java.nio.file.attribute.BasicFileAttributes +import java.nio.file.Files +import scala.jdk.CollectionConverters.SetHasAsJava +import scala.util.matching.Regex object SourceFiles: private val logger = LoggerFactory.getLogger(getClass) - private def isIgnoredByFileList(filePath: String, config: X2CpgConfig[?]): Boolean = - val isInIgnoredFiles = config.ignoredFiles.exists { - case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) - case ignorePath => filePath == ignorePath + /** A failsafe implementation of a [[FileVisitor]] that continues iterating through files even if + * an [[IOException]] occurs during traversal. + * + * This visitor determines during traversal whether a given file should be excluded based on + * several criteria, such as matching default ignore patterns, specific file name patterns, or + * explicit file paths to ignore. It does not descent into folders matching such ignore patterns. + * + * This class is useful in scenarios where file traversal must be resilient to errors, such as + * accessing files with restricted permissions or encountering corrupted file entries. + * + * @param inputPath + * The root path from which the file traversal starts. + * @param ignoredDefaultRegex + * Optional sequence of regular expressions to filter out default ignored file patterns. + * @param ignoredFilesRegex + * Optional regular expression to filter out specific files based on their names. + * @param ignoredFilesPath + * Optional sequence of file paths to exclude from traversal explicitly. + */ + private final class FailsafeFileVisitor( + inputPath: String, + sourceFileExtensions: Set[String], + ignoredDefaultRegex: Option[Seq[Regex]] = None, + ignoredFilesRegex: Option[Regex] = None, + ignoredFilesPath: Option[Seq[String]] = None + ) extends FileVisitor[Path]: + + private val seenFiles = scala.collection.mutable.ArrayBuffer.empty[Path] + + def files(): Array[File] = seenFiles.map(File(_)).toArray + + override def preVisitDirectory(dir: Path, attrs: BasicFileAttributes): FileVisitResult = + if filterFile( + dir.toString, + inputPath, + ignoredDefaultRegex, + ignoredFilesRegex, + ignoredFilesPath + ) + then + FileVisitResult.CONTINUE + else + FileVisitResult.SKIP_SUBTREE + + override def visitFile(file: Path, attrs: BasicFileAttributes): FileVisitResult = + if + hasSourceFileExtension(file, sourceFileExtensions) && + filterFile( + file.toString, + inputPath, + ignoredDefaultRegex, + ignoredFilesRegex, + ignoredFilesPath + ) + then seenFiles.addOne(file) + FileVisitResult.CONTINUE + + override def visitFileFailed(file: Path, exc: java.io.IOException): FileVisitResult = + exc match + case _: java.nio.file.FileSystemLoopException => + logger.warn(s"Ignoring '$file' (cyclic symlink)") + case other => logger.warn(s"Ignoring '$file'", other) + FileVisitResult.CONTINUE + + override def postVisitDirectory(dir: Path, exc: java.io.IOException): FileVisitResult = + FileVisitResult.CONTINUE + end FailsafeFileVisitor + + private def isIgnoredByFileList(filePath: String, ignoredFiles: Seq[String]): Boolean = + val filePathFile = File(filePath) + if !filePathFile.exists || !filePathFile.isReadable then + logger.debug(s"'$filePath' ignored (not readable or broken symlink)") + return true + val isInIgnoredFiles = ignoredFiles.exists { ignorePath => + val ignorePathFile = File(ignorePath) + ignorePathFile.exists && + (ignorePathFile.contains(filePathFile, strict = false) || ignorePathFile.isSameFileAs( + filePathFile + )) } if isInIgnoredFiles then logger.debug(s"'$filePath' ignored (--exclude)") @@ -22,106 +104,190 @@ object SourceFiles: else false - private def isIgnoredByDefault(filePath: String, config: X2CpgConfig[?]): Boolean = - val relPath = toRelativePath(filePath, config.inputPath) - if config.defaultIgnoredFilesRegex.exists(_.matches(relPath)) || File( - filePath - ).isSymbolicLink - then + private def isIgnoredByDefaultRegex( + filePath: String, + inputPath: String, + ignoredDefaultRegex: Seq[Regex] + ): Boolean = + val relPath = toRelativePath(filePath, inputPath) + if ignoredDefaultRegex.exists(_.matches(relPath)) then logger.debug(s"'$relPath' ignored by default") true else false - private def isIgnoredByRegex(filePath: String, config: X2CpgConfig[?]): Boolean = - val relPath = toRelativePath(filePath, config.inputPath) - val isInIgnoredFilesRegex = config.ignoredFilesRegex.matches(relPath) + private def isIgnoredByRegex( + filePath: String, + inputPath: String, + ignoredFilesRegex: Regex + ): Boolean = + val relPath = toRelativePath(filePath, inputPath) + val isInIgnoredFilesRegex = ignoredFilesRegex.matches(relPath) if isInIgnoredFilesRegex then logger.debug(s"'$relPath' ignored (--exclude-regex)") true else false - private def filterFiles(files: List[String], config: X2CpgConfig[?]): List[String] = - files.filter { - case filePath if isIgnoredByDefault(filePath, config) => false - case filePath if isIgnoredByFileList(filePath, config) => false - case filePath if isIgnoredByRegex(filePath, config) => false - case _ => true - } + /** Filters a file based on the provided ignore rules. + * + * This method determines whether a given file should be excluded from processing based on + * several criteria, such as matching default ignore patterns, specific file name patterns, or + * explicit file paths to ignore. + * + * @param file + * The file name or path to evaluate. + * @param inputPath + * The root input path for the file traversal. + * @param ignoredDefaultRegex + * Optional sequence of regular expressions defining default file patterns to ignore. + * @param ignoredFilesRegex + * Optional regular expression defining specific file name patterns to ignore. + * @param ignoredFilesPath + * Optional sequence of file paths to explicitly exclude. + * @return + * `true` if the file is accepted, i.e., does not match any of the ignore criteria, `false` + * otherwise. + */ + def filterFile( + file: String, + inputPath: String, + ignoredDefaultRegex: Option[Seq[Regex]] = None, + ignoredFilesRegex: Option[Regex] = None, + ignoredFilesPath: Option[Seq[String]] = None + ): Boolean = + !ignoredDefaultRegex.exists(isIgnoredByDefaultRegex(file, inputPath, _)) + && !ignoredFilesRegex.exists(isIgnoredByRegex(file, inputPath, _)) + && !ignoredFilesPath.exists(isIgnoredByFileList(file, _)) - /** For a given input path, determine all source files by inspecting filename extensions. + /** Filters a list of files based on the provided ignore rules. + * + * This method applies [[filterFile]] to each file in the input list, returning only those files + * that do not match any of the ignore criteria. + * + * @param files + * The list of file names or paths to evaluate. + * @param inputPath + * The root input path for the file traversal. + * @param ignoredDefaultRegex + * Optional sequence of regular expressions defining default file patterns to ignore. + * @param ignoredFilesRegex + * Optional regular expression defining specific file name patterns to ignore. + * @param ignoredFilesPath + * Optional sequence of file paths to explicitly exclude. + * @return + * A filtered list of files that do not match the ignore criteria. */ - def determine(inputPath: String, sourceFileExtensions: Set[String]): List[String] = - determine(Set(inputPath), sourceFileExtensions) + def filterFiles( + files: List[String], + inputPath: String, + ignoredDefaultRegex: Option[Seq[Regex]] = None, + ignoredFilesRegex: Option[Regex] = None, + ignoredFilesPath: Option[Seq[String]] = None + ): List[String] = files.filter(filterFile( + _, + inputPath, + ignoredDefaultRegex, + ignoredFilesRegex, + ignoredFilesPath + )) + + private def hasSourceFileExtension(file: File, sourceFileExtensions: Set[String]): Boolean = + sourceFileExtensions.exists(ext => file.pathAsString.endsWith(ext)) - /** For a given input path, determine all source files by inspecting filename extensions and - * filter the result according to the given config (by its ignoredFilesRegex and ignoredFiles). + /** Determines a sorted list of file paths in a directory that match the specified criteria. + * + * @param inputPath + * The root directory to search for files. + * @param sourceFileExtensions + * A set of file extensions to include in the search. + * @param ignoredDefaultRegex + * An optional sequence of regular expressions for default files to ignore. + * @param ignoredFilesRegex + * An optional regular expression for additional files to ignore. + * @param ignoredFilesPath + * An optional sequence of specific file paths to ignore. + * @param visitOptions + * Implicit parameter defining the options for visiting the file tree. Defaults to + * `VisitOptions.follow`, which follows symbolic links. + * @return + * A sorted `List[String]` of file paths matching the criteria. + * + * This function traverses the file tree starting at the given `inputPath` and collects file + * paths that: + * - Have extensions specified in `sourceFileExtensions`. + * - Are not ignored based on `ignoredDefaultRegex`, `ignoredFilesRegex`, or + * `ignoredFilesPath`. + * + * It uses a custom `FailsafeFileVisitor` to handle the filtering logic and `Files.walkFileTree` + * to perform the traversal. + * + * Example usage: + * {{{ + * val files = determine( + * inputPath = "/path/to/dir", + * sourceFileExtensions = Set(".scala", ".java"), + * ignoredDefaultRegex = Some(Seq(".*\\.tmp".r)), + * ignoredFilesRegex = Some(".*_backup\\.scala".r), + * ignoredFilesPath = Some(Seq("/path/to/dir/ignore_me.scala")) + * ) + * println(files) + * }}} + * @throws java.io.FileNotFoundException + * if the `inputPath` does not exist or is not readable. + * @see + * [[FailsafeFileVisitor]] for details on the visitor used to process files. */ def determine( inputPath: String, sourceFileExtensions: Set[String], - config: X2CpgConfig[?] - ): List[String] = - determine(Set(inputPath), sourceFileExtensions, config) + ignoredDefaultRegex: Option[Seq[Regex]] = None, + ignoredFilesRegex: Option[Regex] = None, + ignoredFilesPath: Option[Seq[String]] = None + )(implicit visitOptions: VisitOptions = VisitOptions.follow): List[String] = + val dir = File(inputPath) + assertExists(dir) + val visitor = new FailsafeFileVisitor( + dir.pathAsString, + sourceFileExtensions, + ignoredDefaultRegex, + ignoredFilesRegex, + ignoredFilesPath + ) + Files.walkFileTree(dir.path, visitOptions.toSet.asJava, Int.MaxValue, visitor) + val matchingFiles = visitor.files().map(_.pathAsString) + matchingFiles.toList.sorted - /** For given input paths, determine all source files by inspecting filename extensions and filter - * the result according to the given config (by its ignoredFilesRegex and ignoredFiles). - */ - def determine( - inputPaths: Set[String], + def determineWithConfig( + inputPath: String, sourceFileExtensions: Set[String], config: X2CpgConfig[?] - ): List[String] = - filterFiles(determine(inputPaths, sourceFileExtensions), config) - - /** For a given array of input paths, determine all source files by inspecting filename - * extensions. - */ - def determine(inputPaths: Set[String], sourceFileExtensions: Set[String]): List[String] = - def hasSourceFileExtension(file: File): Boolean = - file.extension.exists(sourceFileExtensions.contains) - - val inputFiles = inputPaths.map(File(_)) - assertAllExist(inputFiles) - - val (dirs, files) = inputFiles.partition(_.isDirectory) - - val matchingFiles = files.filter(hasSourceFileExtension).map(_.toString) - val matchingFilesFromDirs = dirs - .flatMap(_.listRecursively(VisitOptions.default)) - .filter(hasSourceFileExtension) - .map(_.pathAsString) - - (matchingFiles ++ matchingFilesFromDirs).toList.sorted - - /** Attempting to analyse source paths that do not exist is a hard error. Terminate execution - * early to avoid unexpected and hard-to-debug issues in the results. - */ - private def assertAllExist(files: Set[File]): Unit = - val (existant, nonExistant) = files.partition(_.isReadable) - val nonReadable = existant.filterNot(_.isReadable) - - if nonExistant.nonEmpty || nonReadable.nonEmpty then - logErrorWithPaths("Source input paths do not exist", nonExistant.map(_.canonicalPath)) - - logErrorWithPaths( - "Source input paths exist, but are not readable", - nonReadable.map(_.canonicalPath) + )(implicit visitOptions: VisitOptions = VisitOptions.follow): List[String] = + determine( + inputPath, + sourceFileExtensions, + ignoredDefaultRegex = Option(config.defaultIgnoredFilesRegex), + ignoredFilesRegex = Option(config.ignoredFilesRegex), + ignoredFilesPath = Option(config.ignoredFiles) ) - throw FileNotFoundException("Invalid source paths provided") - - private def logErrorWithPaths(message: String, paths: Iterable[String]): Unit = - val pathsArray = paths.toArray.sorted - - pathsArray.lengthCompare(1) match - case cmp if cmp < 0 => // pathsArray is empty, so don't log anything - case cmp if cmp == 0 => logger.debug(s"$message: ${paths.head}") - - case cmp => - val errorMessage = (message +: pathsArray.map(path => s"- $path")).mkString("\n") - logger.debug(errorMessage) + /** Asserts that a given file exists and is readable. + * + * This method validates the existence and readability of the specified file. If the file does + * not exist or is not readable, it logs an error and throws a [[FileNotFoundException]]. + * + * @param file + * The file to validate. + * @throws FileNotFoundException + * if the file does not exist or is not readable. + */ + private def assertExists(file: File): Unit = + if !file.exists then + logger.error(s"Source input path does not exist: ${file.pathAsString}") + throw FileNotFoundException("Invalid source path provided!") + if !file.isReadable then + logger.error(s"Source input path exists, but is not readable: ${file.pathAsString}") + throw FileNotFoundException("Invalid source path provided!") /** Constructs an absolute path against rootPath. If the given path is already absolute this path * is returned unaltered. Otherwise, "rootPath / path" is returned. diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenConfig.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenConfig.scala new file mode 100644 index 00000000..d5c5bc04 --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenConfig.scala @@ -0,0 +1,30 @@ +package io.appthreat.x2cpg.astgen + +import io.appthreat.x2cpg.X2CpgConfig +import io.appthreat.x2cpg.astgen.AstGenRunner.AstGenProgramMetaData + +trait AstGenConfig[R <: X2CpgConfig[R]]: + this: R => + + /** The prefix/name of the AST Gen binary. + */ + protected val astGenProgramName: String + + /** The root prefix in application.conf that concerns this frontend. + */ + protected val astGenConfigPrefix: String + + /** Indicates that a single binary handles all architectures. + */ + protected val multiArchitectureBuilds: Boolean = false + + /** Creates a meta-data class of information about the AST Gen executable. + */ + def astGenMetaData: AstGenProgramMetaData = + AstGenProgramMetaData( + astGenProgramName, + astGenConfigPrefix, + multiArchitectureBuilds, + getClass.getProtectionDomain.getCodeSource.getLocation + ) +end AstGenConfig diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenNodeBuilder.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenNodeBuilder.scala new file mode 100644 index 00000000..d9182e2f --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenNodeBuilder.scala @@ -0,0 +1,23 @@ +package io.appthreat.x2cpg.astgen + +import io.appthreat.x2cpg.AstNodeBuilder +import io.shiftleft.codepropertygraph.generated.nodes.AstNode.PropertyDefaults + +/** An extension of AstNodeBuilder that is able to provide useful defaults from the more specialized + * node type that AstGen-based frontends use. + */ +trait AstGenNodeBuilder[NodeProcessor] extends AstNodeBuilder[BaseNodeInfo[?], NodeProcessor]: + this: NodeProcessor => + + override def code(node: BaseNodeInfo[?]): String = + Option(node).map(_.code).getOrElse(PropertyDefaults.Code) + + override def line(node: BaseNodeInfo[?]): Option[Integer] = Option(node).flatMap(_.lineNumber) + + override def lineEnd(node: BaseNodeInfo[?]): Option[Integer] = + Option(node).flatMap(_.lineNumberEnd) + + override def column(node: BaseNodeInfo[?]): Option[Integer] = Option(node).flatMap(_.columnNumber) + + override def columnEnd(node: BaseNodeInfo[?]): Option[Integer] = + Option(node).flatMap(_.columnNumberEnd) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenRunner.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenRunner.scala new file mode 100644 index 00000000..daa4eac0 --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/AstGenRunner.scala @@ -0,0 +1,155 @@ +package io.appthreat.x2cpg.astgen + +import better.files.File +import com.typesafe.config.ConfigFactory +import io.appthreat.x2cpg.utils.Environment.ArchitectureType.ArchitectureType +import io.appthreat.x2cpg.utils.Environment.OperatingSystemType.OperatingSystemType +import io.appthreat.x2cpg.utils.{Environment, ExternalCommand} +import io.appthreat.x2cpg.{SourceFiles, X2CpgConfig} +import versionsort.VersionHelper + +import java.net.URL +import java.nio.file.Paths +import scala.util.{Failure, Success, Try} + +object AstGenRunner: + + trait AstGenRunnerResult: + def parsedFiles: List[String] + def skippedFiles: List[String] + + /** @param parsedFiles + * the files parsed by the runner. + * @param skippedFiles + * the files skipped by the runner. + */ + case class DefaultAstGenRunnerResult( + parsedFiles: List[String] = List.empty, + skippedFiles: List[String] = List.empty + ) extends AstGenRunnerResult + + /** @param name + * the name of the AST gen executable, e.g., goastgen, dotnetastgen, swiftastgen, etc. + * @param configPrefix + * the prefix of the executable's respective configuration path. + * @param multiArchitectureBuilds + * whether there is a binary for specific architectures or not. + * @param packagePath + * the code path for the frontend. + */ + case class AstGenProgramMetaData( + name: String, + configPrefix: String, + multiArchitectureBuilds: Boolean, + packagePath: URL + ) + + def executableDir(implicit metaData: AstGenProgramMetaData): String = + ExternalCommand + .executableDir(Paths.get(metaData.packagePath.toURI)) + .resolve("astgen") + .toString + + def hasCompatibleAstGenVersion(compatibleVersion: String)(implicit + metaData: AstGenProgramMetaData + ): Boolean = + ExternalCommand.runWithResult(Seq(metaData.name, "-version"), ".").successOption.map( + _.mkString.strip() + ) match + case Some(installedVersion) + if installedVersion != "unknown" && + Try(VersionHelper.compare(installedVersion, compatibleVersion)).toOption.getOrElse( + -1 + ) >= 0 => + true + case Some(installedVersion) => + false + case _ => + false +end AstGenRunner + +trait AstGenRunnerBase(config: X2CpgConfig[?] & AstGenConfig[?]): + + import io.appthreat.x2cpg.astgen.AstGenRunner.* + + // Suffixes for the binary based on OS & architecture + protected val WinX86 = "win.exe" + protected val WinArm = "win-arm.exe" + protected val LinuxX86 = "linux" + protected val LinuxArm = "linux-arm" + protected val MacX86 = "macos" + protected val MacArm = "macos-arm" + + /** All the supported combinations of architectures. + */ + protected val SupportedBinaries: Set[(OperatingSystemType, ArchitectureType)] = Set( + Environment.OperatingSystemType.Windows -> Environment.ArchitectureType.X86, + Environment.OperatingSystemType.Windows -> Environment.ArchitectureType.ARMv8, + Environment.OperatingSystemType.Linux -> Environment.ArchitectureType.X86, + Environment.OperatingSystemType.Linux -> Environment.ArchitectureType.ARMv8, + Environment.OperatingSystemType.Mac -> Environment.ArchitectureType.X86, + Environment.OperatingSystemType.Mac -> Environment.ArchitectureType.ARMv8 + ) + + /** Determines the name of the executable to run, based on the host system. Usually, AST GEN + * binaries support three operating systems, and two architectures. Some binaries are + * multiplatform, in which case the suffix for x86 is used for both architectures. + */ + protected def executableName(implicit metaData: AstGenProgramMetaData): String = + if !SupportedBinaries.contains(Environment.operatingSystem -> Environment.architecture) then + throw new UnsupportedOperationException( + s"No compatible binary of ${metaData.name} for your operating system!" + ) + else + Environment.operatingSystem match + case Environment.OperatingSystemType.Windows => executableName(WinX86, WinArm) + case Environment.OperatingSystemType.Linux => executableName(LinuxX86, LinuxArm) + case Environment.OperatingSystemType.Mac => executableName(MacX86, MacArm) + case Environment.OperatingSystemType.Unknown => + executableName(LinuxX86, LinuxArm) + + protected def executableName(x86Suffix: String, armSuffix: String)(implicit + metaData: AstGenProgramMetaData + ): String = + if metaData.multiArchitectureBuilds then + s"${metaData.name}-$x86Suffix" + else + Environment.architecture match + case Environment.ArchitectureType.X86 => s"${metaData.name}-$x86Suffix" + case Environment.ArchitectureType.ARMv8 => s"${metaData.name}-$armSuffix" + + protected def isIgnoredByUserConfig(filePath: String): Boolean = + lazy val isInIgnoredFiles = config.ignoredFiles.exists { + case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) + case ignorePath => filePath == ignorePath + } + lazy val isInIgnoredFileRegex = config.ignoredFilesRegex.matches(filePath) + if isInIgnoredFiles || isInIgnoredFileRegex then + true + else + false + + protected def filterFiles(files: List[String], out: File): List[String] = + files.filter(fileFilter(_, out)) + + protected def fileFilter(file: String, out: File): Boolean = + file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match + case filePath if isIgnoredByUserConfig(filePath) => false + case _ => true + + protected def runAstGenNative(in: String, out: File, exclude: String, include: String)(implicit + metaData: AstGenProgramMetaData + ): AstGenRunnerResult + + protected def astGenCommand(implicit metaData: AstGenProgramMetaData): String = + val conf = ConfigFactory.load + val astGenVersion = conf.getString(s"${metaData.configPrefix}.${metaData.name}_version") + if hasCompatibleAstGenVersion(astGenVersion) then + metaData.name + else + s"$executableDir/$executableName" + + def execute(out: File): AstGenRunnerResult = + implicit val metaData: AstGenProgramMetaData = config.astGenMetaData + runAstGenNative(config.inputPath, out, config.ignoredFilesRegex.toString(), "") +end AstGenRunnerBase diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/package.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/package.scala new file mode 100644 index 00000000..232da51a --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/astgen/package.scala @@ -0,0 +1,41 @@ +package io.appthreat.x2cpg + +import ujson.Value + +import scala.Option + +package object astgen: + + /** The base components of a JSON node. + */ + trait BaseNodeInfo[T]: + def node: T + def json: Value + def code: String + def lineNumber: Option[Integer] + def columnNumber: Option[Integer] + def lineNumberEnd: Option[Integer] + def columnNumberEnd: Option[Integer] + + /** The basic components of the results from parsing the JSON AST. + */ + trait BaseParserResult: + def filename: String + def fullPath: String + def json: Value + def fileContent: String + + /** The default parser result. A minimal implementation of BaseParserResult + * + * @param filename + * the relative filename + * @param fullPath + * the absolute file path + * @param json + * the deserialized JSON content. + * @param fileContent + * the raw file contents. + */ + case class ParserResult(filename: String, fullPath: String, json: Value, var fileContent: String) + extends BaseParserResult +end astgen diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ProgramSummary.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ProgramSummary.scala new file mode 100644 index 00000000..707c2dcf --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ProgramSummary.scala @@ -0,0 +1,456 @@ +package io.appthreat.x2cpg.datastructures + +import io.shiftleft.codepropertygraph.generated.nodes.DeclarationNew + +import scala.annotation.targetName +import scala.collection.mutable +import scala.reflect.ClassTag + +/** A hierarchical data-structure that stores the result of types and their respective members. + * These types can be sourced from pre-parsing the application, or pre-computed stubs of common + * libraries. + * + * The utility of this object is in assisting resolving shorthand types during AST creation. + * + * @tparam T + * the type/class meta data class. + * @tparam M + * the function/method meta data class. + * @tparam F + * the field/property/member meta data class. + */ +trait ProgramSummary[T <: TypeLike[M, F], M <: MethodLike, F <: FieldLike]: + + /** A mapping between a namespace/directory and the containing types. + */ + protected val namespaceToType: mutable.Map[String, mutable.Set[T]] + + /** For the given namespace, returns the declared types. + */ + def typesUnderNamespace(namespace: String): Set[T] = + namespaceToType.getOrElse(namespace, Set.empty).toSet + + /** For a type, will search for the associated namespace. + */ + def namespaceFor(clazz: T): Option[String] = namespaceToType.find { case (_, v) => + v.contains(clazz) + }.map(_._1) + + /** @param typeName + * the type name or full name. Can be partially qualified. + * @return + * the set of matching types' meta data. + */ + def matchingTypes(typeName: String): List[T] = + namespaceToType.values.flatten.filter(t => t.name.split('.').endsWith(typeName.split('.'))).toList + + /** Absorbs the given program summary information into this program summary. + * @param o + * the program summary to absorb. + * @return + * this program summary. + */ + def absorb(o: ProgramSummary[T, M, F]): ProgramSummary[T, M, F] = + ProgramSummary.merge(this.namespaceToType, o.namespaceToType) + this +end ProgramSummary + +object ProgramSummary: + + def merge[T <: TypeLike[M, F], M <: MethodLike, F <: FieldLike]( + a: mutable.Map[String, mutable.Set[T]], + b: mutable.Map[String, mutable.Set[T]] + ): mutable.Map[String, mutable.Set[T]] = + + def dedupTypesInPlace(m: mutable.Map[String, mutable.Set[T]]): Unit = + val newMap = m + .map { case (namespace, ts) => namespace -> ts.groupBy(_.name) } + .map { case (namespace, typMap) => + val dedupedTypes = mutable.Set.from( + typMap + .map { case (name, ts) => name -> ts.reduce((u, v) => (u + v).asInstanceOf[T]) } + .values + .toSet + ) + m.put(namespace, dedupedTypes) + namespace -> dedupedTypes + } + .toMap + assert(m.flatMap { case (name, ts) => ts.groupBy(_.name).map(_._2.size) }.forall(_ == 1)) + + // Handle duplicate types sharing the same namespace. This can be introduced from serialized type stubs. + dedupTypesInPlace(a) + dedupTypesInPlace(b) + + b.foreach { case (namespace, bts) => + a.updateWith(namespace) { + case Some(ats: mutable.Set[T]) => + // Assert that we can simply reduce the grouped values to a simple key-value mapping for fast look-ups + assert(ats.groupBy(_.name).values.forall(_.sizeIs == 1)) + val atsMap = ats.groupBy(_.name).map { case (name, ts) => name -> ts.head } + + bts.foreach { bt => + atsMap.get(bt.name) match + case Some(at) => + ats.remove(at) + ats.add((at + bt).asInstanceOf[T]) + case None => + ats.add(bt) + } + Option(ats) + case None => b.get(namespace) + } + } + a + end merge +end ProgramSummary + +/** Extends the capability of the scope object to track types in scope as provide type resolution. + * + * @tparam M + * the method/function meta data class. + * @tparam F + * the field/object property meta data class. + * @tparam T + * the type/class meta data class. + */ +trait TypedScope[M <: MethodLike, F <: FieldLike, T <: TypeLike[M, F]](summary: ProgramSummary[ + T, + M, + F +]): + this: Scope[?, ?, TypedScopeElement] => + + /** Tracks the types that are visible to this scope. + */ + protected val typesInScope = mutable.Set.empty[T] + + /** Tracks the members visible to this scope. In languages like JavaScript or Python, where + * members can be directly imported and accessed without an explicit base, they are kept here. + */ + protected val membersInScope = mutable.Set.empty[MemberLike] + + /** Tracks any types or modules imported under alternative names to their type full names. + */ + protected val aliasedTypes = mutable.HashMap.empty[String, String] + + /** Given a type name or alias, attempts to resolve its full name using the types currently in + * scope. + * + * @param typeName + * the shorthand name. + * @return + * the type meta-data if found. + */ + def tryResolveTypeReference(typeName: String): Option[T] = + typesInScope + .collectFirst { + // Handle partially qualified names + case typ if typ.name.split("[.]").endsWith(typeName.split("[.]")) => typ + case typ if aliasedTypes.contains(typeName) && typ.name == aliasedTypes(typeName) => + typ + } + + protected def matchingM(callName: String)(implicit + tag: ClassTag[M] + ): PartialFunction[MemberLike, M] = { + case m: M if m.name == callName => m + } + + /** Given the type full name and call name, will attempt to find the matching entry. + * + * @param typeFullName + * the base type full name. If none, will refer to loosely imported member or functions. + * @param callName + * the call name. + * @param argTypes + * the observed argument types. Only relevant for languages that implement overloading. + * @return + * the method meta data if found. + */ + def tryResolveMethodInvocation( + callName: String, + argTypes: List[String] = Nil, + typeFullName: Option[String] = None + )( + implicit tag: ClassTag[M] + ): Option[M] = typeFullName match + case None => + // TODO: The typesInScope part is to imprecisely solve the unimplemented polymorphism limitation + membersInScope.collectFirst(matchingM(callName)).orElse { + typesInScope.flatMap(_.methods).collectFirst(matchingM(callName)) + } + case Some(tfn) => + tryResolveTypeReference(tfn).flatMap { t => + t.methods.find(m => m.name == callName) + } + + /** Given the type full name and field name, will attempt to find the matching entry. + * @param typeFullName + * the base type full name. If none, will refer to loosely imported member or functions. + * @param fieldName + * the field/object property/module variable name. + * @return + * the field/object property/module variable's meta data. + */ + def tryResolveFieldAccess(fieldName: String, typeFullName: Option[String] = None): Option[F] = + typeFullName match + case None => membersInScope.collectFirst { + case f: FieldLike if f.name == fieldName => f.asInstanceOf[F] + } + case Some(tfn) => + tryResolveTypeReference(tfn).flatMap { t => + t.fields.find { f => f.name == fieldName } + } + + /** Appends known types imported into the scope. + * @param namespace + * the fully qualified imported namespace. + */ + def addImportedNamespace(namespace: String): Unit = + val knownTypesFromNamespace = summary.typesUnderNamespace(namespace) + typesInScope.addAll(knownTypesFromNamespace) + + /** Appends known types imported into the scope. + * @param typeOrModule + * the type name or full name. + */ + def addImportedTypeOrModule(typeOrModule: String): Unit = + val matchingTypes = summary.matchingTypes(typeOrModule) + typesInScope.addAll(matchingTypes) + + /** Appends known members to the scope. + * @param typeOrModule + * the type name or full name. + * @param memberNames + * the names of the members, or, if empty, imports all members from the type. + */ + def addImportedMember(typeOrModule: String, memberNames: String*): Unit = + val matchingTypes = summary.matchingTypes(typeOrModule) + val matchingMembers = matchingTypes.flatMap(t => t.fields ++ t.methods) + memberNames match + case Nil => membersInScope.addAll(matchingMembers) + case names => + val nameSet = names.toSet // Cast to set for O(1) membership query + val filteredMembers = matchingMembers.filter(member => nameSet.contains(member.name)) + membersInScope.addAll(filteredMembers) + + /** Given a method, will attempt to find the associated type with preference to the types in + * scope. + * @param m + * the method meta data. + * @return + * the type meta data, if found. + */ + def typeForMethod(m: M): Option[T] = + typesInScope.find(t => t.methods.contains(m)) +end TypedScope + +trait OverloadableScope[M <: OverloadableMethod]: + this: TypedScope[M, ?, ?] => + override def tryResolveMethodInvocation( + callName: String, + argTypes: List[String], + typeFullName: Option[String] = None + )(implicit tag: ClassTag[M]): Option[M] = typeFullName match + case None => + // TODO: The typesInScope part is to imprecisely solve the unimplemented polymorphism limitation + membersInScope.collectFirst(matchingM(callName)).orElse { + typesInScope.flatMap(_.methods).collectFirst(matchingM(callName)) + } + case Some(tfn) => + val methodsWithEqualArgs = tryResolveTypeReference(tfn).flatMap { t => + // TODO: Investigate using `isOverloadedBy` here + Option( + t.methods.filter(m => + m.name == callName && m.parameterTypes.filterNot(_._1 == "this").size == argTypes.size + ) + ) + } + + methodsWithEqualArgs + .getOrElse(List.empty[M]) + .find(isOverloadedBy(_, argTypes)) + .orElse(methodsWithEqualArgs.getOrElse(List.empty[M]).headOption) + + /** Determines if, by observing the given argument types, that the method's signature is a + * plausible match to the observed arguments. + * + * The default implementation only considers that the same number of arguments are added and does + * not account for variadic arguments nor polymorphism. + * + * @param method + * the method meta data. + * @param argTypes + * the observed arguments from the call-site. + * @return + * true if the method could be overloaded by a call with these argument types. + */ + protected def isOverloadedBy(method: M, argTypes: List[String]): Boolean = + method.parameterTypes.size == argTypes.size +end OverloadableScope + +/** An implementation of combining the typed scoping structures to manage the available type + * information at namespace levels. + * + * @tparam M + * the method/function meta data class. + * @tparam F + * the field/object property meta data class. + * @tparam T + * the type/class meta data class. + * @param summary + * the program summary. + */ +class DefaultTypedScope[M <: MethodLike, F <: FieldLike, T <: TypeLike[M, F]]( + summary: ProgramSummary[T, M, F] +) extends Scope[String, DeclarationNew, TypedScopeElement] + with TypedScope[M, F, T](summary): + + /** Pops the scope, adding types from the scope if necessary. + */ + override def pushNewScope(scopeNode: TypedScopeElement): Unit = + scopeNode match + case n: NamespaceLikeScope => typesInScope.addAll(summary.typesUnderNamespace(n.fullName)) + case _ => + super.pushNewScope(scopeNode) + + /** Pops the scope, removing types from the scope if necessary. + */ + override def popScope(): Option[TypedScopeElement] = + super.popScope().map { + case n: NamespaceLikeScope => + summary.typesUnderNamespace(n.fullName).foreach(typesInScope.remove) + n + case x => x + } +end DefaultTypedScope + +/* + Traits related to scoping classes + */ + +/** A scope element designed for the TypedScope. + */ +trait TypedScopeElement + +/** A namespace scope to synchronise types entering and exiting scopes. + */ +trait NamespaceLikeScope extends TypedScopeElement: + + /** @return + * the namespace full name. + */ + def fullName: String + +/* + Traits related to meta-data classes + */ + +/** A type declaration or module. Holds methods and field entities. + * + * @tparam M + * the method/function meta data class. + * @tparam F + * the field/object property meta data class. + */ +trait TypeLike[M <: MethodLike, F <: FieldLike]: + + /** @return + * the type full name. + */ + def name: String + + /** @return + * the methods declared directly under the type declaration. + */ + def methods: List[M] + + /** @return + * the fields/properties declared directly under the type declaration. + */ + def fields: List[F] + + /** Adds the contents of the two types and produces a new type. + * @param o + * the other type-like. + * @return + * a type-like that is the combination of the two, with precedence to colliding contents to + * this type (LHS). + */ + @targetName("add") + def +(o: TypeLike[M, F]): TypeLike[M, F] + + /** Helper method for creating the sum of two type-like's methods, while preferring this types' + * methods on collisions. + * @param o + * the other type-like. + * @return + * the combination of the two type-like's methods. + */ + protected def mergeMethods(o: TypeLike[M, ?]): List[M] = + val methodNames = methods.map(_.name).toSet + methods ++ o.methods.filterNot(m => methodNames.contains(m.name)) + + /** Helper method for creating the sum of two type-like's fields, while preferring this types' + * fields on collisions. + * + * @param o + * the other type-like. + * @return + * the combination of the two type-like's fields. + */ + protected def mergeFields(o: TypeLike[?, F]): List[F] = + val fieldNames = fields.map(_.name).toSet + fields ++ o.fields.filterNot(f => fieldNames.contains(f.name)) +end TypeLike + +/** An entity that is a member to some type or module. + */ +trait MemberLike: + + /** @return + * the name of the member. + */ + def name: String + +/** A member that behaves like a field/property/module variabe. + */ +trait FieldLike extends MemberLike: + + /** @return + * the name of the field. + */ + def name: String + + /** @return + * the type declared (not necessarily resolved) + */ + def typeName: String + +/** A function or procedure. + */ +trait MethodLike extends MemberLike: + + /** @return + * the name of the method. + */ + def name: String + + /** Stores the return type name. + * + * @return + * the return type name. + */ + def returnType: String + +trait OverloadableMethod extends MethodLike: + + /** Stores a tuple of the parameter name and type name. + * + * @return + * the names and type names of the parameters. + */ + def parameterTypes: List[(String, String)] + +trait StubbedType[M <: MethodLike, F <: FieldLike] extends TypeLike[M, F] diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/package.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/package.scala new file mode 100644 index 00000000..34a676a1 --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/package.scala @@ -0,0 +1,16 @@ +package io.appthreat.x2cpg + +/** This package solely exists to extract some code from the frontends that is shared between joern + * and the frontends, e.g. for parsing commandline arguments and running postprocessing passes. + * + * If this code was to be in the frontend's subproject, joern would need to have + * classpath-dependencies on the frontends, inheriting all their transitive dependencies. As + * discussed in e.g. https://github.com/joernio/joern/issues/4625#issuecomment-2166427270 we want + * to avoid having classpath dependencies on the frontends and instead invoke them frontends as + * external processes (i.e. execute their start script). Otherwise we'll end in jar hell with + * various incompatible versions of many different dependencies, and complex issues with things + * like OSGI and JPMS. + */ +package object frontendspecific: + // Special string used to separate joern-parse opts from frontend-specific opts + val FrontendArgsDelimitor = "--frontend-args" diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/Constants.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/Constants.scala new file mode 100644 index 00000000..aec62b31 --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/Constants.scala @@ -0,0 +1,92 @@ +package io.appthreat.x2cpg.frontendspecific.ruby2atom + +object Constants: + + val builtinPrefix = "__core" + val kernelPrefix = s"$builtinPrefix.Kernel" + val Initialize = "initialize" + val Main = "
" + + /* Source: https://ruby-doc.org/3.2.2/Kernel.html + * + * We comment-out methods that require an explicit "receiver" (target of member access.) + */ + val kernelFunctions: Set[String] = Set( + "Array", + "Complex", + "Float", + "Hash", + "Integer", + "Rational", + "String", + "__callee__", + "__dir__", + "__method__", + "abort", + "at_exit", + "autoload", + "autoload?", + "binding", + "block_given?", + "callcc", + "caller", + "caller_locations", + "catch", + "chomp", + "chomp!", + "chop", + "chop!", + // "class", + // "clone", + "eval", + "exec", + "exit", + "exit!", + "fail", + "fork", + "format", + // "frozen?", + "gets", + "global_variables", + "gsub", + "gsub!", + "iterator?", + "lambda", + "load", + "local_variables", + "loop", + "open", + "p", + "print", + "printf", + "proc", + "putc", + "puts", + "raise", + "rand", + "readline", + "readlines", + "require", + "require_all", + "require_relative", + "select", + "set_trace_func", + "sleep", + "spawn", + "sprintf", + "srand", + "sub", + "sub!", + "syscall", + "system", + "tap", + "test", + // "then", + "throw", + "trace_var", + // "trap", + "untrace_var", + "warn" + // "yield_self", + ) +end Constants diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImplicitRequirePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImplicitRequirePass.scala new file mode 100644 index 00000000..3c3ede36 --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImplicitRequirePass.scala @@ -0,0 +1,195 @@ +package io.appthreat.x2cpg.frontendspecific.ruby2atom + +import io.appthreat.x2cpg.Defines +import io.appthreat.x2cpg.frontendspecific.ruby2atom.Constants.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{Cpg, DispatchTypes, EdgeTypes, Operators} +import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} + +import java.util.regex.Pattern +import scala.annotation.tailrec +import scala.collection.mutable + +/** A tuple holding the (name, importPath) for types in the analysis. + */ +case class TypeImportInfo(name: String, importPath: String) + +/** In some Ruby frameworks, it is common to have an autoloader library that implicitly loads + * requirements onto the stack. This pass makes these imports explicit. The most popular one is Zeitwerk which we check in `Gemsfile.lock` to enable + * this pass. + * + * @param externalTypes + * a list of additional types to consider that may be importable but are not in the CPG. + */ +class ImplicitRequirePass(cpg: Cpg, externalTypes: Seq[TypeImportInfo] = Nil) + extends ForkJoinParallelCpgPass[Method](cpg): + + /** A tuple holding information about the type import info, additionally with a boolean indicating + * if it is external or not. + */ + private case class TypeImportInfoWithProvidence(info: TypeImportInfo, isExternal: Boolean) + private val typeNameToImportInfo = mutable.Map.empty[String, Seq[TypeImportInfoWithProvidence]] + + private val Require: String = "require" + private val Self: String = "self" + private val Initialize: String = "initialize" + private val Clazz: String = "" + + override def init(): Unit = + val importableTypeInfo = cpg.typeDecl + .isExternal(false) + .filter { typeDecl => + // zeitwerk will match types that share the name of the path. + // This match is insensitive to camel case, i.e, foo_bar will match type FooBar. + val fileName = typeDecl.filename.split(Array('/', '\\')).last + val typeName = typeDecl.name + ImplicitRequirePass.isAutoloadable(typeName, fileName) + } + .map { typeDecl => + val typeImportInfo = + TypeImportInfo(typeDecl.name, ImplicitRequirePass.normalizePath(typeDecl.filename)) + TypeImportInfoWithProvidence(typeImportInfo, typeDecl.isExternal) + } + .l + // Group types by symbol and add to map for quicker retrieval later + typeNameToImportInfo.addAll(importableTypeInfo.groupBy { + case TypeImportInfoWithProvidence(typeImportInfo, _) => + typeImportInfo.name + }) + typeNameToImportInfo.addAll(externalTypes.map(TypeImportInfoWithProvidence(_, true)).groupBy { + case TypeImportInfoWithProvidence(typeImportInfo, _) => typeImportInfo.name + }) + end init + + private def getFieldBaseFromString(fieldAccessString: String): String = + val normalizedFieldAccessString = fieldAccessString.replace("::", ".") + normalizedFieldAccessString.split('.').headOption.getOrElse(normalizedFieldAccessString) + + override def generateParts(): Array[Method] = + cpg.method.whereNot(_.astChildren.isCall.nameExact(Require)).toArray + + /** Collects methods within a module. + */ + private def findMethodsViaAstChildren(module: Method): Iterator[Method] = + // TODO For now we have to go via the full name regex because the AST is not yet linked + // at the execution time of this pass. + // Iterator(module) ++ module.astChildren.flatMap { + // case x: TypeDecl => x.method.flatMap(findMethodsViaAstChildren) + // case x: Method => Iterator(x) ++ x.astChildren.collectAll[Method].flatMap(findMethodsViaAstChildren) + // case _ => Iterator.empty + // } + cpg.method.fullName(Pattern.quote(module.fullName) + ".*") + + override def runOnPart(builder: DiffGraphBuilder, moduleMethod: Method): Unit = + val possiblyImportedSymbols = mutable.ArrayBuffer.empty[String] + val currPath = ImplicitRequirePass.normalizePath(moduleMethod.filename) + + val typeDecl = cpg.typeDecl.fullName(Pattern.quote(moduleMethod.fullName) + ".*").l + typeDecl.inheritsFromTypeFullName + .filterNot(_.endsWith(Clazz)) + .map(getFieldBaseFromString) + .foreach(possiblyImportedSymbols.append) + + val methodsOfModule = findMethodsViaAstChildren(moduleMethod).toList + val callsOfModule = methodsOfModule.ast.isCall.toList + + val symbolsGatheredFromCalls = callsOfModule + .flatMap { + case x if x.name == Initialize => + x.receiver.headOption.flatMap { + case x: TypeRef => Option(getFieldBaseFromString(x.code)) + case x: Identifier => Option(x.name) + case x: Call if x.name == Operators.fieldAccess => + Option(fieldAccessBase(x.asInstanceOf[FieldAccess])) + case _ => None + }.iterator + case x if x.methodFullName == Operators.fieldAccess => + fieldAccessBase(x.asInstanceOf[FieldAccess]) :: Nil + case _ => + Iterator.empty + } + .filterNot(_.isBlank) + + possiblyImportedSymbols.appendAll(symbolsGatheredFromCalls) + + var currOrder = moduleMethod.block.astChildren.size + possiblyImportedSymbols.distinct + .flatMap { identifierName => + typeNameToImportInfo + .getOrElse(identifierName, Seq.empty) + .sortBy { case TypeImportInfoWithProvidence(_, isExternal) => + isExternal // sorting booleans puts false (internal) first + } + .collectFirst { + // ignore an import to a file that defines the type + case TypeImportInfoWithProvidence(TypeImportInfo(_, importPath), _) + if importPath != currPath => importPath + } + } + .distinct + .foreach { importPath => + val requireCall = createRequireCall(builder, importPath) + requireCall.order(currOrder) + builder.addEdge(moduleMethod.block, requireCall, EdgeTypes.AST) + currOrder += 1 + } + end runOnPart + + private def createRequireCall(builder: DiffGraphBuilder, path: String): NewCall = + val requireCallNode = NewCall() + .name(Require) + .code(s"$Require '$path'") + .methodFullName(s"$kernelPrefix.$Require") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(Defines.Any) + builder.addNode(requireCallNode) + // Create literal argument + val pathLiteralNode = + NewLiteral().code(s"'$path'").typeFullName(s"$kernelPrefix.String").argumentIndex(1).order( + 2 + ) + builder.addEdge(requireCallNode, pathLiteralNode, EdgeTypes.AST) + builder.addEdge(requireCallNode, pathLiteralNode, EdgeTypes.ARGUMENT) + requireCallNode + + private def fieldAccessBase(fa: FieldAccess): String = + fieldAccessParts(fa).headOption.getOrElse(fa.argument(1).code) + + @tailrec + private def fieldAccessParts(fa: FieldAccess): Seq[String] = + fa.argument(1) match + case subFa: Call if subFa.name == Operators.fieldAccess => + fieldAccessParts(subFa.asInstanceOf[FieldAccess]) + case self: Identifier if self.name == Self => fa.fieldIdentifier.map(_.canonicalName).toSeq + case assignCall: Call if assignCall.name == Operators.assignment => + val assign = assignCall.asInstanceOf[Assignment] + // Handle the tmp var assign of qualified names + (assign.target, assign.source) match + case (lhs: Identifier, rhs: Call) + if lhs.name.startsWith(" + fieldAccessParts(rhs.asInstanceOf[FieldAccess]) + case _ => Seq.empty + case _ => Seq.empty +end ImplicitRequirePass + +object ImplicitRequirePass: + + /** Determines if the given type name and its corresponding parent file name allow for the type to + * be autoloaded by zeitwerk. + * @return + * true if the type is autoloadable from the given filename. + */ + def isAutoloadable(typeName: String, fileName: String): Boolean = + // We use lowercase as something like `openssl` and `OpenSSL` don't give obvious clues where capitalisation occurs + val strippedFileName = normalizePath(fileName).toLowerCase + val lowerCaseTypeName = typeName.toLowerCase + lowerCaseTypeName == strippedFileName.toLowerCase || lowerCaseTypeName == strippedFileName.replace( + "_", + "-" + ).toLowerCase + + private def normalizePath(path: String): String = path.replace("\\", "/").stripSuffix(".rb") diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImportsPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImportsPass.scala new file mode 100644 index 00000000..a637361f --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/frontendspecific/ruby2atom/ImportsPass.scala @@ -0,0 +1,24 @@ +package io.appthreat.x2cpg.frontendspecific.ruby2atom + +import io.appthreat.x2cpg.Imports.createImportNodeAndLink +import io.appthreat.x2cpg.X2Cpg.stripQuotes +import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.shiftleft.semanticcpg.language.* + +class ImportsPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg): + + override def generateParts(): Array[Call] = + cpg.call.nameExact(ImportsPass.ImportCallNames.toSeq*).isStatic.toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = + val importedEntity = stripQuotes(call.argument.isLiteral.code.l match + case s :: _ => s + case _ => "" + ) + val importNode = createImportNodeAndLink(importedEntity, importedEntity, Some(call), diffGraph) + if call.name == "require_all" then importNode.isWildcard(true) + +object ImportsPass: + val ImportCallNames: Set[String] = Set("require", "load", "require_relative", "require_all") diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubConfig.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubConfig.scala new file mode 100644 index 00000000..51e0282f --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubConfig.scala @@ -0,0 +1,46 @@ +package io.appthreat.x2cpg.typestub + +import io.appthreat.x2cpg.X2CpgConfig +import scopt.OParser + +import java.net.URL + +/** Extends the config to download type stubs to help resolve type full names. + */ +trait TypeStubConfig[R <: X2CpgConfig[R]]: + this: R => + + /** Whether to use type stubs to help resolve type information or not. Using type stubs may + * increase memory consumption. + */ + def useTypeStubs: Boolean + + /** The entrypoint to load the type stubs into the config. + */ + def withTypeStubs(value: Boolean): R + + /** Creates a meta-data class of information about the type stub management. + */ + def typeStubMetaData: TypeStubMetaData = + TypeStubMetaData(useTypeStubs, getClass.getProtectionDomain.getCodeSource.getLocation) + +/** The meta data around managing type stub resources for this frontend. + * @param useTypeStubs + * a flag to indicate whether types stubs should be used. + * @param packagePath + * the code path for the frontend. + */ +case class TypeStubMetaData(useTypeStubs: Boolean, packagePath: URL) + +object TypeStubConfig: + + def parserOptions[R <: X2CpgConfig[R] & TypeStubConfig[R]]: OParser[?, R] = + val builder = OParser.builder[R] + import builder.* + OParser.sequence( + opt[Unit]("disable-type-stubs") + .text( + "Disables the use type stubs for type information recovery. Using type stubs may increase memory consumption." + ) + .action((x, c) => c.withTypeStubs(false)) + ) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubUtil.scala new file mode 100644 index 00000000..6c6688f1 --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/typestub/TypeStubUtil.scala @@ -0,0 +1,27 @@ +package io.appthreat.x2cpg.typestub + +import better.files.File + +import java.nio.file.Paths + +object TypeStubUtil: + + /** Obtains the type stub dir for this frontend. + * @param metaData + * meta data describing the loaded type stubs. + * @return + * the directory where type stubs are. + */ + def typeStubDir(implicit metaData: TypeStubMetaData): File = + val dir = metaData.packagePath.toString + val indexOfLib = dir.lastIndexOf("lib") + val fixedDir = if indexOfLib != -1 then + new java.io.File(dir.substring("file:".length, indexOfLib)).toString + else + val indexOfTarget = dir.lastIndexOf("target") + if indexOfTarget != -1 then + new java.io.File(dir.substring("file:".length, indexOfTarget)).toString + else + "." + File(Paths.get(fixedDir, "/type_stubs").toAbsolutePath) +end TypeStubUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ConcurrentTaskUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ConcurrentTaskUtil.scala new file mode 100644 index 00000000..bf042ef8 --- /dev/null +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ConcurrentTaskUtil.scala @@ -0,0 +1,65 @@ +package io.appthreat.x2cpg.utils + +import java.util +import java.util.concurrent.{Callable, Executors} +import java.util.stream.{Collectors, StreamSupport} +import java.util.{Collections, Spliterator, Spliterators} +import scala.jdk.CollectionConverters.* +import scala.util.Try + +/** A utility for providing out-of-the-box basic concurrent execution for a collection of Scala + * functions. + */ +object ConcurrentTaskUtil: + + private val MAX_POOL_SIZE = Runtime.getRuntime.availableProcessors() + + /** Uses a thread pool with a limited number of active threads executing a task at any given + * point. This is effective when tasks may require large amounts of memory, or single tasks are + * too short lived. + * + * @param tasks + * the tasks to parallelize. + * @param maxPoolSize + * the max pool size to allow for active threads. + * @tparam V + * the output type of each task. + * @return + * an array of the executed tasks as either a success or failure. + */ + def runUsingThreadPool[V]( + tasks: Iterator[() => V], + maxPoolSize: Int = MAX_POOL_SIZE + ): List[Try[V]] = + val ex = Executors.newFixedThreadPool(maxPoolSize) + try + val callables = Collections.list(tasks.map { x => + new Callable[V]: + override def call(): V = x.apply() + }.asJavaEnumeration) + ex.invokeAll(callables).asScala.map(x => Try(x.get())).toList + finally + ex.shutdown() + + /** Uses a Spliterator to run a number of tasks in parallel, where any number of threads may be + * alive at any point. This is useful for running a large number of tasks with low memory + * consumption. Spliterator's default thread pool is ForkJoinPool.commonPool(). + * + * @param tasks + * the tasks to parallelize. + * @tparam V + * the output type of each task. + * @return + * an array of the executed tasks as either a success or failure. + */ + def runUsingSpliterator[V](tasks: Iterator[() => V]): Seq[Try[V]] = + scala.collection.immutable.ArraySeq + .ofRef( + java.util.Arrays + .stream(tasks.toArray) + .parallel() + .map(task => Try(task.apply())) + .toArray + ) + .asInstanceOf[scala.collection.immutable.ArraySeq[Try[V]]] +end ConcurrentTaskUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala index eeaf456d..3dd39f60 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala @@ -1,7 +1,5 @@ package io.appthreat.x2cpg.utils -import org.slf4j.LoggerFactory - import java.nio.file.Paths object Environment: @@ -14,7 +12,7 @@ object Environment: object ArchitectureType extends Enumeration: type ArchitectureType = Value - val X86, ARM = Value + val X86, ARMv8 = Value lazy val operatingSystem: OperatingSystemType.OperatingSystemType = if scala.util.Properties.isMac then OperatingSystemType.Mac @@ -23,16 +21,13 @@ object Environment: else OperatingSystemType.Unknown lazy val architecture: ArchitectureType.ArchitectureType = - if scala.util.Properties.propOrNone("os.arch").contains("aarch64") then ArchitectureType.ARM + if scala.util.Properties.propOrNone("os.arch").contains("aarch64") then ArchitectureType.ARMv8 // We do not distinguish between x86 and x64. E.g, a 64 bit Windows will always lie about // this and will report x86 anyway for backwards compatibility with 32 bit software. else ArchitectureType.X86 - private val logger = LoggerFactory.getLogger(getClass) - def pathExists(path: String): Boolean = if !Paths.get(path).toFile.exists() then - logger.debug(s"Input path '$path' does not exist!") false else true diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala index 42b97526..11daf165 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala @@ -1,10 +1,14 @@ package io.appthreat.x2cpg.utils +import io.shiftleft.utils.IOUtils +import java.io.File +import java.nio.file.{Path, Paths} import java.util.concurrent.ConcurrentLinkedQueue import org.apache.commons.lang.StringUtils import scala.sys.process.{Process, ProcessLogger} import scala.util.{Failure, Success, Try} import scala.jdk.CollectionConverters.* +import scala.util.control.NonFatal object ExternalCommand: @@ -14,6 +18,21 @@ object ExternalCommand: private val shellPrefix: Seq[String] = if IS_WIN then "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil + case class ExternalCommandResult(exitCode: Int, stdOut: Seq[String], stdErr: Seq[String]): + def successOption: Option[Seq[String]] = exitCode match + case 0 => Some(stdOut) + case _ => None + + def toTry: Try[Seq[String]] = exitCode match + case 0 => Success(stdOut) + case nonZeroExitCode => + val allOutput = stdOut ++ stdErr + val message = + s"""Process exited with code $nonZeroExitCode. Output: + |${allOutput.mkString(System.lineSeparator())} + |""".stripMargin + Failure(new RuntimeException(message)) + def run(command: String, cwd: String, separateStdErr: Boolean = false): Try[Seq[String]] = val stdOutOutput = new ConcurrentLinkedQueue[String] val stdErrOutput = @@ -52,4 +71,61 @@ object ExternalCommand: val allOutput = stdOutOutput.asScala ++ stdErrOutput.asScala Failure(new RuntimeException(allOutput.mkString(System.lineSeparator()))) end runMultiple + + def runWithResult( + command: Seq[String], + cwd: String, + mergeStdErrInStdOut: Boolean = false, + extraEnv: Map[String, String] = Map.empty + ): ExternalCommandResult = + val builder = new ProcessBuilder() + .command(command.toArray*) + .directory(new File(cwd)) + .redirectErrorStream(mergeStdErrInStdOut) + builder.environment().putAll(extraEnv.asJava) + + val stdOutFile = File.createTempFile("x2cpg", "stdout") + val stdErrFile = Option.when(!mergeStdErrInStdOut)(File.createTempFile("x2cpg", "stderr")) + + try + builder.redirectOutput(stdOutFile) + stdErrFile.foreach(f => builder.redirectError(f)) + + val process = builder.start() + val returnValue = process.waitFor() + + val stdOut = IOUtils.readLinesInFile(stdOutFile.toPath) + val stdErr = stdErrFile.map(f => IOUtils.readLinesInFile(f.toPath)).getOrElse(Seq.empty) + ExternalCommandResult(returnValue, stdOut, stdErr) + catch + case NonFatal(exception) => + ExternalCommandResult(1, Seq.empty, stdErr = Seq(exception.getMessage)) + finally + stdOutFile.delete() + stdErrFile.foreach(_.delete()) + end runWithResult + + /** Finds the absolute path to the executable directory (e.g. `/path/to/javasrc2cpg/bin`). Based + * on the package path of a loaded classfile based on some (potentially flakey?) filename + * heuristics. Context: we want to be able to invoke the x2cpg frontends from any directory, not + * just their install directory, and then invoke other executables, like astgen, php-parser et + * al. + */ + def executableDir(packagePath: Path): Path = + val packagePathAbsolute = packagePath.toAbsolutePath + val fixedDir = + if packagePathAbsolute.toString.contains("lib") then + var dir = packagePathAbsolute + while dir.toString.contains("lib") do + dir = dir.getParent + dir + else if packagePathAbsolute.toString.contains("target") then + var dir = packagePathAbsolute + while dir.toString.contains("target") do + dir = dir.getParent + dir + else + Paths.get(".") + + fixedDir.resolve("bin/").toAbsolutePath end ExternalCommand diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/AstTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/AstTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/AstTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/AstTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/SourceFilesTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/SourceFilesTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/SourceFilesTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/SourceFilesTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/X2CpgTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/X2CpgTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/X2CpgTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/X2CpgTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/layers/DumpAstTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/layers/DumpAstTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/layers/DumpAstTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/layers/DumpAstTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/layers/DumpCdgTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/layers/DumpCdgTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/layers/DumpCdgTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/layers/DumpCdgTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/layers/DumpCfgTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/layers/DumpCfgTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/layers/DumpCfgTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/layers/DumpCfgTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorFrontierTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/CfgDominatorFrontierTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorFrontierTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/CfgDominatorFrontierTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorPassTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/CfgDominatorPassTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/CfgDominatorPassTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/CfgDominatorPassTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/ContainsEdgePassTest.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/ContainsEdgePassTest.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/ContainsEdgePassTest.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/ContainsEdgePassTest.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MemberAccessLinkerTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/MemberAccessLinkerTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MemberAccessLinkerTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/MemberAccessLinkerTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MethodDecoratorPassTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/MethodDecoratorPassTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/MethodDecoratorPassTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/MethodDecoratorPassTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/NamespaceCreatorTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/NamespaceCreatorTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/passes/NamespaceCreatorTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/passes/NamespaceCreatorTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/CfgTestFixture.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/CfgTestFixture.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/CfgTestFixture.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/CfgTestFixture.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/Code2CpgFixture.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/Code2CpgFixture.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/Code2CpgFixture.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/Code2CpgFixture.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/DefaultTestCpg.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/DefaultTestCpg.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/DefaultTestCpg.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/DefaultTestCpg.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/EmptyGraphFixture.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/EmptyGraphFixture.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/EmptyGraphFixture.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/EmptyGraphFixture.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/LanguageFrontend.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/LanguageFrontend.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/LanguageFrontend.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/LanguageFrontend.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/TestCpg.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/TestCpg.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/testfixtures/TestCpg.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/testfixtures/TestCpg.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/ExternalCommandTest.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/ExternalCommandTest.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/ExternalCommandTest.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/ExternalCommandTest.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/HashUtilsTest.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/HashUtilsTest.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/HashUtilsTest.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/HashUtilsTest.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/IgnoreInWindows.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/IgnoreInWindows.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/IgnoreInWindows.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/IgnoreInWindows.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/DependencyResolverTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolverTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/DependencyResolverTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolverTests.scala diff --git a/platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/MavenCoordinatesTests.scala b/platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinatesTests.scala similarity index 100% rename from platform/frontends/x2cpg/src/test/scala/io/joern/x2cpg/utils/dependency/MavenCoordinatesTests.scala rename to platform/frontends/x2cpg/src/test/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinatesTests.scala diff --git a/platform/src/universal/schema-extender/build.sbt b/platform/src/universal/schema-extender/build.sbt index 33de8603..f976ecf0 100644 --- a/platform/src/universal/schema-extender/build.sbt +++ b/platform/src/universal/schema-extender/build.sbt @@ -1,6 +1,6 @@ name := "schema-extender" -ThisBuild / scalaVersion := "3.5.2" +ThisBuild / scalaVersion := "3.6.2" val cpgVersion = IO.read(file("cpg-version")) diff --git a/platform/src/universal/schema-extender/project/build.properties b/platform/src/universal/schema-extender/project/build.properties index 40b3b8e7..73df629a 100644 --- a/platform/src/universal/schema-extender/project/build.properties +++ b/platform/src/universal/schema-extender/project/build.properties @@ -1 +1 @@ -sbt.version=1.9.0 +sbt.version=1.10.7 diff --git a/project/DownloadHelper.scala b/project/DownloadHelper.scala new file mode 100644 index 00000000..9dba958c --- /dev/null +++ b/project/DownloadHelper.scala @@ -0,0 +1,48 @@ +import java.io.File +import java.net.URI +import java.nio.file.{Files, Path, Paths} + +object DownloadHelper { + val LocalStorageDir = Paths.get(".local/source-urls") + + /** Downloads the remote file from the given url if either + * - the localFile is not available, + * - or the url is different from the previously downloaded file + * - or we don't have the original url from the previously downloaded file + * We store the information about the previously downloaded urls and the localFile in `.local` + */ + def ensureIsAvailable(url: String, localFile: File): Unit = { + if (!localFile.exists() || Option(url) != previousUrlForLocalFile(localFile)) { + val localPath = localFile.toPath + Files.deleteIfExists(localPath) + + println(s"[INFO] downloading $url to $localFile") + sbt.io.Using.urlInputStream(new URI(url).toURL) { inputStream => + sbt.IO.transfer(inputStream, localFile) + } + + // persist url in local storage + val storageFile = storageInfoFileFor(localFile) + Files.createDirectories(storageFile.getParent) + Files.writeString(storageFile, url) + } + } + + private def relativePathToProjectRoot(path: Path): String = + Paths + .get("") + .toAbsolutePath + .normalize() + .relativize(path.toAbsolutePath) + .toString + + private def previousUrlForLocalFile(localFile: File): Option[String] = { + Option(storageInfoFileFor(localFile)) + .filter(Files.exists(_)) + .map(Files.readString) + .filter(_.nonEmpty) + } + + private def storageInfoFileFor(localFile: File): Path = + LocalStorageDir.resolve(relativePathToProjectRoot(localFile.toPath)) +} diff --git a/project/Projects.scala b/project/Projects.scala index 1a8d642c..391e3085 100644 --- a/project/Projects.scala +++ b/project/Projects.scala @@ -15,5 +15,6 @@ object Projects { lazy val jssrc2cpg = project.in(frontendsRoot / "jssrc2cpg") lazy val javasrc2cpg = project.in(frontendsRoot / "javasrc2cpg") lazy val jimple2cpg = project.in(frontendsRoot / "jimple2cpg") - lazy val php2atom = project.in(frontendsRoot / "php2atom") + lazy val php2atom = project.in(frontendsRoot / "php2atom") + lazy val ruby2atom = project.in(frontendsRoot / "ruby2atom") } diff --git a/project/Versions.scala b/project/Versions.scala index 09740fa9..4b8b602e 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -3,13 +3,17 @@ object Versions { val cpg = parseVersion("cpgVersion") val antlr = "4.13.2" val scalatest = "3.2.19" - val cats = "3.5.5" + val cats = "3.5.7" val json4s = "4.0.7" val gradleTooling = "8.10.1" val circe = "0.14.10" val requests = "0.9.0" val upickle = "4.0.2" val scalaReplPP = "0.1.85" + val commonsCompress = "1.27.1" + val jRuby = "9.4.9.0" + val typeSafeConfig = "1.4.3" + val versionSort = "1.0.11" private def parseVersion(key: String): String = { val versionRegexp = s""".*val $key[ ]+=[ ]?"(.*?)"""".r diff --git a/project/build.properties b/project/build.properties index db1723b0..73df629a 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.10.5 +sbt.version=1.10.7 diff --git a/pyproject.toml b/pyproject.toml index 756969f5..a50506a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "appthreat-chen" -version = "2.2.3" +version = "2.3.0" description = "Code Hierarchy Exploration Net (chen)" authors = ["Team AppThreat "] license = "Apache-2.0" From 9523be2bac37fb83f2c3ecf8ddb555ca7fccecd8 Mon Sep 17 00:00:00 2001 From: Prabhu Subramanian Date: Sat, 4 Jan 2025 21:14:44 +0000 Subject: [PATCH 2/2] Install ruby 3.4.0 Signed-off-by: Prabhu Subramanian --- ci/Dockerfile | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ci/Dockerfile b/ci/Dockerfile index e814ae67..a6e90054 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -14,10 +14,12 @@ LABEL maintainer="appthreat" \ ARG JAVA_VERSION=23.0.1-tem ARG MAVEN_VERSION=3.9.9 ARG GRADLE_VERSION=8.11 +ARG RUBY_VERSION=3.4.0 ENV JAVA_VERSION=$JAVA_VERSION \ MAVEN_VERSION=$MAVEN_VERSION \ GRADLE_VERSION=$GRADLE_VERSION \ + RUBY_VERSION=$RUBY_VERSION \ GRADLE_OPTS="-Dorg.gradle.daemon=false" \ JAVA_HOME="/opt/java/${JAVA_VERSION}" \ MAVEN_HOME="/opt/maven/${MAVEN_VERSION}" \ @@ -61,19 +63,19 @@ RUN set -e; \ *) echo >&2 "error: unsupported architecture: '$ARCH_NAME'"; exit 1 ;; \ esac; \ echo -e "[nodejs]\nname=nodejs\nstream=20\nprofiles=\nstate=enabled\n" > /etc/dnf/modules.d/nodejs.module \ - && microdnf install --nodocs -y gcc git-core php php-cli php-curl php-zip php-bcmath php-json php-pear php-mbstring php-devel make wget bash graphviz graphviz-gd \ + && microdnf install -y gcc git-core php php-cli php-curl php-zip php-bcmath php-json php-pear php-mbstring php-devel make wget bash graphviz graphviz-gd \ openssl-devel libffi-devel readline-devel libyaml zlib-devel ncurses ncurses-devel rust \ pcre2 findutils which tar gzip zip unzip sudo nodejs sqlite-devel glibc-common glibc-all-langpacks \ - && microdnf install --nodocs -y epel-release \ - && microdnf install --nodocs --enablerepo=crb -y libyaml-devel jemalloc-devel \ + && microdnf install -y epel-release \ + && microdnf install --enablerepo=crb -y libyaml-devel jemalloc-devel \ && git clone https://github.com/rbenv/rbenv.git --depth=1 ~/.rbenv \ && echo 'export PATH="/root/.rbenv/bin:$PATH"' >> ~/.bashrc \ && echo 'eval "$(/root/.rbenv/bin/rbenv init - bash)"' >> ~/.bashrc \ && source ~/.bashrc \ && mkdir -p "$(rbenv root)/plugins" \ && git clone https://github.com/rbenv/ruby-build.git --depth=1 "$(rbenv root)/plugins/ruby-build" \ - && MAKE_OPTS=-j2 rbenv install 3.4.0 \ - && rbenv global 3.4.0 \ + && MAKE_OPTS=-j2 rbenv install ${RUBY_VERSION} \ + && rbenv global ${RUBY_VERSION} \ && ruby --version \ && which ruby \ && rm -rf /root/.rbenv/cache $RUBY_BUILD_BUILD_PATH \