diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6c1a6064..b0a34988 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,256 +6,22 @@ on: push: branches: - master + - its-happening tags: - 'v[0-9]+.[0-9]+.[0-9]+' - 'v[0-9]+.[0-9]+.[0-9]+-**' concurrency: group: ${{ github.workflow }} + cancel-in-progress: false jobs: - test: - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Test - uses: ./.github/actions/test - docker: - runs-on: ubuntu-latest - needs: [ test ] - permissions: - packages: write - contents: read - steps: - - uses: actions/checkout@v4 - - uses: docker/setup-qemu-action@v3 - - uses: docker/setup-buildx-action@v3 - - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - uses: docker/metadata-action@v5 - name: generate tags - id: meta - with: - images: ghcr.io/${{ github.repository }} - tags: | - type=ref,event=branch - type=sha,prefix= - type=semver,pattern={{version}} - - uses: docker/build-push-action@v5 - with: - context: . - file: ./docker/Dockerfile - platforms: linux/amd64,linux/arm64 - push: true - tags: ${{ steps.meta.outputs.tags }} - cache-from: type=gha - cache-to: type=gha,mode=max - build-linux: - runs-on: ubuntu-latest - needs: [ test ] - strategy: - matrix: - go-arch: [amd64, arm64] - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Setup - run: | - sudo apt update - go generate ./... - if [ ${{ matrix.go-arch }} == "arm64" ]; then - sudo apt install -y gcc-aarch64-linux-gnu - echo "CC=aarch64-linux-gnu-gcc" >> $GITHUB_ENV - fi - - name: Build ${{ matrix.go-arch }} - env: - CGO_ENABLED: 1 - GOOS: linux - GOARCH: ${{ matrix.go-arch }} - run: | - mkdir -p release - ZIP_OUTPUT=release/hostd_${GOOS}_${GOARCH}.zip - go build -tags='netgo' -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/hostd - cp README.md LICENSE bin/ - zip -qj $ZIP_OUTPUT bin/* - - uses: actions/upload-artifact@v4 - with: - name: hostd_linux_${{ matrix.go-arch }} - path: release/* - build-mac: - runs-on: macos-latest - needs: [ test ] - strategy: - matrix: - go-arch: [amd64, arm64] - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Setup - env: - APPLE_CERT_ID: ${{ secrets.APPLE_CERT_ID }} - APPLE_API_KEY: ${{ secrets.APPLE_API_KEY }} - APPLE_API_ISSUER: ${{ secrets.APPLE_API_ISSUER }} - APPLE_KEY_B64: ${{ secrets.APPLE_KEY_B64 }} - APPLE_CERT_B64: ${{ secrets.APPLE_CERT_B64 }} - APPLE_CERT_PASSWORD: ${{ secrets.APPLE_CERT_PASSWORD }} - APPLE_KEYCHAIN_PASSWORD: ${{ secrets.APPLE_KEYCHAIN_PASSWORD }} - run: | - # extract apple cert - APPLE_CERT_PATH=$RUNNER_TEMP/apple_cert.p12 - KEYCHAIN_PATH=$RUNNER_TEMP/app-signing.keychain-db - echo -n "$APPLE_CERT_B64" | base64 --decode --output $APPLE_CERT_PATH - - # extract apple key - mkdir -p ~/private_keys - APPLE_API_KEY_PATH=~/private_keys/AuthKey_$APPLE_API_KEY.p8 - echo -n "$APPLE_KEY_B64" | base64 --decode --output $APPLE_API_KEY_PATH - - # create temp keychain - security create-keychain -p "$APPLE_KEYCHAIN_PASSWORD" $KEYCHAIN_PATH - security default-keychain -s $KEYCHAIN_PATH - security set-keychain-settings -lut 21600 $KEYCHAIN_PATH - security unlock-keychain -p "$APPLE_KEYCHAIN_PASSWORD" $KEYCHAIN_PATH - - # import keychain - security import $APPLE_CERT_PATH -P $APPLE_CERT_PASSWORD -A -t cert -f pkcs12 -k $KEYCHAIN_PATH - security find-identity -v $KEYCHAIN_PATH -p codesigning - security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k $APPLE_KEYCHAIN_PASSWORD $KEYCHAIN_PATH - - # generate - go generate ./... - - # resync system clock https://github.com/actions/runner/issues/2996#issuecomment-1833103110 - sudo sntp -sS time.windows.com - - name: Build ${{ matrix.go-arch }} - env: - APPLE_CERT_ID: ${{ secrets.APPLE_CERT_ID }} - APPLE_API_KEY: ${{ secrets.APPLE_API_KEY }} - APPLE_API_ISSUER: ${{ secrets.APPLE_API_ISSUER }} - APPLE_KEY_B64: ${{ secrets.APPLE_KEY_B64 }} - APPLE_CERT_B64: ${{ secrets.APPLE_CERT_B64 }} - APPLE_CERT_PASSWORD: ${{ secrets.APPLE_CERT_PASSWORD }} - APPLE_KEYCHAIN_PASSWORD: ${{ secrets.APPLE_KEYCHAIN_PASSWORD }} - CGO_ENABLED: 1 - GOOS: darwin - GOARCH: ${{ matrix.go-arch }} - run: | - ZIP_OUTPUT=release/hostd_${GOOS}_${GOARCH}.zip - mkdir -p release - go build -tags='netgo' -trimpath -o bin/ -a -ldflags '-s -w' ./cmd/hostd - cp README.md LICENSE bin/ - /usr/bin/codesign --deep -f -v --timestamp -o runtime,library -s $APPLE_CERT_ID bin/hostd - ditto -ck bin $ZIP_OUTPUT - xcrun notarytool submit -k ~/private_keys/AuthKey_$APPLE_API_KEY.p8 -d $APPLE_API_KEY -i $APPLE_API_ISSUER --wait --timeout 10m $ZIP_OUTPUT - - uses: actions/upload-artifact@v4 - with: - name: hostd_darwin_${{ matrix.go-arch }} - path: release/* - build-windows: - runs-on: windows-latest - needs: [ test ] - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Setup - shell: bash - run: | - dotnet tool install --global AzureSignTool - go generate ./... - - name: Build amd64 - env: - CGO_ENABLED: 1 - GOOS: windows - GOARCH: amd64 - shell: bash - run: | - mkdir -p release - ZIP_OUTPUT=release/hostd_${GOOS}_${GOARCH}.zip - go build -tags='netgo' -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/hostd - azuresigntool sign -kvu "${{ secrets.AZURE_KEY_VAULT_URI }}" -kvi "${{ secrets.AZURE_CLIENT_ID }}" -kvt "${{ secrets.AZURE_TENANT_ID }}" -kvs "${{ secrets.AZURE_CLIENT_SECRET }}" -kvc ${{ secrets.AZURE_CERT_NAME }} -tr http://timestamp.digicert.com -v bin/hostd.exe - cp README.md LICENSE bin/ - 7z a $ZIP_OUTPUT ./bin/* - - uses: actions/upload-artifact@v4 - with: - name: hostd_windows_amd64 - path: release/* - combine-release-assets: - runs-on: ubuntu-latest - needs: [ build-linux, build-mac, build-windows ] - steps: - - name: Merge Artifacts - uses: actions/upload-artifact/merge@v4 - with: - name: hostd - - dispatch-homebrew: # only runs on full releases - if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') - needs: [ build-mac ] - runs-on: ubuntu-latest - steps: - - name: Extract Tag Name - id: get_tag - run: echo "::set-output name=tag_name::${GITHUB_REF#refs/tags/}" - - - name: Dispatch - uses: peter-evans/repository-dispatch@v3 - with: - token: ${{ secrets.PAT_REPOSITORY_DISPATCH }} - repository: siafoundation/homebrew-sia - event-type: release-tagged - client-payload: > - { - "description": "hostd: The Next-Gen Sia Host", - "tag": "${{ steps.get_tag.outputs.tag_name }}", - "project": "hostd", - "workflow_id": "${{ github.run_id }}" - } - dispatch-linux: # always runs - needs: [ build-linux ] - runs-on: ubuntu-latest - steps: - - name: Build Dispatch Payload - id: get_payload - uses: actions/github-script@v7 - with: - script: | - const isRelease = context.ref.startsWith('refs/tags/v'), - isBeta = isRelease && context.ref.includes('-beta'), - tag = isRelease ? context.ref.replace('refs/tags/', '') : 'master'; - - let component = 'nightly'; - if (isBeta) { - component = 'beta'; - } else if (isRelease) { - component = 'main'; - } - - return { - description: "hostd: The Next-Gen Sia Host", - tag: tag, - project: "hostd", - workflow_id: context.runId, - component: component - }; - - - name: Dispatch - uses: peter-evans/repository-dispatch@v3 - with: - token: ${{ secrets.PAT_REPOSITORY_DISPATCH }} - repository: siafoundation/linux - event-type: release-tagged - client-payload: ${{ steps.get_payload.outputs.result }} \ No newline at end of file + publish: + uses: SiaFoundation/workflows/.github/workflows/go-publish.yml@master + secrets: inherit + with: + linux-build-args: -tags=timetzdata -trimpath -a -ldflags '-s -w -linkmode external -extldflags "-static"' + windows-build-args: -tags=timetzdata -trimpath -a -ldflags '-s -w -linkmode external -extldflags "-static"' + macos-build-args: -tags=timetzdata -trimpath -a -ldflags '-s -w' + cgo-enabled: 1 + project: hostd diff --git a/.github/workflows/publish_testnet.yml b/.github/workflows/publish_testnet.yml deleted file mode 100644 index 14d293e7..00000000 --- a/.github/workflows/publish_testnet.yml +++ /dev/null @@ -1,205 +0,0 @@ -name: Publish - Testnet - -# Controls when the action will run. -on: - # Triggers the workflow on new SemVer tags - push: - branches: - - master - tags: - - 'v[0-9]+.[0-9]+.[0-9]+' - - 'v[0-9]+.[0-9]+.[0-9]+-**' - -concurrency: - group: ${{ github.workflow }} - -jobs: - test: - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Test - uses: ./.github/actions/test - docker: - runs-on: ubuntu-latest - needs: [ test ] - permissions: - packages: write - contents: read - steps: - - uses: actions/checkout@v4 - - uses: docker/setup-qemu-action@v3 - - uses: docker/setup-buildx-action@v3 - - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - uses: docker/metadata-action@v5 - name: generate tags - id: meta - with: - images: ghcr.io/${{ github.repository }} - flavor: | - suffix=-testnet,onlatest=true - tags: | - type=ref,event=branch - type=sha,prefix= - type=semver,pattern={{version}} - - uses: docker/build-push-action@v5 - with: - context: . - file: ./docker/Dockerfile.testnet - platforms: linux/amd64,linux/arm64 - push: true - tags: ${{ steps.meta.outputs.tags }} - cache-from: type=gha - cache-to: type=gha,mode=max - build-linux: - runs-on: ubuntu-latest - needs: [ test ] - strategy: - matrix: - go-arch: [amd64, arm64] - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Setup - run: | - sudo apt update - go generate ./... - if [ ${{ matrix.go-arch }} == "arm64" ]; then - sudo apt install -y gcc-aarch64-linux-gnu - echo "CC=aarch64-linux-gnu-gcc" >> $GITHUB_ENV - fi - - name: Build ${{ matrix.go-arch }} - env: - CGO_ENABLED: 1 - GOOS: linux - GOARCH: ${{ matrix.go-arch }} - run: | - mkdir -p release - ZIP_OUTPUT=release/hostd_zen_${GOOS}_${GOARCH}.zip - go build -tags='testnet netgo' -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/hostd - cp README.md LICENSE bin/ - zip -qj $ZIP_OUTPUT bin/* - - uses: actions/upload-artifact@v4 - with: - name: hostd_zen_linux_${{ matrix.go-arch }} - path: release/* - build-mac: - runs-on: macos-latest - needs: [ test ] - strategy: - matrix: - go-arch: [amd64, arm64] - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Setup - env: - APPLE_CERT_ID: ${{ secrets.APPLE_CERT_ID }} - APPLE_API_KEY: ${{ secrets.APPLE_API_KEY }} - APPLE_API_ISSUER: ${{ secrets.APPLE_API_ISSUER }} - APPLE_KEY_B64: ${{ secrets.APPLE_KEY_B64 }} - APPLE_CERT_B64: ${{ secrets.APPLE_CERT_B64 }} - APPLE_CERT_PASSWORD: ${{ secrets.APPLE_CERT_PASSWORD }} - APPLE_KEYCHAIN_PASSWORD: ${{ secrets.APPLE_KEYCHAIN_PASSWORD }} - run: | - # extract apple cert - APPLE_CERT_PATH=$RUNNER_TEMP/apple_cert.p12 - KEYCHAIN_PATH=$RUNNER_TEMP/app-signing.keychain-db - echo -n "$APPLE_CERT_B64" | base64 --decode --output $APPLE_CERT_PATH - - # extract apple key - mkdir -p ~/private_keys - APPLE_API_KEY_PATH=~/private_keys/AuthKey_$APPLE_API_KEY.p8 - echo -n "$APPLE_KEY_B64" | base64 --decode --output $APPLE_API_KEY_PATH - - # create temp keychain - security create-keychain -p "$APPLE_KEYCHAIN_PASSWORD" $KEYCHAIN_PATH - security default-keychain -s $KEYCHAIN_PATH - security set-keychain-settings -lut 21600 $KEYCHAIN_PATH - security unlock-keychain -p "$APPLE_KEYCHAIN_PASSWORD" $KEYCHAIN_PATH - - # import keychain - security import $APPLE_CERT_PATH -P $APPLE_CERT_PASSWORD -A -t cert -f pkcs12 -k $KEYCHAIN_PATH - security find-identity -v $KEYCHAIN_PATH -p codesigning - security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k $APPLE_KEYCHAIN_PASSWORD $KEYCHAIN_PATH - - # generate - go generate ./... - - # resync system clock https://github.com/actions/runner/issues/2996#issuecomment-1833103110 - sudo sntp -sS time.windows.com - - name: Build ${{ matrix.go-arch }} - env: - APPLE_CERT_ID: ${{ secrets.APPLE_CERT_ID }} - APPLE_API_KEY: ${{ secrets.APPLE_API_KEY }} - APPLE_API_ISSUER: ${{ secrets.APPLE_API_ISSUER }} - APPLE_KEY_B64: ${{ secrets.APPLE_KEY_B64 }} - APPLE_CERT_B64: ${{ secrets.APPLE_CERT_B64 }} - APPLE_CERT_PASSWORD: ${{ secrets.APPLE_CERT_PASSWORD }} - APPLE_KEYCHAIN_PASSWORD: ${{ secrets.APPLE_KEYCHAIN_PASSWORD }} - CGO_ENABLED: 1 - GOOS: darwin - GOARCH: ${{ matrix.go-arch }} - run: | - ZIP_OUTPUT=release/hostd_zen_${GOOS}_${GOARCH}.zip - mkdir -p release - go build -tags='testnet netgo' -trimpath -o bin/ -a -ldflags '-s -w' ./cmd/hostd - cp README.md LICENSE bin/ - /usr/bin/codesign --deep -f -v --timestamp -o runtime,library -s $APPLE_CERT_ID bin/hostd - ditto -ck bin $ZIP_OUTPUT - xcrun notarytool submit -k ~/private_keys/AuthKey_$APPLE_API_KEY.p8 -d $APPLE_API_KEY -i $APPLE_API_ISSUER --wait --timeout 10m $ZIP_OUTPUT - - uses: actions/upload-artifact@v4 - with: - name: hostd_zen_darwin_${{ matrix.go-arch }} - path: release/* - build-windows: - runs-on: windows-latest - needs: [ test ] - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.21' - - name: Setup - shell: bash - run: | - dotnet tool install --global AzureSignTool - go generate ./... - - name: Build amd64 - env: - CGO_ENABLED: 1 - GOOS: windows - GOARCH: amd64 - shell: bash - run: | - mkdir -p release - ZIP_OUTPUT=release/hostd_zen_${GOOS}_${GOARCH}.zip - go build -tags='testnet netgo' -trimpath -o bin/ -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/hostd - azuresigntool sign -kvu "${{ secrets.AZURE_KEY_VAULT_URI }}" -kvi "${{ secrets.AZURE_CLIENT_ID }}" -kvt "${{ secrets.AZURE_TENANT_ID }}" -kvs "${{ secrets.AZURE_CLIENT_SECRET }}" -kvc ${{ secrets.AZURE_CERT_NAME }} -tr http://timestamp.digicert.com -v bin/hostd.exe - cp README.md LICENSE bin/ - 7z a $ZIP_OUTPUT ./bin/* - - uses: actions/upload-artifact@v4 - with: - name: hostd_zen_windows_amd64 - path: release/* - combine-release-assets: - runs-on: ubuntu-latest - needs: [ build-linux, build-mac, build-windows ] - steps: - - name: Merge Artifacts - uses: actions/upload-artifact/merge@v4 - with: - name: hostd \ No newline at end of file diff --git a/.github/workflows/publish_v2.yml b/.github/workflows/publish_v2.yml deleted file mode 100644 index 104d08a2..00000000 --- a/.github/workflows/publish_v2.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Publish - -# Controls when the action will run. -on: - # Triggers the workflow on new SemVer tags - push: - branches: - - its-happening - -concurrency: - group: ${{ github.workflow }} - cancel-in-progress: false - -jobs: - publish: - uses: SiaFoundation/workflows/.github/workflows/go-publish.yml@master - secrets: inherit - with: - linux-build-args: -tags=timetzdata -trimpath -a -ldflags '-s -w -linkmode external -extldflags "-static"' - windows-build-args: -tags=timetzdata -trimpath -a -ldflags '-s -w -linkmode external -extldflags "-static"' - macos-build-args: -tags=timetzdata -trimpath -a -ldflags '-s -w' - cgo-enabled: 1 - project: hostd diff --git a/docker/Dockerfile b/Dockerfile similarity index 87% rename from docker/Dockerfile rename to Dockerfile index f789296d..29a84790 100644 --- a/docker/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM docker.io/library/golang:1.21 AS builder +FROM docker.io/library/golang:1.23 AS builder WORKDIR /hostd @@ -13,7 +13,8 @@ RUN go generate ./... # build RUN CGO_ENABLED=1 go build -o bin/ -tags='netgo timetzdata' -trimpath -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/hostd -FROM docker.io/library/alpine:3 +FROM scratch + LABEL maintainer="The Sia Foundation " \ org.opencontainers.image.description.vendor="The Sia Foundation" \ org.opencontainers.image.description="A hostd container - provide storage on the Sia network and earn Siacoin" \ @@ -24,11 +25,12 @@ ENV PUID=0 ENV PGID=0 ENV HOSTD_API_PASSWORD= -ENV HOSTD_SEED= +ENV HOSTD_WALLET_SEED= ENV HOSTD_CONFIG_FILE=/data/hostd.yml # copy binary and prepare data dir. COPY --from=builder /hostd/bin/* /usr/bin/ +COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ VOLUME [ "/data" ] # API port @@ -39,8 +41,6 @@ EXPOSE 9981/tcp EXPOSE 9982/tcp # RHP3 TCP port EXPOSE 9983/tcp -# RHP3 WebSocket port -EXPOSE 9984/tcp USER ${PUID}:${PGID} diff --git a/README.md b/README.md index 57dc1399..c1cb9dcd 100644 --- a/README.md +++ b/README.md @@ -30,14 +30,6 @@ ports, can be configured via CLI flags. To simplify more complex configurations, + `9982` - RHP2 + `9983` - RHP3 -#### Testnet -The Zen testnet version of `hostd` changes the default ports: - -+ `9880` - UI and API -+ `9881` - Sia consensus -+ `9882` - RHP2 -+ `9883` - RHP3 - ### Environment Variables + `HOSTD_API_PASSWORD` - The password for the UI and API + `HOSTD_SEED` - The recovery phrase for the wallet @@ -46,36 +38,31 @@ The Zen testnet version of `hostd` changes the default ports: + `HOSTD_CONFIG_FILE` - changes the path of the optional config file. If unset, `hostd` will check for a config file in the current directory -#### Testnet -The Zen testnet version of `hostd` changes the environment variables: - -+ `HOSTD_ZEN_SEED` - The recovery phrase for the wallet -+ `HOSTD_ZEN_API_PASSWORD` - The password for the UI and API -+ `HOSTD_ZEN_LOG_PATH` - changes the path of the log file `hostd.log`. If unset, the - log file will be created in the data directory ### CLI Flags ```sh --bootstrap - bootstrap the gateway and consensus modules -dir string - directory to store hostd metadata (default ".") --env - disable stdin prompts for environment variables (default false) + directory to store hostd metadata (default "/Users/n8maninger/Downloads/hostd-core-tmp") -http string - address to serve API on (default ":9980") + address to serve API on (default "localhost:9980") -log.level string - log level (debug, info, warn, error) (default "info") + log level (debug, info, warn, error) (default "debug") -name string a friendly name for the host, only used for display --rpc string +-network string + network name (mainnet, testnet, etc) (default "mainnet") +-openui + automatically open the web UI on startup (default true) +-syncer.address string address to listen on for peer connections (default ":9981") +-syncer.bootstrap + bootstrap the gateway and consensus modules (default true) -rhp2 string address to listen on for RHP2 connections (default ":9982") --rhp3.tcp string +-rhp3 string address to listen on for TCP RHP3 connections (default ":9983") --rhp3.ws string - address to listen on for WebSocket RHP3 connections (default ":9984") +-env + disable stdin prompts for environment variables (default false) ``` ### YAML @@ -90,14 +77,16 @@ recoveryPhrase: indicate nature buzz route rude embody engage confirm aspect pot http: address: :9980 password: sia is cool -consensus: - gatewayAddress: :9981 +syncer: + address: :9981 bootstrap: true +consensus: + network: mainnet + indexBatchSize: 100 rhp2: address: :9982 rhp3: tcp: :9983 - websocket: :9984 log: level: info # global log level stdout: @@ -121,45 +110,15 @@ go generate ./... CGO_ENABLED=1 go build -o bin/ -tags='netgo timetzdata' -trimpath -a -ldflags '-s -w' ./cmd/hostd ``` -## Testnet Builds - -`hostd` can be built to run on the Zen testnet by adding the `testnet` build -tag. - -```sh -go generate ./... -CGO_ENABLED=1 go build -o bin/ -tags='testnet netgo timetzdata' -trimpath -a -ldflags '-s -w' ./cmd/hostd -``` - -# Docker Support +# Docker `hostd` includes a `Dockerfile` which can be used for building and running hostd within a docker container. The image can also be pulled from `ghcr.io/siafoundation/hostd`. -1. Generate a wallet seed using `hostd seed`. -2. Create `hostd.yml` in the directory you want to store your `hostd` data. Replace the recovery phrase with the one you generated above. Replace the password with a secure password to unlock the UI. - -```yml -recoveryPhrase: indicate nature buzz route rude embody engage confirm aspect potato weapon bid -http: - password: sia is cool -``` - -3. Create your docker container using one of the examples below. Replace "./data" in the volume mount with the directory in which you want to store your Sia data. Replace "./storage" in the volume mount with the location of your mounted storage volumes +Be careful with port `9980` as Docker will expose it publicly by default. It is +recommended to bind it to `127.0.0.1` to prevent unauthorized access. -## Mainnet - -```sh -docker run -d \ - --name hostd \ - -p 127.0.0.1:9980:9980 \ - -p 9981-9983:9981-9983 \ - -v ./data:/data \ - -v ./storage:/storage \ - ghcr.io/siafoundation/hostd:latest -``` - -### Docker Compose +## Docker Compose ```yml version: "3.9" @@ -175,51 +134,20 @@ services: restart: unless-stopped ``` -## Testnet - -Suffix any tag with `-testnet` to use the testnet image. +## Docker Engine ```sh docker run -d \ --name hostd \ - -p 127.0.0.1:9880:9880 \ - -p 9881-9883:9881-9883 \ + -p 127.0.0.1:9980:9980 \ + -p 9981-9983:9981-9983 \ -v ./data:/data \ -v ./storage:/storage \ - -e HOSTD_ZEN_SEED="my wallet seed" \ - -e HOSTD_ZEN_API_PASSWORD=hostsarecool \ - ghcr.io/siafoundation/hostd:latest-testnet -``` - -### Docker Compose - -```yml -version: "3.9" -services: - host: - image: ghcr.io/siafoundation/hostd:latest-testnet - environment: - - HOSTD_ZEN_SEED=my wallet seed - - HOSTD_ZEN_API_PASSWORD=hostsarecool - ports: - - 127.0.0.1:9880:9880/tcp - - 9881-9883:9881-9883/tcp - volumes: - - /data:/data - - /storage:/storage - restart: unless-stopped -``` - -## Building image - -### Mainnet - -```sh -docker build -t hostd:latest -f ./docker/Dockerfile . + ghcr.io/siafoundation/hostd:latest ``` -### Testnet +## Building Image ```sh -docker build -t hostd:latest-testnet -f ./docker/Dockerfile.testnet . +docker build -t hostd . ``` diff --git a/alerts/alerts.go b/alerts/alerts.go index bb1db0fd..3a1870ca 100644 --- a/alerts/alerts.go +++ b/alerts/alerts.go @@ -36,6 +36,12 @@ type ( BroadcastEvent(event string, scope string, data any) error } + // An Alerter is an interface that registers and dismisses alerts. + Alerter interface { + Register(a Alert) + Dismiss(ids ...types.Hash256) + } + // An Alert is a dismissible message that is displayed to the user. Alert struct { // ID is a unique identifier for the alert. @@ -61,6 +67,8 @@ type ( } ) +var _ Alerter = (*Manager)(nil) + // String implements the fmt.Stringer interface. func (s Severity) String() string { switch s { @@ -142,11 +150,12 @@ func (m *Manager) Active() []Alert { } // NewManager initializes a new alerts manager. -func NewManager(er EventReporter, log *zap.Logger) *Manager { - return &Manager{ - log: log, - events: er, - +func NewManager(opts ...ManagerOption) *Manager { + m := &Manager{ alerts: make(map[types.Hash256]Alert), } + for _, opt := range opts { + opt(m) + } + return m } diff --git a/alerts/noop.go b/alerts/noop.go new file mode 100644 index 00000000..958d1398 --- /dev/null +++ b/alerts/noop.go @@ -0,0 +1,19 @@ +package alerts + +import "go.sia.tech/core/types" + +// A NoOpAlerter is an Alerter that does nothing. +type NoOpAlerter struct{} + +// Register implements the Alerter interface. +func (NoOpAlerter) Register(Alert) {} + +// Dismiss implements the Alerter interface. +func (NoOpAlerter) Dismiss(...types.Hash256) {} + +var _ Alerter = NoOpAlerter{} + +// NewNop returns a new NoOpAlerter. +func NewNop() NoOpAlerter { + return NoOpAlerter{} +} diff --git a/alerts/options.go b/alerts/options.go new file mode 100644 index 00000000..d153f874 --- /dev/null +++ b/alerts/options.go @@ -0,0 +1,22 @@ +package alerts + +import "go.uber.org/zap" + +// ManagerOption is a functional option for the alert manager. +type ManagerOption func(*Manager) error + +// WithLog sets the logger for the manager. +func WithLog(l *zap.Logger) ManagerOption { + return func(m *Manager) error { + m.log = l + return nil + } +} + +// WithEventReporter sets the event reporter for the manager. +func WithEventReporter(e EventReporter) ManagerOption { + return func(m *Manager) error { + m.events = e + return nil + } +} diff --git a/api/api.go b/api/api.go index ae1907d2..a64aedce 100644 --- a/api/api.go +++ b/api/api.go @@ -10,19 +10,19 @@ import ( rhp2 "go.sia.tech/core/rhp/v2" rhp3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/alerts" + "go.sia.tech/hostd/explorer" "go.sia.tech/hostd/host/accounts" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/metrics" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/settings/pin" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/explorer" "go.sia.tech/hostd/rhp" - "go.sia.tech/hostd/wallet" "go.sia.tech/hostd/webhooks" "go.sia.tech/jape" - "go.sia.tech/siad/modules" "go.uber.org/zap" ) @@ -30,12 +30,19 @@ type ( // A Wallet manages Siacoins and funds transactions Wallet interface { Address() types.Address - ScanHeight() uint64 - Balance() (spendable, confirmed, unconfirmed types.Currency, err error) - UnconfirmedTransactions() ([]wallet.Transaction, error) - FundTransaction(txn *types.Transaction, amount types.Currency) (toSign []types.Hash256, release func(), err error) - SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error - Transactions(limit, offset int) ([]wallet.Transaction, error) + Balance() (balance wallet.Balance, err error) + UnconfirmedEvents() ([]wallet.Event, error) + Events(offset, limit int) ([]wallet.Event, error) + + ReleaseInputs(txns []types.Transaction, v2txns []types.V2Transaction) + + // v1 + FundTransaction(txn *types.Transaction, amount types.Currency, useUnconfirmed bool) ([]types.Hash256, error) + SignTransaction(txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) + + // v2 + FundV2Transaction(txn *types.V2Transaction, amount types.Currency, useUnconfirmed bool) (types.ChainIndex, []int, error) + SignV2Inputs(txn *types.V2Transaction, toSign []int) } // Settings updates and retrieves the host's settings @@ -49,6 +56,11 @@ type ( UpdateDDNS(force bool) error } + // An Index persists updates from the blockchain to a store + Index interface { + Tip() types.ChainIndex + } + // PinnedSettings updates and retrieves the host's pinned settings PinnedSettings interface { Update(context.Context, pin.PinnedSettings) error @@ -105,30 +117,32 @@ type ( // A Syncer can connect to other peers and synchronize the blockchain. Syncer interface { - Address() modules.NetAddress - Peers() []modules.Peer - Connect(addr modules.NetAddress) error - Disconnect(addr modules.NetAddress) error + Addr() string + Peers() []*syncer.Peer + Connect(ctx context.Context, addr string) (*syncer.Peer, error) + + BroadcastTransactionSet(txns []types.Transaction) + BroadcastV2TransactionSet(index types.ChainIndex, txns []types.V2Transaction) } // A ChainManager retrieves the current blockchain state ChainManager interface { - Synced() bool + Tip() types.ChainIndex TipState() consensus.State - } - // A TPool manages the transaction pool - TPool interface { - RecommendedFee() (fee types.Currency) - AcceptTransactionSet(txns []types.Transaction) error + RecommendedFee() types.Currency + AddPoolTransactions(txns []types.Transaction) (known bool, err error) + UnconfirmedParents(txn types.Transaction) []types.Transaction + AddV2PoolTransactions(basis types.ChainIndex, txns []types.V2Transaction) (known bool, err error) + V2TransactionSet(basis types.ChainIndex, txn types.V2Transaction) (types.ChainIndex, []types.V2Transaction, error) } - // WebHooks manages webhooks - WebHooks interface { - WebHooks() ([]webhooks.WebHook, error) - RegisterWebHook(callbackURL string, scopes []string) (webhooks.WebHook, error) - UpdateWebHook(id int64, callbackURL string, scopes []string) (webhooks.WebHook, error) - RemoveWebHook(id int64) error + // Webhooks manages webhooks + Webhooks interface { + Webhooks() ([]webhooks.Webhook, error) + RegisterWebhook(callbackURL string, scopes []string) (webhooks.Webhook, error) + UpdateWebhook(id int64, callbackURL string, scopes []string) (webhooks.Webhook, error) + RemoveWebhook(id int64) error BroadcastToWebhook(id int64, event, scope string, data interface{}) error } @@ -145,20 +159,20 @@ type ( hostKey types.PublicKey name string - log *zap.Logger + log *zap.Logger + alerts Alerts + webhooks Webhooks + sessions RHPSessionReporter - alerts Alerts - webhooks WebHooks syncer Syncer chain ChainManager - tpool TPool accounts AccountManager contracts ContractManager volumes VolumeManager wallet Wallet metrics MetricManager settings Settings - sessions RHPSessionReporter + index Index explorerDisabled bool explorer *explorer.Explorer @@ -180,11 +194,34 @@ func (a *api) requiresExplorer(h jape.Handler) jape.Handler { } // NewServer initializes the API -func NewServer(name string, hostKey types.PublicKey, opts ...ServerOption) http.Handler { +// syncer +// chain +// accounts +// contracts +// volumes +// wallet +// metrics +// settings +// index +func NewServer(name string, hostKey types.PublicKey, cm ChainManager, s Syncer, am AccountManager, c ContractManager, vm VolumeManager, wm Wallet, mm MetricManager, sm Settings, im Index, opts ...ServerOption) http.Handler { a := &api{ hostKey: hostKey, name: name, - log: zap.NewNop(), + + sessions: noopSessionReporter{}, + alerts: noopAlerts{}, + webhooks: noopWebhooks{}, + log: zap.NewNop(), + + syncer: s, + chain: cm, + accounts: am, + contracts: c, + volumes: vm, + wallet: wm, + metrics: mm, + settings: sm, + index: im, explorerDisabled: true, } @@ -202,13 +239,17 @@ func NewServer(name string, hostKey types.PublicKey, opts ...ServerOption) http. return jape.Mux(map[string]jape.Handler{ // state endpoints - "GET /state/host": a.handleGETHostState, - "GET /state/consensus": a.handleGETConsensusState, - // gateway endpoints - "GET /syncer/address": a.handleGETSyncerAddr, - "GET /syncer/peers": a.handleGETSyncerPeers, - "PUT /syncer/peers": a.handlePUTSyncerPeer, - "DELETE /syncer/peers/:address": a.handleDeleteSyncerPeer, + "GET /state": a.handleGETState, + // consensus endpoints + "GET /consensus/tip": a.handleGETConsensusTip, + "GET /consensus/tipstate": a.handleGETConsensusTipState, + "GET /consensus/network": a.handleGETConsensusNetwork, + // syncer endpoints + "GET /syncer/address": a.handleGETSyncerAddr, + "GET /syncer/peers": a.handleGETSyncerPeers, + "PUT /syncer/peers": a.handlePUTSyncerPeer, + // index endpoints + "GET /index/tip": a.handleGETIndexTip, // alerts endpoints "GET /alerts": a.handleGETAlerts, "POST /alerts/dismiss": a.handlePOSTAlertsDismiss, @@ -248,10 +289,10 @@ func NewServer(name string, hostKey types.PublicKey, opts ...ServerOption) http. // tpool endpoints "GET /tpool/fee": a.handleGETTPoolFee, // wallet endpoints - "GET /wallet": a.handleGETWallet, - "GET /wallet/transactions": a.handleGETWalletTransactions, - "GET /wallet/pending": a.handleGETWalletPending, - "POST /wallet/send": a.handlePOSTWalletSend, + "GET /wallet": a.handleGETWallet, + "GET /wallet/events": a.handleGETWalletEvents, + "GET /wallet/pending": a.handleGETWalletPending, + "POST /wallet/send": a.handlePOSTWalletSend, // system endpoints "GET /system/dir": a.handleGETSystemDir, "PUT /system/dir": a.handlePUTSystemDir, diff --git a/api/client.go b/api/client.go index 4dd7aea1..16747075 100644 --- a/api/client.go +++ b/api/client.go @@ -4,14 +4,16 @@ import ( "fmt" "net/url" "strconv" + "sync" "time" + "go.sia.tech/core/consensus" "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/metrics" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/wallet" "go.sia.tech/hostd/webhooks" "go.sia.tech/jape" ) @@ -19,17 +21,45 @@ import ( // A Client is a client for the hostd API. type Client struct { c jape.Client + + mu sync.Mutex // protects the following fields + n *consensus.Network } -// Host returns the current state of the host -func (c *Client) Host() (resp HostState, err error) { - err = c.c.GET("/state/host", &resp) +// State returns the current state of the host +func (c *Client) State() (resp State, err error) { + err = c.c.GET("/state", &resp) return } -// Consensus returns the current consensus state. -func (c *Client) Consensus() (resp ConsensusState, err error) { - err = c.c.GET("/state/consensus", &resp) +// ConsensusNetwork returns the node's consensus network +func (c *Client) ConsensusNetwork() (network *consensus.Network, err error) { + err = c.c.GET("/state/consensus/network", network) + return +} + +// ConsensusTip returns the current consensus tip +func (c *Client) ConsensusTip() (tip types.ChainIndex, err error) { + err = c.c.GET("/state/consensus/tip", &tip) + return +} + +// ConsensusTipState returns the current consensus tip state +func (c *Client) ConsensusTipState() (state consensus.State, err error) { + err = c.c.GET("/state/consensus/tipstate", &state) + if err != nil { + return + } + c.mu.Lock() + if c.n == nil { + c.n, err = c.ConsensusNetwork() + if err != nil { + c.mu.Unlock() + return consensus.State{}, fmt.Errorf("failed to get consensus network: %w", err) + } + } + state.Network = c.n + c.mu.Unlock() return } @@ -184,15 +214,15 @@ func (c *Client) Wallet() (resp WalletResponse, err error) { return } -// Transactions returns the transactions of the host's wallet. -func (c *Client) Transactions(limit, offset int) (transactions []wallet.Transaction, err error) { - err = c.c.GET(fmt.Sprintf("/wallet/transactions?limit=%d&offset=%d", limit, offset), &transactions) +// Events returns the transactions of the host's wallet. +func (c *Client) Events(limit, offset int) (transactions []wallet.Event, err error) { + err = c.c.GET(fmt.Sprintf("/wallet/events?limit=%d&offset=%d", limit, offset), &transactions) return } -// PendingTransactions returns transactions that are not yet confirmed. -func (c *Client) PendingTransactions() (transactions []wallet.Transaction, err error) { - err = c.c.GET("/wallet/pending", &transactions) +// PendingEvents returns transactions that are not yet confirmed. +func (c *Client) PendingEvents() (events []wallet.Event, err error) { + err = c.c.GET("/wallet/pending", &events) return } @@ -225,8 +255,8 @@ func (c *Client) MkDir(path string) error { return c.c.PUT("/system/dir", req) } -// RegisterWebHook registers a new WebHook. -func (c *Client) RegisterWebHook(callbackURL string, scopes []string) (hook webhooks.WebHook, err error) { +// RegisterWebHook registers a new Webhook. +func (c *Client) RegisterWebHook(callbackURL string, scopes []string) (hook webhooks.Webhook, err error) { req := RegisterWebHookRequest{ CallbackURL: callbackURL, Scopes: scopes, @@ -235,8 +265,8 @@ func (c *Client) RegisterWebHook(callbackURL string, scopes []string) (hook webh return } -// UpdateWebHook updates the WebHook with the specified ID. -func (c *Client) UpdateWebHook(id int64, callbackURL string, scopes []string) (hook webhooks.WebHook, err error) { +// UpdateWebHook updates the Webhook with the specified ID. +func (c *Client) UpdateWebHook(id int64, callbackURL string, scopes []string) (hook webhooks.Webhook, err error) { req := RegisterWebHookRequest{ CallbackURL: callbackURL, Scopes: scopes, @@ -245,13 +275,13 @@ func (c *Client) UpdateWebHook(id int64, callbackURL string, scopes []string) (h return } -// DeleteWebHook deletes the WebHook with the specified ID. +// DeleteWebHook deletes the Webhook with the specified ID. func (c *Client) DeleteWebHook(id int64) error { return c.c.DELETE(fmt.Sprintf("/webhooks/%d", id)) } // WebHooks returns all registered WebHooks. -func (c *Client) WebHooks() (hooks []webhooks.WebHook, err error) { +func (c *Client) WebHooks() (hooks []webhooks.Webhook, err error) { err = c.c.GET("/webhooks", &hooks) return } diff --git a/api/endpoints.go b/api/endpoints.go index e87c4222..a0e52b6a 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -22,7 +22,6 @@ import ( "go.sia.tech/hostd/internal/prometheus" "go.sia.tech/hostd/webhooks" "go.sia.tech/jape" - "go.sia.tech/siad/modules" "go.uber.org/zap" ) @@ -32,17 +31,17 @@ var startTime = time.Now() // checkServerError conditionally writes an error to the response if err is not // nil. -func (a *api) checkServerError(c jape.Context, context string, err error) bool { +func (a *api) checkServerError(jc jape.Context, context string, err error) bool { if err != nil { - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) a.log.Warn(context, zap.Error(err)) } return err == nil } -func (a *api) writeResponse(c jape.Context, resp any) { +func (a *api) writeResponse(jc jape.Context, resp any) { var responseFormat string - if err := c.DecodeForm("response", &responseFormat); err != nil { + if err := jc.DecodeForm("response", &responseFormat); err != nil { return } @@ -52,26 +51,26 @@ func (a *api) writeResponse(c jape.Context, resp any) { v, ok := resp.(prometheus.Marshaller) if !ok { err := fmt.Errorf("response does not implement prometheus.Marshaller %T", resp) - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) a.log.Error("response does not implement prometheus.Marshaller", zap.Stack("stack"), zap.Error(err)) return } - enc := prometheus.NewEncoder(c.ResponseWriter) + enc := prometheus.NewEncoder(jc.ResponseWriter) if err := enc.Append(v); err != nil { a.log.Error("failed to marshal prometheus response", zap.Error(err)) return } default: - c.Encode(resp) + jc.Encode(resp) } } } -func (a *api) handleGETHostState(c jape.Context) { +func (a *api) handleGETState(jc jape.Context) { announcement, err := a.settings.LastAnnouncement() if err != nil { - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) return } @@ -80,10 +79,9 @@ func (a *api) handleGETHostState(c jape.Context) { baseURL = a.explorer.BaseURL() } - a.writeResponse(c, HostState{ + a.writeResponse(jc, State{ Name: a.name, PublicKey: a.hostKey, - WalletAddress: a.wallet.Address(), StartTime: startTime, LastAnnouncement: announcement, Explorer: ExplorerState{ @@ -91,7 +89,6 @@ func (a *api) handleGETHostState(c jape.Context) { URL: baseURL, }, BuildState: BuildState{ - Network: build.NetworkName(), Version: build.Version(), Commit: build.Commit(), OS: runtime.GOOS, @@ -100,170 +97,168 @@ func (a *api) handleGETHostState(c jape.Context) { }) } -func (a *api) handleGETConsensusState(c jape.Context) { - a.writeResponse(c, ConsensusState{ - Synced: a.chain.Synced(), - ChainIndex: a.chain.TipState().Index, - }) +func (a *api) handleGETConsensusTip(jc jape.Context) { + jc.Encode(a.chain.Tip()) +} +func (a *api) handleGETConsensusTipState(jc jape.Context) { + jc.Encode(a.chain.TipState()) +} +func (a *api) handleGETConsensusNetwork(jc jape.Context) { + jc.Encode(a.chain.TipState().Network) } -func (a *api) handleGETSyncerAddr(c jape.Context) { - a.writeResponse(c, SyncerAddrResp(a.syncer.Address())) +func (a *api) handleGETSyncerAddr(jc jape.Context) { + a.writeResponse(jc, SyncerAddrResp(a.syncer.Addr())) } -func (a *api) handleGETSyncerPeers(c jape.Context) { +func (a *api) handleGETSyncerPeers(jc jape.Context) { p := a.syncer.Peers() peers := make([]Peer, len(p)) for i, peer := range p { peers[i] = Peer{ - Address: string(peer.NetAddress), - Version: peer.Version, + Address: peer.ConnAddr, + Version: peer.Version(), } } - a.writeResponse(c, PeerResp(peers)) + a.writeResponse(jc, PeerResp(peers)) } -func (a *api) handlePUTSyncerPeer(c jape.Context) { +func (a *api) handlePUTSyncerPeer(jc jape.Context) { var req SyncerConnectRequest - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } - err := a.syncer.Connect(modules.NetAddress(req.Address)) - a.checkServerError(c, "failed to connect to peer", err) + _, err := a.syncer.Connect(jc.Request.Context(), req.Address) + a.checkServerError(jc, "failed to connect to peer", err) } -func (a *api) handleDeleteSyncerPeer(c jape.Context) { - var addr modules.NetAddress - if err := c.DecodeParam("address", &addr); err != nil { - return - } - err := a.syncer.Disconnect(addr) - a.checkServerError(c, "failed to disconnect from peer", err) +func (a *api) handleGETIndexTip(jc jape.Context) { + jc.Encode(a.index.Tip()) } -func (a *api) handleGETAlerts(c jape.Context) { - a.writeResponse(c, AlertResp(a.alerts.Active())) +func (a *api) handleGETAlerts(jc jape.Context) { + a.writeResponse(jc, AlertResp(a.alerts.Active())) } -func (a *api) handlePOSTAlertsDismiss(c jape.Context) { +func (a *api) handlePOSTAlertsDismiss(jc jape.Context) { var ids []types.Hash256 - if err := c.Decode(&ids); err != nil { + if err := jc.Decode(&ids); err != nil { return } else if len(ids) == 0 { - c.Error(errors.New("no alerts to dismiss"), http.StatusBadRequest) + jc.Error(errors.New("no alerts to dismiss"), http.StatusBadRequest) return } a.alerts.Dismiss(ids...) } -func (a *api) handlePOSTAnnounce(c jape.Context) { +func (a *api) handlePOSTAnnounce(jc jape.Context) { err := a.settings.Announce() - a.checkServerError(c, "failed to announce", err) + a.checkServerError(jc, "failed to announce", err) } -func (a *api) handleGETSettings(c jape.Context) { +func (a *api) handleGETSettings(jc jape.Context) { hs := HostSettings(a.settings.Settings()) - a.writeResponse(c, hs) + a.writeResponse(jc, hs) } -func (a *api) handlePATCHSettings(c jape.Context) { +func (a *api) handlePATCHSettings(jc jape.Context) { buf, err := json.Marshal(a.settings.Settings()) - if !a.checkServerError(c, "failed to marshal existing settings", err) { + if !a.checkServerError(jc, "failed to marshal existing settings", err) { return } var current map[string]any err = json.Unmarshal(buf, ¤t) - if !a.checkServerError(c, "failed to unmarshal existing settings", err) { + if !a.checkServerError(jc, "failed to unmarshal existing settings", err) { return } var req map[string]any - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } err = patchSettings(current, req) - if !a.checkServerError(c, "failed to patch settings", err) { + if !a.checkServerError(jc, "failed to patch settings", err) { return } buf, err = json.Marshal(current) - if !a.checkServerError(c, "failed to marshal patched settings", err) { + if !a.checkServerError(jc, "failed to marshal patched settings", err) { return } var settings settings.Settings if err := json.Unmarshal(buf, &settings); err != nil { - c.Error(err, http.StatusBadRequest) + jc.Error(err, http.StatusBadRequest) return } err = a.settings.UpdateSettings(settings) - if !a.checkServerError(c, "failed to update settings", err) { + if !a.checkServerError(jc, "failed to update settings", err) { return } // Resize the cache based on the updated settings a.volumes.ResizeCache(settings.SectorCacheSize) - c.Encode(a.settings.Settings()) + jc.Encode(a.settings.Settings()) } -func (a *api) handleGETPinnedSettings(c jape.Context) { - c.Encode(a.pinned.Pinned(c.Request.Context())) +func (a *api) handleGETPinnedSettings(jc jape.Context) { + jc.Encode(a.pinned.Pinned(jc.Request.Context())) } -func (a *api) handlePUTPinnedSettings(c jape.Context) { +func (a *api) handlePUTPinnedSettings(jc jape.Context) { var req pin.PinnedSettings - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } - a.checkServerError(c, "failed to update pinned settings", a.pinned.Update(c.Request.Context(), req)) + a.checkServerError(jc, "failed to update pinned settings", a.pinned.Update(jc.Request.Context(), req)) } -func (a *api) handlePUTDDNSUpdate(c jape.Context) { +func (a *api) handlePUTDDNSUpdate(jc jape.Context) { err := a.settings.UpdateDDNS(true) - a.checkServerError(c, "failed to update dynamic DNS", err) + a.checkServerError(jc, "failed to update dynamic DNS", err) } -func (a *api) handleGETMetrics(c jape.Context) { +func (a *api) handleGETMetrics(jc jape.Context) { var timestamp time.Time - if err := c.DecodeForm("timestamp", ×tamp); err != nil { + if err := jc.DecodeForm("timestamp", ×tamp); err != nil { return } else if timestamp.IsZero() { timestamp = time.Now() } metrics, err := a.metrics.Metrics(timestamp) - if !a.checkServerError(c, "failed to get metrics", err) { + if !a.checkServerError(jc, "failed to get metrics", err) { return } - a.writeResponse(c, Metrics(metrics)) + a.writeResponse(jc, Metrics(metrics)) } -func (a *api) handleGETPeriodMetrics(c jape.Context) { +func (a *api) handleGETPeriodMetrics(jc jape.Context) { var interval metrics.Interval - if err := c.DecodeParam("period", &interval); err != nil { + if err := jc.DecodeParam("period", &interval); err != nil { return } var start time.Time var periods int - if err := c.DecodeForm("start", &start); err != nil { + if err := jc.DecodeForm("start", &start); err != nil { return - } else if err := c.DecodeForm("periods", &periods); err != nil { + } else if err := jc.DecodeForm("periods", &periods); err != nil { return } else if start.IsZero() { - c.Error(errors.New("start time cannot be zero"), http.StatusBadRequest) + jc.Error(errors.New("start time cannot be zero"), http.StatusBadRequest) return } else if start.After(time.Now()) { - c.Error(errors.New("start time cannot be in the future"), http.StatusBadRequest) + jc.Error(errors.New("start time cannot be in the future"), http.StatusBadRequest) } start, err := metrics.Normalize(start, interval) if err != nil { - c.Error(err, http.StatusBadRequest) + jc.Error(err, http.StatusBadRequest) return } @@ -297,15 +292,15 @@ func (a *api) handleGETPeriodMetrics(c jape.Context) { } period, err := a.metrics.PeriodMetrics(start, periods, interval) - if !a.checkServerError(c, "failed to get metrics", err) { + if !a.checkServerError(jc, "failed to get metrics", err) { return } - c.Encode(period) + jc.Encode(period) } -func (a *api) handlePostContracts(c jape.Context) { +func (a *api) handlePostContracts(jc jape.Context) { var filter contracts.ContractFilter - if err := c.Decode(&filter); err != nil { + if err := jc.Decode(&filter); err != nil { return } @@ -314,162 +309,190 @@ func (a *api) handlePostContracts(c jape.Context) { } contracts, count, err := a.contracts.Contracts(filter) - if !a.checkServerError(c, "failed to get contracts", err) { + if !a.checkServerError(jc, "failed to get contracts", err) { return } - c.Encode(ContractsResponse{ + jc.Encode(ContractsResponse{ Contracts: contracts, Count: count, }) } -func (a *api) handleGETContract(c jape.Context) { +func (a *api) handleGETContract(jc jape.Context) { var id types.FileContractID - if err := c.DecodeParam("id", &id); err != nil { + if err := jc.DecodeParam("id", &id); err != nil { return } contract, err := a.contracts.Contract(id) if errors.Is(err, contracts.ErrNotFound) { - c.Error(err, http.StatusNotFound) + jc.Error(err, http.StatusNotFound) return - } else if !a.checkServerError(c, "failed to get contract", err) { + } else if !a.checkServerError(jc, "failed to get contract", err) { return } - c.Encode(contract) + jc.Encode(contract) } -func (a *api) handleGETVolume(c jape.Context) { +func (a *api) handleGETVolume(jc jape.Context) { var id int64 - if err := c.DecodeParam("id", &id); err != nil { + if err := jc.DecodeParam("id", &id); err != nil { return } else if id < 0 { - c.Error(errors.New("invalid volume id"), http.StatusBadRequest) + jc.Error(errors.New("invalid volume id"), http.StatusBadRequest) return } volume, err := a.volumes.Volume(id) if errors.Is(err, storage.ErrVolumeNotFound) { - c.Error(err, http.StatusNotFound) + jc.Error(err, http.StatusNotFound) return - } else if !a.checkServerError(c, "failed to get volume", err) { + } else if !a.checkServerError(jc, "failed to get volume", err) { return } - c.Encode(toJSONVolume(volume)) + jc.Encode(toJSONVolume(volume)) } -func (a *api) handlePUTVolume(c jape.Context) { +func (a *api) handlePUTVolume(jc jape.Context) { var id int64 - if err := c.DecodeParam("id", &id); err != nil { + if err := jc.DecodeParam("id", &id); err != nil { return } else if id < 0 { - c.Error(errors.New("invalid volume id"), http.StatusBadRequest) + jc.Error(errors.New("invalid volume id"), http.StatusBadRequest) return } var req UpdateVolumeRequest - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } err := a.volumes.SetReadOnly(id, req.ReadOnly) if errors.Is(err, storage.ErrVolumeNotFound) { - c.Error(err, http.StatusNotFound) + jc.Error(err, http.StatusNotFound) return } - a.checkServerError(c, "failed to update volume", err) + a.checkServerError(jc, "failed to update volume", err) } -func (a *api) handleDeleteSector(c jape.Context) { +func (a *api) handleDeleteSector(jc jape.Context) { var root types.Hash256 - if err := c.DecodeParam("root", &root); err != nil { + if err := jc.DecodeParam("root", &root); err != nil { return } err := a.volumes.RemoveSector(root) - a.checkServerError(c, "failed to remove sector", err) + a.checkServerError(jc, "failed to remove sector", err) } -func (a *api) handleGETWallet(c jape.Context) { - spendable, confirmed, unconfirmed, err := a.wallet.Balance() - if !a.checkServerError(c, "failed to get wallet", err) { +func (a *api) handleGETWallet(jc jape.Context) { + balance, err := a.wallet.Balance() + if !a.checkServerError(jc, "failed to get wallet", err) { return } - a.writeResponse(c, WalletResponse{ - ScanHeight: a.wallet.ScanHeight(), - Address: a.wallet.Address(), - Spendable: spendable, - Confirmed: confirmed, - Unconfirmed: unconfirmed, + a.writeResponse(jc, WalletResponse{ + Balance: balance, + Address: a.wallet.Address(), }) } -func (a *api) handleGETWalletTransactions(c jape.Context) { - limit, offset := parseLimitParams(c, 100, 500) +func (a *api) handleGETWalletEvents(jc jape.Context) { + limit, offset := parseLimitParams(jc, 100, 500) - transactions, err := a.wallet.Transactions(limit, offset) - if !a.checkServerError(c, "failed to get wallet transactions", err) { + transactions, err := a.wallet.Events(offset, limit) + if !a.checkServerError(jc, "failed to get events", err) { return } - a.writeResponse(c, WalletTransactionsResp(transactions)) + a.writeResponse(jc, WalletTransactionsResp(transactions)) } -func (a *api) handleGETWalletPending(c jape.Context) { - pending, err := a.wallet.UnconfirmedTransactions() - if !a.checkServerError(c, "failed to get wallet pending", err) { +func (a *api) handleGETWalletPending(jc jape.Context) { + pending, err := a.wallet.UnconfirmedEvents() + if !a.checkServerError(jc, "failed to get wallet pending", err) { return } - a.writeResponse(c, WalletPendingResp(pending)) + a.writeResponse(jc, WalletPendingResp(pending)) } -func (a *api) handlePOSTWalletSend(c jape.Context) { +func (a *api) handlePOSTWalletSend(jc jape.Context) { var req WalletSendSiacoinsRequest - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } else if req.Address == types.VoidAddress { - c.Error(errors.New("cannot send to void address"), http.StatusBadRequest) + jc.Error(errors.New("cannot send to void address"), http.StatusBadRequest) return } // estimate miner fee - feePerByte := a.tpool.RecommendedFee() + feePerByte := a.chain.RecommendedFee() minerFee := feePerByte.Mul64(stdTxnSize) if req.SubtractMinerFee { var underflow bool req.Amount, underflow = req.Amount.SubWithUnderflow(minerFee) if underflow { - c.Error(fmt.Errorf("amount must be greater than miner fee: %s", minerFee), http.StatusBadRequest) + jc.Error(fmt.Errorf("amount must be greater than miner fee: %s", minerFee), http.StatusBadRequest) return } } - // build transaction - txn := types.Transaction{ - MinerFees: []types.Currency{minerFee}, - SiacoinOutputs: []types.SiacoinOutput{ - {Address: req.Address, Value: req.Amount}, - }, - } - // fund and sign transaction - toSign, release, err := a.wallet.FundTransaction(&txn, req.Amount.Add(minerFee)) - if !a.checkServerError(c, "failed to fund transaction", err) { - return - } - defer release() - err = a.wallet.SignTransaction(a.chain.TipState(), &txn, toSign, types.CoveredFields{WholeTransaction: true}) - if !a.checkServerError(c, "failed to sign transaction", err) { - return - } - // broadcast transaction - err = a.tpool.AcceptTransactionSet([]types.Transaction{txn}) - if !a.checkServerError(c, "failed to broadcast transaction", err) { - return + state := a.chain.TipState() + // if the current height is below the v2 hardfork height, send a v1 + // transaction + if state.Index.Height < state.Network.HardforkV2.AllowHeight { + // build transaction + txn := types.Transaction{ + MinerFees: []types.Currency{minerFee}, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: req.Address, Value: req.Amount}, + }, + } + toSign, err := a.wallet.FundTransaction(&txn, req.Amount.Add(minerFee), false) + if !a.checkServerError(jc, "failed to fund transaction", err) { + return + } + a.wallet.SignTransaction(&txn, toSign, types.CoveredFields{WholeTransaction: true}) + // shouldn't be necessary to get parents since the transaction is + // not using unconfirmed outputs, but good practice + txnset := append(a.chain.UnconfirmedParents(txn), txn) + // verify the transaction and add it to the transaction pool + if _, err := a.chain.AddPoolTransactions(txnset); !a.checkServerError(jc, "failed to add transaction set", err) { + a.wallet.ReleaseInputs([]types.Transaction{txn}, nil) + return + } + // broadcast the transaction + a.syncer.BroadcastTransactionSet(txnset) + jc.Encode(txn.ID()) + } else { + txn := types.V2Transaction{ + MinerFee: minerFee, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: req.Address, Value: req.Amount}, + }, + } + // fund and sign transaction + basis, toSign, err := a.wallet.FundV2Transaction(&txn, req.Amount.Add(minerFee), false) + if !a.checkServerError(jc, "failed to fund transaction", err) { + return + } + a.wallet.SignV2Inputs(&txn, toSign) + basis, txnset, err := a.chain.V2TransactionSet(basis, txn) + if !a.checkServerError(jc, "failed to create transaction set", err) { + a.wallet.ReleaseInputs(nil, []types.V2Transaction{txn}) + return + } + // verify the transaction and add it to the transaction pool + if _, err := a.chain.AddV2PoolTransactions(basis, txnset); !a.checkServerError(jc, "failed to add v2 transaction set", err) { + a.wallet.ReleaseInputs(nil, []types.V2Transaction{txn}) + return + } + // broadcast the transaction + a.syncer.BroadcastV2TransactionSet(basis, txnset) + jc.Encode(txn.ID()) } - c.Encode(txn.ID()) } -func (a *api) handleGETSystemDir(c jape.Context) { +func (a *api) handleGETSystemDir(jc jape.Context) { var path string - if err := c.DecodeForm("path", &path); err != nil { + if err := jc.DecodeForm("path", &path); err != nil { return } @@ -477,10 +500,10 @@ func (a *api) handleGETSystemDir(c jape.Context) { // special handling for / on Windows if path == `/` || path == `\` { drives, err := disk.Drives() - if !a.checkServerError(c, "failed to get drives", err) { + if !a.checkServerError(jc, "failed to get drives", err) { return } - c.Encode(SystemDirResponse{ + jc.Encode(SystemDirResponse{ Path: path, Directories: drives, }) @@ -502,7 +525,7 @@ func (a *api) handleGETSystemDir(c jape.Context) { // try to get the working directory instead path, err = os.Getwd() if err != nil { - c.Error(fmt.Errorf("failed to get home dir: %w", err), http.StatusInternalServerError) + jc.Error(fmt.Errorf("failed to get home dir: %w", err), http.StatusInternalServerError) return } } @@ -510,27 +533,27 @@ func (a *api) handleGETSystemDir(c jape.Context) { var err error path, err = os.Getwd() if err != nil { - c.Error(fmt.Errorf("failed to get working dir: %w", err), http.StatusInternalServerError) + jc.Error(fmt.Errorf("failed to get working dir: %w", err), http.StatusInternalServerError) return } } path = filepath.Clean(path) if !filepath.IsAbs(path) { - c.Error(errors.New("path must be absolute"), http.StatusBadRequest) + jc.Error(errors.New("path must be absolute"), http.StatusBadRequest) return } dir, err := os.ReadDir(path) if errors.Is(err, os.ErrNotExist) { - c.Error(fmt.Errorf("path does not exist: %w", err), http.StatusNotFound) + jc.Error(fmt.Errorf("path does not exist: %w", err), http.StatusNotFound) return - } else if !a.checkServerError(c, "failed to read dir", err) { + } else if !a.checkServerError(jc, "failed to read dir", err) { return } // get disk usage free, total, err := disk.Usage(path) - if !a.checkServerError(c, "failed to get disk usage", err) { + if !a.checkServerError(jc, "failed to get disk usage", err) { return } @@ -545,111 +568,111 @@ func (a *api) handleGETSystemDir(c jape.Context) { resp.Directories = append(resp.Directories, entry.Name()) } } - c.Encode(resp) + jc.Encode(resp) } -func (a *api) handlePUTSystemDir(c jape.Context) { +func (a *api) handlePUTSystemDir(jc jape.Context) { var req CreateDirRequest - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } - a.checkServerError(c, "failed to create dir", os.MkdirAll(req.Path, 0775)) + a.checkServerError(jc, "failed to create dir", os.MkdirAll(req.Path, 0775)) } -func (a *api) handleGETTPoolFee(c jape.Context) { - a.writeResponse(c, TPoolResp(a.tpool.RecommendedFee())) +func (a *api) handleGETTPoolFee(jc jape.Context) { + a.writeResponse(jc, TPoolResp(a.chain.RecommendedFee())) } -func (a *api) handleGETAccounts(c jape.Context) { - limit, offset := parseLimitParams(c, 100, 500) +func (a *api) handleGETAccounts(jc jape.Context) { + limit, offset := parseLimitParams(jc, 100, 500) accounts, err := a.accounts.Accounts(limit, offset) - if !a.checkServerError(c, "failed to get accounts", err) { + if !a.checkServerError(jc, "failed to get accounts", err) { return } - c.Encode(accounts) + jc.Encode(accounts) } -func (a *api) handleGETAccountFunding(c jape.Context) { +func (a *api) handleGETAccountFunding(jc jape.Context) { var account rhp3.Account - if err := c.DecodeParam("account", &account); err != nil { + if err := jc.DecodeParam("account", &account); err != nil { return } funding, err := a.accounts.AccountFunding(account) - if !a.checkServerError(c, "failed to get account funding", err) { + if !a.checkServerError(jc, "failed to get account funding", err) { return } - c.Encode(funding) + jc.Encode(funding) } -func (a *api) handleGETWebhooks(c jape.Context) { - hooks, err := a.webhooks.WebHooks() +func (a *api) handleGETWebhooks(jc jape.Context) { + hooks, err := a.webhooks.Webhooks() if err != nil { - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) return } - c.Encode(hooks) + jc.Encode(hooks) } -func (a *api) handlePOSTWebhooks(c jape.Context) { +func (a *api) handlePOSTWebhooks(jc jape.Context) { var req RegisterWebHookRequest - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } - hook, err := a.webhooks.RegisterWebHook(req.CallbackURL, req.Scopes) + hook, err := a.webhooks.RegisterWebhook(req.CallbackURL, req.Scopes) if err != nil { - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) return } - c.Encode(hook) + jc.Encode(hook) } -func (a *api) handlePUTWebhooks(c jape.Context) { +func (a *api) handlePUTWebhooks(jc jape.Context) { var id int64 - if err := c.DecodeParam("id", &id); err != nil { + if err := jc.DecodeParam("id", &id); err != nil { return } var req RegisterWebHookRequest - if err := c.Decode(&req); err != nil { + if err := jc.Decode(&req); err != nil { return } - _, err := a.webhooks.UpdateWebHook(id, req.CallbackURL, req.Scopes) + _, err := a.webhooks.UpdateWebhook(id, req.CallbackURL, req.Scopes) if err != nil { - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) return } } -func (a *api) handlePOSTWebhooksTest(c jape.Context) { +func (a *api) handlePOSTWebhooksTest(jc jape.Context) { var id int64 - if err := c.DecodeParam("id", &id); err != nil { + if err := jc.DecodeParam("id", &id); err != nil { return } if err := a.webhooks.BroadcastToWebhook(id, "test", webhooks.ScopeTest, nil); err != nil { - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) return } } -func (a *api) handleDELETEWebhooks(c jape.Context) { +func (a *api) handleDELETEWebhooks(jc jape.Context) { var id int64 - if err := c.DecodeParam("id", &id); err != nil { + if err := jc.DecodeParam("id", &id); err != nil { return } - err := a.webhooks.RemoveWebHook(id) + err := a.webhooks.RemoveWebhook(id) if err != nil { - c.Error(err, http.StatusInternalServerError) + jc.Error(err, http.StatusInternalServerError) return } } -func parseLimitParams(c jape.Context, defaultLimit, maxLimit int) (limit, offset int) { - if err := c.DecodeForm("limit", &limit); err != nil { +func parseLimitParams(jc jape.Context, defaultLimit, maxLimit int) (limit, offset int) { + if err := jc.DecodeForm("limit", &limit); err != nil { return - } else if err := c.DecodeForm("offset", &offset); err != nil { + } else if err := jc.DecodeForm("offset", &offset); err != nil { return } if limit > maxLimit { diff --git a/api/options.go b/api/options.go index 61f813e5..d7e9a0c7 100644 --- a/api/options.go +++ b/api/options.go @@ -1,7 +1,11 @@ package api import ( - "go.sia.tech/hostd/internal/explorer" + "go.sia.tech/core/types" + "go.sia.tech/hostd/alerts" + "go.sia.tech/hostd/explorer" + "go.sia.tech/hostd/rhp" + "go.sia.tech/hostd/webhooks" "go.uber.org/zap" ) @@ -15,62 +19,13 @@ func ServerWithAlerts(al Alerts) ServerOption { } } -// ServerWithWebHooks sets the webhooks manager for the API server. -func ServerWithWebHooks(w WebHooks) ServerOption { +// ServerWithWebhooks sets the webhooks manager for the API server. +func ServerWithWebhooks(w Webhooks) ServerOption { return func(a *api) { a.webhooks = w } } -// ServerWithSyncer sets the syncer for the API server. -func ServerWithSyncer(g Syncer) ServerOption { - return func(a *api) { - a.syncer = g - } -} - -// ServerWithChainManager sets the chain manager for the API server. -func ServerWithChainManager(chain ChainManager) ServerOption { - return func(a *api) { - a.chain = chain - } -} - -// ServerWithTransactionPool sets the transaction pool for the API server. -func ServerWithTransactionPool(tp TPool) ServerOption { - return func(a *api) { - a.tpool = tp - } -} - -// ServerWithContractManager sets the contract manager for the API server. -func ServerWithContractManager(cm ContractManager) ServerOption { - return func(a *api) { - a.contracts = cm - } -} - -// ServerWithAccountManager sets the account manager for the API server. -func ServerWithAccountManager(am AccountManager) ServerOption { - return func(a *api) { - a.accounts = am - } -} - -// ServerWithVolumeManager sets the volume manager for the API server. -func ServerWithVolumeManager(vm VolumeManager) ServerOption { - return func(a *api) { - a.volumes = vm - } -} - -// ServerWithMetricManager sets the metric manager for the API server. -func ServerWithMetricManager(m MetricManager) ServerOption { - return func(a *api) { - a.metrics = m - } -} - // ServerWithPinnedSettings sets the pinned settings for the API server. func ServerWithPinnedSettings(p PinnedSettings) ServerOption { return func(a *api) { @@ -86,13 +41,6 @@ func ServerWithExplorer(explorer *explorer.Explorer) ServerOption { } } -// ServerWithSettings sets the settings manager for the API server. -func ServerWithSettings(s Settings) ServerOption { - return func(a *api) { - a.settings = s - } -} - // ServerWithRHPSessionReporter sets the RHP session reporter for the API server. func ServerWithRHPSessionReporter(rsr RHPSessionReporter) ServerOption { return func(a *api) { @@ -100,16 +48,32 @@ func ServerWithRHPSessionReporter(rsr RHPSessionReporter) ServerOption { } } -// ServerWithWallet sets the wallet for the API server. -func ServerWithWallet(w Wallet) ServerOption { - return func(a *api) { - a.wallet = w - } -} - // ServerWithLogger sets the logger for the API server. func ServerWithLogger(log *zap.Logger) ServerOption { return func(a *api) { a.log = log } } + +type noopWebhooks struct{} + +func (noopWebhooks) Webhooks() ([]webhooks.Webhook, error) { return nil, nil } +func (noopWebhooks) RemoveWebhook(id int64) error { return nil } +func (noopWebhooks) BroadcastToWebhook(id int64, event, scope string, data any) error { return nil } +func (noopWebhooks) RegisterWebhook(callbackURL string, scopes []string) (webhooks.Webhook, error) { + return webhooks.Webhook{}, nil +} +func (noopWebhooks) UpdateWebhook(id int64, callbackURL string, scopes []string) (webhooks.Webhook, error) { + return webhooks.Webhook{}, nil +} + +type noopAlerts struct{} + +func (noopAlerts) Active() []alerts.Alert { return nil } +func (noopAlerts) Dismiss(...types.Hash256) {} + +type noopSessionReporter struct{} + +func (noopSessionReporter) Subscribe(rhp.SessionSubscriber) {} +func (noopSessionReporter) Unsubscribe(rhp.SessionSubscriber) {} +func (noopSessionReporter) Active() []rhp.Session { return nil } diff --git a/api/prometheus.go b/api/prometheus.go index 364d91cc..065ee119 100644 --- a/api/prometheus.go +++ b/api/prometheus.go @@ -10,55 +10,33 @@ import ( ) // PrometheusMetric returns a Prometheus metric for the host state. -func (hs HostState) PrometheusMetric() []prometheus.Metric { +func (s State) PrometheusMetric() []prometheus.Metric { return []prometheus.Metric{ { Name: "hostd_host_state", Labels: map[string]any{ - "name": hs.Name, - "public_key": hs.PublicKey, - "wallet_address": hs.WalletAddress, - "network": hs.Network, - "version": hs.Version, - "commit": hs.Commit, - "os": hs.OS, - "build_time": hs.BuildTime, + "name": s.Name, + "public_key": s.PublicKey, + "version": s.Version, + "commit": s.Commit, + "os": s.OS, + "build_time": s.BuildTime, }, Value: 1, }, { Name: "hostd_start_time", - Value: float64(hs.StartTime.UTC().UnixMilli()), + Value: float64(s.StartTime.UTC().UnixMilli()), }, { Name: "hostd_runtime", - Value: float64(time.Since(hs.StartTime).Milliseconds()), + Value: float64(time.Since(s.StartTime).Milliseconds()), Timestamp: time.Now(), }, { Name: "hostd_last_announcement", - Labels: map[string]any{"address": hs.LastAnnouncement.Address, "id": hs.LastAnnouncement.Index.ID}, - Value: float64(hs.LastAnnouncement.Index.Height), - }, - } -} - -// PrometheusMetric returns a Prometheus metric for the consensus state. -func (cs ConsensusState) PrometheusMetric() []prometheus.Metric { - return []prometheus.Metric{ - { - Name: "hostd_consensus_state_synced", - Value: func() float64 { - if cs.Synced { - return 1 - } - return 0 - }(), - }, - { - Name: "hostd_consensus_state_index", - Labels: map[string]any{"id": cs.ChainIndex.ID}, - Value: float64(cs.ChainIndex.Height), + Labels: map[string]any{"address": s.LastAnnouncement.Address, "id": s.LastAnnouncement.Index.ID}, + Value: float64(s.LastAnnouncement.Index.Height), }, } } @@ -200,10 +178,6 @@ func (m Metrics) PrometheusMetric() []prometheus.Metric { Name: "hostd_metrics_pricing_collateral_multiplier", Value: m.Pricing.CollateralMultiplier, }, - { - Name: "hostd_metrics_contracts_pending", - Value: float64(m.Contracts.Pending), - }, { Name: "hostd_metrics_contracts_active", Value: float64(m.Contracts.Active), @@ -269,8 +243,12 @@ func (m Metrics) PrometheusMetric() []prometheus.Metric { Value: float64(m.Data.RHP.Egress), }, { - Name: "hostd_metrics_balance", - Value: m.Balance.Siacoins(), + Name: "hostd_metrics_wallet_balance", + Value: m.Wallet.Balance.Siacoins(), + }, + { + Name: "hostd_metrics_wallet_immature_balance", + Value: m.Wallet.ImmatureBalance.Siacoins(), }, } } @@ -278,13 +256,6 @@ func (m Metrics) PrometheusMetric() []prometheus.Metric { // PrometheusMetric returns Prometheus samples for the host wallet. func (wr WalletResponse) PrometheusMetric() []prometheus.Metric { return []prometheus.Metric{ - { - Name: "hostd_wallet_scan_height", - Labels: map[string]any{ - "address": wr.Address, - }, - Value: float64(wr.ScanHeight), - }, { Name: "hostd_wallet_spendable", Labels: map[string]any{ @@ -306,6 +277,13 @@ func (wr WalletResponse) PrometheusMetric() []prometheus.Metric { }, Value: wr.Unconfirmed.Siacoins(), }, + { + Name: "hostd_wallet_immature", + Labels: map[string]any{ + "address": wr.Address, + }, + Value: wr.Immature.Siacoins(), + }, } } @@ -438,19 +416,22 @@ func (s SyncerAddrResp) PrometheusMetric() (metrics []prometheus.Metric) { // PrometheusMetric returns Prometheus samples for the hosts transactions. func (w WalletTransactionsResp) PrometheusMetric() (metrics []prometheus.Metric) { - metricName := "hostd_wallet_transaction" + metricName := "hostd_wallet_event" for _, txn := range w { + inflow, outflow := txn.SiacoinInflow(), txn.SiacoinOutflow() + var value float64 - if txn.Inflow.Cmp(txn.Outflow) > 0 { // inflow > outflow = positive value - value = txn.Inflow.Sub(txn.Outflow).Siacoins() + if inflow.Cmp(outflow) > 0 { // inflow > outflow = positive value + value = inflow.Sub(outflow).Siacoins() } else { // inflow < outflow = negative value - value = txn.Outflow.Sub(txn.Inflow).Siacoins() * -1 + value = outflow.Sub(inflow).Siacoins() * -1 } metrics = append(metrics, prometheus.Metric{ Name: metricName, Labels: map[string]any{ "txid": strings.Split(txn.ID.String(), ":")[1], + "type": txn.Type, }, Value: value, }) @@ -460,19 +441,21 @@ func (w WalletTransactionsResp) PrometheusMetric() (metrics []prometheus.Metric) // PrometheusMetric returns Prometheus samples for the host pending transactions. func (w WalletPendingResp) PrometheusMetric() (metrics []prometheus.Metric) { - metricName := "hostd_wallet_transaction_pending" + metricName := "hostd_wallet_event_pending" for _, txn := range w { + inflow, outflow := txn.SiacoinInflow(), txn.SiacoinOutflow() var value float64 - if txn.Inflow.Cmp(txn.Outflow) > 0 { // inflow > outflow = positive value - value = txn.Inflow.Sub(txn.Outflow).Siacoins() + if inflow.Cmp(outflow) > 0 { // inflow > outflow = positive value + value = inflow.Sub(outflow).Siacoins() } else { // inflow < outflow = negative value - value = txn.Outflow.Sub(txn.Inflow).Siacoins() * -1 + value = outflow.Sub(inflow).Siacoins() * -1 } metrics = append(metrics, prometheus.Metric{ Name: metricName, Labels: map[string]any{ "txid": strings.Split(txn.ID.String(), ":")[1], + "type": txn.Type, }, Value: value, }) diff --git a/api/types.go b/api/types.go index 7508dfe7..fe7ff65d 100644 --- a/api/types.go +++ b/api/types.go @@ -8,13 +8,13 @@ import ( "time" "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/metrics" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/storage" "go.sia.tech/hostd/rhp" - "go.sia.tech/hostd/wallet" ) // JSON keys for host setting fields @@ -46,7 +46,6 @@ type ( // BuildState contains static information about the build. BuildState struct { - Network string `json:"network"` Version string `json:"version"` Commit string `json:"commit"` OS string `json:"os"` @@ -59,12 +58,11 @@ type ( URL string `json:"url"` } - // HostState is the response body for the [GET] /state/host endpoint. - HostState struct { + // State is the response body for the [GET] /state endpoint. + State struct { Name string `json:"name,omitempty"` PublicKey types.PublicKey `json:"publicKey"` LastAnnouncement settings.Announcement `json:"lastAnnouncement"` - WalletAddress types.Address `json:"walletAddress"` StartTime time.Time `json:"startTime"` Explorer ExplorerState `json:"explorer"` BuildState @@ -76,12 +74,6 @@ type ( // Metrics is the response body for the [GET] /metrics endpoint. Metrics metrics.Metrics - // ConsensusState is the response body for the [GET] /consensus endpoint. - ConsensusState struct { - Synced bool `json:"synced"` - ChainIndex types.ChainIndex `json:"chainIndex"` - } - // ContractIntegrityResponse is the response body for the [POST] /contracts/:id/check endpoint. ContractIntegrityResponse struct { BadSectors []types.Hash256 `json:"badSectors"` @@ -123,11 +115,9 @@ type ( // WalletResponse is the response body for the [GET] /wallet endpoint. WalletResponse struct { - ScanHeight uint64 `json:"scanHeight"` - Address types.Address `json:"address"` - Spendable types.Currency `json:"spendable"` - Confirmed types.Currency `json:"confirmed"` - Unconfirmed types.Currency `json:"unconfirmed"` + wallet.Balance + + Address types.Address `json:"address"` } // WalletSendSiacoinsRequest is the request body for the [POST] /wallet/send endpoint. @@ -188,10 +178,10 @@ type ( SyncerAddrResp string // WalletTransactionsResp is the response body for the [GET] /wallet/transactions endpoint - WalletTransactionsResp []wallet.Transaction + WalletTransactionsResp []wallet.Event // WalletPendingResp is the response body for the [GET] /wallet/pending endpoint - WalletPendingResp []wallet.Transaction + WalletPendingResp []wallet.Event // SessionResp is the response body for the [GET] /sessions endpoint SessionResp []rhp.Session @@ -228,6 +218,16 @@ func (je *JSONErrors) UnmarshalJSON(b []byte) error { return nil } +// MarshalText implements test.Marshaler +func (tr TPoolResp) MarshalText() ([]byte, error) { + return types.Currency(tr).MarshalText() +} + +// UnmarshalText implements test.Unmarshaler +func (tr *TPoolResp) UnmarshalText(b []byte) error { + return (*types.Currency)(tr).UnmarshalText(b) +} + // SetAcceptingContracts sets the AcceptingContracts field of the request func SetAcceptingContracts(value bool) Setting { return func(v map[string]any) { diff --git a/build/build.go b/build/build.go index 0bd7ca2c..e3da5f20 100644 --- a/build/build.go +++ b/build/build.go @@ -5,19 +5,6 @@ package build import "time" -// NetworkName returns the human-readable name of the current network. -func NetworkName() string { - n, _ := Network() - switch n.Name { - case "mainnet": - return "Mainnet" - case "zen": - return "Zen Testnet" - default: - return n.Name - } -} - // Commit returns the commit hash of hostd func Commit() string { return commit diff --git a/build/meta.go b/build/meta.go index 05dd9fc6..f6408384 100644 --- a/build/meta.go +++ b/build/meta.go @@ -2,6 +2,6 @@ package build const ( commit = "?" - version = "?" + version = "2.0.0" buildTime = 0 ) diff --git a/build/network_default.go b/build/network_default.go deleted file mode 100644 index 25a1a56b..00000000 --- a/build/network_default.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !testnet && !testing - -package build - -import ( - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/coreutils/chain" -) - -// Network returns the Sia network consts and genesis block for the current build. -func Network() (*consensus.Network, types.Block) { - return chain.Mainnet() -} diff --git a/build/network_testing.go b/build/network_testing.go deleted file mode 100644 index 78ff4e61..00000000 --- a/build/network_testing.go +++ /dev/null @@ -1,56 +0,0 @@ -//go:build testing - -package build - -import ( - "time" - - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" -) - -// Network returns the Sia network consts and genesis block for the current build. -func Network() (*consensus.Network, types.Block) { - n := &consensus.Network{ - Name: "testing", - InitialCoinbase: types.Siacoins(300000), - MinimumCoinbase: types.Siacoins(299990), - InitialTarget: types.BlockID{4: 32}, - } - - n.HardforkDevAddr.Height = 3 - n.HardforkDevAddr.OldAddress = types.Address{} - n.HardforkDevAddr.NewAddress = types.Address{} - - n.HardforkTax.Height = 10 - - n.HardforkStorageProof.Height = 10 - - n.HardforkOak.Height = 20 - n.HardforkOak.FixHeight = 23 - n.HardforkOak.GenesisTimestamp = time.Now().Add(-1e6 * time.Second) - - n.HardforkASIC.Height = 5 - n.HardforkASIC.OakTime = 10000 * time.Second - n.HardforkASIC.OakTarget = types.BlockID{255, 255} - - n.HardforkFoundation.Height = 50 - n.HardforkFoundation.PrimaryAddress = types.StandardUnlockHash(types.GeneratePrivateKey().PublicKey()) - n.HardforkFoundation.FailsafeAddress = types.StandardUnlockHash(types.GeneratePrivateKey().PublicKey()) - - // make it difficult to reach v2 in most tests - n.HardforkV2.AllowHeight = 1000 - n.HardforkV2.RequireHeight = 1020 - - return n, types.Block{ - Transactions: []types.Transaction{ - { - SiafundOutputs: []types.SiafundOutput{ - {Value: 2000, Address: types.Address{214, 166, 197, 164, 29, 201, 53, 236, 106, 239, 10, 158, 127, 131, 20, 138, 63, 221, 230, 16, 98, 247, 32, 77, 210, 68, 116, 12, 241, 89, 27, 223}}, - {Value: 7000, Address: types.Address{209, 246, 228, 60, 248, 78, 242, 110, 9, 8, 227, 248, 225, 216, 163, 52, 142, 93, 47, 176, 103, 41, 137, 80, 212, 8, 132, 58, 241, 189, 2, 17}}, - {Value: 1000, Address: types.VoidAddress}, - }, - }, - }, - } -} diff --git a/build/network_testnet.go b/build/network_testnet.go deleted file mode 100644 index 4bc62f6c..00000000 --- a/build/network_testnet.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build testnet - -package build - -import ( - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/coreutils/chain" -) - -// Network returns the Sia network consts and genesis block for the current build. -func Network() (*consensus.Network, types.Block) { - return chain.TestnetZen() -} diff --git a/cmd/hostd/config.go b/cmd/hostd/config.go index 5bfa1502..c1285396 100644 --- a/cmd/hostd/config.go +++ b/cmd/hostd/config.go @@ -213,11 +213,11 @@ func setAdvancedConfig() { fmt.Println("It should not be exposed to the public internet without setting up a reverse proxy.") setListenAddress("HTTP Address", &cfg.HTTP.Address) - // gateway address + // syncer address fmt.Println("") - fmt.Println("The gateway address is used to exchange blocks with other nodes in the Sia network") + fmt.Println("The syncer address is used to exchange blocks with other nodes in the Sia network") fmt.Println("It should be exposed publicly to improve the host's connectivity.") - setListenAddress("Gateway Address", &cfg.Consensus.GatewayAddress) + setListenAddress("Gateway Address", &cfg.Syncer.Address) // rhp2 address fmt.Println("") @@ -264,7 +264,7 @@ func setDataDirectory() { func buildConfig() { // write the config file configPath := "hostd.yml" - if str := os.Getenv("HOSTD_CONFIG_FILE"); str != "" { + if str := os.Getenv(configFileEnvVar); str != "" { configPath = str } diff --git a/cmd/hostd/consts_default.go b/cmd/hostd/consts_default.go deleted file mode 100644 index 2539b451..00000000 --- a/cmd/hostd/consts_default.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build !testnet - -package main - -const ( - apiPasswordEnvVariable = "HOSTD_API_PASSWORD" - walletSeedEnvVariable = "HOSTD_SEED" - // logPathEnvVariable overrides the path of the log file. - // Deprecated: use logFileEnvVar instead. - logPathEnvVariable = "HOSTD_LOG_PATH" - // logFileEnvVariable overrides the location of the log file. - logFileEnvVariable = "HOSTD_LOG_FILE" - configPathEnvVariable = "HOSTD_CONFIG_FILE" - - defaultAPIAddr = "localhost:9980" - defaultGatewayAddr = ":9981" - defaultRHP2Addr = ":9982" - defaultRHP3TCPAddr = ":9983" - defaultRHP3WSAddr = ":9984" -) diff --git a/cmd/hostd/consts_testnet.go b/cmd/hostd/consts_testnet.go deleted file mode 100644 index 99f95819..00000000 --- a/cmd/hostd/consts_testnet.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build testnet - -package main - -const ( - apiPasswordEnvVariable = "HOSTD_ZEN_API_PASSWORD" - walletSeedEnvVariable = "HOSTD_ZEN_SEED" - // logPathEnvVariable overrides the path of the log file. - // Deprecated: use logFileEnvVar instead. - logPathEnvVariable = "HOSTD_ZEN_LOG_PATH" - // logFileEnvVariable overrides the location of the log file. - logFileEnvVariable = "HOSTD_ZEN_LOG_FILE" - configPathEnvVariable = "HOSTD_ZEN_CONFIG_FILE" - - defaultAPIAddr = "localhost:9880" - defaultGatewayAddr = ":9881" - defaultRHP2Addr = ":9882" - defaultRHP3TCPAddr = ":9883" - defaultRHP3WSAddr = ":9884" -) diff --git a/cmd/hostd/main.go b/cmd/hostd/main.go index 5d41bcf9..ed80ff42 100644 --- a/cmd/hostd/main.go +++ b/cmd/hostd/main.go @@ -2,67 +2,66 @@ package main import ( "context" - "errors" "flag" "fmt" - "io" - stdlog "log" - "net" - "net/http" "os" "os/exec" "os/signal" "path/filepath" "runtime" "syscall" - "time" "go.sia.tech/core/types" "go.sia.tech/coreutils/wallet" - "go.sia.tech/hostd/api" "go.sia.tech/hostd/build" "go.sia.tech/hostd/config" - "go.sia.tech/hostd/internal/explorer" "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/jape" - "go.sia.tech/web/hostd" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "golang.org/x/sys/cpu" "gopkg.in/yaml.v3" ) +const ( + walletSeedEnvVar = "HOSTD_WALLET_SEED" + apiPasswordEnvVar = "HOSTD_API_PASSWORD" + configFileEnvVar = "HOSTD_CONFIG_FILE" + logFileEnvVar = "HOSTD_LOG_FILE_PATH" +) + var ( cfg = config.Config{ - Directory: ".", // default to current directory - RecoveryPhrase: os.Getenv(walletSeedEnvVariable), // default to env variable + Directory: ".", // default to current directory + RecoveryPhrase: os.Getenv(walletSeedEnvVar), // default to env variable AutoOpenWebUI: true, HTTP: config.HTTP{ - Address: defaultAPIAddr, - Password: os.Getenv(apiPasswordEnvVariable), + Address: "127.0.0.1:9980", + Password: os.Getenv(apiPasswordEnvVar), }, Explorer: config.ExplorerData{ URL: "https://api.siascan.com", }, + Syncer: config.Syncer{ + Address: ":9981", + Bootstrap: true, + }, Consensus: config.Consensus{ - GatewayAddress: defaultGatewayAddr, - Bootstrap: true, + Network: "mainnet", + IndexBatchSize: 1000, }, RHP2: config.RHP2{ - Address: defaultRHP2Addr, + Address: ":9982", }, RHP3: config.RHP3{ - TCPAddress: defaultRHP3TCPAddr, - WebSocketAddress: defaultRHP3WSAddr, + TCPAddress: ":9983", }, Log: config.Log{ - Path: os.Getenv(logPathEnvVariable), // deprecated. included for compatibility. + Path: os.Getenv(logFileEnvVar), // deprecated. included for compatibility. Level: "info", File: config.LogFile{ Enabled: true, Format: "json", - Path: os.Getenv(logFileEnvVariable), + Path: os.Getenv(logFileEnvVar), }, StdOut: config.StdOut{ Enabled: true, @@ -75,35 +74,6 @@ var ( disableStdin bool ) -func startAPIListener(log *zap.Logger) (l net.Listener, err error) { - addr, port, err := net.SplitHostPort(cfg.HTTP.Address) - if err != nil { - return nil, fmt.Errorf("failed to parse API address: %w", err) - } - - // if the address is not localhost, listen on the address as-is - if addr != "localhost" { - return net.Listen("tcp", cfg.HTTP.Address) - } - - // localhost fails on some new installs of Windows 11, so try a few - // different addresses - tryAddresses := []string{ - net.JoinHostPort("localhost", port), // original address - net.JoinHostPort("127.0.0.1", port), // IPv4 loopback - net.JoinHostPort("::1", port), // IPv6 loopback - } - - for _, addr := range tryAddresses { - l, err = net.Listen("tcp", addr) - if err == nil { - return - } - log.Debug("failed to listen on fallback address", zap.String("address", addr), zap.Error(err)) - } - return -} - func openBrowser(url string) error { switch runtime.GOOS { case "linux": @@ -117,11 +87,11 @@ func openBrowser(url string) error { } } -// tryLoadConfig loads the config file specified by the HOSTD_CONFIG_PATH. If +// tryLoadConfig loads the config file specified by the HOSTD_CONFIG_FILE. If // the config file does not exist, it will not be loaded. func tryLoadConfig() { configPath := "hostd.yml" - if str := os.Getenv(configPathEnvVariable); str != "" { + if str := os.Getenv(configFileEnvVar); str != "" { configPath = str } @@ -149,7 +119,9 @@ func tryLoadConfig() { // jsonEncoder returns a zapcore.Encoder that encodes logs as JSON intended for // parsing. func jsonEncoder() zapcore.Encoder { - return zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()) + cfg := zap.NewProductionEncoderConfig() + cfg.EncodeTime = zapcore.RFC3339TimeEncoder + return zapcore.NewJSONEncoder(cfg) } // humanEncoder returns a zapcore.Encoder that encodes logs as human-readable @@ -198,13 +170,14 @@ func main() { flag.StringVar(&cfg.Directory, "dir", cfg.Directory, "directory to store hostd metadata") flag.BoolVar(&disableStdin, "env", false, "disable stdin prompts for environment variables (default false)") flag.BoolVar(&cfg.AutoOpenWebUI, "openui", cfg.AutoOpenWebUI, "automatically open the web UI on startup") + // syncer + flag.StringVar(&cfg.Syncer.Address, "syncer.address", cfg.Syncer.Address, "address to listen on for peer connections") + flag.BoolVar(&cfg.Syncer.Bootstrap, "syncer.bootstrap", cfg.Syncer.Bootstrap, "bootstrap the gateway and consensus modules") // consensus - flag.StringVar(&cfg.Consensus.GatewayAddress, "rpc", cfg.Consensus.GatewayAddress, "address to listen on for peer connections") - flag.BoolVar(&cfg.Consensus.Bootstrap, "bootstrap", cfg.Consensus.Bootstrap, "bootstrap the gateway and consensus modules") + flag.StringVar(&cfg.Consensus.Network, "network", cfg.Consensus.Network, "network name (mainnet, zen, etc)") // rhp flag.StringVar(&cfg.RHP2.Address, "rhp2", cfg.RHP2.Address, "address to listen on for RHP2 connections") - flag.StringVar(&cfg.RHP3.TCPAddress, "rhp3.tcp", cfg.RHP3.TCPAddress, "address to listen on for TCP RHP3 connections") - flag.StringVar(&cfg.RHP3.WebSocketAddress, "rhp3.ws", cfg.RHP3.WebSocketAddress, "address to listen on for WebSocket RHP3 connections") + flag.StringVar(&cfg.RHP3.TCPAddress, "rhp3", cfg.RHP3.TCPAddress, "address to listen on for TCP RHP3 connections") // http flag.StringVar(&cfg.HTTP.Address, "http", cfg.HTTP.Address, "address to serve API on") // log @@ -214,7 +187,6 @@ func main() { switch flag.Arg(0) { case "version": fmt.Println("hostd", build.Version()) - fmt.Println("Network", build.NetworkName()) fmt.Println("Commit:", build.Commit()) fmt.Println("Build Date:", build.Time()) return @@ -261,222 +233,125 @@ func main() { log.Fatal("failed to vacuum database", zap.Error(err)) } return - } - - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer cancel() + case "": + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() - // check that the API password is set - if cfg.HTTP.Password == "" { - if disableStdin { - stdoutFatalError("API password must be set via environment variable or config file when --env flag is set") - return + // check that the API password is set + if cfg.HTTP.Password == "" { + if disableStdin { + stdoutFatalError("API password must be set via environment variable or config file when --env flag is set") + return + } + setAPIPassword() } - setAPIPassword() - } - // check that the wallet seed is set - if cfg.RecoveryPhrase == "" { - if disableStdin { - stdoutFatalError("Wallet seed must be set via environment variable or config file when --env flag is set") - return + // check that the wallet seed is set + if cfg.RecoveryPhrase == "" { + if disableStdin { + stdoutFatalError("Wallet seed must be set via environment variable or config file when --env flag is set") + return + } + setSeedPhrase() } - setSeedPhrase() - } - - // create the data directory if it does not already exist - if err := os.MkdirAll(cfg.Directory, 0700); err != nil { - stdoutFatalError("unable to create config directory: " + err.Error()) - } - // configure the logger - if !cfg.Log.StdOut.Enabled && !cfg.Log.File.Enabled { - stdoutFatalError("At least one of stdout or file logging must be enabled") - return - } - - // normalize log level - if cfg.Log.Level == "" { - cfg.Log.Level = "info" - } - - var logCores []zapcore.Core - if cfg.Log.StdOut.Enabled { - // if no log level is set for stdout, use the global log level - if cfg.Log.StdOut.Level == "" { - cfg.Log.StdOut.Level = cfg.Log.Level + // create the data directory if it does not already exist + if err := os.MkdirAll(cfg.Directory, 0700); err != nil { + stdoutFatalError("unable to create config directory: " + err.Error()) } - var encoder zapcore.Encoder - switch cfg.Log.StdOut.Format { - case "json": - encoder = jsonEncoder() - default: // stdout defaults to human - encoder = humanEncoder(cfg.Log.StdOut.EnableANSI) + // configure the logger + if !cfg.Log.StdOut.Enabled && !cfg.Log.File.Enabled { + stdoutFatalError("At least one of stdout or file logging must be enabled") + return } - // create the stdout logger - level := parseLogLevel(cfg.Log.StdOut.Level) - logCores = append(logCores, zapcore.NewCore(encoder, zapcore.Lock(os.Stdout), level)) - } - - if cfg.Log.File.Enabled { - // if no log level is set for file, use the global log level - if cfg.Log.File.Level == "" { - cfg.Log.File.Level = cfg.Log.Level + // normalize log level + if cfg.Log.Level == "" { + cfg.Log.Level = "info" } - // normalize log path - if cfg.Log.File.Path == "" { - // If the log path is not set, try the deprecated log path. If that - // is also not set, default to hostd.log in the data directory. - if cfg.Log.Path != "" { - cfg.Log.File.Path = filepath.Join(cfg.Log.Path, "hostd.log") - } else { - cfg.Log.File.Path = filepath.Join(cfg.Directory, "hostd.log") + var logCores []zapcore.Core + if cfg.Log.StdOut.Enabled { + // if no log level is set for stdout, use the global log level + if cfg.Log.StdOut.Level == "" { + cfg.Log.StdOut.Level = cfg.Log.Level } - } - // configure file logging - var encoder zapcore.Encoder - switch cfg.Log.File.Format { - case "human": - encoder = humanEncoder(false) // disable colors in file log - default: // log file defaults to JSON - encoder = jsonEncoder() - } + var encoder zapcore.Encoder + switch cfg.Log.StdOut.Format { + case "json": + encoder = jsonEncoder() + default: // stdout defaults to human + encoder = humanEncoder(cfg.Log.StdOut.EnableANSI) + } - fileWriter, closeFn, err := zap.Open(cfg.Log.File.Path) - if err != nil { - stdoutFatalError("failed to open log file: " + err.Error()) - return + // create the stdout logger + level := parseLogLevel(cfg.Log.StdOut.Level) + logCores = append(logCores, zapcore.NewCore(encoder, zapcore.Lock(os.Stdout), level)) } - defer closeFn() - - // create the file logger - level := parseLogLevel(cfg.Log.File.Level) - logCores = append(logCores, zapcore.NewCore(encoder, zapcore.Lock(fileWriter), level)) - } - - var log *zap.Logger - if len(logCores) == 1 { - log = zap.New(logCores[0], zap.AddCaller()) - } else { - log = zap.New(zapcore.NewTee(logCores...), zap.AddCaller()) - } - defer log.Sync() - - // redirect stdlib log to zap - zap.RedirectStdLog(log.Named("stdlib")) - log.Info("hostd", zap.String("version", build.Version()), zap.String("network", build.NetworkName()), zap.String("commit", build.Commit()), zap.Time("buildDate", build.Time())) - - var seed [32]byte - if err := wallet.SeedFromPhrase(&seed, cfg.RecoveryPhrase); err != nil { - log.Fatal("failed to load wallet", zap.Error(err)) - } - walletKey := wallet.KeyFromSeed(&seed, 0) + if cfg.Log.File.Enabled { + // if no log level is set for file, use the global log level + if cfg.Log.File.Level == "" { + cfg.Log.File.Level = cfg.Log.Level + } - apiListener, err := startAPIListener(log) - if err != nil { - log.Fatal("failed to listen on API address", zap.Error(err), zap.String("address", cfg.HTTP.Address)) - } - defer apiListener.Close() + // normalize log path + if cfg.Log.File.Path == "" { + // If the log path is not set, try the deprecated log path. If that + // is also not set, default to hostd.log in the data directory. + if cfg.Log.Path != "" { + cfg.Log.File.Path = filepath.Join(cfg.Log.Path, "hostd.log") + } else { + cfg.Log.File.Path = filepath.Join(cfg.Directory, "hostd.log") + } + } - rhp3WSListener, err := net.Listen("tcp", cfg.RHP3.WebSocketAddress) - if err != nil { - log.Fatal("failed to listen on RHP3 WebSocket address", zap.Error(err), zap.String("address", cfg.RHP3.WebSocketAddress)) - } - defer rhp3WSListener.Close() + // configure file logging + var encoder zapcore.Encoder + switch cfg.Log.File.Format { + case "human": + encoder = humanEncoder(false) // disable colors in file log + default: // log file defaults to JSON + encoder = jsonEncoder() + } - var ex *explorer.Explorer - if !cfg.Explorer.Disable { - ex = explorer.New(cfg.Explorer.URL) - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() + fileWriter, closeFn, err := zap.Open(cfg.Log.File.Path) + if err != nil { + stdoutFatalError("failed to open log file: " + err.Error()) + return + } + defer closeFn() - if _, err := ex.SiacoinExchangeRate(ctx, "usd"); err != nil { - log.Error("failed to get exchange rate. explorer features may not work correctly", zap.Error(err)) + // create the file logger + level := parseLogLevel(cfg.Log.File.Level) + logCores = append(logCores, zapcore.NewCore(encoder, zapcore.Lock(fileWriter), level)) } - } - - node, hostKey, err := newNode(ctx, walletKey, ex, log) - if err != nil { - log.Fatal("failed to create node", zap.Error(err)) - } - defer node.Close() - - opts := []api.ServerOption{ - api.ServerWithAlerts(node.a), - api.ServerWithWebHooks(node.wh), - api.ServerWithSyncer(node.g), - api.ServerWithChainManager(node.cm), - api.ServerWithTransactionPool(node.tp), - api.ServerWithContractManager(node.contracts), - api.ServerWithAccountManager(node.accounts), - api.ServerWithVolumeManager(node.storage), - api.ServerWithRHPSessionReporter(node.sessions), - api.ServerWithMetricManager(node.metrics), - api.ServerWithSettings(node.settings), - api.ServerWithWallet(node.w), - api.ServerWithLogger(log.Named("api")), - } - - if !cfg.Explorer.Disable { - opts = append(opts, api.ServerWithExplorer(ex)) - opts = append(opts, api.ServerWithPinnedSettings(node.pinned)) - } - - auth := jape.BasicAuth(cfg.HTTP.Password) - web := http.Server{ - Handler: webRouter{ - api: auth(api.NewServer(cfg.Name, hostKey.PublicKey(), opts...)), - ui: hostd.Handler(), - }, - ReadTimeout: 30 * time.Second, - } - defer web.Close() - rhp3WS := http.Server{ - Handler: node.rhp3.WebSocketHandler(), - ReadTimeout: 30 * time.Second, - TLSConfig: node.settings.RHP3TLSConfig(), - ErrorLog: stdlog.New(io.Discard, "", 0), - } - defer rhp3WS.Close() - - go func() { - err := rhp3WS.ServeTLS(rhp3WSListener, "", "") - if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Error("failed to serve rhp3 websocket", zap.Error(err)) + var log *zap.Logger + if len(logCores) == 1 { + log = zap.New(logCores[0], zap.AddCaller()) + } else { + log = zap.New(zapcore.NewTee(logCores...), zap.AddCaller()) } - }() + defer log.Sync() - log.Info("hostd started", zap.String("hostKey", hostKey.PublicKey().String()), zap.String("api", apiListener.Addr().String()), zap.String("p2p", string(node.g.Address())), zap.String("rhp2", node.rhp2.LocalAddr()), zap.String("rhp3", node.rhp3.LocalAddr())) - if runtime.GOARCH == "amd64" && !cpu.X86.HasAVX2 { - log.Warn("hostd is running on a system without AVX2 support, performance may be degraded") - } + // redirect stdlib log to zap + zap.RedirectStdLog(log.Named("stdlib")) + + log.Info("hostd", zap.String("version", build.Version()), zap.String("network", cfg.Consensus.Network), zap.String("commit", build.Commit()), zap.Time("buildDate", build.Time())) - go func() { - err := web.Serve(apiListener) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Error("failed to serve web", zap.Error(err)) + var seed [32]byte + if err := wallet.SeedFromPhrase(&seed, cfg.RecoveryPhrase); err != nil { + log.Fatal("failed to load wallet", zap.Error(err)) } - }() + walletKey := wallet.KeyFromSeed(&seed, 0) - if cfg.AutoOpenWebUI { - time.Sleep(time.Millisecond) // give the web server a chance to start - _, port, err := net.SplitHostPort(apiListener.Addr().String()) - if err != nil { - log.Debug("failed to parse API address", zap.Error(err)) - } else if err := openBrowser(fmt.Sprintf("http://127.0.0.1:%s", port)); err != nil { - log.Debug("failed to open browser", zap.Error(err)) + if err := runNode(ctx, cfg, walletKey, log); err != nil { + log.Error("failed to start node", zap.Error(err)) } + default: + stdoutFatalError("unknown command: " + flag.Arg(0)) } - - <-ctx.Done() - log.Info("shutting down...") - time.AfterFunc(5*time.Minute, func() { - log.Fatal("failed to shut down within 5 minutes") - }) } diff --git a/cmd/hostd/node.go b/cmd/hostd/node.go deleted file mode 100644 index 1a3c07bf..00000000 --- a/cmd/hostd/node.go +++ /dev/null @@ -1,249 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net" - "os" - "path/filepath" - "strings" - - "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" - "go.sia.tech/hostd/host/accounts" - "go.sia.tech/hostd/host/contracts" - "go.sia.tech/hostd/host/metrics" - "go.sia.tech/hostd/host/registry" - "go.sia.tech/hostd/host/settings" - "go.sia.tech/hostd/host/settings/pin" - "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/chain" - "go.sia.tech/hostd/internal/explorer" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/rhp" - rhp2 "go.sia.tech/hostd/rhp/v2" - rhp3 "go.sia.tech/hostd/rhp/v3" - "go.sia.tech/hostd/wallet" - "go.sia.tech/hostd/webhooks" - "go.sia.tech/siad/modules" - "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" - "go.uber.org/zap" -) - -type node struct { - g modules.Gateway - a *alerts.Manager - wh *webhooks.Manager - cm *chain.Manager - tp *chain.TransactionPool - w *wallet.SingleAddressWallet - store *sqlite.Store - - metrics *metrics.MetricManager - settings *settings.ConfigManager - pinned *pin.Manager - accounts *accounts.AccountManager - contracts *contracts.ContractManager - registry *registry.Manager - storage *storage.VolumeManager - - sessions *rhp.SessionReporter - data *rhp.DataRecorder - rhp2 *rhp2.SessionHandler - rhp3 *rhp3.SessionHandler -} - -func (n *node) Close() { - n.rhp3.Close() - n.rhp2.Close() - n.data.Close() - n.registry.Close() - n.storage.Close() - n.contracts.Close() - n.w.Close() - n.tp.Close() - n.cm.Close() - n.g.Close() - n.wh.Close() - n.store.Close() -} - -func startRHP2(l net.Listener, hostKey types.PrivateKey, rhp3Addr string, cs rhp2.ChainManager, tp rhp2.TransactionPool, w rhp2.Wallet, cm rhp2.ContractManager, sr rhp2.SettingsReporter, sm rhp2.StorageManager, monitor rhp.DataMonitor, sessions *rhp.SessionReporter, log *zap.Logger) (*rhp2.SessionHandler, error) { - rhp2, err := rhp2.NewSessionHandler(l, hostKey, rhp3Addr, cs, tp, w, cm, sr, sm, monitor, sessions, log) - if err != nil { - return nil, err - } - go rhp2.Serve() - return rhp2, nil -} - -func startRHP3(l net.Listener, hostKey types.PrivateKey, cs rhp3.ChainManager, tp rhp3.TransactionPool, w rhp3.Wallet, am rhp3.AccountManager, cm rhp3.ContractManager, rm rhp3.RegistryManager, sr rhp3.SettingsReporter, sm rhp3.StorageManager, monitor rhp.DataMonitor, sessions *rhp.SessionReporter, log *zap.Logger) (*rhp3.SessionHandler, error) { - rhp3, err := rhp3.NewSessionHandler(l, hostKey, cs, tp, w, am, cm, rm, sm, sr, monitor, sessions, log) - if err != nil { - return nil, err - } - go rhp3.Serve() - return rhp3, nil -} - -func newNode(ctx context.Context, walletKey types.PrivateKey, ex *explorer.Explorer, logger *zap.Logger) (*node, types.PrivateKey, error) { - gatewayDir := filepath.Join(cfg.Directory, "gateway") - if err := os.MkdirAll(gatewayDir, 0700); err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create gateway dir: %w", err) - } - g, err := gateway.NewCustomGateway(cfg.Consensus.GatewayAddress, cfg.Consensus.Bootstrap, false, gatewayDir, modules.ProdDependencies) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create gateway: %w", err) - } - - // connect to additional peers from the config file - go func() { - for _, peer := range cfg.Consensus.Peers { - g.Connect(modules.NetAddress(peer)) - } - }() - - consensusDir := filepath.Join(cfg.Directory, "consensus") - if err := os.MkdirAll(consensusDir, 0700); err != nil { - return nil, types.PrivateKey{}, err - } - cs, errCh := consensus.New(g, cfg.Consensus.Bootstrap, consensusDir) - select { - case err := <-errCh: - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create consensus: %w", err) - } - default: - go func() { - if err := <-errCh; err != nil && !strings.Contains(err.Error(), "ThreadGroup already stopped") { - logger.Warn("consensus subscribe error", zap.Error(err)) - } - }() - } - tpoolDir := filepath.Join(cfg.Directory, "tpool") - if err := os.MkdirAll(tpoolDir, 0700); err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create tpool dir: %w", err) - } - stp, err := transactionpool.New(cs, g, tpoolDir) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create tpool: %w", err) - } - tp := chain.NewTPool(stp) - - db, err := sqlite.OpenDatabase(filepath.Join(cfg.Directory, "hostd.db"), logger.Named("sqlite")) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create sqlite store: %w", err) - } - - // load the host identity - hostKey := db.HostKey() - - cm, err := chain.NewManager(cs, tp) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create chain manager: %w", err) - } - - w, err := wallet.NewSingleAddressWallet(walletKey, cm, db, logger.Named("wallet")) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create wallet: %w", err) - } - - webhookReporter, err := webhooks.NewManager(db, logger.Named("webhooks")) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create webhook reporter: %w", err) - } - - rhp2Listener, err := net.Listen("tcp", cfg.RHP2.Address) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to listen on rhp2 addr: %w", err) - } - - rhp3Listener, err := net.Listen("tcp", cfg.RHP3.TCPAddress) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to listen on rhp3 addr: %w", err) - } - - _, rhp2Port, err := net.SplitHostPort(cfg.RHP2.Address) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to parse rhp2 addr: %w", err) - } - discoveredAddr := net.JoinHostPort(g.Address().Host(), rhp2Port) - logger.Debug("discovered address", zap.String("addr", discoveredAddr)) - - am := alerts.NewManager(webhookReporter, logger.Named("alerts")) - sr, err := settings.NewConfigManager(settings.WithHostKey(hostKey), - settings.WithStore(db), - settings.WithChainManager(cm), - settings.WithTransactionPool(tp), - settings.WithWallet(w), - settings.WithAlertManager(am), - settings.WithLog(logger.Named("settings"))) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create settings manager: %w", err) - } - - var pm *pin.Manager - if !cfg.Explorer.Disable { - pm, err = pin.NewManager( - pin.WithAlerts(am), - pin.WithStore(db), - pin.WithSettings(sr), - pin.WithExchangeRateRetriever(ex), - pin.WithLogger(logger.Named("pin"))) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create pin manager: %w", err) - } - go pm.Run(ctx) - } - - accountManager := accounts.NewManager(db, sr) - - sm, err := storage.NewVolumeManager(db, am, cm, logger.Named("volumes"), sr.Settings().SectorCacheSize) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create storage manager: %w", err) - } - - contractManager, err := contracts.NewManager(db, am, sm, cm, tp, w, logger.Named("contracts")) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to create contract manager: %w", err) - } - registryManager := registry.NewManager(hostKey, db, logger.Named("registry")) - - sessions := rhp.NewSessionReporter() - - dm := rhp.NewDataRecorder(db, logger.Named("data")) - rhp2, err := startRHP2(rhp2Listener, hostKey, rhp3Listener.Addr().String(), cm, tp, w, contractManager, sr, sm, dm, sessions, logger.Named("rhp2")) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to start rhp2: %w", err) - } - - rhp3, err := startRHP3(rhp3Listener, hostKey, cm, tp, w, accountManager, contractManager, registryManager, sr, sm, dm, sessions, logger.Named("rhp3")) - if err != nil { - return nil, types.PrivateKey{}, fmt.Errorf("failed to start rhp3: %w", err) - } - - return &node{ - g: g, - a: am, - wh: webhookReporter, - cm: cm, - tp: tp, - w: w, - store: db, - - metrics: metrics.NewManager(db), - settings: sr, - pinned: pm, - accounts: accountManager, - contracts: contractManager, - storage: sm, - registry: registryManager, - - sessions: sessions, - data: dm, - rhp2: rhp2, - rhp3: rhp3, - }, hostKey, nil -} diff --git a/cmd/hostd/run.go b/cmd/hostd/run.go new file mode 100644 index 00000000..680b261e --- /dev/null +++ b/cmd/hostd/run.go @@ -0,0 +1,340 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "strconv" + "time" + + "go.sia.tech/core/consensus" + "go.sia.tech/core/gateway" + "go.sia.tech/core/types" + "go.sia.tech/coreutils" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/alerts" + "go.sia.tech/hostd/api" + "go.sia.tech/hostd/config" + "go.sia.tech/hostd/explorer" + "go.sia.tech/hostd/host/accounts" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/host/registry" + "go.sia.tech/hostd/host/settings" + "go.sia.tech/hostd/host/settings/pin" + "go.sia.tech/hostd/host/storage" + "go.sia.tech/hostd/index" + "go.sia.tech/hostd/persist/sqlite" + "go.sia.tech/hostd/rhp" + rhp2 "go.sia.tech/hostd/rhp/v2" + rhp3 "go.sia.tech/hostd/rhp/v3" + "go.sia.tech/hostd/webhooks" + "go.sia.tech/jape" + "go.sia.tech/web/hostd" + "go.uber.org/zap" + "lukechampine.com/upnp" +) + +func setupUPNP(ctx context.Context, port uint16, log *zap.Logger) (string, error) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + d, err := upnp.Discover(ctx) + if err != nil { + return "", fmt.Errorf("couldn't discover UPnP router: %w", err) + } else if !d.IsForwarded(port, "TCP") { + if err := d.Forward(uint16(port), "TCP", "walletd"); err != nil { + log.Debug("couldn't forward port", zap.Error(err)) + } else { + log.Debug("upnp: forwarded p2p port", zap.Uint16("port", port)) + } + } + return d.ExternalIP() +} + +// migrateDB renames hostd.db to hostd.sqlite3 to be explicit about the store +func migrateDB(dir string) error { + oldPath := filepath.Join(dir, "hostd.db") + newPath := filepath.Join(dir, "hostd.sqlite3") + if _, err := os.Stat(oldPath); errors.Is(err, os.ErrNotExist) { + return nil + } else if err != nil { + return err + } + return os.Rename(oldPath, newPath) +} + +// deleteSiadData deletes the siad specific databases if they exist +func deleteSiadData(dir string) error { + paths := []string{ + filepath.Join(dir, "consensus"), + filepath.Join(dir, "gateway"), + filepath.Join(dir, "tpool"), + } + + for _, path := range paths { + if dir, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { + continue + } else if err != nil { + return err + } else if !dir.IsDir() { + return fmt.Errorf("expected %s to be a directory", path) + } + + if err := os.RemoveAll(path); err != nil { + return fmt.Errorf("failed to delete %s: %w", path, err) + } + } + return nil +} + +// startLocalhostListener https://github.com/SiaFoundation/hostd/issues/202 +func startLocalhostListener(listenAddr string, log *zap.Logger) (l net.Listener, err error) { + addr, port, err := net.SplitHostPort(listenAddr) + if err != nil { + return nil, fmt.Errorf("failed to parse API address: %w", err) + } + + // if the address is not localhost, listen on the address as-is + if addr != "localhost" { + return net.Listen("tcp", listenAddr) + } + + // localhost fails on some new installs of Windows 11, so try a few + // different addresses + tryAddresses := []string{ + net.JoinHostPort("localhost", port), // original address + net.JoinHostPort("127.0.0.1", port), // IPv4 loopback + net.JoinHostPort("::1", port), // IPv6 loopback + } + + for _, addr := range tryAddresses { + l, err = net.Listen("tcp", addr) + if err == nil { + return + } + log.Debug("failed to listen on fallback address", zap.String("address", addr), zap.Error(err)) + } + return +} + +func runNode(ctx context.Context, cfg config.Config, walletKey types.PrivateKey, log *zap.Logger) error { + if err := migrateDB(cfg.Directory); err != nil { + return fmt.Errorf("failed to migrate database: %w", err) + } else if err := deleteSiadData(cfg.Directory); err != nil { + return fmt.Errorf("failed to migrate v1 consensus database: %w", err) + } + + store, err := sqlite.OpenDatabase(filepath.Join(cfg.Directory, "hostd.sqlite3"), log.Named("sqlite3")) + if err != nil { + return fmt.Errorf("failed to open database: %w", err) + } + defer store.Close() + + // load the host identity + hostKey := store.HostKey() + + var network *consensus.Network + var genesisBlock types.Block + switch cfg.Consensus.Network { + case "mainnet": + network, genesisBlock = chain.Mainnet() + if cfg.Syncer.Bootstrap { + cfg.Syncer.Peers = append(cfg.Syncer.Peers, syncer.MainnetBootstrapPeers...) + } + case "zen": + network, genesisBlock = chain.TestnetZen() + if cfg.Syncer.Bootstrap { + cfg.Syncer.Peers = append(cfg.Syncer.Peers, syncer.ZenBootstrapPeers...) + } + default: + return errors.New("invalid network: must be one of 'mainnet' or 'zen'") + } + + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(cfg.Directory, "consensus.db")) + if err != nil { + return fmt.Errorf("failed to open consensus database: %w", err) + } + defer bdb.Close() + + dbstore, tipState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + return fmt.Errorf("failed to create chain store: %w", err) + } + cm := chain.NewManager(dbstore, tipState) + + httpListener, err := startLocalhostListener(cfg.HTTP.Address, log.Named("listener")) + if err != nil { + return fmt.Errorf("failed to listen on http address: %w", err) + } + defer httpListener.Close() + + syncerListener, err := net.Listen("tcp", cfg.Syncer.Address) + if err != nil { + return fmt.Errorf("failed to listen on syncer address: %w", err) + } + defer syncerListener.Close() + + rhp2Listener, err := net.Listen("tcp", cfg.RHP2.Address) + if err != nil { + return fmt.Errorf("failed to listen on rhp2 addr: %w", err) + } + defer rhp2Listener.Close() + + rhp3Listener, err := net.Listen("tcp", cfg.RHP3.TCPAddress) + if err != nil { + return fmt.Errorf("failed to listen on rhp3 addr: %w", err) + } + defer rhp3Listener.Close() + + syncerAddr := syncerListener.Addr().String() + if cfg.Syncer.EnableUPnP { + _, portStr, _ := net.SplitHostPort(cfg.Syncer.Address) + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return fmt.Errorf("failed to parse syncer port: %w", err) + } + + ip, err := setupUPNP(context.Background(), uint16(port), log) + if err != nil { + log.Warn("failed to set up UPnP", zap.Error(err)) + } else { + syncerAddr = net.JoinHostPort(ip, portStr) + } + } + // peers will reject us if our hostname is empty or unspecified, so use loopback + host, port, _ := net.SplitHostPort(syncerAddr) + if ip := net.ParseIP(host); ip == nil || ip.IsUnspecified() { + syncerAddr = net.JoinHostPort("127.0.0.1", port) + } + + ps, err := sqlite.NewPeerStore(store) + if err != nil { + return fmt.Errorf("failed to create peer store: %w", err) + } + for _, peer := range cfg.Syncer.Peers { + if err := ps.AddPeer(peer); err != nil { + log.Warn("failed to add peer", zap.String("address", peer), zap.Error(err)) + } + } + + log.Debug("starting syncer", zap.String("syncer address", syncerAddr)) + s := syncer.New(syncerListener, cm, ps, gateway.Header{ + GenesisID: genesisBlock.ID(), + UniqueID: gateway.GenerateUniqueID(), + NetAddress: syncerAddr, + }, syncer.WithLogger(log.Named("syncer"))) + go s.Run(ctx) + defer s.Close() + + wm, err := wallet.NewSingleAddressWallet(walletKey, cm, store, wallet.WithLogger(log.Named("wallet")), wallet.WithReservationDuration(3*time.Hour)) + if err != nil { + return fmt.Errorf("failed to create wallet: %w", err) + } + defer wm.Close() + + wr, err := webhooks.NewManager(store, log.Named("webhooks")) + if err != nil { + return fmt.Errorf("failed to create webhook reporter: %w", err) + } + defer wr.Close() + sr := rhp.NewSessionReporter() + + am := alerts.NewManager(alerts.WithEventReporter(wr), alerts.WithLog(log.Named("alerts"))) + + cfm, err := settings.NewConfigManager(hostKey, store, cm, s, wm, settings.WithAlertManager(am), settings.WithLog(log.Named("settings"))) + if err != nil { + return fmt.Errorf("failed to create settings manager: %w", err) + } + defer cfm.Close() + + vm, err := storage.NewVolumeManager(store, storage.WithLogger(log.Named("volumes")), storage.WithAlerter(am)) + if err != nil { + return fmt.Errorf("failed to create storage manager: %w", err) + } + defer vm.Close() + + contractManager, err := contracts.NewManager(store, vm, cm, s, wm, contracts.WithLog(log.Named("contracts")), contracts.WithAlerter(am)) + if err != nil { + return fmt.Errorf("failed to create contracts manager: %w", err) + } + defer contractManager.Close() + + index, err := index.NewManager(store, cm, contractManager, wm, cfm, vm, index.WithLog(log.Named("index")), index.WithBatchSize(cfg.Consensus.IndexBatchSize)) + if err != nil { + return fmt.Errorf("failed to create index manager: %w", err) + } + defer index.Close() + + dr := rhp.NewDataRecorder(store, log.Named("data")) + + rhp2, err := rhp2.NewSessionHandler(rhp2Listener, hostKey, rhp3Listener.Addr().String(), cm, s, wm, contractManager, cfm, vm, rhp2.WithDataMonitor(dr), rhp2.WithLog(log.Named("rhp2"))) + if err != nil { + return fmt.Errorf("failed to create rhp2 session handler: %w", err) + } + go rhp2.Serve() + defer rhp2.Close() + + registry := registry.NewManager(hostKey, store, log.Named("registry")) + accounts := accounts.NewManager(store, cfm) + rhp3, err := rhp3.NewSessionHandler(rhp3Listener, hostKey, cm, s, wm, accounts, contractManager, registry, vm, cfm, rhp3.WithDataMonitor(dr), rhp3.WithSessionReporter(sr), rhp3.WithLog(log.Named("rhp3"))) + if err != nil { + return fmt.Errorf("failed to create rhp3 session handler: %w", err) + } + go rhp3.Serve() + defer rhp3.Close() + + apiOpts := []api.ServerOption{ + api.ServerWithAlerts(am), + api.ServerWithLogger(log.Named("api")), + api.ServerWithRHPSessionReporter(sr), + api.ServerWithWebhooks(wr), + } + if !cfg.Explorer.Disable { + ex := explorer.New(cfg.Explorer.URL) + pm, err := pin.NewManager(store, cfm, ex, pin.WithLogger(log.Named("pin"))) + if err != nil { + return fmt.Errorf("failed to create pin manager: %w", err) + } + + apiOpts = append(apiOpts, api.ServerWithPinnedSettings(pm), api.ServerWithExplorer(ex)) + } + + web := http.Server{ + Handler: webRouter{ + api: jape.BasicAuth(cfg.HTTP.Password)(api.NewServer(cfg.Name, hostKey.PublicKey(), cm, s, accounts, contractManager, vm, wm, store, cfm, index, apiOpts...)), + ui: hostd.Handler(), + }, + ReadTimeout: 30 * time.Second, + } + defer web.Close() + + go func() { + log.Debug("starting http server", zap.String("address", cfg.HTTP.Address)) + if err := web.Serve(httpListener); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Error("http server failed", zap.Error(err)) + } + }() + + if cfg.AutoOpenWebUI { + time.Sleep(time.Millisecond) // give the web server a chance to start + _, port, err := net.SplitHostPort(httpListener.Addr().String()) + if err != nil { + log.Debug("failed to parse API address", zap.Error(err)) + } else if err := openBrowser(fmt.Sprintf("http://127.0.0.1:%s", port)); err != nil { + log.Debug("failed to open browser", zap.Error(err)) + } + } + + log.Info("node started", zap.String("network", cm.TipState().Network.Name), zap.String("hostKey", hostKey.PublicKey().String()), zap.String("http", httpListener.Addr().String()), zap.String("p2p", string(s.Addr())), zap.String("rhp2", rhp2.LocalAddr()), zap.String("rhp3", rhp3.LocalAddr())) + <-ctx.Done() + log.Info("shutting down...") + time.AfterFunc(5*time.Minute, func() { + log.Fatal("failed to shut down within 5 minutes") + }) + return nil +} diff --git a/config/config.go b/config/config.go index 847912e0..38576cc8 100644 --- a/config/config.go +++ b/config/config.go @@ -7,11 +7,18 @@ type ( Password string `yaml:"password,omitempty"` } + // Syncer contains the configuration for the p2p syncer. + Syncer struct { + Address string `yaml:"address,omitempty"` + Bootstrap bool `yaml:"bootstrap,omitempty"` + EnableUPnP bool `yaml:"enableUPnP,omitempty"` + Peers []string `yaml:"peers,omitempty"` + } + // Consensus contains the configuration for the consensus set. Consensus struct { - GatewayAddress string `yaml:"gatewayAddress,omitempty"` - Bootstrap bool `yaml:"bootstrap,omitempty"` - Peers []string `yaml:"peers,omitempty"` + Network string `yaml:"network,omitempty"` + IndexBatchSize int `yaml:"indexBatchSize,omitempty"` } // RHP2 contains the configuration for the RHP2 server. @@ -27,10 +34,7 @@ type ( // RHP3 contains the configuration for the RHP3 server. RHP3 struct { - TCPAddress string `yaml:"tcp,omitempty"` - WebSocketAddress string `yaml:"websocket,omitempty"` - CertPath string `yaml:"certPath,omitempty"` - KeyPath string `yaml:"keyPath,omitempty"` + TCPAddress string `yaml:"tcp,omitempty"` } // LogFile configures the file output of the logger. @@ -68,6 +72,7 @@ type ( AutoOpenWebUI bool `yaml:"autoOpenWebUI,omitempty"` HTTP HTTP `yaml:"http,omitempty"` + Syncer Syncer `yaml:"syncer,omitempty"` Consensus Consensus `yaml:"consensus,omitempty"` Explorer ExplorerData `yaml:"explorer,omitempty"` RHP2 RHP2 `yaml:"rhp2,omitempty"` diff --git a/docker/Dockerfile.testnet b/docker/Dockerfile.testnet deleted file mode 100644 index 7bda2e80..00000000 --- a/docker/Dockerfile.testnet +++ /dev/null @@ -1,46 +0,0 @@ -FROM docker.io/library/golang:1.21 AS builder - -WORKDIR /hostd - -# get dependencies -COPY go.mod go.sum ./ -RUN go mod download - -# copy source -COPY . . -# codegen -RUN go generate ./... -# build -RUN CGO_ENABLED=1 go build -o bin/ -tags='netgo timetzdata testnet' -trimpath -a -ldflags '-s -w -linkmode external -extldflags "-static"' ./cmd/hostd - -FROM docker.io/library/alpine:3 -LABEL maintainer="The Sia Foundation " \ - org.opencontainers.image.description.vendor="The Sia Foundation" \ - org.opencontainers.image.description="A hostd container - provide storage on the Sia Zen testnet and earn testnet Siacoin" \ - org.opencontainers.image.source="https://github.com/SiaFoundation/hostd" \ - org.opencontainers.image.licenses=MIT - -ENV PUID=0 -ENV PGID=0 - -ENV HOSTD_ZEN_API_PASSWORD= -ENV HOSTD_ZEN_SEED= -ENV HOSTD_ZEN_CONFIG_FILE=/data/hostd.yml - -COPY --from=builder /hostd/bin/* /usr/bin/ -VOLUME [ "/data" ] - -# API port -EXPOSE 9880/tcp -# RPC port -EXPOSE 9881/tcp -# RHP2 port -EXPOSE 9882/tcp -# RHP3 TCP port -EXPOSE 9883/tcp -# RHP3 WebSocket port -EXPOSE 9884/tcp - -USER ${PUID}:${PGID} - -ENTRYPOINT [ "hostd", "--env", "--dir", "/data", "--http", ":9880" ] \ No newline at end of file diff --git a/internal/explorer/explorer.go b/explorer/explorer.go similarity index 100% rename from internal/explorer/explorer.go rename to explorer/explorer.go diff --git a/go.mod b/go.mod index 23a5cf1a..6710c99c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module go.sia.tech/hostd -go 1.21.8 - -toolchain go1.22.3 +go 1.23.0 require ( github.com/aws/aws-sdk-go v1.55.5 @@ -10,12 +8,10 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/mattn/go-sqlite3 v1.14.22 github.com/shopspring/decimal v1.4.0 - gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe - go.sia.tech/core v0.4.4 - go.sia.tech/coreutils v0.3.0 - go.sia.tech/jape v0.12.0 - go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca - go.sia.tech/web/hostd v0.45.1 + go.sia.tech/core v0.4.5-0.20240831170056-a91ba45f50d6 + go.sia.tech/coreutils v0.3.1-0.20240831170413-70443a42b7b2 + go.sia.tech/jape v0.12.1 + go.sia.tech/web/hostd v0.46.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 golang.org/x/sys v0.24.0 @@ -23,35 +19,19 @@ require ( golang.org/x/time v0.6.0 gopkg.in/yaml.v3 v3.0.1 lukechampine.com/frand v1.4.2 - nhooyr.io/websocket v1.8.17 + lukechampine.com/upnp v0.3.0 + nhooyr.io/websocket v1.8.11 ) require ( github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect - github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/gorilla/websocket v1.5.2 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/julienschmidt/httprouter v1.3.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.8 // indirect - github.com/klauspost/reedsolomon v1.12.1 // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - gitlab.com/NebulousLabs/bolt v1.4.4 // indirect - gitlab.com/NebulousLabs/demotemutex v0.0.0-20151003192217-235395f71c40 // indirect - gitlab.com/NebulousLabs/entropy-mnemonics v0.0.0-20181018051301-7532f67e3500 // indirect - gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975 // indirect - gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40 // indirect - gitlab.com/NebulousLabs/go-upnp v0.0.0-20211002182029-11da932010b6 // indirect - gitlab.com/NebulousLabs/log v0.0.0-20210609172545-77f6775350e2 // indirect - gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4 // indirect - gitlab.com/NebulousLabs/monitor v0.0.0-20191205095550-2b0fd3e1012a // indirect - gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877 // indirect - gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e // indirect - gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259 // indirect - gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213 // indirect + go.etcd.io/bbolt v1.3.11 // indirect go.sia.tech/mux v1.2.0 // indirect go.sia.tech/web v0.0.0-20240610131903-5611d44a533e // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/go.sum b/go.sum index 0114cdb2..ddf26d4a 100644 --- a/go.sum +++ b/go.sum @@ -1,306 +1,96 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= -github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudflare/cloudflare-go v0.102.0 h1:+0MGbkirM/yzVLOYpWMgW7CDdKzesSbdwA2Y+rABrWI= github.com/cloudflare/cloudflare-go v0.102.0/go.mod h1:BOB41tXf31ti/qtBO9paYhyapotQbGRDbQoLOAF7pSg= -github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf h1:K5VXW9LjmJv/xhjvQcNWTdk4WOSyreil6YaubuCPeRY= -github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf/go.mod h1:bXVurdTuvOiJu7NHALemFe0JMvC2UmwYHW+7fcZaZ2M= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= -github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/gorilla/websocket v1.5.2 h1:qoW6V1GT3aZxybsbC6oLnailWnB+qTMVwMreOso9XUw= -github.com/gorilla/websocket v1.5.2/go.mod h1:0n9H61RBAcf5/38py2MCYbxzPIY9rOkpvvMT24Rqs30= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok= -github.com/hanwen/go-fuse/v2 v2.1.0/go.mod h1:oRyA5eK+pvJyv5otpO/DgccS8y/RvYMaO00GgRLGryc= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf/go.mod h1:hyb9oH7vZsitZCiBt0ZvifOrB+qc8PS5IiilCIb87rg= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= -github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/cpuid v1.2.2/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= -github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= -github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/klauspost/reedsolomon v1.9.3/go.mod h1:CwCi+NUr9pqSVktrkN+Ondf06rkhYZ/pcNv7fu+8Un4= -github.com/klauspost/reedsolomon v1.12.1 h1:NhWgum1efX1x58daOBGCFWcxtEhOhXKKl1HAPQUp03Q= -github.com/klauspost/reedsolomon v1.12.1/go.mod h1:nEi5Kjb6QqtbofI6s+cbG/j1da11c96IBYBSnVGtuBs= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= -github.com/vbauerster/mpb/v5 v5.0.3/go.mod h1:h3YxU5CSr8rZP4Q3xZPVB3jJLhWPou63lHEdr9ytH4Y= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -gitlab.com/NebulousLabs/bolt v1.4.4 h1:3UhpR2qtHs87dJBE3CIzhw48GYSoUUNByJmic0cbu1w= -gitlab.com/NebulousLabs/bolt v1.4.4/go.mod h1:ZL02cwhpLNif6aruxvUMqu/Bdy0/lFY21jMFfNAA+O8= -gitlab.com/NebulousLabs/demotemutex v0.0.0-20151003192217-235395f71c40 h1:IbucNi8u1a1ErgVFVgg8pERhSyzYe5l+o8krDMnNjWA= -gitlab.com/NebulousLabs/demotemutex v0.0.0-20151003192217-235395f71c40/go.mod h1:HfnnxM8isYA7FUlqS5h34XTeiBhPtcuCquVujKsn9aw= -gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe h1:vylvMCgxVPYojpQ2p536xDooW/B3znEnw58mCxrlZow= -gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe/go.mod h1:Gi3CPCauIWmGp7YrnV/mKZ8qkD/N/LrunGNc8QmsVkU= -gitlab.com/NebulousLabs/entropy-mnemonics v0.0.0-20181018051301-7532f67e3500 h1:BUDZfLl/9IRseYl7/GW1DF+11SYCMJ6P4whCBJhtEhQ= -gitlab.com/NebulousLabs/entropy-mnemonics v0.0.0-20181018051301-7532f67e3500/go.mod h1:4koft3fRXTETovKPTeX/Aggj+ajCGWCcuuBBc598Pcs= -gitlab.com/NebulousLabs/errors v0.0.0-20171229012116-7ead97ef90b8/go.mod h1:ZkMZ0dpQyWwlENaeZVBiQRjhMEZvk6VTXquzl3FOFP8= -gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975 h1:L/ENs/Ar1bFzUeKx6m3XjlmBgIUlykX9dzvp5k9NGxc= -gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975/go.mod h1:ZkMZ0dpQyWwlENaeZVBiQRjhMEZvk6VTXquzl3FOFP8= -gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40 h1:dizWJqTWjwyD8KGcMOwgrkqu1JIkofYgKkmDeNE7oAs= -gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40/go.mod h1:rOnSnoRyxMI3fe/7KIbVcsHRGxe30OONv8dEgo+vCfA= -gitlab.com/NebulousLabs/go-upnp v0.0.0-20181011194642-3a71999ed0d3/go.mod h1:sleOmkovWsDEQVYXmOJhx69qheoMTmCuPYyiCFCihlg= -gitlab.com/NebulousLabs/go-upnp v0.0.0-20211002182029-11da932010b6 h1:WKij6HF8ECp9E7K0E44dew9NrRDGiNR5u4EFsXnJUx4= -gitlab.com/NebulousLabs/go-upnp v0.0.0-20211002182029-11da932010b6/go.mod h1:vhrHTGDh4YR7wK8Z+kRJ+x8SF/6RUM3Vb64Si5FD0L8= -gitlab.com/NebulousLabs/log v0.0.0-20200529173103-40b250c2d92c/go.mod h1:qOhJbQ7Vzw+F+RCVmpPZ7WAwBIM9PZv4tWKp6Kgd9CY= -gitlab.com/NebulousLabs/log v0.0.0-20200604091839-0ba4a941cdc2/go.mod h1:qOhJbQ7Vzw+F+RCVmpPZ7WAwBIM9PZv4tWKp6Kgd9CY= -gitlab.com/NebulousLabs/log v0.0.0-20210609172545-77f6775350e2 h1:ovh05+n1jw7R9KT3qa5kdK4T26fIKyVogws06goZ5+Y= -gitlab.com/NebulousLabs/log v0.0.0-20210609172545-77f6775350e2/go.mod h1:qOhJbQ7Vzw+F+RCVmpPZ7WAwBIM9PZv4tWKp6Kgd9CY= -gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4 h1:iuNdBfBg0umjOvrEf9MxGzK+NwAyE2oCZjDqUx9zVFs= -gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4/go.mod h1:0cjDwhA+Pv9ZQXHED7HUSS3sCvo2zgsoaMgE7MeGBWo= -gitlab.com/NebulousLabs/monitor v0.0.0-20191205095550-2b0fd3e1012a h1:fs891phmYZrVdaCVPXfHGDMpV5LWPKvnOMjx70EpJkw= -gitlab.com/NebulousLabs/monitor v0.0.0-20191205095550-2b0fd3e1012a/go.mod h1:QxXtb5hIp2xQkfb+lzBDIqQIGEj22U7AkYCXO3hkhqc= -gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877 h1:BGJ+na/hpeAV6WR8Pys9bJM2ynEwKmT6+qgF8pn01fM= -gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877/go.mod h1:KT2SgNX75xjMIQdDi3Rf3tcDWsX/D289R65Ss/7lKBg= -gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e h1:sMZdmPFduUilFk8Ed1Ya/DP0gVfUbGhLlNtLG2tONYk= -gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e/go.mod h1:HVrehlTxX2hYjsrL1k0WK43OZ0NGZfGvqzPL+n0/zrM= -gitlab.com/NebulousLabs/siamux v0.0.0-20200723083235-f2c35a421446/go.mod h1:B0RyynPElUG2Y2CAVIIRriIqR9qht2I+nDisi3gfKn0= -gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259 h1:94CUlkiIN8Mu+hTYgT7n36SbJ7WR6ZMM91ReaSDxUlQ= -gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259/go.mod h1:owMSVLlMMCSK6tfhfSshZhrsIFCUNvQEsiGZoWhaXcc= -gitlab.com/NebulousLabs/threadgroup v0.0.0-20200527092543-afa01960408c/go.mod h1:av52iTyGuPtGU+GMcqfGtZu2vxhIjPgrxvIwVYelEvs= -gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213 h1:owERlKtUEFTPQ897iiqWPOuWBdq7BYqPxDOCgEZnbN4= -gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213/go.mod h1:vIutAvl7lmJqLVYTCBY5WDdJomP+V74At8LCeEYoH8w= -gitlab.com/NebulousLabs/writeaheadlog v0.0.0-20200618142844-c59a90f49130/go.mod h1:SxigdS5Q1ui+OMgGAXt1E/Fg3RB6PvKXMov2O3gvIzs= -go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= -go.etcd.io/bbolt v1.3.10/go.mod h1:bK3UQLPJZly7IlNmV7uVHJDxfe5aK9Ll93e/74Y9oEQ= -go.sia.tech/core v0.4.4 h1:DYb0/DxgACstJUGgsRJIVtrsTC0mk6GfA6pTxQwzKV0= -go.sia.tech/core v0.4.4/go.mod h1:Zuq0Tn2aIXJyO0bjGu8cMeVWe+vwQnUfZhG1LCmjD5c= -go.sia.tech/coreutils v0.3.0 h1:TutrhfNe8hq0GxWcibSRIVZQpFpBoKId7pFjxdvDIR8= -go.sia.tech/coreutils v0.3.0/go.mod h1:8DNsiy6Xon5R9M/FnaSzAi2wcATh98EsDV3N6iGq4yI= -go.sia.tech/jape v0.12.0 h1:13fBi7c5X8zxTQ05Cd9ZsIfRJgdvGoZqbEzH861z7BU= -go.sia.tech/jape v0.12.0/go.mod h1:wU+h6Wh5olDjkPXjF0tbZ1GDgoZ6VTi4naFw91yyWC4= +go.etcd.io/bbolt v1.3.11 h1:yGEzV1wPz2yVCLsD8ZAiGHhHVlczyC9d1rP43/VCRJ0= +go.etcd.io/bbolt v1.3.11/go.mod h1:dksAq7YMXoljX0xu6VF5DMZGbhYYoLUalEiSySYAS4I= +go.sia.tech/core v0.4.5-0.20240831170056-a91ba45f50d6 h1:ad4dKiZnjumnAZvpy9x+yK+djIQIJYXjmazH5jxcweE= +go.sia.tech/core v0.4.5-0.20240831170056-a91ba45f50d6/go.mod h1:Zuq0Tn2aIXJyO0bjGu8cMeVWe+vwQnUfZhG1LCmjD5c= +go.sia.tech/coreutils v0.3.1-0.20240831170413-70443a42b7b2 h1:ibPgvT9ICWvKKqj7Avk6jpz1BgxyiFsKEJEPM+powzA= +go.sia.tech/coreutils v0.3.1-0.20240831170413-70443a42b7b2/go.mod h1:yI2eAuWhkjpSfDsgsczCkeruBkoc1guAI226ZzGqqZk= +go.sia.tech/jape v0.12.1 h1:xr+o9V8FO8ScRqbSaqYf9bjj1UJ2eipZuNcI1nYousU= +go.sia.tech/jape v0.12.1/go.mod h1:wU+h6Wh5olDjkPXjF0tbZ1GDgoZ6VTi4naFw91yyWC4= go.sia.tech/mux v1.2.0 h1:ofa1Us9mdymBbGMY2XH/lSpY8itFsKIo/Aq8zwe+GHU= go.sia.tech/mux v1.2.0/go.mod h1:Yyo6wZelOYTyvrHmJZ6aQfRoer3o4xyKQ4NmQLJrBSo= -go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca h1:aZMg2AKevn7jKx+wlusWQfwSM5pNU9aGtRZme29q3O4= -go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca/go.mod h1:h/1afFwpxzff6/gG5i1XdAgPK7dEY6FaibhK7N5F86Y= go.sia.tech/web v0.0.0-20240610131903-5611d44a533e h1:oKDz6rUExM4a4o6n/EXDppsEka2y/+/PgFOZmHWQRSI= go.sia.tech/web v0.0.0-20240610131903-5611d44a533e/go.mod h1:4nyDlycPKxTlCqvOeRO0wUfXxyzWCEE7+2BRrdNqvWk= -go.sia.tech/web/hostd v0.45.1 h1:G6E2f48U6OwsWzld58ubmUbS7fXBBpqoltHbQvvgC4g= -go.sia.tech/web/hostd v0.45.1/go.mod h1:ie6ujZp3tziLJgNQ3p8PHCuRXpWIvVznKpqFdng+x04= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.sia.tech/web/hostd v0.46.0 h1:gCdLkrg31H6sROdAjOaxmfyDatWRS4EpfKImb1Sf21o= +go.sia.tech/web/hostd v0.46.0/go.mod h1:ie6ujZp3tziLJgNQ3p8PHCuRXpWIvVznKpqFdng+x04= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191105034135-c7e5f84aec59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200109152110-61a87790db17/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210421210424-b80969c67360/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw= lukechampine.com/frand v1.4.2/go.mod h1:4S/TM2ZgrKejMcKMbeLjISpJMO+/eZ1zu3vYX9dtj3s= -nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= -nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= +lukechampine.com/upnp v0.3.0 h1:UVCD6eD6fmJmwak6DVE3vGN+L46Fk8edTcC6XYCb6C4= +lukechampine.com/upnp v0.3.0/go.mod h1:sOuF+fGSDKjpUm6QI0mfb82ScRrhj8bsqsD78O5nK1k= +nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0= +nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/host/accounts/accounts_test.go b/host/accounts/accounts_test.go index b2ee8599..4bd88504 100644 --- a/host/accounts/accounts_test.go +++ b/host/accounts/accounts_test.go @@ -1,23 +1,14 @@ package accounts_test import ( - "path/filepath" "testing" "time" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/host/accounts" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/settings" - "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/chain" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/wallet" - "go.sia.tech/hostd/webhooks" - "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" + "go.sia.tech/hostd/internal/testutil" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) @@ -34,61 +25,8 @@ func (s ephemeralSettings) Settings() settings.Settings { func TestCredit(t *testing.T) { log := zaptest.NewLogger(t) - dir := t.TempDir() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("accounts")) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - if err := <-errCh; err != nil { - t.Fatal(err) - } - defer cs.Close() - - stp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - tp := chain.NewTPool(stp) - defer tp.Close() - - cm, err := chain.NewManager(cs, tp) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - - w, err := wallet.NewSingleAddressWallet(types.NewPrivateKeyFromSeed(frand.Bytes(32)), cm, db, log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - a := alerts.NewManager(webhookReporter, log.Named("alerts")) - sm, err := storage.NewVolumeManager(db, a, cm, log.Named("storage"), 0) - if err != nil { - t.Fatal(err) - } - defer sm.Close() - - com, err := contracts.NewManager(db, a, sm, cm, tp, w, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer cm.Close() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, types.GeneratePrivateKey(), network, genesis, log) rev := contracts.SignedRevision{ Revision: types.FileContractRevision{ @@ -101,11 +39,11 @@ func TestCredit(t *testing.T) { }, }, } - if err := com.AddContract(rev, []types.Transaction{{}}, types.Siacoins(1), contracts.Usage{}); err != nil { + if err := node.Contracts.AddContract(rev, []types.Transaction{{}}, types.Siacoins(1), contracts.Usage{}); err != nil { t.Fatal(err) } - am := accounts.NewManager(db, ephemeralSettings{maxBalance: types.NewCurrency64(100)}) + am := accounts.NewManager(node.Store, ephemeralSettings{maxBalance: types.NewCurrency64(100)}) accountID := frand.Entropy256() // attempt to credit the account @@ -120,7 +58,7 @@ func TestCredit(t *testing.T) { } if _, err := am.Credit(req, false); err != nil { t.Fatal("expected successful credit", err) - } else if balance, err := db.AccountBalance(accountID); err != nil { + } else if balance, err := node.Store.AccountBalance(accountID); err != nil { t.Fatal("expected successful balance", err) } else if balance.Cmp(amount) != 0 { t.Fatalf("expected balance %v to be equal to amount %v", balance, amount) @@ -134,14 +72,14 @@ func TestCredit(t *testing.T) { t.Fatalf("expected funding amount to be %v, got %v", expectedFunding, sources[0].Amount) } - contract, err := com.Contract(rev.Revision.ParentID) + contract, err := node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if !contract.Usage.AccountFunding.Equals(expectedFunding) { t.Fatalf("expected contract usage to be %v, got %v", expectedFunding, contract.Usage.AccountFunding) } - if m, err := db.Metrics(time.Now()); err != nil { + if m, err := node.Store.Metrics(time.Now()); err != nil { t.Fatal(err) } else if !m.Accounts.Balance.Equals(expectedFunding) { t.Fatalf("expected account balance to be %v, got %v", expectedFunding, m.Accounts.Balance) @@ -170,14 +108,14 @@ func TestCredit(t *testing.T) { t.Fatalf("expected funding amount to be %v, got %v", expectedFunding, sources[0].Amount) } - contract, err = com.Contract(rev.Revision.ParentID) + contract, err = node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if !contract.Usage.AccountFunding.Equals(expectedFunding) { t.Fatalf("expected contract usage to be %v, got %v", expectedFunding, contract.Usage.AccountFunding) } - if m, err := db.Metrics(time.Now()); err != nil { + if m, err := node.Store.Metrics(time.Now()); err != nil { t.Fatal(err) } else if !m.Accounts.Balance.Equals(expectedFunding) { t.Fatalf("expected account balance to be %v, got %v", expectedFunding, m.Accounts.Balance) @@ -199,14 +137,14 @@ func TestCredit(t *testing.T) { t.Fatalf("expected funding amount to be %v, got %v", expectedFunding, sources[0].Amount) } - contract, err = com.Contract(rev.Revision.ParentID) + contract, err = node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if !contract.Usage.AccountFunding.Equals(expectedFunding) { t.Fatalf("expected contract usage to be %v, got %v", expectedFunding, contract.Usage.AccountFunding) } - if m, err := db.Metrics(time.Now()); err != nil { + if m, err := node.Store.Metrics(time.Now()); err != nil { t.Fatal(err) } else if !m.Accounts.Balance.Equals(expectedFunding) { t.Fatalf("expected account balance to be %v, got %v", expectedFunding, m.Accounts.Balance) diff --git a/host/accounts/budget_test.go b/host/accounts/budget_test.go index bd594c7b..86945af2 100644 --- a/host/accounts/budget_test.go +++ b/host/accounts/budget_test.go @@ -3,23 +3,14 @@ package accounts_test import ( "errors" "math" - "path/filepath" "reflect" "testing" "time" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/host/accounts" "go.sia.tech/hostd/host/contracts" - "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/chain" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/wallet" - "go.sia.tech/hostd/webhooks" - "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" + "go.sia.tech/hostd/internal/testutil" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) @@ -71,61 +62,9 @@ func TestUsageAdd(t *testing.T) { func TestBudget(t *testing.T) { log := zaptest.NewLogger(t) - dir := t.TempDir() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("accounts")) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - if err := <-errCh; err != nil { - t.Fatal(err) - } - defer cs.Close() - - stp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - tp := chain.NewTPool(stp) - defer tp.Close() - - cm, err := chain.NewManager(cs, tp) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - - w, err := wallet.NewSingleAddressWallet(types.NewPrivateKeyFromSeed(frand.Bytes(32)), cm, db, log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - a := alerts.NewManager(webhookReporter, log.Named("alerts")) - sm, err := storage.NewVolumeManager(db, a, cm, log.Named("storage"), 0) - if err != nil { - t.Fatal(err) - } - defer sm.Close() - - com, err := contracts.NewManager(db, a, sm, cm, tp, w, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer cm.Close() + hostKey := types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) rev := contracts.SignedRevision{ Revision: types.FileContractRevision{ @@ -139,11 +78,11 @@ func TestBudget(t *testing.T) { }, } amount := types.NewCurrency64(100) - if err := com.AddContract(rev, []types.Transaction{{}}, types.Siacoins(1), contracts.Usage{}); err != nil { + if err := node.Contracts.AddContract(rev, []types.Transaction{{}}, types.Siacoins(1), contracts.Usage{}); err != nil { t.Fatal(err) } - am := accounts.NewManager(db, ephemeralSettings{maxBalance: types.NewCurrency64(100)}) + am := accounts.NewManager(node.Store, ephemeralSettings{maxBalance: types.NewCurrency64(100)}) accountID := frand.Entropy256() expectedFunding := amount req := accounts.FundAccountWithContract{ @@ -166,7 +105,7 @@ func TestBudget(t *testing.T) { t.Fatalf("expected funding amount to be %v, got %v", expectedFunding, sources[0].Amount) } - contract, err := com.Contract(rev.Revision.ParentID) + contract, err := node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if !contract.Usage.AccountFunding.Equals(expectedFunding) { @@ -175,7 +114,7 @@ func TestBudget(t *testing.T) { expectedBalance := amount - if m, err := db.Metrics(time.Now()); err != nil { + if m, err := node.Store.Metrics(time.Now()); err != nil { t.Fatal(err) } else if !m.Accounts.Balance.Equals(expectedBalance) { t.Fatalf("expected account balance to be %v, got %v", expectedBalance, m.Accounts.Balance) @@ -241,7 +180,7 @@ func TestBudget(t *testing.T) { } // check that the contract's usage has been updated - contract, err = com.Contract(rev.Revision.ParentID) + contract, err = node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if !contract.Usage.AccountFunding.Equals(expectedFunding) { @@ -257,7 +196,7 @@ func TestBudget(t *testing.T) { t.Fatalf("expected in-memory balance to be %d, got %d", expectedBalance, balance) } - if m, err := db.Metrics(time.Now()); err != nil { + if m, err := node.Store.Metrics(time.Now()); err != nil { t.Fatal(err) } else if !m.Accounts.Balance.Equals(expectedBalance) { t.Fatalf("expected account balance to be %v, got %v", expectedBalance, m.Accounts.Balance) @@ -267,7 +206,7 @@ func TestBudget(t *testing.T) { // check that the account balance has been updated and only the spent // amount has been deducted - if balance, err := db.AccountBalance(accountID); err != nil { + if balance, err := node.Store.AccountBalance(accountID); err != nil { t.Fatal("expected successful balance", err) } else if !balance.Equals(expectedBalance) { t.Fatalf("expected balance to be equal to %d, got %d", expectedBalance, balance) @@ -291,14 +230,14 @@ func TestBudget(t *testing.T) { } // check that the contract's usage has been updated - contract, err = com.Contract(rev.Revision.ParentID) + contract, err = node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if !contract.Usage.AccountFunding.IsZero() { t.Fatalf("expected contract usage to be %v, got %v", types.ZeroCurrency, contract.Usage.AccountFunding) } - if m, err := db.Metrics(time.Now()); err != nil { + if m, err := node.Store.Metrics(time.Now()); err != nil { t.Fatal(err) } else if !m.Accounts.Balance.IsZero() { t.Fatalf("expected account balance to be %v, got %v", types.ZeroCurrency, m.Accounts.Balance) diff --git a/host/contracts/actions.go b/host/contracts/actions.go deleted file mode 100644 index db81701e..00000000 --- a/host/contracts/actions.go +++ /dev/null @@ -1,290 +0,0 @@ -package contracts - -import ( - "encoding/json" - "fmt" - "sync/atomic" - "time" - - rhp2 "go.sia.tech/core/rhp/v2" - "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" - "go.uber.org/zap" -) - -// An action determines what lifecycle event should be performed on a contract. -const ( - ActionBroadcastFormation = "formation" - ActionReject = "reject" - ActionBroadcastFinalRevision = "revision" - ActionBroadcastResolution = "resolve" - ActionExpire = "expire" -) - -func (cm *ContractManager) buildStorageProof(id types.FileContractID, filesize uint64, index uint64, log *zap.Logger) (types.StorageProof, error) { - if filesize == 0 { - return types.StorageProof{ - ParentID: id, - }, nil - } - - sectorIndex := index / rhp2.LeavesPerSector - segmentIndex := index % rhp2.LeavesPerSector - - roots, err := cm.getSectorRoots(id) - if err != nil { - return types.StorageProof{}, fmt.Errorf("failed to get sector roots: %w", err) - } else if uint64(len(roots)) < sectorIndex { - log.Error("failed to build storage proof. invalid root index", zap.Uint64("sectorIndex", sectorIndex), zap.Uint64("segmentIndex", segmentIndex), zap.Int("rootsLength", len(roots))) - return types.StorageProof{}, fmt.Errorf("invalid root index") - } - root := roots[sectorIndex] - sector, err := cm.storage.Read(root) - if err != nil { - log.Error("failed to build storage proof. unable to read sector data", zap.Error(err), zap.Stringer("sectorRoot", root)) - return types.StorageProof{}, fmt.Errorf("failed to read sector data") - } - segmentProof := rhp2.ConvertProofOrdering(rhp2.BuildProof(sector, segmentIndex, segmentIndex+1, nil), segmentIndex) - sectorProof := rhp2.ConvertProofOrdering(rhp2.BuildSectorRangeProof(roots, sectorIndex, sectorIndex+1), sectorIndex) - sp := types.StorageProof{ - ParentID: id, - Proof: append(segmentProof, sectorProof...), - } - copy(sp.Leaf[:], sector[segmentIndex*rhp2.LeafSize:]) - return sp, nil -} - -// processActions performs lifecycle actions on contracts. Triggered by a -// consensus change, changes are processed in the order they were received. -func (cm *ContractManager) processActions() { - for { - select { - case height := <-cm.processQueue: - err := func() error { - done, err := cm.tg.Add() - if err != nil { - return nil - } - defer done() - - err = cm.store.ContractAction(height, cm.handleContractAction) - if err != nil { - return fmt.Errorf("failed to process contract actions: %w", err) - } else if err = cm.store.ExpireContractSectors(height); err != nil { - return fmt.Errorf("failed to expire contract sectors: %w", err) - } - return nil - }() - if err != nil { - cm.log.Panic("failed to process contract actions", zap.Error(err), zap.Stack("stack")) - } - atomic.StoreUint64(&cm.blockHeight, height) - case <-cm.tg.Done(): - return - } - } -} - -// handleContractAction performs a lifecycle action on a contract. -func (cm *ContractManager) handleContractAction(id types.FileContractID, height uint64, action string) { - log := cm.log.Named("lifecycle").With(zap.String("contractID", id.String()), zap.Uint64("height", height), zap.String("action", action)) - contract, err := cm.store.Contract(id) - if err != nil { - log.Error("failed to get contract", zap.Error(err)) - return - } - log = log.With(zap.Uint64("revisionNumber", contract.Revision.RevisionNumber), zap.Uint64("size", contract.Revision.Filesize), zap.Stringer("merkleRoot", contract.Revision.FileMerkleRoot), zap.Uint64("scanHeight", cm.chain.TipState().Index.Height)) - log.Debug("performing contract action", zap.Uint64("negotiationHeight", contract.NegotiationHeight), zap.Uint64("windowStart", contract.Revision.WindowStart), zap.Uint64("windowEnd", contract.Revision.WindowEnd)) - start := time.Now() - cs := cm.chain.TipState() - - // helper to register a contract alert - registerContractAlert := func(severity alerts.Severity, message string, err error) { - data := map[string]any{ - "contractID": id, - "blockHeight": height, - } - if err != nil { - data["error"] = err.Error() - } - - cm.alerts.Register(alerts.Alert{ - ID: types.Hash256(id), - Severity: severity, - Message: message, - Data: data, - Timestamp: time.Now(), - }) - } - - switch action { - case ActionBroadcastFormation: - if (height-contract.NegotiationHeight)%3 != 0 { - // debounce formation broadcasts to prevent spamming - log.Debug("skipping rebroadcast", zap.Uint64("negotiationHeight", contract.NegotiationHeight)) - return - } - formationSet, err := cm.store.ContractFormationSet(id) - if err != nil { - log.Error("failed to get formation set", zap.Error(err)) - return - } else if err := cm.tpool.AcceptTransactionSet(formationSet); err != nil { - log.Error("failed to broadcast formation transaction", zap.Error(err)) - return - } - log.Info("rebroadcast formation transaction", zap.String("transactionID", formationSet[len(formationSet)-1].ID().String())) - case ActionBroadcastFinalRevision: - if (contract.Revision.WindowStart-height)%3 != 0 { - // debounce final revision broadcasts to prevent spamming - log.Debug("skipping revision", zap.Uint64("windowStart", contract.Revision.WindowStart)) - return - } - revisionTxn := types.Transaction{ - FileContractRevisions: []types.FileContractRevision{contract.Revision}, - Signatures: []types.TransactionSignature{ - { - ParentID: types.Hash256(contract.Revision.ParentID), - CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, - Signature: contract.RenterSignature[:], - }, - { - ParentID: types.Hash256(contract.Revision.ParentID), - CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, - Signature: contract.HostSignature[:], - PublicKeyIndex: 1, - }, - }, - } - - fee := cm.tpool.RecommendedFee().Mul64(1000) - revisionTxn.MinerFees = append(revisionTxn.MinerFees, fee) - toSign, release, err := cm.wallet.FundTransaction(&revisionTxn, fee) - if err != nil { - log.Error("failed to fund revision transaction", zap.Error(err)) - return - } - if err := cm.wallet.SignTransaction(cs, &revisionTxn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - log.Error("failed to sign revision transaction", zap.Error(err)) - return - } else if err := cm.tpool.AcceptTransactionSet([]types.Transaction{revisionTxn}); err != nil { - release() - log.Error("failed to broadcast revision transaction", zap.Error(err)) - return - } - log.Info("broadcast final revision", zap.Uint64("revisionNumber", contract.Revision.RevisionNumber), zap.String("transactionID", revisionTxn.ID().String())) - case ActionBroadcastResolution: - if (height-contract.Revision.WindowStart)%3 != 0 { - // debounce resolution broadcasts to prevent spamming - log.Debug("skipping resolution", zap.Uint64("windowStart", contract.Revision.WindowStart)) - return - } - validPayout, missedPayout := contract.Revision.ValidHostPayout(), contract.Revision.MissedHostPayout() - if missedPayout.Cmp(validPayout) >= 0 { - log.Info("skipping storage proof, no benefit to host", zap.String("validPayout", validPayout.ExactString()), zap.String("missedPayout", missedPayout.ExactString())) - return - } - - // get the block before the proof window starts - windowStart, err := cm.chain.IndexAtHeight(contract.Revision.WindowStart - 1) - if err != nil { - log.Error("failed to get chain index at height", zap.Uint64("height", contract.Revision.WindowStart-1), zap.Error(err)) - return - } - - // get the proof leaf index - leafIndex := cs.StorageProofLeafIndex(contract.Revision.Filesize, windowStart.ID, contract.Revision.ParentID) - sp, err := cm.buildStorageProof(contract.Revision.ParentID, contract.Revision.Filesize, leafIndex, log.Named("buildStorageProof")) - if err != nil { - log.Error("failed to build storage proof", zap.Error(err)) - registerContractAlert(alerts.SeverityError, "Failed to build storage proof", err) - return - } - - // TODO: consider cost of broadcasting the proof - fee := cm.tpool.RecommendedFee().Mul64(1000) - resolutionTxnSet := []types.Transaction{ - { - // intermediate funding transaction is required by siad because - // transactions with storage proofs cannot have change outputs - SiacoinOutputs: []types.SiacoinOutput{ - {Address: cm.wallet.Address(), Value: fee}, - }, - }, - { - MinerFees: []types.Currency{fee}, - StorageProofs: []types.StorageProof{sp}, - }, - } - intermediateToSign, release, err := cm.wallet.FundTransaction(&resolutionTxnSet[0], fee) - if err != nil { - log.Error("failed to fund resolution transaction", zap.Error(err)) - registerContractAlert(alerts.SeverityError, "Failed to fund resolution transaction", err) - return - } - // add the intermediate output to the proof transaction - resolutionTxnSet[1].SiacoinInputs = append(resolutionTxnSet[1].SiacoinInputs, types.SiacoinInput{ - ParentID: resolutionTxnSet[0].SiacoinOutputID(0), - UnlockConditions: cm.wallet.UnlockConditions(), - }) - proofToSign := []types.Hash256{types.Hash256(resolutionTxnSet[1].SiacoinInputs[0].ParentID)} - start = time.Now() - if err := cm.wallet.SignTransaction(cs, &resolutionTxnSet[0], intermediateToSign, types.CoveredFields{WholeTransaction: true}); err != nil { // sign the intermediate transaction - release() - log.Error("failed to sign resolution intermediate transaction", zap.Error(err)) - return - } else if err := cm.wallet.SignTransaction(cs, &resolutionTxnSet[1], proofToSign, types.CoveredFields{WholeTransaction: true}); err != nil { // sign the proof transaction - release() - log.Error("failed to sign resolution transaction", zap.Error(err)) - return - } else if err := cm.tpool.AcceptTransactionSet(resolutionTxnSet); err != nil { // broadcast the transaction set - release() - buf, _ := json.Marshal(resolutionTxnSet) - log.Error("failed to broadcast resolution transaction set", zap.Error(err), zap.ByteString("transactionSet", buf)) - registerContractAlert(alerts.SeverityError, "Failed to broadcast resolution transaction set", err) - return - } - cm.alerts.Dismiss(types.Hash256(id)) // dismiss any previous failure alerts - log.Info("broadcast storage proof", zap.String("transactionID", resolutionTxnSet[1].ID().String()), zap.Duration("elapsed", time.Since(start))) - case ActionReject: - if err := cm.store.ExpireContract(id, ContractStatusRejected); err != nil { - log.Error("failed to set contract status", zap.Error(err)) - } - log.Info("contract rejected", zap.Uint64("negotiationHeight", contract.NegotiationHeight)) - case ActionExpire: - validPayout, missedPayout := contract.Revision.ValidHostPayout(), contract.Revision.MissedHostPayout() - switch { - case !contract.FormationConfirmed: - // if the contract was never confirmed, nothing was ever lost or - // gained - if err := cm.store.ExpireContract(id, ContractStatusRejected); err != nil { - log.Error("failed to set contract status", zap.Error(err)) - } - case validPayout.Cmp(missedPayout) <= 0 || contract.ResolutionHeight != 0: - // if the host valid payout is less than or equal to the missed - // payout or if a resolution was confirmed, the contract was - // successful - if err := cm.store.ExpireContract(id, ContractStatusSuccessful); err != nil { - log.Error("failed to set contract status", zap.Error(err)) - } - payout := validPayout - if contract.ResolutionHeight != 0 { - payout = missedPayout - } - log.Info("contract successful", zap.String("payout", payout.ExactString())) - case validPayout.Cmp(missedPayout) > 0 && contract.ResolutionHeight == 0: - // if the host valid payout is greater than the missed payout and a - // proof was not broadcast, the contract failed - if err := cm.store.ExpireContract(id, ContractStatusFailed); err != nil { - log.Error("failed to set contract status", zap.Error(err)) - } - registerContractAlert(alerts.SeverityError, "Contract failed without storage proof", nil) - log.Error("contract failed, revenue lost", zap.Uint64("windowStart", contract.Revision.WindowStart), zap.Uint64("windowEnd", contract.Revision.WindowEnd), zap.String("validPayout", validPayout.ExactString()), zap.String("missedPayout", missedPayout.ExactString())) - default: - log.Panic("unrecognized contract state", zap.Stack("stack"), zap.String("validPayout", validPayout.ExactString()), zap.String("missedPayout", missedPayout.ExactString()), zap.Uint64("resolutionHeight", contract.ResolutionHeight), zap.Bool("formationConfirmed", contract.FormationConfirmed)) - } - default: - log.Panic("unrecognized contract action", zap.Stack("stack")) - } - log.Debug("contract action completed", zap.Duration("elapsed", time.Since(start))) -} diff --git a/host/contracts/consts_default.go b/host/contracts/consts_default.go deleted file mode 100644 index 352d7215..00000000 --- a/host/contracts/consts_default.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !testing - -package contracts - -const ( - // RebroadcastBuffer is the number of blocks after the negotiation height to - // attempt to rebroadcast the contract. - RebroadcastBuffer = 36 // 6 hours - // RevisionSubmissionBuffer number of blocks before the proof window to - // submit a revision and prevent modification of the contract. - RevisionSubmissionBuffer = 144 // 24 hours -) diff --git a/host/contracts/consts_testing.go b/host/contracts/consts_testing.go deleted file mode 100644 index 96a06b9b..00000000 --- a/host/contracts/consts_testing.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build testing - -package contracts - -const ( - // RebroadcastBuffer is the number of blocks after the negotiation height to - // attempt to rebroadcast the contract. - RebroadcastBuffer = 12 - // RevisionSubmissionBuffer number of blocks before the proof window to - // submit a revision and prevent modification of the contract. - RevisionSubmissionBuffer = 24 -) diff --git a/host/contracts/contracts.go b/host/contracts/contracts.go index 63231eb0..1c0da26b 100644 --- a/host/contracts/contracts.go +++ b/host/contracts/contracts.go @@ -7,7 +7,6 @@ import ( "sync" "time" - lru "github.com/hashicorp/golang-lru/v2" rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.uber.org/zap" @@ -43,6 +42,31 @@ const ( ContractStatusFailed ) +// V2ContractStatus is an enum that indicates the current status of a v2 contract. +const ( + // V2ContractStatusPending indicates that the contract has been formed but + // has not yet been confirmed on the blockchain. The contract is still + // usable, but there is a risk that the contract will never be confirmed. + V2ContractStatusPending V2ContractStatus = "pending" + // V2ContractStatusRejected indicates that the contract formation transaction + // was never confirmed on the blockchain + V2ContractStatusRejected V2ContractStatus = "rejected" + // V2ContractStatusActive indicates that the contract has been confirmed on + // the blockchain and is currently active. + V2ContractStatusActive V2ContractStatus = "active" + // V2ContractStatusFinalized indicates that the contract has been finalized. + V2ContractStatusFinalized V2ContractStatus = "finalized" + // V2ContractStatusRenewed indicates that the contract has been renewed. + V2ContractStatusRenewed V2ContractStatus = "renewed" + // V2ContractStatusSuccessful indicates that a storage proof has been + // confirmed or the contract expired without requiring the host to burn + // Siacoin. + V2ContractStatusSuccessful V2ContractStatus = "successful" + // V2ContractStatusFailed indicates that the contract ended without a storage proof + // and the host was required to burn Siacoin. + V2ContractStatusFailed V2ContractStatus = "failed" +) + // fields that the contracts can be sorted by. const ( ContractSortStatus = "status" @@ -58,6 +82,9 @@ type ( // ContractStatus is an enum that indicates the current status of a contract. ContractStatus uint8 + // V2ContractStatus is an enum that indicates the current status of a v2 contract. + V2ContractStatus string + // A SignedRevision pairs a contract revision with the signatures of the host // and renter needed to broadcast the revision. SignedRevision struct { @@ -79,6 +106,51 @@ type ( RiskedCollateral types.Currency `json:"riskedCollateral"` } + // V2Usage tracks the usage of a contract's funds. + V2Usage struct { + RPCRevenue types.Currency `json:"rpc"` + StorageRevenue types.Currency `json:"storage"` + EgressRevenue types.Currency `json:"egress"` + IngressRevenue types.Currency `json:"ingress"` + AccountFunding types.Currency `json:"accountFunding"` + RiskedCollateral types.Currency `json:"riskedCollateral"` + } + + // A V2Contract contains metadata on the current state of a v2 file contract. + V2Contract struct { + types.V2FileContract + + ID types.FileContractID `json:"id"` + Status V2ContractStatus `json:"status"` + Usage V2Usage `json:"usage"` + + // NegotiationHeight is the height the contract was negotiated at. + NegotiationHeight uint64 `json:"negotiationHeight"` + // RevisionConfirmed is true if the contract revision transaction has + // been confirmed on the blockchain. + RevisionConfirmed bool `json:"revisionConfirmed"` + // FormationConfirmed is true if the contract formation transaction + // has been confirmed on the blockchain. + FormationIndex types.ChainIndex `json:"formationIndex"` + // ResolutionIndex is the height the resolution was confirmed + // at. If the contract has not been resolved, the field is the zero + // value. + ResolutionIndex types.ChainIndex `json:"resolutionHeight"` + // RenewedTo is the ID of the contract that renewed this contract. If + // this contract was not renewed, this field is the zero value. + RenewedTo types.FileContractID `json:"renewedTo"` + // RenewedFrom is the ID of the contract that this contract renewed. If + // this contract is not a renewal, the field is the zero value. + RenewedFrom types.FileContractID `json:"renewedFrom"` + } + + // A V2FormationTransactionSet contains the formation transaction set for a + // v2 contract. + V2FormationTransactionSet struct { + TransactionSet []types.V2Transaction + Basis types.ChainIndex + } + // A Contract contains metadata on the current state of a file contract. Contract struct { SignedRevision @@ -131,6 +203,30 @@ type ( SortDesc bool `json:"sortDesc"` } + // V2ContractFilter defines the filter criteria for a contract query. + V2ContractFilter struct { + // filters + Statuses []V2ContractStatus `json:"statuses"` + ContractIDs []types.FileContractID `json:"contractIDs"` + RenewedFrom []types.FileContractID `json:"renewedFrom"` + RenewedTo []types.FileContractID `json:"renewedTo"` + RenterKey []types.PublicKey `json:"renterKey"` + + MinNegotiationHeight uint64 `json:"minNegotiationHeight"` + MaxNegotiationHeight uint64 `json:"maxNegotiationHeight"` + + MinExpirationHeight uint64 `json:"minExpirationHeight"` + MaxExpirationHeight uint64 `json:"maxExpirationHeight"` + + // pagination + Limit int `json:"limit"` + Offset int `json:"offset"` + + // sorting + SortField string `json:"sortField"` + SortDesc bool `json:"sortDesc"` + } + // A SectorChange defines an action to be performed on a contract's sectors. SectorChange struct { Action SectorAction @@ -141,12 +237,12 @@ type ( // A ContractUpdater is used to atomically update a contract's sectors // and metadata. ContractUpdater struct { - store ContractStore - log *zap.Logger + manager *Manager + store ContractStore + log *zap.Logger - rootsCache *lru.TwoQueueCache[types.FileContractID, []types.Hash256] // reference to the cache in the contract manager - once sync.Once - done func() // done is called when the updater is closed. + once sync.Once + done func() // done is called when the updater is closed. contractID types.FileContractID sectorActions []SectorChange @@ -164,17 +260,55 @@ var ( ErrContractExists = errors.New("contract already exists") ) -// Add returns the sum of two usages. -func (u Usage) Add(b Usage) (c Usage) { +// Add returns u + b +func (a Usage) Add(b Usage) (c Usage) { return Usage{ - RPCRevenue: u.RPCRevenue.Add(b.RPCRevenue), - StorageRevenue: u.StorageRevenue.Add(b.StorageRevenue), - EgressRevenue: u.EgressRevenue.Add(b.EgressRevenue), - IngressRevenue: u.IngressRevenue.Add(b.IngressRevenue), - AccountFunding: u.AccountFunding.Add(b.AccountFunding), - RiskedCollateral: u.RiskedCollateral.Add(b.RiskedCollateral), - RegistryRead: u.RegistryRead.Add(b.RegistryRead), - RegistryWrite: u.RegistryWrite.Add(b.RegistryWrite), + RPCRevenue: a.RPCRevenue.Add(b.RPCRevenue), + StorageRevenue: a.StorageRevenue.Add(b.StorageRevenue), + EgressRevenue: a.EgressRevenue.Add(b.EgressRevenue), + IngressRevenue: a.IngressRevenue.Add(b.IngressRevenue), + AccountFunding: a.AccountFunding.Add(b.AccountFunding), + RiskedCollateral: a.RiskedCollateral.Add(b.RiskedCollateral), + RegistryRead: a.RegistryRead.Add(b.RegistryRead), + RegistryWrite: a.RegistryWrite.Add(b.RegistryWrite), + } +} + +// Sub returns a - b +func (a Usage) Sub(b Usage) (c Usage) { + return Usage{ + RPCRevenue: a.RPCRevenue.Sub(b.RPCRevenue), + StorageRevenue: a.StorageRevenue.Sub(b.StorageRevenue), + EgressRevenue: a.EgressRevenue.Sub(b.EgressRevenue), + IngressRevenue: a.IngressRevenue.Sub(b.IngressRevenue), + AccountFunding: a.AccountFunding.Sub(b.AccountFunding), + RiskedCollateral: a.RiskedCollateral.Sub(b.RiskedCollateral), + RegistryRead: a.RegistryRead.Sub(b.RegistryRead), + RegistryWrite: a.RegistryWrite.Sub(b.RegistryWrite), + } +} + +// Add returns u + b +func (a V2Usage) Add(b V2Usage) (c V2Usage) { + return V2Usage{ + RPCRevenue: a.RPCRevenue.Add(b.RPCRevenue), + StorageRevenue: a.StorageRevenue.Add(b.StorageRevenue), + EgressRevenue: a.EgressRevenue.Add(b.EgressRevenue), + IngressRevenue: a.IngressRevenue.Add(b.IngressRevenue), + AccountFunding: a.AccountFunding.Add(b.AccountFunding), + RiskedCollateral: a.RiskedCollateral.Add(b.RiskedCollateral), + } +} + +// Sub returns a - b +func (a V2Usage) Sub(b V2Usage) (c V2Usage) { + return V2Usage{ + RPCRevenue: a.RPCRevenue.Sub(b.RPCRevenue), + StorageRevenue: a.StorageRevenue.Sub(b.StorageRevenue), + EgressRevenue: a.EgressRevenue.Sub(b.EgressRevenue), + IngressRevenue: a.IngressRevenue.Sub(b.IngressRevenue), + AccountFunding: a.AccountFunding.Sub(b.AccountFunding), + RiskedCollateral: a.RiskedCollateral.Sub(b.RiskedCollateral), } } @@ -339,7 +473,7 @@ func (cu *ContractUpdater) Commit(revision SignedRevision, usage Usage) error { // clear the committed sector actions cu.sectorActions = cu.sectorActions[:0] // update the roots cache - cu.rootsCache.Add(revision.Revision.ParentID, append([]types.Hash256(nil), cu.sectorRoots...)) + cu.manager.setSectorRoots(cu.contractID, cu.sectorRoots) cu.log.Debug("contract update committed", zap.String("contractID", revision.Revision.ParentID.String()), zap.Uint64("revision", revision.Revision.RevisionNumber), zap.Duration("elapsed", time.Since(start))) return nil } diff --git a/host/contracts/contracts_test.go b/host/contracts/contracts_test.go index 25f14528..7a60a939 100644 --- a/host/contracts/contracts_test.go +++ b/host/contracts/contracts_test.go @@ -2,18 +2,14 @@ package contracts_test import ( "math" - "path/filepath" "reflect" "testing" rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/test" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/webhooks" + "go.sia.tech/hostd/internal/testutil" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) @@ -22,49 +18,21 @@ func TestContractUpdater(t *testing.T) { const sectors = 256 hostKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) - dir := t.TempDir() log := zaptest.NewLogger(t) - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("sqlite")) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - node, err := test.NewWallet(hostKey, t.TempDir(), log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer node.Close() - - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(db, am, node.ChainManager(), log.Named("storage"), sectorCacheSize) - if err != nil { - t.Fatal(err) - } - defer s.Close() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) // create a fake volume so disk space is not used - id, err := db.AddVolume("test", false) + id, err := node.Store.AddVolume("test", false) if err != nil { t.Fatal(err) - } else if err := db.GrowVolume(id, sectors); err != nil { + } else if err := node.Store.GrowVolume(id, sectors); err != nil { t.Fatal(err) - } else if err := db.SetAvailable(id, true); err != nil { + } else if err := node.Store.SetAvailable(id, true); err != nil { t.Fatal(err) } - c, err := contracts.NewManager(db, am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer c.Close() - contractUnlockConditions := types.UnlockConditions{ PublicKeys: []types.UnlockKey{ renterKey.PublicKey().UnlockKey(), @@ -84,7 +52,7 @@ func TestContractUpdater(t *testing.T) { }, } - if err := c.AddContract(rev, []types.Transaction{}, types.ZeroCurrency, contracts.Usage{}); err != nil { + if err := node.Contracts.AddContract(rev, []types.Transaction{}, types.ZeroCurrency, contracts.Usage{}); err != nil { t.Fatal(err) } @@ -122,7 +90,7 @@ func TestContractUpdater(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - updater, err := c.ReviseContract(rev.Revision.ParentID) + updater, err := node.Contracts.ReviseContract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } @@ -139,7 +107,7 @@ func TestContractUpdater(t *testing.T) { for i := 0; i < test.append; i++ { root := frand.Entropy256() - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) + release, err := node.Store.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) if err != nil { t.Fatal(err) } @@ -174,14 +142,14 @@ func TestContractUpdater(t *testing.T) { } // check that the sector roots are correct in the database - dbRoots, err := db.SectorRoots(rev.Revision.ParentID) + allRoots, err := node.Store.SectorRoots() if err != nil { t.Fatal(err) - } else if rhp2.MetaRoot(dbRoots) != rhp2.MetaRoot(roots) { + } else if rhp2.MetaRoot(allRoots[rev.Revision.ParentID]) != rhp2.MetaRoot(roots) { t.Fatal("wrong merkle root in database") } // check that the cache sector roots are correct - cachedRoots, err := c.SectorRoots(rev.Revision.ParentID) + cachedRoots := node.Contracts.SectorRoots(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if rhp2.MetaRoot(cachedRoots) != rhp2.MetaRoot(roots) { @@ -218,3 +186,34 @@ func TestUsageAdd(t *testing.T) { } } } + +func TestUsageSub(t *testing.T) { + var ua, ub contracts.Usage + var expected contracts.Usage + uav := reflect.ValueOf(&ua).Elem() + ubv := reflect.ValueOf(&ub).Elem() + ev := reflect.ValueOf(&expected).Elem() + + for i := 0; i < uav.NumField(); i++ { + va := types.NewCurrency(frand.Uint64n(math.MaxUint64), 0) + vb := types.NewCurrency(frand.Uint64n(math.MaxUint64), 0) + if va.Cmp(vb) < 0 { + va, vb = vb, va + } + total := va.Sub(vb) + + uav.Field(i).Set(reflect.ValueOf(va)) + ubv.Field(i).Set(reflect.ValueOf(vb)) + ev.Field(i).Set(reflect.ValueOf(total)) + } + + total := ua.Sub(ub) + tv := reflect.ValueOf(total) + for i := 0; i < tv.NumField(); i++ { + va := ev.Field(i).Interface().(types.Currency) + vb := tv.Field(i).Interface().(types.Currency) + if !va.Equals(vb) { + t.Fatalf("field %v: expected %v, got %v", tv.Type().Field(i).Name, va, vb) + } + } +} diff --git a/host/contracts/integrity.go b/host/contracts/integrity.go index f840b965..aee4e633 100644 --- a/host/contracts/integrity.go +++ b/host/contracts/integrity.go @@ -62,7 +62,7 @@ func (i *IntegrityResult) UnmarshalJSON(b []byte) error { // CheckIntegrity checks the integrity of a contract's sector roots on disk. The // result of every checked sector is sent on the returned channel. The channel is closed // when all checks are complete. -func (cm *ContractManager) CheckIntegrity(ctx context.Context, contractID types.FileContractID) (<-chan IntegrityResult, uint64, error) { +func (cm *Manager) CheckIntegrity(ctx context.Context, contractID types.FileContractID) (<-chan IntegrityResult, uint64, error) { // lock the contract to ensure it doesn't get modified before the sector // roots are retrieved. contract, err := cm.Lock(ctx, contractID) @@ -73,10 +73,8 @@ func (cm *ContractManager) CheckIntegrity(ctx context.Context, contractID types. expectedRoots := contract.Revision.Filesize / rhp2.SectorSize - roots, err := cm.getSectorRoots(contractID) - if err != nil { - return nil, 0, fmt.Errorf("failed to get sector roots: %w", err) - } else if uint64(len(roots)) != expectedRoots { + roots := cm.getSectorRoots(contractID) + if uint64(len(roots)) != expectedRoots { return nil, 0, fmt.Errorf("expected %v sector roots, got %v", expectedRoots, len(roots)) } else if calculated := rhp2.MetaRoot(roots); contract.Revision.FileMerkleRoot != calculated { return nil, 0, fmt.Errorf("expected Merkle root %v, got %v", contract.Revision.FileMerkleRoot, calculated) diff --git a/host/contracts/integrity_test.go b/host/contracts/integrity_test.go index 7efb0160..341af351 100644 --- a/host/contracts/integrity_test.go +++ b/host/contracts/integrity_test.go @@ -1,3 +1,5 @@ +//go:build ignore + package contracts_test import ( @@ -12,12 +14,7 @@ import ( rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/host/contracts" - "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/test" - "go.sia.tech/hostd/webhooks" - stypes "go.sia.tech/siad/types" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) @@ -72,25 +69,9 @@ func TestIntegrityResultJSON(t *testing.T) { func TestCheckIntegrity(t *testing.T) { hostKey, renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)), types.NewPrivateKeyFromSeed(frand.Bytes(32)) - log := zaptest.NewLogger(t) dir := t.TempDir() - node, err := test.NewWallet(hostKey, dir, log) - if err != nil { - t.Fatal(err) - } - defer node.Close() - - webhookReporter, err := webhooks.NewManager(node.Store(), log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(node.Store(), am, node.ChainManager(), log.Named("storage"), 0) - if err != nil { - t.Fatal(err) - } - defer s.Close() + log := zaptest.NewLogger(t) + db, cm, _, wm, c, s := testNode(t, hostKey, log) result := make(chan error, 1) if _, err := s.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { @@ -99,17 +80,8 @@ func TestCheckIntegrity(t *testing.T) { t.Fatal(err) } - c, err := contracts.NewManager(node.Store(), am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - // note: many more blocks than necessary are mined to ensure all forks have activated - if err := node.MineBlocks(node.Address(), int(stypes.MaturityDelay*4)); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + // mine enough for the wallet to have some funds + mineAndSync(t, cm, db, wm.Address(), 150) rev, err := formContract(renterKey, hostKey, 50, 60, types.Siacoins(500), types.Siacoins(1000), c, node, node.ChainManager(), node.TPool()) if err != nil { diff --git a/host/contracts/manager.go b/host/contracts/manager.go index cab7ff52..cedb2747 100644 --- a/host/contracts/manager.go +++ b/host/contracts/manager.go @@ -1,62 +1,50 @@ package contracts import ( - "bytes" "context" "errors" "fmt" "math" - "strings" "sync" - "sync/atomic" "time" - lru "github.com/hashicorp/golang-lru/v2" - "gitlab.com/NebulousLabs/encoding" "go.sia.tech/core/consensus" rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/hostd/alerts" - "go.sia.tech/hostd/internal/chain" "go.sia.tech/hostd/internal/threadgroup" - "go.sia.tech/siad/modules" "go.uber.org/zap" ) -const ( - // sectorRootCacheSize is the number of contracts' sector roots to cache. - // Caching prevents frequently updated contracts from continuously hitting the - // DB. This is left as a hard-coded small value to limit memory usage since - // contracts can contain any number of sector roots - sectorRootCacheSize = 30 -) - type ( - contractChange struct { - id types.FileContractID - index types.ChainIndex - } - // ChainManager defines the interface required by the contract manager to // interact with the consensus set. ChainManager interface { + Tip() types.ChainIndex TipState() consensus.State - IndexAtHeight(height uint64) (types.ChainIndex, error) - Subscribe(s modules.ConsensusSetSubscriber, ccID modules.ConsensusChangeID, cancel <-chan struct{}) error + BestIndex(height uint64) (types.ChainIndex, bool) + UnconfirmedParents(txn types.Transaction) []types.Transaction + AddPoolTransactions([]types.Transaction) (known bool, err error) + AddV2PoolTransactions(types.ChainIndex, []types.V2Transaction) (known bool, err error) + RecommendedFee() types.Currency + } + + // A Syncer broadcasts transactions to its peers + Syncer interface { + BroadcastTransactionSet([]types.Transaction) + BroadcastV2TransactionSet(types.ChainIndex, []types.V2Transaction) } // A Wallet manages Siacoins and funds transactions Wallet interface { Address() types.Address UnlockConditions() types.UnlockConditions - FundTransaction(txn *types.Transaction, amount types.Currency) (toSign []types.Hash256, release func(), err error) - SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error - } + ReleaseInputs(txns []types.Transaction, v2txns []types.V2Transaction) + FundTransaction(txn *types.Transaction, amount types.Currency, useUnconfirmed bool) ([]types.Hash256, error) + SignTransaction(txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) - // A TransactionPool broadcasts transactions to the network. - TransactionPool interface { - AcceptTransactionSet([]types.Transaction) error - RecommendedFee() types.Currency + FundV2Transaction(txn *types.V2Transaction, amount types.Currency, useUnconfirmed bool) (types.ChainIndex, []int, error) + SignV2Inputs(txn *types.V2Transaction, toSign []int) } // A StorageManager stores and retrieves sectors. @@ -76,9 +64,10 @@ type ( waiters int } - // A ContractManager manages contracts' lifecycle - ContractManager struct { - blockHeight uint64 // ensure 64-bit alignment on 32-bit systems + // A Manager manages contracts' lifecycle + Manager struct { + rejectBuffer uint64 + revisionSubmissionBuffer uint64 store ContractStore tg *threadgroup.ThreadGroup @@ -87,40 +76,38 @@ type ( alerts Alerts storage StorageManager chain ChainManager - tpool TransactionPool + syncer Syncer wallet Wallet - processQueue chan uint64 // signals that the contract manager should process actions for a given block height - - // caches the sector roots of contracts to avoid hitting the DB - // for frequently accessed contracts. The cache is limited to a - // small number of contracts to limit memory usage. - rootsCache *lru.TwoQueueCache[types.FileContractID, []types.Hash256] - - mu sync.Mutex // guards the following fields - locks map[types.FileContractID]*locker // contracts must be locked while they are being modified + mu sync.Mutex // guards the following fields + // caches the sector roots of all contracts to avoid long reads from + // the store + sectorRoots map[types.FileContractID][]types.Hash256 + locks map[types.FileContractID]*locker // contracts must be locked while they are being modified } ) -func (cm *ContractManager) getSectorRoots(id types.FileContractID) ([]types.Hash256, error) { - // check the cache first - roots, ok := cm.rootsCache.Get(id) +func (cm *Manager) getSectorRoots(id types.FileContractID) []types.Hash256 { + cm.mu.Lock() + defer cm.mu.Unlock() + + roots, ok := cm.sectorRoots[id] if !ok { - var err error - // if the cache doesn't have the roots, read them from the store - roots, err = cm.store.SectorRoots(id) - if err != nil { - return nil, fmt.Errorf("failed to get sector roots: %w", err) - } - // add the roots to the cache - cm.rootsCache.Add(id, roots) + return nil } // return a deep copy of the roots - return append([]types.Hash256(nil), roots...), nil + return append([]types.Hash256(nil), roots...) +} + +func (cm *Manager) setSectorRoots(id types.FileContractID, roots []types.Hash256) { + cm.mu.Lock() + defer cm.mu.Unlock() + // deep copy the roots + cm.sectorRoots[id] = append([]types.Hash256(nil), roots...) } // Lock locks a contract for modification. -func (cm *ContractManager) Lock(ctx context.Context, id types.FileContractID) (SignedRevision, error) { +func (cm *Manager) Lock(ctx context.Context, id types.FileContractID) (SignedRevision, error) { ctx, cancel, err := cm.tg.AddContext(ctx) if err != nil { return SignedRevision{}, err @@ -132,7 +119,7 @@ func (cm *ContractManager) Lock(ctx context.Context, id types.FileContractID) (S if err != nil { cm.mu.Unlock() return SignedRevision{}, fmt.Errorf("failed to get contract: %w", err) - } else if err := isGoodForModification(contract, cm.chain.TipState().Index.Height); err != nil { + } else if err := cm.isGoodForModification(contract); err != nil { cm.mu.Unlock() return SignedRevision{}, fmt.Errorf("contract is not good for modification: %w", err) } @@ -157,7 +144,7 @@ func (cm *ContractManager) Lock(ctx context.Context, id types.FileContractID) (S contract, err := cm.store.Contract(id) if err != nil { return SignedRevision{}, fmt.Errorf("failed to get contract: %w", err) - } else if err := isGoodForModification(contract, cm.chain.TipState().Index.Height); err != nil { + } else if err := cm.isGoodForModification(contract); err != nil { return SignedRevision{}, fmt.Errorf("contract is not good for modification: %w", err) } return contract.SignedRevision, nil @@ -167,7 +154,7 @@ func (cm *ContractManager) Lock(ctx context.Context, id types.FileContractID) (S } // Unlock unlocks a locked contract. -func (cm *ContractManager) Unlock(id types.FileContractID) { +func (cm *Manager) Unlock(id types.FileContractID) { cm.mu.Lock() defer cm.mu.Unlock() lock, exists := cm.locks[id] @@ -183,18 +170,28 @@ func (cm *ContractManager) Unlock(id types.FileContractID) { // Contracts returns a paginated list of contracts matching the filter and the // total number of contracts matching the filter. -func (cm *ContractManager) Contracts(filter ContractFilter) ([]Contract, int, error) { +func (cm *Manager) Contracts(filter ContractFilter) ([]Contract, int, error) { return cm.store.Contracts(filter) } // Contract returns the contract with the given id. -func (cm *ContractManager) Contract(id types.FileContractID) (Contract, error) { +func (cm *Manager) Contract(id types.FileContractID) (Contract, error) { return cm.store.Contract(id) } +// V2Contract returns the v2 contract with the given ID. +func (cm *Manager) V2Contract(id types.FileContractID) (V2Contract, error) { + return cm.store.V2Contract(id) +} + +// V2ContractElement returns the latest v2 state element with the given ID. +func (cm *Manager) V2ContractElement(id types.FileContractID) (types.V2FileContractElement, error) { + return cm.store.V2ContractElement(id) +} + // AddContract stores the provided contract, should error if the contract // already exists. -func (cm *ContractManager) AddContract(revision SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, initialUsage Usage) error { +func (cm *Manager) AddContract(revision SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, initialUsage Usage) error { done, err := cm.tg.Add() if err != nil { return err @@ -209,7 +206,7 @@ func (cm *ContractManager) AddContract(revision SignedRevision, formationSet []t // RenewContract renews a contract. It is expected that the existing // contract will be cleared. -func (cm *ContractManager) RenewContract(renewal SignedRevision, existing SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, clearingUsage, initialUsage Usage) error { +func (cm *Manager) RenewContract(renewal SignedRevision, existing SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, clearingUsage, initialUsage Usage) error { done, err := cm.tg.Add() if err != nil { return err @@ -217,247 +214,201 @@ func (cm *ContractManager) RenewContract(renewal SignedRevision, existing Signed defer done() // sanity checks + existingRoots := cm.getSectorRoots(existing.Revision.ParentID) if existing.Revision.FileMerkleRoot != (types.Hash256{}) { return errors.New("existing contract must be cleared") } else if existing.Revision.Filesize != 0 { return errors.New("existing contract must be cleared") } else if existing.Revision.RevisionNumber != types.MaxRevisionNumber { return errors.New("existing contract must be cleared") + } else if renewal.Revision.Filesize != uint64(rhp2.SectorSize*len(existingRoots)) { + return errors.New("renewal contract must have same file size as existing contract") + } else if renewal.Revision.FileMerkleRoot != rhp2.MetaRoot(existingRoots) { + return errors.New("renewal root does not match existing roots") } if err := cm.store.RenewContract(renewal, existing, formationSet, lockedCollateral, clearingUsage, initialUsage, cm.chain.TipState().Index.Height); err != nil { return err } + cm.setSectorRoots(renewal.Revision.ParentID, existingRoots) cm.log.Debug("contract renewed", zap.Stringer("renewalID", renewal.Revision.ParentID), zap.Stringer("existingID", existing.Revision.ParentID)) return nil } -// SectorRoots returns the roots of all sectors stored by the contract. -func (cm *ContractManager) SectorRoots(id types.FileContractID) ([]types.Hash256, error) { +// ReviseV2Contract atomically updates a contract and its associated sector roots. +func (cm *Manager) ReviseV2Contract(contractID types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage Usage) error { done, err := cm.tg.Add() if err != nil { - return nil, err + return err } defer done() - return cm.getSectorRoots(id) -} + existing, err := cm.store.V2Contract(contractID) + if err != nil { + return fmt.Errorf("failed to get existing contract: %w", err) + } -// ScanHeight returns the height of the last block processed by the contract -func (cm *ContractManager) ScanHeight() uint64 { - return atomic.LoadUint64(&cm.blockHeight) + // validate the contract revision fields + switch { + case existing.RenterPublicKey != revision.RenterPublicKey: + return errors.New("renter public key does not match") + case existing.HostPublicKey != revision.HostPublicKey: + return errors.New("host public key does not match") + case existing.ProofHeight != revision.ProofHeight: + return errors.New("proof height does not match") + case existing.ExpirationHeight != revision.ExpirationHeight: + return errors.New("expiration height does not match") + case revision.Filesize != uint64(rhp2.SectorSize*len(roots)): + return errors.New("revision has incorrect file size") + } + + // validate signatures + sigHash := cm.chain.TipState().ContractSigHash(revision) + if !revision.RenterPublicKey.VerifyHash(sigHash, revision.RenterSignature) { + return errors.New("renter signature is invalid") + } else if !revision.HostPublicKey.VerifyHash(sigHash, revision.HostSignature) { + return errors.New("host signature is invalid") + } + + // validate contract Merkle root + metaRoot := rhp2.MetaRoot(roots) + if revision.FileMerkleRoot != metaRoot { + return errors.New("revision root does not match") + } else if revision.Filesize != uint64(rhp2.SectorSize*len(roots)) { + return errors.New("revision has incorrect file size") + } + + // revise the contract in the store + if err := cm.store.ReviseV2Contract(contractID, revision, roots, usage); err != nil { + return err + } + // update the sector roots cache + cm.setSectorRoots(contractID, roots) + cm.log.Debug("contract revised", zap.Stringer("contractID", contractID), zap.Uint64("revisionNumber", revision.RevisionNumber)) + return nil } -// ProcessConsensusChange applies a block update to the contract manager. -func (cm *ContractManager) ProcessConsensusChange(cc modules.ConsensusChange) { +// AddV2Contract stores the provided contract, should error if the contract +// already exists. +func (cm *Manager) AddV2Contract(formation V2FormationTransactionSet, usage V2Usage) error { done, err := cm.tg.Add() if err != nil { - return + return err } defer done() - log := cm.log.Named("consensusChange") - - // calculate the block height of the first reverted diff - blockHeight := uint64(cc.BlockHeight) - uint64(len(cc.AppliedBlocks)) + uint64(len(cc.RevertedBlocks)) + 1 - var revertedFormations, revertedResolutions []contractChange - revertedRevisions := make(map[types.FileContractID]contractChange) - for _, reverted := range cc.RevertedBlocks { - index := types.ChainIndex{ - Height: blockHeight, - ID: types.BlockID(reverted.ID()), - } - for _, transaction := range reverted.Transactions { - for i := range transaction.FileContracts { - contractID := types.FileContractID(transaction.FileContractID(uint64(i))) - revertedFormations = append(revertedFormations, contractChange{contractID, index}) - } - - for _, rev := range transaction.FileContractRevisions { - contractID := types.FileContractID(rev.ParentID) - revertedRevisions[contractID] = contractChange{types.FileContractID(rev.ParentID), index} // TODO: revert to the previous revision number, instead of setting to 0 - } - - for _, proof := range transaction.StorageProofs { - contractID := types.FileContractID(proof.ParentID) - revertedResolutions = append(revertedResolutions, contractChange{contractID, index}) - } - } - blockHeight-- + + formationSet := formation.TransactionSet + if len(formationSet) == 0 { + return errors.New("no formation transactions provided") + } else if len(formationSet[len(formationSet)-1].FileContracts) != 1 { + return errors.New("last transaction must contain one file contract") } - var appliedFormations, appliedResolutions []contractChange - appliedRevisions := make(map[types.FileContractID]types.FileContractRevision) - for _, applied := range cc.AppliedBlocks { - index := types.ChainIndex{ - Height: blockHeight, - ID: types.BlockID(applied.ID()), - } - for _, transaction := range applied.Transactions { - for i := range transaction.FileContracts { - contractID := types.FileContractID(transaction.FileContractID(uint64(i))) - appliedFormations = append(appliedFormations, contractChange{contractID, index}) - } - - for _, rev := range transaction.FileContractRevisions { - contractID := types.FileContractID(rev.ParentID) - var revision types.FileContractRevision - convertToCore(rev, &revision) - appliedRevisions[contractID] = revision - } - - for _, proof := range transaction.StorageProofs { - contractID := types.FileContractID(proof.ParentID) - appliedResolutions = append(appliedResolutions, contractChange{contractID, index}) - } - } - blockHeight++ - } - - err = cm.store.UpdateContractState(cc.ID, uint64(cc.BlockHeight), func(tx UpdateStateTransaction) error { - for _, reverted := range revertedFormations { - if relevant, err := tx.ContractRelevant(reverted.id); err != nil { - return fmt.Errorf("failed to check if contract %v is relevant: %w", reverted, err) - } else if !relevant { - continue - } else if err := tx.RevertFormation(reverted.id); err != nil { - return fmt.Errorf("failed to revert formation: %w", err) - } - - log.Warn("contract formation reverted", zap.Stringer("contractID", reverted.id), zap.Stringer("block", reverted.index)) - cm.alerts.Register(alerts.Alert{ - ID: types.Hash256(reverted.id), - Severity: alerts.SeverityWarning, - Message: "Contract formation reverted", - Data: map[string]any{ - "contractID": reverted.id, - "index": reverted.index, - }, - Timestamp: time.Now(), - }) - } + formationTxn := formationSet[len(formationSet)-1] + fc := formationTxn.FileContracts[0] + contractID := formationTxn.V2FileContractID(formationTxn.ID(), 0) - for _, reverted := range revertedRevisions { - if relevant, err := tx.ContractRelevant(reverted.id); err != nil { - return fmt.Errorf("failed to check if contract %v is relevant: %w", reverted, err) - } else if !relevant { - continue - } else if err := tx.RevertRevision(reverted.id); err != nil { - return fmt.Errorf("failed to revert revision: %w", err) - } - - log.Warn("contract revision reverted", zap.Stringer("contractID", reverted.id), zap.Stringer("block", reverted.index)) - cm.alerts.Register(alerts.Alert{ - ID: types.Hash256(reverted.id), - Severity: alerts.SeverityWarning, - Message: "Contract revision reverted", - Data: map[string]any{ - "contractID": reverted.id, - "index": reverted.index, - }, - Timestamp: time.Now(), - }) - } + contract := V2Contract{ + V2FileContract: fc, - for _, reverted := range revertedResolutions { - if relevant, err := tx.ContractRelevant(reverted.id); err != nil { - return fmt.Errorf("failed to check if contract %v is relevant: %w", reverted.id, err) - } else if !relevant { - continue - } else if err := tx.RevertResolution(reverted.id); err != nil { - return fmt.Errorf("failed to revert proof: %w", err) - } - - log.Warn("contract resolution reverted", zap.Stringer("contractID", reverted.id), zap.Stringer("block", reverted.index)) - cm.alerts.Register(alerts.Alert{ - ID: types.Hash256(reverted.id), - Severity: alerts.SeverityWarning, - Message: "Contract resolution reverted", - Data: map[string]any{ - "contractID": reverted.id, - "index": reverted.index, - }, - Timestamp: time.Now(), - }) - } + ID: contractID, + Status: V2ContractStatusPending, + NegotiationHeight: cm.chain.Tip().Height, + Usage: usage, + } - for _, applied := range appliedFormations { - if relevant, err := tx.ContractRelevant(applied.id); err != nil { - return fmt.Errorf("failed to check if contract %v is relevant: %w", applied.id, err) - } else if !relevant { - continue - } else if err := tx.ConfirmFormation(applied.id); err != nil { - return fmt.Errorf("failed to apply formation: %w", err) - } - - log.Info("contract formation confirmed", zap.Stringer("contractID", applied.id), zap.Stringer("block", applied.index)) - cm.alerts.Dismiss(types.Hash256(applied.id)) // dismiss any lifecycle alerts for this contract - } + if err := cm.store.AddV2Contract(contract, formation); err != nil { + return err + } + cm.log.Debug("contract formed", zap.Stringer("contractID", contractID)) + return nil +} - for _, applied := range appliedRevisions { - if relevant, err := tx.ContractRelevant(applied.ParentID); err != nil { - return fmt.Errorf("failed to check if contract %v is relevant: %w", applied.ParentID, err) - } else if !relevant { - continue - } else if err := tx.ConfirmRevision(applied); err != nil { - return fmt.Errorf("failed to apply revision: %w", err) - } - - log.Info("contract revision confirmed", zap.Stringer("contractID", applied.ParentID), zap.Uint64("revisionNumber", applied.RevisionNumber)) - cm.alerts.Dismiss(types.Hash256(applied.ParentID)) // dismiss any lifecycle alerts for this contract - } +// RenewV2Contract renews a contract. It is expected that the existing +// contract will be cleared. +func (cm *Manager) RenewV2Contract(renewal V2FormationTransactionSet, usage V2Usage) error { + done, err := cm.tg.Add() + if err != nil { + return err + } + defer done() - for _, applied := range appliedResolutions { - if relevant, err := tx.ContractRelevant(applied.id); err != nil { - return fmt.Errorf("failed to check if contract %v is relevant: %w", applied, err) - } else if !relevant { - continue - } else if err := tx.ConfirmResolution(applied.id, applied.index.Height); err != nil { - return fmt.Errorf("failed to apply proof: %w", err) - } - - log.Info("contract resolution confirmed", zap.Stringer("contractID", applied.id), zap.Stringer("block", applied.index)) - cm.alerts.Dismiss(types.Hash256(applied.id)) // dismiss any lifecycle alerts for this contract - } - return nil - }) + renewalSet := renewal.TransactionSet + if len(renewalSet) == 0 { + return errors.New("no renewal transactions provided") + } else if len(renewalSet[len(renewalSet)-1].FileContractResolutions) != 1 { + return errors.New("last transaction must contain one file contract resolution") + } + + resolutionTxn := renewalSet[len(renewalSet)-1] + resolution, ok := resolutionTxn.FileContractResolutions[0].Resolution.(*types.V2FileContractRenewal) + if !ok { + return fmt.Errorf("unexpected resolution type %T", resolutionTxn.FileContractResolutions[0].Resolution) + } + + parentID := resolutionTxn.FileContractResolutions[0].Parent.ID + existing, err := cm.store.V2Contract(types.FileContractID(parentID)) if err != nil { - log.Error("failed to process consensus change", zap.Error(err)) - return + return fmt.Errorf("failed to get existing contract: %w", err) } + finalRevision := resolution.FinalRevision + fc := resolution.NewContract - scanHeight := uint64(cc.BlockHeight) - log.Debug("consensus change applied", zap.Uint64("height", scanHeight), zap.String("changeID", cc.ID.String())) + // sanity checks + if finalRevision.FileMerkleRoot != (types.Hash256{}) { + return errors.New("existing contract must be cleared") + } else if finalRevision.Filesize != 0 { + return errors.New("existing contract must be cleared") + } else if finalRevision.RevisionNumber != types.MaxRevisionNumber { + return errors.New("existing contract must be cleared") + } else if fc.Filesize != existing.Filesize { + return errors.New("renewal contract must have same file size as existing contract") + } else if fc.FileMerkleRoot != existing.FileMerkleRoot { + return errors.New("renewal root does not match existing roots") + } - // if the last block is more than 3 days old, skip action processing until - // consensus is caught up - blockTime := time.Unix(int64(cc.AppliedBlocks[len(cc.AppliedBlocks)-1].Timestamp), 0) - if time.Since(blockTime) > 72*time.Hour { - return + existingID := types.FileContractID(existing.ID) + existingRoots := cm.getSectorRoots(existingID) + if fc.FileMerkleRoot != rhp2.MetaRoot(existingRoots) { + return errors.New("renewal root does not match existing roots") } - // perform actions in a separate goroutine to avoid deadlock in tpool. - // triggers the processActions goroutine to process the block - go func() { - cm.processQueue <- uint64(cc.BlockHeight) - }() + contract := V2Contract{ + V2FileContract: fc, + + ID: existingID.V2RenewalID(), + Status: V2ContractStatusPending, + NegotiationHeight: cm.chain.Tip().Height, + RenewedFrom: existingID, + Usage: usage, + } + + if err := cm.store.RenewV2Contract(contract, renewal, existingID, finalRevision); err != nil { + return err + } + cm.setSectorRoots(contract.ID, existingRoots) + cm.log.Debug("contract renewed", zap.Stringer("formedID", contract.ID), zap.Stringer("existingID", existingID)) + return nil +} + +// SectorRoots returns the roots of all sectors stored by the contract. +func (cm *Manager) SectorRoots(id types.FileContractID) []types.Hash256 { + return cm.getSectorRoots(id) } // ReviseContract initializes a new contract updater for the given contract. -func (cm *ContractManager) ReviseContract(contractID types.FileContractID) (*ContractUpdater, error) { +func (cm *Manager) ReviseContract(contractID types.FileContractID) (*ContractUpdater, error) { done, err := cm.tg.Add() if err != nil { return nil, err } - roots, err := cm.getSectorRoots(contractID) - if err != nil { - return nil, fmt.Errorf("failed to get sector roots: %w", err) - } + roots := cm.getSectorRoots(contractID) return &ContractUpdater{ - store: cm.store, - log: cm.log.Named("contractUpdater"), + manager: cm, + store: cm.store, + log: cm.log.Named("contractUpdater"), - rootsCache: cm.rootsCache, contractID: contractID, sectorRoots: roots, // roots is already a deep copy oldRoots: append([]types.Hash256(nil), roots...), @@ -467,78 +418,51 @@ func (cm *ContractManager) ReviseContract(contractID types.FileContractID) (*Con } // Close closes the contract manager. -func (cm *ContractManager) Close() error { +func (cm *Manager) Close() error { cm.tg.Stop() return nil } // isGoodForModification validates if a contract can be modified -func isGoodForModification(contract Contract, height uint64) error { +func (cm *Manager) isGoodForModification(contract Contract) error { + height := cm.chain.TipState().Index.Height switch { case contract.Status != ContractStatusActive && contract.Status != ContractStatusPending: return fmt.Errorf("contract status is %v, contract cannot be used", contract.Status) - case (height + RevisionSubmissionBuffer) > contract.Revision.WindowStart: - return fmt.Errorf("contract is too close to the proof window to be revised (%v > %v)", height+RevisionSubmissionBuffer, contract.Revision.WindowStart) + case (height + cm.revisionSubmissionBuffer) > contract.Revision.WindowStart: + return fmt.Errorf("contract is too close to the proof window to be revised (%v > %v)", height+cm.revisionSubmissionBuffer, contract.Revision.WindowStart) case contract.Revision.RevisionNumber == math.MaxUint64: return fmt.Errorf("contract has reached the maximum number of revisions") } return nil } -func convertToCore(siad encoding.SiaMarshaler, core types.DecoderFrom) { - var buf bytes.Buffer - siad.MarshalSia(&buf) - d := types.NewBufDecoder(buf.Bytes()) - core.DecodeFrom(d) - if d.Err() != nil { - panic(d.Err()) - } -} - // NewManager creates a new contract manager. -func NewManager(store ContractStore, alerts Alerts, storage StorageManager, c ChainManager, tpool TransactionPool, wallet Wallet, log *zap.Logger) (*ContractManager, error) { - cache, err := lru.New2Q[types.FileContractID, []types.Hash256](sectorRootCacheSize) - if err != nil { - return nil, fmt.Errorf("failed to create cache: %w", err) - } - cm := &ContractManager{ +func NewManager(store ContractStore, storage StorageManager, chain ChainManager, syncer Syncer, wallet Wallet, opts ...ManagerOption) (*Manager, error) { + cm := &Manager{ store: store, - tg: threadgroup.New(), - log: log, - alerts: alerts, storage: storage, - chain: c, - tpool: tpool, + chain: chain, + syncer: syncer, wallet: wallet, - rootsCache: cache, + alerts: alerts.NewNop(), + tg: threadgroup.New(), + log: zap.NewNop(), - processQueue: make(chan uint64, 100), - locks: make(map[types.FileContractID]*locker), + locks: make(map[types.FileContractID]*locker), } - changeID, err := store.LastContractChange() - if err != nil { - return nil, fmt.Errorf("failed to get last contract change: %w", err) - } - - // start the actions queue. Required to avoid a deadlock in the tpool, but - // still process consensus changes serially. - go cm.processActions() - - // subscribe to the consensus set in a separate goroutine to prevent - // blocking startup - go func() { - err := cm.chain.Subscribe(cm, changeID, cm.tg.Done()) - if errors.Is(err, chain.ErrInvalidChangeID) { - cm.log.Warn("rescanning blockchain due to unknown consensus change ID") - if err := cm.chain.Subscribe(cm, modules.ConsensusChangeBeginning, cm.tg.Done()); err != nil { - cm.log.Fatal("failed to reset consensus change subscription", zap.Error(err)) - } - } else if err != nil && !strings.Contains(err.Error(), "ThreadGroup already stopped") { - cm.log.Fatal("failed to subscribe to consensus changes", zap.Error(err)) - } - }() + for _, opt := range opts { + opt(cm) + } + start := time.Now() + roots, err := store.SectorRoots() + if err != nil { + return nil, fmt.Errorf("failed to get sector roots: %w", err) + } + cm.sectorRoots = roots + cm.log.Debug("loaded sector roots", zap.Duration("elapsed", time.Since(start))) return cm, nil } diff --git a/host/contracts/manager_test.go b/host/contracts/manager_test.go index 3126d817..47c6c65c 100644 --- a/host/contracts/manager_test.go +++ b/host/contracts/manager_test.go @@ -11,27 +11,82 @@ import ( rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/test" + "go.sia.tech/hostd/internal/testutil" "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/webhooks" - stypes "go.sia.tech/siad/types" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) -const sectorCacheSize = 64 - func hashRevision(rev types.FileContractRevision) types.Hash256 { h := types.NewHasher() rev.EncodeTo(h.E) return h.Sum() } -func formContract(renterKey, hostKey types.PrivateKey, start, end uint64, renterPayout, hostPayout types.Currency, c *contracts.ContractManager, w contracts.Wallet, cm contracts.ChainManager, tp contracts.TransactionPool) (contracts.SignedRevision, error) { - contract := rhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey.PublicKey(), renterPayout, hostPayout, start, rhp2.HostSettings{WindowSize: end - start}, w.Address()) +func formV2Contract(t *testing.T, cm *chain.Manager, c *contracts.Manager, w *wallet.SingleAddressWallet, s *syncer.Syncer, renterKey, hostKey types.PrivateKey, renterFunds, hostFunds types.Currency, duration uint64, broadcast bool) (types.FileContractID, types.V2FileContract) { + t.Helper() + + cs := cm.TipState() + fc := types.V2FileContract{ + RevisionNumber: 0, + Filesize: 0, + FileMerkleRoot: types.Hash256{}, + ProofHeight: cs.Index.Height + duration, + ExpirationHeight: cs.Index.Height + duration + 10, + RenterOutput: types.SiacoinOutput{ + Value: renterFunds, + Address: w.Address(), + }, + HostOutput: types.SiacoinOutput{ + Value: hostFunds, + Address: w.Address(), + }, + MissedHostValue: hostFunds, + TotalCollateral: hostFunds, + RenterPublicKey: renterKey.PublicKey(), + HostPublicKey: hostKey.PublicKey(), + } + fundAmount := cs.V2FileContractTax(fc).Add(hostFunds).Add(renterFunds) + sigHash := cs.ContractSigHash(fc) + fc.HostSignature = hostKey.SignHash(sigHash) + fc.RenterSignature = renterKey.SignHash(sigHash) + + txn := types.V2Transaction{ + FileContracts: []types.V2FileContract{fc}, + } + + basis, toSign, err := w.FundV2Transaction(&txn, fundAmount, false) + if err != nil { + t.Fatal("failed to fund transaction:", err) + } + w.SignV2Inputs(&txn, toSign) + formationSet := contracts.V2FormationTransactionSet{ + TransactionSet: []types.V2Transaction{txn}, + Basis: basis, + } + + if broadcast { + if _, err := cm.AddV2PoolTransactions(formationSet.Basis, formationSet.TransactionSet); err != nil { + t.Fatal("failed to add formation set to pool:", err) + } + s.BroadcastV2TransactionSet(formationSet.Basis, formationSet.TransactionSet) + } + + if err := c.AddV2Contract(formationSet, contracts.V2Usage{}); err != nil { + t.Fatal("failed to add contract:", err) + } + return txn.V2FileContractID(txn.ID(), 0), fc +} + +func formContract(t *testing.T, cm *chain.Manager, c *contracts.Manager, w *wallet.SingleAddressWallet, s *syncer.Syncer, renterKey, hostKey types.PrivateKey, renterFunds, hostFunds types.Currency, duration uint64, broadcast bool) contracts.SignedRevision { + t.Helper() + + contract := rhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey.PublicKey(), renterFunds, hostFunds, cm.Tip().Height+duration, rhp2.HostSettings{WindowSize: 10}, w.Address()) state := cm.TipState() formationCost := rhp2.ContractFormationCost(state, contract, types.ZeroCurrency) contractUnlockConditions := types.UnlockConditions{ @@ -44,17 +99,19 @@ func formContract(renterKey, hostKey types.PrivateKey, start, end uint64, renter txn := types.Transaction{ FileContracts: []types.FileContract{contract}, } - toSign, release, err := w.FundTransaction(&txn, formationCost.Add(hostPayout)) // we're funding both sides of the payout + toSign, err := w.FundTransaction(&txn, formationCost.Add(hostFunds), true) // we're funding both sides of the payout if err != nil { - return contracts.SignedRevision{}, fmt.Errorf("failed to fund transaction: %w", err) + t.Fatal("failed to fund transaction:", err) } - if err := w.SignTransaction(state, &txn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - return contracts.SignedRevision{}, fmt.Errorf("failed to sign transaction: %w", err) - } else if err := tp.AcceptTransactionSet([]types.Transaction{txn}); err != nil { - release() - return contracts.SignedRevision{}, fmt.Errorf("failed to accept transaction set: %w", err) + w.SignTransaction(&txn, toSign, types.CoveredFields{WholeTransaction: true}) + formationSet := append(cm.UnconfirmedParents(txn), txn) + if broadcast { + if _, err := cm.AddPoolTransactions(formationSet); err != nil { + t.Fatal("failed to add formation set to pool:", err) + } + s.BroadcastTransactionSet(formationSet) } + revision := types.FileContractRevision{ ParentID: txn.FileContractID(0), UnlockConditions: contractUnlockConditions, @@ -67,48 +124,19 @@ func formContract(renterKey, hostKey types.PrivateKey, start, end uint64, renter HostSignature: hostKey.SignHash(sigHash), RenterSignature: renterKey.SignHash(sigHash), } - - if err := c.AddContract(rev, []types.Transaction{}, hostPayout, contracts.Usage{}); err != nil { - return contracts.SignedRevision{}, fmt.Errorf("failed to add contract: %w", err) + if err := c.AddContract(rev, formationSet, hostFunds, contracts.Usage{}); err != nil { + t.Fatal(err) } - return rev, nil + return rev } func TestContractLockUnlock(t *testing.T) { hostKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) - dir := t.TempDir() log := zaptest.NewLogger(t) - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("sqlite")) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - node, err := test.NewWallet(hostKey, t.TempDir(), log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer node.Close() - - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(db, am, node.ChainManager(), log.Named("storage"), sectorCacheSize) - if err != nil { - t.Fatal(err) - } - defer s.Close() - - c, err := contracts.NewManager(db, am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer c.Close() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) contractUnlockConditions := types.UnlockConditions{ PublicKeys: []types.UnlockKey{ @@ -129,127 +157,187 @@ func TestContractLockUnlock(t *testing.T) { }, } - if err := c.AddContract(rev, []types.Transaction{}, types.ZeroCurrency, contracts.Usage{}); err != nil { + if err := node.Contracts.AddContract(rev, []types.Transaction{}, types.ZeroCurrency, contracts.Usage{}); err != nil { t.Fatal(err) } - if _, err := c.Lock(context.Background(), rev.Revision.ParentID); err != nil { + if _, err := node.Contracts.Lock(context.Background(), rev.Revision.ParentID); err != nil { t.Fatal(err) } - err = func() error { + err := func() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _, err = c.Lock(ctx, rev.Revision.ParentID) + _, err := node.Contracts.Lock(ctx, rev.Revision.ParentID) return err }() if !errors.Is(err, context.DeadlineExceeded) { t.Fatal("expected context deadline exceeded, got", err) } - c.Unlock(rev.Revision.ParentID) + node.Contracts.Unlock(rev.Revision.ParentID) var wg sync.WaitGroup for i := 0; i < 50; i++ { wg.Add(1) go func() { defer wg.Done() - if _, err := c.Lock(context.Background(), rev.Revision.ParentID); err != nil { + if _, err := node.Contracts.Lock(context.Background(), rev.Revision.ParentID); err != nil { t.Error(err) } time.Sleep(100 * time.Millisecond) - c.Unlock(rev.Revision.ParentID) + node.Contracts.Unlock(rev.Revision.ParentID) }() } wg.Wait() } func TestContractLifecycle(t *testing.T) { - t.Run("successful with proof", func(t *testing.T) { - hostKey, renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)), types.NewPrivateKeyFromSeed(frand.Bytes(32)) + assertContractStatus := func(t *testing.T, c *contracts.Manager, contractID types.FileContractID, status contracts.ContractStatus) { + t.Helper() - dir := t.TempDir() - log := zaptest.NewLogger(t) - node, err := test.NewWallet(hostKey, dir, log.Named("wallet")) + contract, err := c.Contract(contractID) if err != nil { - t.Fatal(err) + t.Fatal("failed to get contract", err) + } else if contract.Status != status { + t.Fatalf("expected contract to be %v, got %v", status, contract.Status) } - defer node.Close() + } - webhookReporter, err := webhooks.NewManager(node.Store(), log.Named("webhooks")) + assertContractMetrics := func(t *testing.T, s *sqlite.Store, active, successful uint64, locked, risked types.Currency) { + t.Helper() + + m, err := s.Metrics(time.Now()) if err != nil { t.Fatal(err) + } else if m.Contracts.Active != active { + t.Fatalf("expected %v active contracts, got %v", active, m.Contracts.Active) + } else if m.Contracts.Successful != successful { + t.Fatalf("expected %v successful contracts, got %v", successful, m.Contracts.Successful) + } else if !m.Contracts.LockedCollateral.Equals(locked) { + t.Fatalf("expected %v locked collateral, got %v", locked, m.Contracts.LockedCollateral) + } else if !m.Contracts.RiskedCollateral.Equals(risked) { + t.Fatalf("expected %v risked collateral, got %v", risked, m.Contracts.RiskedCollateral) } + } + + t.Run("reject", func(t *testing.T) { + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + log := zaptest.NewLogger(t) - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(node.Store(), am, node.ChainManager(), log.Named("storage"), sectorCacheSize) + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) + + cm := node.Chain + c := node.Contracts + w := node.Wallet + + renterFunds := types.Siacoins(10) + hostFunds := types.Siacoins(20) + contract := rhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey.PublicKey(), renterFunds, hostFunds, cm.Tip().Height+10, rhp2.HostSettings{WindowSize: 10}, w.Address()) + state := cm.TipState() + formationCost := rhp2.ContractFormationCost(state, contract, types.ZeroCurrency) + contractUnlockConditions := types.UnlockConditions{ + PublicKeys: []types.UnlockKey{ + renterKey.PublicKey().UnlockKey(), + hostKey.PublicKey().UnlockKey(), + }, + SignaturesRequired: 2, + } + txn := types.Transaction{ + FileContracts: []types.FileContract{contract}, + } + toSign, err := w.FundTransaction(&txn, formationCost.Add(hostFunds), true) // we're funding both sides of the payout if err != nil { + t.Fatal("failed to fund transaction:", err) + } + w.SignTransaction(&txn, toSign, types.CoveredFields{WholeTransaction: true}) + formationSet := append(cm.UnconfirmedParents(txn), txn) + revision := types.FileContractRevision{ + ParentID: txn.FileContractID(0), + UnlockConditions: contractUnlockConditions, + FileContract: txn.FileContracts[0], + } + // corrupt the transaction set to simulate a rejected contract + formationSet[len(formationSet)-1].Signatures = nil + revision.RevisionNumber = 1 + sigHash := hashRevision(revision) + rev := contracts.SignedRevision{ + Revision: revision, + HostSignature: hostKey.SignHash(sigHash), + RenterSignature: renterKey.SignHash(sigHash), + } + if err := c.AddContract(rev, formationSet, hostFunds, contracts.Usage{}); err != nil { t.Fatal(err) } - defer s.Close() + + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusPending) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) + + // mine until the contract is rejected + testutil.MineAndSync(t, node, types.VoidAddress, 20) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusRejected) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) + }) + + t.Run("rebroadcast", func(t *testing.T) { + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + log := zaptest.NewLogger(t) + + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) + + rev := formContract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, types.Siacoins(10), types.Siacoins(20), 10, false) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusPending) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) + + // mine a block to rebroadcast the formation set + testutil.MineAndSync(t, node, types.VoidAddress, 1) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusPending) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) + + // mine another block to confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + assertContractMetrics(t, node.Store, 1, 0, types.Siacoins(20), types.ZeroCurrency) + + // mine until the contract is successful + testutil.MineAndSync(t, node, types.VoidAddress, int(rev.Revision.WindowEnd-node.Chain.Tip().Height)+1) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusSuccessful) + assertContractMetrics(t, node.Store, 0, 1, types.ZeroCurrency, types.ZeroCurrency) + }) + + t.Run("successful with proof", func(t *testing.T) { + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + + dir := t.TempDir() + log := zaptest.NewLogger(t) + + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) result := make(chan error, 1) - if _, err := s.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { t.Fatal(err) } else if err := <-result; err != nil { t.Fatal(err) } - c, err := contracts.NewManager(node.Store(), am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - // note: many more blocks than necessary are mined to ensure all forks have activated - if err := node.MineBlocks(node.Address(), int(stypes.MaturityDelay*4)); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) renterFunds := types.Siacoins(500) hostCollateral := types.Siacoins(1000) - rev, err := formContract(renterKey, hostKey, 50, 60, renterFunds, hostCollateral, c, node, node.ChainManager(), node.TPool()) - if err != nil { - t.Fatal(err) - } + rev := formContract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, renterFunds, hostCollateral, 10, true) - contract, err := c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusPending { - t.Fatal("expected contract to be pending") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 1 { - t.Fatal("expected 1 pending contract") - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected 0 risked collateral, got %v", m.Contracts.RiskedCollateral) - } - - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusPending) + // pending contracts do not contribute to metrics + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 0 { - t.Fatal("expected 0 pending contracts") - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contract") - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected 0 risked collateral, got %v", m.Contracts.RiskedCollateral) - } + testutil.MineAndSync(t, node, types.VoidAddress, 1) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, types.ZeroCurrency) var releaseFuncs []func() error defer func() { @@ -265,7 +353,7 @@ func TestContractLifecycle(t *testing.T) { var sector [rhp2.SectorSize]byte frand.Read(sector[:256]) root := rhp2.SectorRoot(§or) - release, err := s.Write(root, §or) + release, err := node.Volumes.Write(root, §or) if err != nil { t.Fatal(err) } @@ -288,7 +376,7 @@ func TestContractLifecycle(t *testing.T) { rev.HostSignature = hostKey.SignHash(sigHash) rev.RenterSignature = renterKey.SignHash(sigHash) - updater, err := c.ReviseContract(rev.Revision.ParentID) + updater, err := node.Contracts.ReviseContract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } @@ -311,184 +399,57 @@ func TestContractLifecycle(t *testing.T) { } } - if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.Equals(collateral) { - t.Fatalf("expected %v risked collateral, got %v", collateral, m.Contracts.RiskedCollateral) - } + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, collateral) - // mine until the revision is broadcast - remainingBlocks := rev.Revision.WindowStart - node.TipState().Index.Height - contracts.RevisionSubmissionBuffer - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time - // confirm the revision - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + // mine until right before the proof window so the revision is broadcast + // and confirmed + remainingBlocks := rev.Revision.WindowStart - node.Chain.Tip().Height - 1 + testutil.MineAndSync(t, node, types.VoidAddress, int(remainingBlocks)) - contract, err = c.Contract(rev.Revision.ParentID) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + contract, err := node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") } else if !contract.RevisionConfirmed { t.Fatal("expected revision to be confirmed") } - // mine until the proof window - remainingBlocks = rev.Revision.WindowStart - node.TipState().Index.Height - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time - // confirm the proof - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time - proofHeight := rev.Revision.WindowStart + 1 - - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") - } else if contract.ResolutionHeight != proofHeight { - t.Fatalf("expected resolution height %v, got %v", proofHeight, contract.ResolutionHeight) - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contracts") - } else if m.Contracts.Successful != 0 { - t.Fatal("expected 0 successful contract") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.Equals(collateral) { - t.Fatalf("expected %v risked collateral, got %v", collateral, m.Contracts.RiskedCollateral) - } + // mine into the proof window + testutil.MineAndSync(t, node, types.VoidAddress, 2) - // mine until the end of the proof window - remainingBlocks = rev.Revision.WindowEnd - node.TipState().Index.Height + 1 - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time - - // check that the contract was marked successful - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusSuccessful { - t.Fatal("expected contract to be successful") - } else if contract.ResolutionHeight != proofHeight { - t.Fatalf("expected resolution height %v, got %v", proofHeight, contract.ResolutionHeight) - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Active != 0 { - t.Fatal("expected 0 active contracts") - } else if m.Contracts.Successful != 1 { - t.Fatal("expected 1 successful contract") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Contracts.LockedCollateral.IsZero() { - t.Fatalf("expected %v locked collateral, got %v", types.ZeroCurrency, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected %v risked collateral, got %v", types.ZeroCurrency, m.Contracts.RiskedCollateral) - } + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusSuccessful) + assertContractMetrics(t, node.Store, 0, 1, types.ZeroCurrency, types.ZeroCurrency) }) t.Run("successful no proof", func(t *testing.T) { - hostKey, renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)), types.NewPrivateKeyFromSeed(frand.Bytes(32)) + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() dir := t.TempDir() log := zaptest.NewLogger(t) - node, err := test.NewWallet(hostKey, dir, log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer node.Close() - webhookReporter, err := webhooks.NewManager(node.Store(), log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(node.Store(), am, node.ChainManager(), log.Named("storage"), sectorCacheSize) - if err != nil { - t.Fatal(err) - } - defer s.Close() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) result := make(chan error, 1) - if _, err := s.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { t.Fatal(err) } else if err := <-result; err != nil { t.Fatal(err) } - c, err := contracts.NewManager(node.Store(), am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - // note: many more blocks than necessary are mined to ensure all forks have activated - if err := node.MineBlocks(node.Address(), int(stypes.MaturityDelay*4)); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) renterFunds := types.Siacoins(500) hostCollateral := types.Siacoins(1000) - rev, err := formContract(renterKey, hostKey, 50, 60, renterFunds, hostCollateral, c, node, node.ChainManager(), node.TPool()) - if err != nil { - t.Fatal(err) - } + rev := formContract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, renterFunds, hostCollateral, 10, true) - contract, err := c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusPending { - t.Fatal("expected contract to be pending") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 1 { - t.Fatal("expected 1 pending contract") - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected 0 risked collateral, got %v", m.Contracts.RiskedCollateral) - } + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusPending) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time - - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 0 { - t.Fatal("expected 0 pending contracts") - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contract") - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected 0 risked collateral, got %v", m.Contracts.RiskedCollateral) - } + // confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, types.ZeroCurrency) // create a revision that transfers funds to the host, simulating // account funding @@ -502,7 +463,7 @@ func TestContractLifecycle(t *testing.T) { rev.HostSignature = hostKey.SignHash(sigHash) rev.RenterSignature = renterKey.SignHash(sigHash) - updater, err := c.ReviseContract(rev.Revision.ParentID) + updater, err := node.Contracts.ReviseContract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } @@ -513,27 +474,15 @@ func TestContractLifecycle(t *testing.T) { }) if err != nil { t.Fatal(err) - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.Equals(types.ZeroCurrency) { - t.Fatalf("expected %v risked collateral, got %v", types.ZeroCurrency, m.Contracts.RiskedCollateral) } - // mine until the revision is broadcast - remainingBlocks := rev.Revision.WindowStart - node.TipState().Index.Height - contracts.RevisionSubmissionBuffer - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time - // confirm the revision - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, types.ZeroCurrency) - contract, err = c.Contract(rev.Revision.ParentID) + // mine until right before the proof window so the revision is broadcast + remainingBlocks := rev.Revision.WindowStart - node.Chain.Tip().Height - 1 + testutil.MineAndSync(t, node, types.VoidAddress, int(remainingBlocks)) + + contract, err := node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if contract.Status != contracts.ContractStatusActive { @@ -544,135 +493,59 @@ func TestContractLifecycle(t *testing.T) { // mine until the end of the proof window -- contract should still be // active since no proof is required. - remainingBlocks = rev.Revision.WindowEnd - node.TipState().Index.Height - 1 - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time + remainingBlocks = rev.Revision.WindowEnd - node.Chain.Tip().Height - 1 + testutil.MineAndSync(t, node, types.VoidAddress, int(remainingBlocks)) - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 0 { - t.Fatal("expected 0 pending contracts") - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contract") - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected 0 risked collateral, got %v", m.Contracts.RiskedCollateral) - } + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, types.ZeroCurrency) - // mine until after the proof window -- contract should be successful - remainingBlocks = rev.Revision.WindowEnd - node.TipState().Index.Height + 1 - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time + // mine after the proof window ends -- contract should be successful + testutil.MineAndSync(t, node, types.VoidAddress, 10) + + assertContractMetrics(t, node.Store, 0, 1, types.ZeroCurrency, types.ZeroCurrency) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusSuccessful) - contract, err = c.Contract(rev.Revision.ParentID) + contract, err = node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if contract.Status != contracts.ContractStatusSuccessful { t.Fatal("expected contract to be successful") - } else if contract.ResolutionHeight != 0 { - t.Fatalf("expected resolution height %v, got %v", 0, contract.ResolutionHeight) - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Active != 0 { - t.Fatal("expected 0 active contracts") - } else if m.Contracts.Successful != 1 { - t.Fatal("expected 1 successful contract") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Contracts.LockedCollateral.IsZero() { - t.Fatalf("expected %v locked collateral, got %v", types.ZeroCurrency, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected %v risked collateral, got %v", types.ZeroCurrency, m.Contracts.RiskedCollateral) + } else if contract.ResolutionHeight != contract.Revision.WindowEnd { + t.Fatalf("expected resolution height %v, got %v", contract.Revision.WindowEnd, contract.ResolutionHeight) } }) t.Run("0 filesize contract", func(t *testing.T) { - hostKey, renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)), types.NewPrivateKeyFromSeed(frand.Bytes(32)) + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() dir := t.TempDir() log := zaptest.NewLogger(t) - node, err := test.NewWallet(hostKey, dir, log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer node.Close() - - webhookReporter, err := webhooks.NewManager(node.Store(), log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(node.Store(), am, node.ChainManager(), log.Named("storage"), sectorCacheSize) - if err != nil { - t.Fatal(err) - } - defer s.Close() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) result := make(chan error, 1) - if _, err := s.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { t.Fatal(err) } else if err := <-result; err != nil { t.Fatal(err) } - c, err := contracts.NewManager(node.Store(), am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - // note: mine enough blocks to ensure all forks have activated - if err := node.MineBlocks(node.Address(), int(stypes.MaturityDelay*4)); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) - rev, err := formContract(renterKey, hostKey, 50, 60, types.Siacoins(500), types.Siacoins(1000), c, node, node.ChainManager(), node.TPool()) - if err != nil { - t.Fatal(err) - } + renterFunds := types.Siacoins(500) + hostCollateral := types.Siacoins(1000) + rev := formContract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, renterFunds, hostCollateral, 10, true) - contract, err := c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusPending { - t.Fatal("expected contract to be pending") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 1 { - t.Fatal("expected 1 pending contract") - } + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusPending) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + // confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, types.ZeroCurrency) - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 0 { - t.Fatal("expected 0 pending contracts") - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contract") - } - - // create a revision that adds sectors and transfers funds to the host + // create a revision that transfers funds to the host with out adding any sectors amount := types.NewCurrency64(100) rev.Revision.RevisionNumber++ rev.Revision.ValidProofOutputs[0].Value = rev.Revision.ValidProofOutputs[0].Value.Sub(amount) @@ -681,7 +554,7 @@ func TestContractLifecycle(t *testing.T) { rev.HostSignature = hostKey.SignHash(sigHash) rev.RenterSignature = renterKey.SignHash(sigHash) - updater, err := c.ReviseContract(rev.Revision.ParentID) + updater, err := node.Contracts.ReviseContract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } @@ -691,19 +564,12 @@ func TestContractLifecycle(t *testing.T) { t.Fatal(err) } - // mine until the revision is broadcast - remainingBlocks := rev.Revision.WindowStart - node.TipState().Index.Height - contracts.RevisionSubmissionBuffer - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + // mine until right before the proof window starts to broadcast and // confirm the revision - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + remainingBlocks := rev.Revision.WindowStart - node.Chain.Tip().Height - 1 + testutil.MineAndSync(t, node, types.VoidAddress, int(remainingBlocks)) - contract, err = c.Contract(rev.Revision.ParentID) + contract, err := node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if contract.Status != contracts.ContractStatusActive { @@ -712,158 +578,61 @@ func TestContractLifecycle(t *testing.T) { t.Fatal("expected revision to be confirmed") } - // mine until the proof window - remainingBlocks = rev.Revision.WindowStart - node.TipState().Index.Height + 1 - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time - // confirm the proof - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time + // mine until just before the end of the proof window to broadcast the + // proof and confirm the resolution + remainingBlocks = rev.Revision.WindowEnd - node.Chain.Tip().Height - 1 + testutil.MineAndSync(t, node, types.VoidAddress, int(remainingBlocks)) - contract, err = c.Contract(rev.Revision.ParentID) + contract, err = node.Contracts.Contract(rev.Revision.ParentID) if err != nil { t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatalf("expected contract to be active, got %v", contract.Status) } else if contract.ResolutionHeight == 0 { - t.Fatal("expected contract to have resolution") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contracts") - } else if m.Contracts.Successful != 0 { - t.Fatal("expected 0 successful contracts") + t.Fatalf("expected contract to have resolution got %v", contract.ResolutionHeight) } - // mine until the proof window ends -- contract should be successful - remainingBlocks = rev.Revision.WindowEnd - node.TipState().Index.Height + 1 - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sync time - - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusSuccessful { - t.Fatalf("expected contract to be active, got %v", contract.Status) - } else if contract.ResolutionHeight == 0 { - t.Fatal("expected contract to have resolution") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Active != 0 { - t.Fatal("expected 0 active contracts") - } else if m.Contracts.Successful != 1 { - t.Fatal("expected 1 successful contracts") - } + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusSuccessful) + assertContractMetrics(t, node.Store, 0, 1, types.ZeroCurrency, types.ZeroCurrency) }) t.Run("failed corrupt sector", func(t *testing.T) { - hostKey, renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)), types.NewPrivateKeyFromSeed(frand.Bytes(32)) + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() dir := t.TempDir() log := zaptest.NewLogger(t) - node, err := test.NewWallet(hostKey, dir, log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer node.Close() - cm := node.ChainManager() - webhookReporter, err := webhooks.NewManager(node.Store(), log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(node.Store(), am, node.ChainManager(), log.Named("storage"), sectorCacheSize) - if err != nil { - t.Fatal(err) - } - defer s.Close() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) result := make(chan error, 1) - if _, err := s.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { t.Fatal(err) } else if err := <-result; err != nil { t.Fatal(err) } - c, err := contracts.NewManager(node.Store(), am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - // waitForScan is a helper func to wait for the contract manager - // to catch up with chain manager - waitForScan := func() { - for cm.TipState().Index.Height != c.ScanHeight() { - time.Sleep(100 * time.Millisecond) - } - } - - // note: many more blocks than necessary are mined to ensure all forks have activated - if err := node.MineBlocks(node.Address(), int(stypes.MaturityDelay*4)); err != nil { - t.Fatal(err) - } - waitForScan() + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) renterFunds := types.Siacoins(500) hostCollateral := types.Siacoins(1000) - rev, err := formContract(renterKey, hostKey, 50, 60, renterFunds, hostCollateral, c, node, node.ChainManager(), node.TPool()) - if err != nil { - t.Fatal(err) - } + rev := formContract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, renterFunds, hostCollateral, 10, true) - contract, err := c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusPending { - t.Fatal("expected contract to be pending") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 1 { - t.Fatal("expected 1 pending contract") - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected 0 risked collateral, got %v", m.Contracts.RiskedCollateral) - } + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusPending) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - waitForScan() + // confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) - contract, err = c.Contract(rev.Revision.ParentID) - if err != nil { - t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if m.Contracts.Pending != 0 { - t.Fatal("expected 0 pending contracts") - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contract") - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected 0 risked collateral, got %v", m.Contracts.RiskedCollateral) - } + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, types.ZeroCurrency) + // add sectors to the volume manager var releaseFuncs []func() error var roots []types.Hash256 for i := 0; i < 5; i++ { var sector [rhp2.SectorSize]byte frand.Read(sector[:]) root := rhp2.SectorRoot(§or) - release, err := s.Write(root, §or) + release, err := node.Volumes.Write(root, §or) if err != nil { t.Fatal(err) } @@ -886,7 +655,7 @@ func TestContractLifecycle(t *testing.T) { rev.HostSignature = hostKey.SignHash(sigHash) rev.RenterSignature = renterKey.SignHash(sigHash) - updater, err := c.ReviseContract(rev.Revision.ParentID) + updater, err := node.Contracts.ReviseContract(rev.Revision.ParentID) if err != nil { t.Fatal(err) } @@ -911,142 +680,535 @@ func TestContractLifecycle(t *testing.T) { } } - if m, err := node.Store().Metrics(time.Now()); err != nil { + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusActive) + assertContractMetrics(t, node.Store, 1, 0, hostCollateral, collateral) + + // mine until right before the proof window so the revision is broadcast + // and confirmed + remainingBlocks := rev.Revision.WindowStart - node.Chain.Tip().Height - 1 + testutil.MineAndSync(t, node, types.VoidAddress, int(remainingBlocks)) + + contract, err := node.Contracts.Contract(rev.Revision.ParentID) + if err != nil { t.Fatal(err) - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", hostCollateral, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.Equals(collateral) { - t.Fatalf("expected %v risked collateral, got %v", collateral, m.Contracts.RiskedCollateral) + } else if contract.Status != contracts.ContractStatusActive { + t.Fatal("expected contract to be active") + } else if !contract.RevisionConfirmed { + t.Fatal("expected revision to be confirmed") } - // mine until the revision is broadcast - remainingBlocks := rev.Revision.WindowStart - node.TipState().Index.Height - contracts.RevisionSubmissionBuffer - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { - t.Fatal(err) + // mine until after the proof window + remainingBlocks = rev.Revision.WindowEnd - node.Chain.Tip().Height + 1 + testutil.MineAndSync(t, node, types.VoidAddress, int(remainingBlocks)) + + assertContractStatus(t, node.Contracts, rev.Revision.ParentID, contracts.ContractStatusFailed) + assertContractMetrics(t, node.Store, 0, 0, types.ZeroCurrency, types.ZeroCurrency) + }) +} + +func TestV2ContractLifecycle(t *testing.T) { + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + + dir := t.TempDir() + log := zaptest.NewLogger(t) + + network, genesis := testutil.V2Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + result := make(chan error, 1) + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { + t.Fatal(err) + } else if err := <-result; err != nil { + t.Fatal(err) + } + + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) + + assertContractStatus := func(t *testing.T, contractID types.FileContractID, status contracts.V2ContractStatus) { + t.Helper() + + contract, err := node.Contracts.V2Contract(contractID) + if err != nil { + t.Fatal("failed to get contract:", err) + } else if contract.Status != status { + t.Fatalf("expected contract to be %v, got %v", status, contract.Status) } - waitForScan() + } - // confirm the revision - if err := node.MineBlocks(types.VoidAddress, 1); err != nil { + // tracks statuses between subtests + expectedStatuses := make(map[contracts.V2ContractStatus]uint64) + assertContractMetrics := func(t *testing.T, locked, risked types.Currency) { + t.Helper() + + m, err := node.Store.Metrics(time.Now()) + if err != nil { t.Fatal(err) + } else if m.Contracts.Active != expectedStatuses[contracts.V2ContractStatusActive] { + t.Fatalf("expected %v active contracts, got %v", expectedStatuses[contracts.V2ContractStatusActive], m.Contracts.Active) + } else if m.Contracts.Successful != expectedStatuses[contracts.V2ContractStatusSuccessful] { + t.Fatalf("expected %v successful contracts, got %v", expectedStatuses[contracts.V2ContractStatusSuccessful], m.Contracts.Successful) + } else if m.Contracts.Renewed != expectedStatuses[contracts.V2ContractStatusRenewed] { + t.Fatalf("expected %v renewed contracts, got %v", expectedStatuses[contracts.V2ContractStatusRenewed], m.Contracts.Renewed) + } else if m.Contracts.Finalized != expectedStatuses[contracts.V2ContractStatusFinalized] { + t.Fatalf("expected %v finalized contracts, got %v", expectedStatuses[contracts.V2ContractStatusFinalized], m.Contracts.Finalized) + } else if m.Contracts.Failed != expectedStatuses[contracts.V2ContractStatusFailed] { + t.Fatalf("expected %v failed contracts, got %v", expectedStatuses[contracts.V2ContractStatusFailed], m.Contracts.Failed) + } else if !m.Contracts.LockedCollateral.Equals(locked) { + t.Fatalf("expected %v locked collateral, got %v", locked, m.Contracts.LockedCollateral) + } else if !m.Contracts.RiskedCollateral.Equals(risked) { + t.Fatalf("expected %v risked collateral, got %v", risked, m.Contracts.RiskedCollateral) + } + } + + assertStorageMetrics := func(t *testing.T, contractSectors, physicalSectors uint64) { + t.Helper() + + m, err := node.Store.Metrics(time.Now()) + if err != nil { + t.Fatal("failed to get metrics:", err) + } else if m.Storage.ContractSectors != contractSectors { + t.Fatalf("expected %v contract sectors, got %v", contractSectors, m.Storage.ContractSectors) + } else if m.Storage.PhysicalSectors != physicalSectors { + t.Fatalf("expected %v physical sectors, got %v", physicalSectors, m.Storage.PhysicalSectors) + } + + vols, err := node.Volumes.Volumes() + if err != nil { + t.Fatal("failed to get volumes:", err) + } + var volumeSectors uint64 + for _, vol := range vols { + volumeSectors += vol.UsedSectors + } + if volumeSectors != physicalSectors { + t.Fatalf("expected %v physical sectors, got %v", physicalSectors, volumeSectors) } - waitForScan() + } + + t.Run("rebroadcast", func(t *testing.T) { + assertStorageMetrics(t, 0, 0) + + contractID, fc := formV2Contract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, types.Siacoins(10), types.Siacoins(20), 10, false) + assertContractStatus(t, contractID, contracts.V2ContractStatusPending) + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + + // mine a block to rebroadcast the formation set + testutil.MineAndSync(t, node, types.VoidAddress, 1) + assertContractStatus(t, contractID, contracts.V2ContractStatusPending) + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + + // mine another block to confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + expectedStatuses[contracts.V2ContractStatusActive]++ + assertContractMetrics(t, types.Siacoins(20), types.ZeroCurrency) + + // mine until the contract is successful + testutil.MineAndSync(t, node, types.VoidAddress, int(fc.ExpirationHeight-node.Chain.Tip().Height)+1) + assertContractStatus(t, contractID, contracts.V2ContractStatusSuccessful) + expectedStatuses[contracts.V2ContractStatusActive]-- + expectedStatuses[contracts.V2ContractStatusSuccessful]++ + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + }) - contract, err = c.Contract(rev.Revision.ParentID) + t.Run("successful empty contract", func(t *testing.T) { + assertStorageMetrics(t, 0, 0) + + contractID, fc := formV2Contract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, types.Siacoins(10), types.Siacoins(20), 10, true) + assertContractStatus(t, contractID, contracts.V2ContractStatusPending) + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + + // mine a block to confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + assertContractStatus(t, contractID, contracts.V2ContractStatusActive) + expectedStatuses[contracts.V2ContractStatusActive]++ + assertContractMetrics(t, types.Siacoins(20), types.ZeroCurrency) + + // mine until the contract is successful + testutil.MineAndSync(t, node, types.VoidAddress, int(fc.ExpirationHeight-node.Chain.Tip().Height)+1) + assertContractStatus(t, contractID, contracts.V2ContractStatusSuccessful) + expectedStatuses[contracts.V2ContractStatusActive]-- + expectedStatuses[contracts.V2ContractStatusSuccessful]++ + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + }) + + t.Run("storage proof", func(t *testing.T) { + assertStorageMetrics(t, 0, 0) + + contractID, fc := formV2Contract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, types.Siacoins(10), types.Siacoins(20), 10, true) + assertContractStatus(t, contractID, contracts.V2ContractStatusPending) + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + + // add a root to the contract + var sector [rhp2.SectorSize]byte + frand.Read(sector[:]) + root := rhp2.SectorRoot(§or) + roots := []types.Hash256{root} + + release, err := node.Volumes.Write(root, §or) if err != nil { t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be active") - } else if !contract.RevisionConfirmed { - t.Fatal("expected revision to be confirmed") } + defer release() + + fc.Filesize = rhp2.SectorSize + fc.FileMerkleRoot = rhp2.MetaRoot(roots) + fc.RevisionNumber++ + // transfer some funds from the renter to the host + cost, collateral := types.Siacoins(1), types.Siacoins(2) + fc.RenterOutput.Value = fc.RenterOutput.Value.Sub(cost) + fc.HostOutput.Value = fc.HostOutput.Value.Add(cost) + fc.MissedHostValue = fc.MissedHostValue.Sub(collateral) + sigHash := node.Chain.TipState().ContractSigHash(fc) + fc.HostSignature = hostKey.SignHash(sigHash) + fc.RenterSignature = renterKey.SignHash(sigHash) - // mine until the end proof window - remainingBlocks = rev.Revision.WindowEnd - node.TipState().Index.Height - 1 - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { + err = node.Contracts.ReviseV2Contract(contractID, fc, roots, contracts.Usage{ + StorageRevenue: cost, + RiskedCollateral: collateral, + }) + if err != nil { + t.Fatal(err) + } else if err := release(); err != nil { t.Fatal(err) } - waitForScan() + // metrics should not have been updated, contract is still pending + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + assertStorageMetrics(t, 1, 1) + + // mine to confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + expectedStatuses[contracts.V2ContractStatusActive]++ + assertContractMetrics(t, types.Siacoins(20), collateral) + assertStorageMetrics(t, 1, 1) + + // mine through the expiration height + testutil.MineAndSync(t, node, types.VoidAddress, int(fc.ExpirationHeight-node.Chain.Tip().Height)+1) + assertContractStatus(t, contractID, contracts.V2ContractStatusSuccessful) + expectedStatuses[contracts.V2ContractStatusActive]-- + expectedStatuses[contracts.V2ContractStatusSuccessful]++ + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + assertStorageMetrics(t, 0, 0) + }) + + t.Run("failed storage proof", func(t *testing.T) { + assertStorageMetrics(t, 0, 0) + + contractID, fc := formV2Contract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, types.Siacoins(10), types.Siacoins(20), 10, true) + assertContractStatus(t, contractID, contracts.V2ContractStatusPending) + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + + // add a root to the contract + var sector [rhp2.SectorSize]byte + frand.Read(sector[:256]) + root := frand.Entropy256() // random root + roots := []types.Hash256{root} - // check that the contract is still active - contract, err = c.Contract(rev.Revision.ParentID) + release, err := node.Volumes.Write(root, §or) if err != nil { t.Fatal(err) - } else if contract.Status != contracts.ContractStatusActive { - t.Fatal("expected contract to be successful") - } else if contract.ResolutionHeight != 0 { - t.Fatalf("expected resolution height %v, got %v", 0, contract.ResolutionHeight) - } else if m, err := node.Store().Metrics(time.Now()); err != nil { + } + defer release() + + fc.Filesize = rhp2.SectorSize + fc.FileMerkleRoot = rhp2.MetaRoot(roots) + fc.RevisionNumber++ + // transfer some funds from the renter to the host + cost, collateral := types.Siacoins(1), types.Siacoins(2) + fc.RenterOutput.Value = fc.RenterOutput.Value.Sub(cost) + fc.HostOutput.Value = fc.HostOutput.Value.Add(cost) + fc.MissedHostValue = fc.MissedHostValue.Sub(collateral) + sigHash := node.Chain.TipState().ContractSigHash(fc) + fc.HostSignature = hostKey.SignHash(sigHash) + fc.RenterSignature = renterKey.SignHash(sigHash) + + err = node.Contracts.ReviseV2Contract(contractID, fc, roots, contracts.Usage{ + StorageRevenue: cost, + RiskedCollateral: collateral, + }) + if err != nil { t.Fatal(err) - } else if m.Contracts.Active != 1 { - t.Fatal("expected 1 active contracts") - } else if m.Contracts.Successful != 0 { - t.Fatal("expected 0 successful contracts") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { + } else if err := release(); err != nil { t.Fatal(err) - } else if !m.Contracts.LockedCollateral.Equals(hostCollateral) { - t.Fatalf("expected %v locked collateral, got %v", types.ZeroCurrency, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.Equals(collateral) { - t.Fatalf("expected %v risked collateral, got %v", types.ZeroCurrency, m.Contracts.RiskedCollateral) } + // metrics should not have been updated, contract is still pending + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + assertStorageMetrics(t, 1, 1) - // mine until after the proof window - remainingBlocks = rev.Revision.WindowEnd - node.TipState().Index.Height + 1 - if err := node.MineBlocks(types.VoidAddress, int(remainingBlocks)); err != nil { + // mine to confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + expectedStatuses[contracts.V2ContractStatusActive]++ + assertContractMetrics(t, types.Siacoins(20), collateral) + assertStorageMetrics(t, 1, 1) + + // mine through the expiration height + testutil.MineAndSync(t, node, types.VoidAddress, int(fc.ExpirationHeight-node.Chain.Tip().Height)+1) + assertContractStatus(t, contractID, contracts.V2ContractStatusFailed) + expectedStatuses[contracts.V2ContractStatusActive]-- + expectedStatuses[contracts.V2ContractStatusFailed]++ + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + assertStorageMetrics(t, 0, 0) + }) + + t.Run("renewal", func(t *testing.T) { + assertStorageMetrics(t, 0, 0) + + contractID, fc := formV2Contract(t, node.Chain, node.Contracts, node.Wallet, node.Syncer, renterKey, hostKey, types.Siacoins(10), types.Siacoins(20), 10, true) + assertContractStatus(t, contractID, contracts.V2ContractStatusPending) + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + + // add a root to the contract + var sector [rhp2.SectorSize]byte + frand.Read(sector[:]) + root := rhp2.SectorRoot(§or) + roots := []types.Hash256{root} + + release, err := node.Volumes.Write(root, §or) + if err != nil { t.Fatal(err) } - time.Sleep(time.Second) // sync time + defer release() + + fc.Filesize = rhp2.SectorSize + fc.FileMerkleRoot = rhp2.MetaRoot(roots) + fc.RevisionNumber++ + // transfer some funds from the renter to the host + cost, collateral := types.Siacoins(1), types.Siacoins(2) + fc.RenterOutput.Value = fc.RenterOutput.Value.Sub(cost) + fc.HostOutput.Value = fc.HostOutput.Value.Add(cost) + fc.MissedHostValue = fc.MissedHostValue.Sub(collateral) + sigHash := node.Chain.TipState().ContractSigHash(fc) + fc.HostSignature = hostKey.SignHash(sigHash) + fc.RenterSignature = renterKey.SignHash(sigHash) - // check that the contract is now failed - contract, err = c.Contract(rev.Revision.ParentID) + err = node.Contracts.ReviseV2Contract(contractID, fc, roots, contracts.Usage{ + StorageRevenue: cost, + RiskedCollateral: collateral, + }) if err != nil { t.Fatal(err) - } else if contract.Status != contracts.ContractStatusFailed { - t.Fatalf("expected contract to be failed, got %q", contract.Status) - } else if contract.ResolutionHeight != 0 { - t.Fatalf("expected resolution height %v, got %v", 0, contract.ResolutionHeight) - } else if m, err := node.Store().Metrics(time.Now()); err != nil { + } else if err := release(); err != nil { + t.Fatal(err) + } + + // mine to confirm the contract + testutil.MineAndSync(t, node, types.VoidAddress, 1) + // ensure the metrics were updated + expectedStatuses[contracts.V2ContractStatusActive]++ + assertContractStatus(t, contractID, contracts.V2ContractStatusActive) + assertContractMetrics(t, types.Siacoins(20), collateral) + assertStorageMetrics(t, 1, 1) + + // renew the contract + com := node.Contracts + cm := node.Chain + + cs := cm.TipState() + final := fc + final.RevisionNumber = types.MaxRevisionNumber + final.FileMerkleRoot = types.Hash256{} + final.Filesize = 0 + final.HostSignature = types.Signature{} + final.RenterSignature = types.Signature{} + final.RevisionNumber = types.MaxRevisionNumber + + additionalCollateral := types.Siacoins(2) + renewal := types.V2FileContractRenewal{ + FinalRevision: final, + NewContract: types.V2FileContract{ + RevisionNumber: 0, + Filesize: fc.Filesize, + FileMerkleRoot: fc.FileMerkleRoot, + ProofHeight: final.ProofHeight + 10, + ExpirationHeight: final.ExpirationHeight + 10, + RenterOutput: final.RenterOutput, + HostOutput: types.SiacoinOutput{ + Address: final.HostOutput.Address, + Value: final.HostOutput.Value.Add(additionalCollateral), + }, + MissedHostValue: final.MissedHostValue.Add(additionalCollateral), + TotalCollateral: final.TotalCollateral.Add(additionalCollateral), + RenterPublicKey: renterKey.PublicKey(), + HostPublicKey: hostKey.PublicKey(), + }, + HostRollover: final.HostOutput.Value, + RenterRollover: final.RenterOutput.Value, + } + renewalSigHash := cs.RenewalSigHash(renewal) + renewal.HostSignature = hostKey.SignHash(renewalSigHash) + renewal.RenterSignature = renterKey.SignHash(renewalSigHash) + + fce, err := com.V2ContractElement(contractID) + if err != nil { t.Fatal(err) - } else if m.Contracts.Active != 0 { - t.Fatal("expected 0 active contracts") - } else if m.Contracts.Failed != 1 { - t.Fatal("expected 1 failed contract") - } else if m, err := node.Store().Metrics(time.Now()); err != nil { + } + + fundAmount := cs.V2FileContractTax(renewal.NewContract).Add(additionalCollateral) + setupTxn := types.V2Transaction{ + SiacoinOutputs: []types.SiacoinOutput{ + {Value: fundAmount, Address: fc.HostOutput.Address}, + }, + } + basis, toSign, err := node.Wallet.FundV2Transaction(&setupTxn, fundAmount, false) + if err != nil { + t.Fatal("failed to fund transaction:", err) + } + node.Wallet.SignV2Inputs(&setupTxn, toSign) + + renewalTxn := types.V2Transaction{ + SiacoinInputs: []types.V2SiacoinInput{ + { + Parent: setupTxn.EphemeralSiacoinOutput(0), + }, + }, + FileContractResolutions: []types.V2FileContractResolution{ + { + Parent: fce, + Resolution: &renewal, + }, + }, + } + node.Wallet.SignV2Inputs(&renewalTxn, []int{0}) + renewalTxnSet := contracts.V2FormationTransactionSet{ + Basis: basis, + TransactionSet: []types.V2Transaction{setupTxn, renewalTxn}, + } + if _, err := cm.AddV2PoolTransactions(renewalTxnSet.Basis, renewalTxnSet.TransactionSet); err != nil { + t.Fatal("failed to add renewal to pool:", err) + } + node.Syncer.BroadcastV2TransactionSet(renewalTxnSet.Basis, renewalTxnSet.TransactionSet) + + err = com.RenewV2Contract(renewalTxnSet, contracts.V2Usage{ + RiskedCollateral: renewal.NewContract.TotalCollateral.Sub(renewal.NewContract.MissedHostValue), + }) + if err != nil { t.Fatal(err) - } else if !m.Contracts.LockedCollateral.IsZero() { - t.Fatalf("expected %v locked collateral, got %v", types.ZeroCurrency, m.Contracts.LockedCollateral) - } else if !m.Contracts.RiskedCollateral.IsZero() { - t.Fatalf("expected %v risked collateral, got %v", types.ZeroCurrency, m.Contracts.RiskedCollateral) } + + renewalID := contractID.V2RenewalID() + + // metrics should not have changed + assertContractStatus(t, renewalID, contracts.V2ContractStatusPending) + assertContractStatus(t, contractID, contracts.V2ContractStatusActive) + assertContractMetrics(t, types.Siacoins(20), collateral) + assertStorageMetrics(t, 1, 1) + + // mine to confirm the renewal + testutil.MineAndSync(t, node, types.VoidAddress, 1) + // new contract pending -> active, old contract active -> renewed + expectedStatuses[contracts.V2ContractStatusRenewed]++ + expectedStatuses[contracts.V2ContractStatusActive] += 0 // no change + assertContractStatus(t, contractID, contracts.V2ContractStatusRenewed) + assertContractStatus(t, renewalID, contracts.V2ContractStatusActive) + // metrics should reflect the new contract + assertContractMetrics(t, types.Siacoins(22), collateral) + assertStorageMetrics(t, 1, 1) + // mine until the renewed contract is successful and the sectors have + // been pruned + testutil.MineAndSync(t, node, types.VoidAddress, int(renewal.NewContract.ExpirationHeight-node.Chain.Tip().Height)+1) + expectedStatuses[contracts.V2ContractStatusActive]-- + expectedStatuses[contracts.V2ContractStatusSuccessful]++ + assertContractStatus(t, renewalID, contracts.V2ContractStatusSuccessful) + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + assertStorageMetrics(t, 0, 0) + }) + + t.Run("reject", func(t *testing.T) { + cm := node.Chain + c := node.Contracts + w := node.Wallet + + renterFunds, hostFunds := types.Siacoins(10), types.Siacoins(20) + duration := uint64(10) + cs := cm.TipState() + fc := types.V2FileContract{ + RevisionNumber: 0, + Filesize: 0, + FileMerkleRoot: types.Hash256{}, + ProofHeight: cs.Index.Height + duration, + ExpirationHeight: cs.Index.Height + duration + 10, + RenterOutput: types.SiacoinOutput{ + Value: renterFunds, + Address: w.Address(), + }, + HostOutput: types.SiacoinOutput{ + Value: hostFunds, + Address: w.Address(), + }, + MissedHostValue: hostFunds, + TotalCollateral: hostFunds, + RenterPublicKey: renterKey.PublicKey(), + HostPublicKey: hostKey.PublicKey(), + } + fundAmount := cs.V2FileContractTax(fc).Add(hostFunds).Add(renterFunds) + sigHash := cs.ContractSigHash(fc) + fc.HostSignature = hostKey.SignHash(sigHash) + fc.RenterSignature = renterKey.SignHash(sigHash) + + txn := types.V2Transaction{ + FileContracts: []types.V2FileContract{fc}, + } + + basis, toSign, err := w.FundV2Transaction(&txn, fundAmount, false) + if err != nil { + t.Fatal("failed to fund transaction:", err) + } + w.SignV2Inputs(&txn, toSign) + formationSet := contracts.V2FormationTransactionSet{ + TransactionSet: []types.V2Transaction{txn}, + Basis: basis, + } + contractID := txn.V2FileContractID(txn.ID(), 0) + // corrupt the formation set to trigger a rejection + formationSet.TransactionSet[len(formationSet.TransactionSet)-1].SiacoinInputs[0].SatisfiedPolicy.Signatures[0] = types.Signature{} + if err := c.AddV2Contract(formationSet, contracts.V2Usage{}); err != nil { + t.Fatal("failed to add contract:", err) + } + + expectedStatuses[contracts.V2ContractStatusPending]++ + assertContractStatus(t, contractID, contracts.V2ContractStatusPending) + // metrics should not have changed + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + assertStorageMetrics(t, 0, 0) + + // mine until the contract is rejected + testutil.MineAndSync(t, node, types.VoidAddress, 20) + expectedStatuses[contracts.V2ContractStatusRejected]++ + assertContractStatus(t, contractID, contracts.V2ContractStatusRejected) + // metrics should not have changed + assertContractMetrics(t, types.ZeroCurrency, types.ZeroCurrency) + assertStorageMetrics(t, 0, 0) }) } func TestSectorRoots(t *testing.T) { + log := zaptest.NewLogger(t) + const sectors = 256 - hostKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) - renterKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() dir := t.TempDir() - log := zaptest.NewLogger(t) - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("sqlite")) - if err != nil { - t.Fatal(err) - } - defer db.Close() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) - node, err := test.NewWallet(hostKey, t.TempDir(), log.Named("wallet")) - if err != nil { + result := make(chan error, 1) + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(dir, "data.dat"), 10, result); err != nil { t.Fatal(err) - } - defer node.Close() - - webhookReporter, err := webhooks.NewManager(node.Store(), log.Named("webhooks")) - if err != nil { + } else if err := <-result; err != nil { t.Fatal(err) } - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - s, err := storage.NewVolumeManager(db, am, node.ChainManager(), log.Named("storage"), sectorCacheSize) - if err != nil { - t.Fatal(err) - } - defer s.Close() + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) // create a fake volume so disk space is not used - id, err := db.AddVolume("test", false) + id, err := node.Store.AddVolume("test", false) if err != nil { t.Fatal(err) - } else if err := db.GrowVolume(id, sectors); err != nil { - t.Fatal(err) - } else if err := db.SetAvailable(id, true); err != nil { + } else if err := node.Store.GrowVolume(id, sectors); err != nil { t.Fatal(err) - } - - c, err := contracts.NewManager(db, am, s, node.ChainManager(), node.TPool(), node, log.Named("contracts")) - if err != nil { + } else if err := node.Store.SetAvailable(id, true); err != nil { t.Fatal(err) } - defer c.Close() contractUnlockConditions := types.UnlockConditions{ PublicKeys: []types.UnlockKey{ @@ -1067,7 +1229,7 @@ func TestSectorRoots(t *testing.T) { }, } - if err := c.AddContract(rev, []types.Transaction{}, types.ZeroCurrency, contracts.Usage{}); err != nil { + if err := node.Contracts.AddContract(rev, []types.Transaction{}, types.ZeroCurrency, contracts.Usage{}); err != nil { t.Fatal(err) } @@ -1075,19 +1237,24 @@ func TestSectorRoots(t *testing.T) { for i := 0; i < sectors; i++ { root, err := func() (types.Hash256, error) { root := frand.Entropy256() - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) + release, err := node.Store.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) if err != nil { return types.Hash256{}, fmt.Errorf("failed to store sector: %w", err) } defer release() - // use the database method directly to avoid the sector cache - err = db.ReviseContract(rev, roots, contracts.Usage{}, []contracts.SectorChange{ - {Action: contracts.SectorActionAppend, Root: root}, - }) + updater, err := node.Contracts.ReviseContract(rev.Revision.ParentID) if err != nil { return types.Hash256{}, fmt.Errorf("failed to revise contract: %w", err) } + defer updater.Close() + + updater.AppendSector(root) + + if err := updater.Commit(rev, contracts.Usage{}); err != nil { + return types.Hash256{}, fmt.Errorf("failed to commit revision: %w", err) + } + return root, nil }() if err != nil { @@ -1096,8 +1263,8 @@ func TestSectorRoots(t *testing.T) { roots = append(roots, root) } - // check that the sector roots are correct - check, err := c.SectorRoots(rev.Revision.ParentID) + // check that the cached sector roots are correct + check := node.Contracts.SectorRoots(rev.Revision.ParentID) if err != nil { t.Fatal(err) } else if len(check) != len(roots) { @@ -1109,11 +1276,12 @@ func TestSectorRoots(t *testing.T) { } } - // check that the cached sector roots are correct - check, err = c.SectorRoots(rev.Revision.ParentID) + dbRoots, err := node.Store.SectorRoots() if err != nil { t.Fatal(err) - } else if len(check) != len(roots) { + } + check = dbRoots[rev.Revision.ParentID] + if len(check) != len(roots) { t.Fatalf("expected %v sector roots, got %v", len(roots), len(check)) } for i := range check { @@ -1122,3 +1290,33 @@ func TestSectorRoots(t *testing.T) { } } } + +func TestChainIndexElementsDeepReorg(t *testing.T) { + log := zaptest.NewLogger(t) + network, genesis := testutil.V2Network() + n1 := testutil.NewConsensusNode(t, network, genesis, log.Named("node1")) + + h1 := testutil.NewHostNode(t, types.GeneratePrivateKey(), network, genesis, log.Named("host")) + + if _, err := h1.Syncer.Connect(context.Background(), n1.Syncer.Addr()); err != nil { + t.Fatal(err) + } + + mineAndSync := func(t *testing.T, cn *testutil.ConsensusNode, addr types.Address, n int) { + t.Helper() + + for i := 0; i < n; i++ { + testutil.MineBlocks(t, cn, addr, 1) + testutil.WaitForSync(t, cn.Chain, h1.Indexer) + } + } + + mineAndSync(t, n1, types.VoidAddress, 145) + n2 := testutil.NewConsensusNode(t, network, genesis, log.Named("node2")) + testutil.MineBlocks(t, n2, types.VoidAddress, 200) + + if _, err := h1.Syncer.Connect(context.Background(), n2.Syncer.Addr()); err != nil { + t.Fatal(err) + } + testutil.WaitForSync(t, n2.Chain, h1.Indexer) +} diff --git a/host/contracts/options.go b/host/contracts/options.go new file mode 100644 index 00000000..b769804c --- /dev/null +++ b/host/contracts/options.go @@ -0,0 +1,36 @@ +package contracts + +import "go.uber.org/zap" + +// A ManagerOption sets options on a Manager. +type ManagerOption func(*Manager) + +// WithRejectAfter sets the number of blocks before a contract will be considered +// rejected +func WithRejectAfter(rejectBuffer uint64) ManagerOption { + return func(m *Manager) { + m.rejectBuffer = rejectBuffer + } +} + +// WithRevisionSubmissionBuffer sets the number of blocks before the proof window +// to broadcast the final revision and prevent modification of the contract. +func WithRevisionSubmissionBuffer(revisionSubmissionBuffer uint64) ManagerOption { + return func(m *Manager) { + m.revisionSubmissionBuffer = revisionSubmissionBuffer + } +} + +// WithAlerter sets the alerts for the Manager. +func WithAlerter(a Alerts) ManagerOption { + return func(m *Manager) { + m.alerts = a + } +} + +// WithLog sets the logger for the Manager. +func WithLog(l *zap.Logger) ManagerOption { + return func(m *Manager) { + m.log = l + } +} diff --git a/host/contracts/persist.go b/host/contracts/persist.go index e8bda305..ba73ae18 100644 --- a/host/contracts/persist.go +++ b/host/contracts/persist.go @@ -2,57 +2,57 @@ package contracts import ( "go.sia.tech/core/types" - "go.sia.tech/siad/modules" ) type ( - // UpdateStateTransaction atomically updates the contract manager's state. - UpdateStateTransaction interface { - ContractRelevant(types.FileContractID) (bool, error) - - ConfirmFormation(types.FileContractID) error - ConfirmRevision(types.FileContractRevision) error - ConfirmResolution(id types.FileContractID, height uint64) error - - RevertFormation(types.FileContractID) error - RevertRevision(types.FileContractID) error - RevertResolution(types.FileContractID) error - } - // A ContractStore stores contracts for the host. It also updates stored // contracts and determines which contracts need lifecycle actions. ContractStore interface { - LastContractChange() (id modules.ConsensusChangeID, err error) + // ContractActions returns the lifecycle actions for the contract at the + // given index. + ContractActions(index types.ChainIndex, revisionBroadcastHeight uint64) (LifecycleActions, error) + // ContractChainIndexElement returns the chain index element for the given height. + ContractChainIndexElement(types.ChainIndex) (types.ChainIndexElement, error) + + // SectorRoots returns the sector roots for a contract. If limit is 0, all roots + // are returned. + SectorRoots() (map[types.FileContractID][]types.Hash256, error) + // Contracts returns a paginated list of contracts sorted by expiration // asc. Contracts(ContractFilter) ([]Contract, int, error) // Contract returns the contract with the given ID. Contract(types.FileContractID) (Contract, error) - // ContractFormationSet returns the formation transaction set for the - // contract with the given ID. - ContractFormationSet(types.FileContractID) ([]types.Transaction, error) - // ExpireContract is used to mark a contract as complete. It should only - // be used on active or pending contracts. - ExpireContract(types.FileContractID, ContractStatus) error - // Add stores the provided contract, should error if the contract + // AddContract stores the provided contract, should error if the contract // already exists in the store. AddContract(revision SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, initialUsage Usage, negotationHeight uint64) error // RenewContract renews a contract. It is expected that the existing // contract will be cleared. RenewContract(renewal SignedRevision, existing SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, clearingUsage, initialUsage Usage, negotationHeight uint64) error - // SectorRoots returns the sector roots for a contract. If limit is 0, all roots - // are returned. - SectorRoots(id types.FileContractID) ([]types.Hash256, error) - // ContractAction calls contractFn on every contract in the store that - // needs a lifecycle action performed. - ContractAction(height uint64, contractFn func(types.FileContractID, uint64, string)) error // ReviseContract atomically updates a contract and its associated // sector roots. ReviseContract(revision SignedRevision, oldRoots []types.Hash256, usage Usage, sectorChanges []SectorChange) error - // UpdateContractState atomically updates the contract manager's state. - UpdateContractState(modules.ConsensusChangeID, uint64, func(UpdateStateTransaction) error) error + // ExpireContractSectors removes sector roots for any contracts that are - // past their proof window. + // rejected or past their proof window. ExpireContractSectors(height uint64) error + + // V2ContractElement returns the latest v2 state element with the given ID. + V2ContractElement(types.FileContractID) (types.V2FileContractElement, error) + // V2Contract returns the v2 contract with the given ID. + V2Contract(types.FileContractID) (V2Contract, error) + // AddV2Contract stores the provided contract, should error if the contract + // already exists in the store. + AddV2Contract(V2Contract, V2FormationTransactionSet) error + // RenewV2Contract renews a contract. It is expected that the existing + // contract will be cleared. + RenewV2Contract(renewal V2Contract, renewalSet V2FormationTransactionSet, renewedID types.FileContractID, finalRevision types.V2FileContract) error + // ReviseV2Contract atomically updates a contract and its associated + // sector roots. + ReviseV2Contract(id types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage Usage) error + + // ExpireV2ContractSectors removes sector roots for any v2 contracts that are + // rejected or past their proof window. + ExpireV2ContractSectors(height uint64) error } ) diff --git a/host/contracts/revenue_test.go b/host/contracts/revenue_test.go index b8c28a1e..5ada7032 100644 --- a/host/contracts/revenue_test.go +++ b/host/contracts/revenue_test.go @@ -1,8 +1,9 @@ +//go:build ignore + package contracts_test import ( "context" - "fmt" "reflect" "testing" "time" @@ -23,12 +24,12 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func checkRevenueConsistency(s *sqlite.Store, potential, earned metrics.Revenue) error { +func assertRevenue(t *testing.T, s *sqlite.Store, potential, earned metrics.Revenue) { time.Sleep(time.Second) // commit time m, err := s.Metrics(time.Now()) if err != nil { - return fmt.Errorf("failed to get metrics: %v", err) + t.Fatalf("failed to get revenue metrics: %v", err) } actualPotentialValue := reflect.ValueOf(m.Revenue.Potential) @@ -39,7 +40,7 @@ func checkRevenueConsistency(s *sqlite.Store, potential, earned metrics.Revenue) av, ev := fa.Interface().(types.Currency), fe.Interface().(types.Currency) if !av.Equals(ev) { - return fmt.Errorf("potential revenue field %q does not match. expected %d, got %d", name, ev, av) + t.Fatalf("potential revenue field %q does not match. expected %d, got %d", name, ev, av) } } @@ -51,11 +52,9 @@ func checkRevenueConsistency(s *sqlite.Store, potential, earned metrics.Revenue) av, ev := fa.Interface().(types.Currency), fe.Interface().(types.Currency) if !av.Equals(ev) { - return fmt.Errorf("earned revenue field %q does not match. expected %d, got %d", name, ev, av) + t.Fatalf("earned revenue field %q does not match. expected %d, got %d", name, ev, av) } } - - return nil } func TestRevenueMetrics(t *testing.T) { diff --git a/host/contracts/update.go b/host/contracts/update.go new file mode 100644 index 00000000..49f5d9ba --- /dev/null +++ b/host/contracts/update.go @@ -0,0 +1,659 @@ +package contracts + +import ( + "fmt" + + "go.sia.tech/core/consensus" + rhp2 "go.sia.tech/core/rhp/v2" + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.uber.org/zap" +) + +// chainIndexBuffer is the number of chain index elements to store and update +// in the database. Older elements will be deleted. The number of elements +// corresponds to the default proof window. +// +// This is less complex than storing an element per contract or +// tracking each contract's proof window. +const chainIndexBuffer = 144 + +type ( + stateUpdater interface { + ForEachFileContractElement(func(types.FileContractElement, bool, *types.FileContractElement, bool, bool)) + ForEachV2FileContractElement(func(types.V2FileContractElement, bool, *types.V2FileContractElement, types.V2FileContractResolutionType)) + } + + // LifecycleActions contains the actions that need to be taken to maintain + // the lifecycle of active contracts. + LifecycleActions struct { + RebroadcastFormation [][]types.Transaction + BroadcastRevision []SignedRevision + BroadcastProof []SignedRevision + + // V2 actions + RebroadcastV2Formation []V2FormationTransactionSet + BroadcastV2Revision []types.V2FileContractRevision + BroadcastV2Proof []types.V2FileContractElement + BroadcastV2Expiration []types.V2FileContractElement + } + + // StateChanges contains the changes to the state of contracts on the + // blockchain + StateChanges struct { + Confirmed []types.FileContractElement + Revised []types.FileContractElement + Successful []types.FileContractID + Failed []types.FileContractID + + // V2 changes + ConfirmedV2 []types.V2FileContractElement + RevisedV2 []types.V2FileContractElement + SuccessfulV2 []types.FileContractID + FinalizedV2 []types.FileContractID + RenewedV2 []types.FileContractID + FailedV2 []types.FileContractID + } + + // An UpdateStateTx atomically updates the state of contracts in the contract + // store. + UpdateStateTx interface { + // ContractStateElements returns all state elements from the contract + // store + ContractStateElements() ([]types.StateElement, error) + // UpdateContractStateElements updates the state elements in the host + // contract store + UpdateContractStateElements([]types.StateElement) error + // ContractRelevant returns whether the contract with the provided id is + // relevant to the host + ContractRelevant(id types.FileContractID) (bool, error) + // V2ContractRelevant returns whether the v2 contract with the + // provided id is relevant to the host + V2ContractRelevant(id types.FileContractID) (bool, error) + // ApplyContracts applies relevant contract changes to the contract + // store + ApplyContracts(types.ChainIndex, StateChanges) error + // RevertContracts reverts relevant contract changes from the contract + // store + RevertContracts(types.ChainIndex, StateChanges) error + // RejectContracts sets the status of any v1 and v2 contracts with a + // negotiation height before the provided height and that have not + // been confirmed to rejected + RejectContracts(height uint64) (v1, v2 []types.FileContractID, err error) + + // ContractChainIndexElements returns all chain index elements from the + // contract store + ContractChainIndexElements() (elements []types.ChainIndexElement, err error) + // ApplyContractChainIndexElements adds or updates the merkle proof of + // chain index state elements + ApplyContractChainIndexElements(elements []types.ChainIndexElement) error + // RevertContractChainIndexElements removes chain index state elements + // that were reverted + RevertContractChainIndexElement(types.ChainIndex) error + // DeleteExpiredContractChainIndexElements deletes chain index state + // elements that are no long necessary + DeleteExpiredContractChainIndexElements(height uint64) error + } +) + +func (cm *Manager) buildStorageProof(revision types.FileContractRevision, index uint64, log *zap.Logger) (types.StorageProof, error) { + if revision.Filesize == 0 { + return types.StorageProof{ + ParentID: revision.ParentID, + }, nil + } + + sectorIndex := index / rhp2.LeavesPerSector + segmentIndex := index % rhp2.LeavesPerSector + + roots := cm.getSectorRoots(revision.ParentID) + contractRoot := rhp2.MetaRoot(roots) + if contractRoot != revision.FileMerkleRoot { + log.Error("unexpected contract merkle root", zap.Stringer("expectedRoot", revision.FileMerkleRoot), zap.Stringer("actualRoot", contractRoot)) + return types.StorageProof{}, fmt.Errorf("merkle root mismatch") + } else if uint64(len(roots)) < sectorIndex { + log.Error("unexpected proof index", zap.Uint64("sectorIndex", sectorIndex), zap.Uint64("segmentIndex", segmentIndex), zap.Int("rootsLength", len(roots))) + return types.StorageProof{}, fmt.Errorf("invalid root index") + } + + sectorRoot := roots[sectorIndex] + sector, err := cm.storage.Read(sectorRoot) + if err != nil { + log.Error("failed to read sector data", zap.Error(err), zap.Stringer("sectorRoot", sectorRoot)) + return types.StorageProof{}, fmt.Errorf("failed to read sector data") + } else if rhp2.SectorRoot(sector) != sectorRoot { + log.Error("sector data corrupt", zap.Stringer("expectedRoot", sectorRoot), zap.Stringer("actualRoot", rhp2.SectorRoot(sector))) + return types.StorageProof{}, fmt.Errorf("invalid sector root") + } + segmentProof := rhp2.ConvertProofOrdering(rhp2.BuildProof(sector, segmentIndex, segmentIndex+1, nil), segmentIndex) + sectorProof := rhp2.ConvertProofOrdering(rhp2.BuildSectorRangeProof(roots, sectorIndex, sectorIndex+1), sectorIndex) + sp := types.StorageProof{ + ParentID: revision.ParentID, + Proof: append(segmentProof, sectorProof...), + } + copy(sp.Leaf[:], sector[segmentIndex*rhp2.LeafSize:]) + return sp, nil +} + +func (cm *Manager) buildV2StorageProof(cs consensus.State, fce types.V2FileContractElement, pi types.ChainIndexElement, log *zap.Logger) (types.V2StorageProof, error) { + if fce.V2FileContract.Filesize == 0 { + return types.V2StorageProof{ + ProofIndex: pi, + }, nil + } + + revision := fce.V2FileContract + contractID := types.FileContractID(fce.ID) + + leafIndex := cs.StorageProofLeafIndex(fce.V2FileContract.Filesize, types.BlockID(pi.ID), contractID) + sectorIndex := leafIndex / rhp2.LeavesPerSector + segmentIndex := leafIndex % rhp2.LeavesPerSector + + roots := cm.getSectorRoots(contractID) + contractRoot := rhp2.MetaRoot(roots) + if contractRoot != revision.FileMerkleRoot { + log.Error("unexpected contract root", zap.Stringer("expectedRoot", revision.FileMerkleRoot), zap.Stringer("actualRoot", contractRoot)) + return types.V2StorageProof{}, fmt.Errorf("merkle root mismatch") + } else if uint64(len(roots)) < sectorIndex { + log.Error("unexpected root index", zap.Uint64("sectorIndex", sectorIndex), zap.Uint64("segmentIndex", segmentIndex), zap.Int("rootsLength", len(roots))) + return types.V2StorageProof{}, fmt.Errorf("invalid root index") + } + + sectorRoot := roots[sectorIndex] + sector, err := cm.storage.Read(sectorRoot) + if err != nil { + log.Error("failed to read sector data", zap.Error(err), zap.Stringer("sectorRoot", sectorRoot)) + return types.V2StorageProof{}, fmt.Errorf("failed to read sector data") + } else if rhp2.SectorRoot(sector) != sectorRoot { + log.Error("sector data corrupt", zap.Stringer("expectedRoot", sectorRoot), zap.Stringer("actualRoot", rhp2.SectorRoot(sector))) + return types.V2StorageProof{}, fmt.Errorf("invalid sector root") + } + segmentProof := rhp2.ConvertProofOrdering(rhp2.BuildProof(sector, segmentIndex, segmentIndex+1, nil), segmentIndex) + sectorProof := rhp2.ConvertProofOrdering(rhp2.BuildSectorRangeProof(roots, sectorIndex, sectorIndex+1), sectorIndex) + sp := types.V2StorageProof{ + ProofIndex: pi, + Proof: append(segmentProof, sectorProof...), + } + copy(sp.Leaf[:], sector[segmentIndex*rhp2.LeafSize:]) + return sp, nil +} + +// ProcessActions processes additional lifecycle actions after a new block is +// added to the index. +func (cm *Manager) ProcessActions(index types.ChainIndex) error { + log := cm.log.Named("lifecycle").With(zap.Stringer("index", index)) + + revisionBroadcastHeight := index.Height + cm.revisionSubmissionBuffer + actions, err := cm.store.ContractActions(index, revisionBroadcastHeight) + if err != nil { + return fmt.Errorf("failed to get contract actions: %w", err) + } + + for _, formationSet := range actions.RebroadcastFormation { + if len(formationSet) == 0 { + continue + } else if _, err := cm.chain.AddPoolTransactions(formationSet); err != nil { + log.Error("failed to add formation transaction to pool", zap.Error(err)) + continue + } + cm.syncer.BroadcastTransactionSet(formationSet) + log.Debug("rebroadcast formation transaction", zap.String("transactionID", formationSet[len(formationSet)-1].ID().String())) + } + + for _, revision := range actions.BroadcastRevision { + log := log.Named("broadcastRevision").With(zap.Stringer("contractID", revision.Revision.ParentID), zap.Uint64("windowStart", revision.Revision.WindowStart), zap.Uint64("revisionNumber", revision.Revision.RevisionNumber)) + revisionTxn := types.Transaction{ + FileContractRevisions: []types.FileContractRevision{revision.Revision}, + Signatures: []types.TransactionSignature{ + { + ParentID: types.Hash256(revision.Revision.ParentID), + CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, + Signature: revision.RenterSignature[:], + }, + { + ParentID: types.Hash256(revision.Revision.ParentID), + CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, + Signature: revision.HostSignature[:], + PublicKeyIndex: 1, + }, + }, + } + + fee := cm.chain.RecommendedFee().Mul64(1000) + revisionTxn.MinerFees = append(revisionTxn.MinerFees, fee) + toSign, err := cm.wallet.FundTransaction(&revisionTxn, fee, true) + if err != nil { + log.Error("failed to fund revision transaction", zap.Error(err)) + continue + } + cm.wallet.SignTransaction(&revisionTxn, toSign, types.CoveredFields{WholeTransaction: true}) + revisionTxnSet := append(cm.chain.UnconfirmedParents(revisionTxn), revisionTxn) + if _, err := cm.chain.AddPoolTransactions(revisionTxnSet); err != nil { + cm.wallet.ReleaseInputs(revisionTxnSet, nil) + log.Error("failed to add revision transaction to pool", zap.Error(err)) + continue + } + cm.syncer.BroadcastTransactionSet(revisionTxnSet) + log.Debug("broadcast revision transaction", zap.String("transactionID", revisionTxn.ID().String())) + } + + cs := cm.chain.TipState() + for _, revision := range actions.BroadcastProof { + log := log.Named("proof").With(zap.Stringer("contractID", revision.Revision.ParentID)) + validPayout, missedPayout := revision.Revision.ValidHostPayout(), revision.Revision.MissedHostPayout() + if missedPayout.Cmp(validPayout) >= 0 { + log.Debug("skipping storage proof, no benefit to host", zap.String("validPayout", validPayout.ExactString()), zap.String("missedPayout", missedPayout.ExactString())) + continue + } + + proofIndex, ok := cm.chain.BestIndex(revision.Revision.WindowStart - 1) + if !ok { + log.Error("proof index not found", zap.Uint64("windowStart", revision.Revision.WindowStart)) + continue + } + + leafIndex := cs.StorageProofLeafIndex(revision.Revision.Filesize, proofIndex.ID, revision.Revision.ParentID) + sp, err := cm.buildStorageProof(revision.Revision, leafIndex, log) + if err != nil { + log.Error("failed to build storage proof", zap.Error(err)) + continue + } + + fee := cm.chain.RecommendedFee().Mul64(2000) + resolutionTxnSet := []types.Transaction{ + { + // intermediate funding transaction is required by v1 because + // transactions with storage proofs cannot have change outputs + SiacoinOutputs: []types.SiacoinOutput{ + {Address: cm.wallet.Address(), Value: fee}, + }, + }, + { + MinerFees: []types.Currency{fee}, + StorageProofs: []types.StorageProof{sp}, + }, + } + + intermediateToSign, err := cm.wallet.FundTransaction(&resolutionTxnSet[0], fee, true) + if err != nil { + log.Error("failed to fund resolution transaction", zap.Error(err)) + continue + } + cm.wallet.SignTransaction(&resolutionTxnSet[0], intermediateToSign, types.CoveredFields{WholeTransaction: true}) + resolutionTxnSet[1].SiacoinInputs = append(resolutionTxnSet[1].SiacoinInputs, types.SiacoinInput{ + ParentID: resolutionTxnSet[0].SiacoinOutputID(0), + UnlockConditions: cm.wallet.UnlockConditions(), + }) + proofToSign := []types.Hash256{types.Hash256(resolutionTxnSet[1].SiacoinInputs[0].ParentID)} + cm.wallet.SignTransaction(&resolutionTxnSet[1], proofToSign, types.CoveredFields{WholeTransaction: true}) + resolutionTxnSet = append(cm.chain.UnconfirmedParents(resolutionTxnSet[0]), resolutionTxnSet...) + if _, err := cm.chain.AddPoolTransactions(resolutionTxnSet); err != nil { + cm.wallet.ReleaseInputs(resolutionTxnSet, nil) + log.Error("failed to add resolution transaction to pool", zap.Error(err)) + continue + } + cm.syncer.BroadcastTransactionSet(resolutionTxnSet) + log.Debug("broadcast transaction", zap.String("transactionID", resolutionTxnSet[1].ID().String())) + } + + for _, formationSet := range actions.RebroadcastV2Formation { + if len(formationSet.TransactionSet) == 0 { + continue + } + formationTxn := formationSet.TransactionSet[len(formationSet.TransactionSet)-1] + if len(formationTxn.FileContracts) == 0 { + continue + } + + contractID := formationTxn.V2FileContractID(formationTxn.ID(), 0) + log := log.Named("v2 formation").With(zap.Stringer("basis", formationSet.Basis), zap.Stringer("contractID", contractID)) + + if _, err := cm.chain.AddV2PoolTransactions(formationSet.Basis, formationSet.TransactionSet); err != nil { + log.Error("failed to add formation transaction to pool", zap.Error(err)) + continue + } + cm.syncer.BroadcastV2TransactionSet(formationSet.Basis, formationSet.TransactionSet) + log.Debug("broadcast transaction", zap.String("transactionID", formationSet.TransactionSet[len(formationSet.TransactionSet)-1].ID().String())) + } + + for _, fcr := range actions.BroadcastV2Revision { + log := log.Named("v2 revision").With(zap.Stringer("contractID", fcr.Parent.ID)) + + fee := cm.chain.RecommendedFee().Mul64(1000) + revisionTxn := types.V2Transaction{ + MinerFee: fee, + FileContractRevisions: []types.V2FileContractRevision{fcr}, + } + basis, toSign, err := cm.wallet.FundV2Transaction(&revisionTxn, fee, false) // TODO: true + if err != nil { + log.Error("failed to fund transaction", zap.Error(err)) + continue + } + cm.wallet.SignV2Inputs(&revisionTxn, toSign) + + revisionTxnSet := []types.V2Transaction{revisionTxn} + if _, err := cm.chain.AddV2PoolTransactions(basis, revisionTxnSet); err != nil { + log.Error("failed to add transaction set to pool", zap.Error(err)) + continue + } + cm.syncer.BroadcastV2TransactionSet(basis, revisionTxnSet) + log.Debug("broadcast transaction", zap.Stringer("transactionID", revisionTxn.ID())) + } + + for _, fce := range actions.BroadcastV2Proof { + log := log.Named("v2 proof").With(zap.Stringer("contractID", fce.ID)) + proofIndex, ok := cm.chain.BestIndex(fce.V2FileContract.ProofHeight) + if !ok { + log.Error("proof index not found", zap.Uint64("proofHeight", fce.V2FileContract.ProofHeight)) + continue + } + proofElement, err := cm.store.ContractChainIndexElement(proofIndex) + if err != nil { + log.Error("failed to get proof index element", zap.Stringer("proofIndex", proofIndex), zap.Error(err)) + continue + } + + sp, err := cm.buildV2StorageProof(cs, fce, proofElement, log.Named("proof")) + if err != nil { + log.Error("failed to build storage proof", zap.Error(err)) + continue + } + + resolution := types.V2FileContractResolution{ + Parent: fce, + Resolution: &sp, + } + + fee := cm.chain.RecommendedFee().Mul64(2000) + setupTxn := types.V2Transaction{ + SiacoinOutputs: []types.SiacoinOutput{ + {Address: cm.wallet.Address(), Value: fee}, + }, + } + basis, toSign, err := cm.wallet.FundV2Transaction(&setupTxn, fee, false) // TODO: true + if err != nil { + log.Error("failed to fund resolution transaction", zap.Error(err)) + continue + } + cm.wallet.SignV2Inputs(&setupTxn, toSign) + resolutionTxn := types.V2Transaction{ + MinerFee: fee, + SiacoinInputs: []types.V2SiacoinInput{{Parent: setupTxn.EphemeralSiacoinOutput(0)}}, + FileContractResolutions: []types.V2FileContractResolution{resolution}, + } + cm.wallet.SignV2Inputs(&resolutionTxn, []int{0}) + resolutionTxnSet := []types.V2Transaction{setupTxn, resolutionTxn} + if _, err := cm.chain.AddV2PoolTransactions(basis, resolutionTxnSet); err != nil { + log.Error("failed to add resolution transaction to pool", zap.Error(err)) + continue + } + cm.syncer.BroadcastV2TransactionSet(basis, resolutionTxnSet) + log.Debug("broadcast transaction", zap.String("transactionID", resolutionTxn.ID().String())) + } + + for _, fce := range actions.BroadcastV2Expiration { + log := log.Named("v2 expiration").With(zap.Stringer("contractID", fce.ID)) + + fee := cm.chain.RecommendedFee().Mul64(1000) + setupTxn := types.V2Transaction{ + SiacoinOutputs: []types.SiacoinOutput{ + {Address: cm.wallet.Address(), Value: fee}, + }, + } + basis, toSign, err := cm.wallet.FundV2Transaction(&setupTxn, fee, false) // TODO: true + if err != nil { + log.Error("failed to fund resolution transaction", zap.Error(err)) + continue + } + cm.wallet.SignV2Inputs(&setupTxn, toSign) + resolutionTxn := types.V2Transaction{ + MinerFee: fee, + SiacoinInputs: []types.V2SiacoinInput{ + { + Parent: setupTxn.EphemeralSiacoinOutput(0), + }, + }, + FileContractResolutions: []types.V2FileContractResolution{ + { + Parent: fce, + Resolution: &types.V2FileContractExpiration{}, + }, + }, + } + cm.wallet.SignV2Inputs(&resolutionTxn, []int{0}) + + resolutionTxnSet := []types.V2Transaction{setupTxn, resolutionTxn} + if _, err := cm.chain.AddV2PoolTransactions(basis, resolutionTxnSet); err != nil { + cm.wallet.ReleaseInputs(nil, resolutionTxnSet) + log.Error("failed to add resolution transaction to pool", zap.Error(err)) + continue + } + cm.syncer.BroadcastV2TransactionSet(basis, resolutionTxnSet) + log.Debug("broadcast transaction", zap.String("transactionID", resolutionTxn.ID().String())) + } + + if err := cm.store.ExpireContractSectors(index.Height); err != nil { + return fmt.Errorf("failed to expire contract sectors: %w", err) + } else if err := cm.store.ExpireV2ContractSectors(index.Height); err != nil { + return fmt.Errorf("failed to expire v2 contract sectors: %w", err) + } + return nil +} + +// buildContractState helper to build state changes from a state update +func buildContractState(tx UpdateStateTx, u stateUpdater, revert bool, log *zap.Logger) (state StateChanges) { + u.ForEachFileContractElement(func(fce types.FileContractElement, created bool, rev *types.FileContractElement, resolved, valid bool) { + log := log.With(zap.Stringer("contractID", fce.ID)) + if relevant, err := tx.ContractRelevant(types.FileContractID(fce.ID)); err != nil { + log.Fatal("failed to check contract relevance", zap.Error(err)) + } else if !relevant { + return + } + + switch { + case created: + state.Confirmed = append(state.Confirmed, fce) + log.Debug("confirmed contract") + case rev != nil: + if revert { + state.Revised = append(state.Revised, fce) + log.Debug("revised contract", zap.Uint64("current", rev.FileContract.RevisionNumber), zap.Uint64("revised", fce.FileContract.RevisionNumber)) + } else { + state.Revised = append(state.Revised, *rev) + log.Debug("revised contract", zap.Uint64("current", fce.FileContract.RevisionNumber), zap.Uint64("revised", fce.FileContract.RevisionNumber)) + } + case resolved && valid: + state.Successful = append(state.Successful, types.FileContractID(fce.ID)) + log.Debug("successful contract") + case resolved && !valid: + successful := fce.FileContract.MissedHostPayout().Cmp(fce.FileContract.ValidHostPayout()) >= 0 + if successful { + state.Successful = append(state.Successful, types.FileContractID(fce.ID)) + } else { + state.Failed = append(state.Failed, types.FileContractID(fce.ID)) + } + log.Debug("expired contract", zap.Bool("successful", successful)) + default: + log.Fatal("unexpected contract state", zap.Bool("resolved", resolved), zap.Bool("valid", valid), zap.Bool("created", created), zap.Bool("revised", rev != nil), zap.Stringer("contractID", fce.ID)) + } + }) + + u.ForEachV2FileContractElement(func(fce types.V2FileContractElement, created bool, rev *types.V2FileContractElement, res types.V2FileContractResolutionType) { + log := log.With(zap.Stringer("contractID", fce.ID)) + + if relevant, err := tx.V2ContractRelevant(types.FileContractID(fce.ID)); err != nil { + log.Fatal("failed to check contract relevance", zap.Error(err)) + } else if !relevant { + return + } + + switch { + case created: + state.ConfirmedV2 = append(state.ConfirmedV2, fce) + log.Debug("confirmed v2 contract", zap.Stringer("contractID", fce.ID)) + case rev != nil: + if revert { + state.RevisedV2 = append(state.RevisedV2, fce) + log.Debug("revised contract", zap.Uint64("current", rev.V2FileContract.RevisionNumber), zap.Uint64("revised", fce.V2FileContract.RevisionNumber)) + } else { + state.RevisedV2 = append(state.RevisedV2, *rev) + log.Debug("revised contract", zap.Uint64("current", fce.V2FileContract.RevisionNumber), zap.Uint64("revised", rev.V2FileContract.RevisionNumber)) + } + case res != nil: + switch res := res.(type) { + case *types.V2FileContractFinalization: + state.FinalizedV2 = append(state.FinalizedV2, types.FileContractID(fce.ID)) + log.Debug("finalized v2 contract", zap.Stringer("contractID", fce.ID)) + case *types.V2FileContractRenewal: + state.RenewedV2 = append(state.RenewedV2, types.FileContractID(fce.ID)) + log.Debug("renewed v2 contract", zap.Stringer("contractID", fce.ID)) + case *types.V2FileContractExpiration: + fc := fce.V2FileContract + successful := fc.MissedHostValue.Cmp(fc.HostOutput.Value) >= 0 + if successful { + state.SuccessfulV2 = append(state.SuccessfulV2, types.FileContractID(fce.ID)) + } else { + state.FailedV2 = append(state.FailedV2, types.FileContractID(fce.ID)) + } + log.Debug("expired v2 contract", zap.Stringer("contractID", fce.ID), zap.Bool("successful", successful)) + case *types.V2StorageProof: + state.SuccessfulV2 = append(state.SuccessfulV2, types.FileContractID(fce.ID)) + log.Debug("successful v2 contract", zap.Stringer("contractID", fce.ID)) + default: + panic(fmt.Sprintf("unexpected contract resolution type: %T", res)) + } + default: + log.Fatal("unexpected v2 contract state", zap.Bool("resolved", res != nil), zap.Bool("created", created), zap.Bool("revised", rev != nil), zap.Stringer("contractID", fce.ID)) + } + }) + return +} + +// UpdateChainState updates the state of the contracts on chain +func (cm *Manager) UpdateChainState(tx UpdateStateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate) error { + log := cm.log.Named("updateChainState") + + chainElements, err := tx.ContractChainIndexElements() + if err != nil { + return fmt.Errorf("failed to get chain index state elements: %w", err) + } + + v2ContractStateElements, err := tx.ContractStateElements() + if err != nil { + return fmt.Errorf("failed to get contract state elements: %w", err) + } + v2ContractElementMap := make(map[types.Hash256]*types.StateElement, len(v2ContractStateElements)) + for _, ele := range v2ContractStateElements { + v2ContractElementMap[ele.ID] = &ele + } + + for _, cru := range reverted { + revertedIndex := types.ChainIndex{ + ID: cru.Block.ID(), + Height: cru.State.Index.Height + 1, + } + + // revert contract state changes + state := buildContractState(tx, cru, true, log.Named("revert").With(zap.Stringer("index", revertedIndex))) + if err := tx.RevertContracts(revertedIndex, state); err != nil { + return fmt.Errorf("failed to revert contracts: %w", err) + } + + // delete reverted contract state elements from the map + for _, reverted := range state.ConfirmedV2 { + delete(v2ContractElementMap, reverted.ID) + } + // update remaining contract state elements + for key := range v2ContractElementMap { + cru.UpdateElementProof(v2ContractElementMap[key]) + } + + // revert contract chain index element + if err := tx.RevertContractChainIndexElement(revertedIndex); err != nil { + return fmt.Errorf("failed to revert chain index state element: %w", err) + } + + // update chain state elements + if len(chainElements) > 0 { + last := chainElements[len(chainElements)-1] + if last.ChainIndex != revertedIndex { + panic(fmt.Errorf("unexpected chain index: %v != %v", last.ChainIndex, revertedIndex)) // developer error + } + chainElements = chainElements[:len(chainElements)-1] + for i := range chainElements { + cru.UpdateElementProof(&chainElements[i].StateElement) + } + } + } + + for _, cau := range applied { + state := buildContractState(tx, cau, false, log.Named("apply").With(zap.Stringer("index", cau.State.Index))) + // apply state changes + if err := tx.ApplyContracts(cau.State.Index, state); err != nil { + return fmt.Errorf("failed to revert contracts: %w", err) + } + + // update existing contract state elements + for id := range v2ContractElementMap { + cau.UpdateElementProof(v2ContractElementMap[id]) + } + // add new contract state elements + for _, applied := range state.ConfirmedV2 { + v2ContractElementMap[applied.ID] = &applied.StateElement + } + + // update existing chain index elements proofs + for i := range chainElements { + cau.UpdateElementProof(&chainElements[i].StateElement) + } + // add new chain index element + chainElements = append(chainElements, cau.ChainIndexElement()) + if len(chainElements) > chainIndexBuffer { + chainElements = chainElements[len(chainElements)-chainIndexBuffer:] + } + + // reject any contracts that have not been confirmed after the reject buffer + index := cau.State.Index + if index.Height >= cm.rejectBuffer { + minNegotiationHeight := index.Height - cm.rejectBuffer + rejectedV1, rejectedV2, err := tx.RejectContracts(minNegotiationHeight) + if err != nil { + return fmt.Errorf("failed to reject contracts: %w", err) + } + + if len(rejectedV1) > 0 { + log.Debug("rejected contracts", zap.Int("count", len(rejectedV1))) + } + if len(rejectedV2) > 0 { + log.Debug("rejected v2 contracts", zap.Int("count", len(rejectedV2))) + } + } + + // delete any chain index elements outside of the proof window buffer + if cau.State.Index.Height > chainIndexBuffer { + minHeight := cau.State.Index.Height - chainIndexBuffer + if err := tx.DeleteExpiredContractChainIndexElements(minHeight); err != nil { + return fmt.Errorf("failed to delete expired chain index elements: %w", err) + } + } + } + + // update chain index state elements + if len(chainElements) > 0 { + if err := tx.ApplyContractChainIndexElements(chainElements); err != nil { + return fmt.Errorf("failed to update chain index state elements: %w", err) + } + } + + // update contract state elements + if len(v2ContractElementMap) > 0 { + contractStateElements := make([]types.StateElement, 0, len(v2ContractElementMap)) + for _, ele := range v2ContractElementMap { + contractStateElements = append(contractStateElements, *ele) + } + if err := tx.UpdateContractStateElements(contractStateElements); err != nil { + return fmt.Errorf("failed to update contract state elements: %w", err) + } + } + return nil +} diff --git a/host/metrics/types.go b/host/metrics/types.go index 3535b0b9..aec608bc 100644 --- a/host/metrics/types.go +++ b/host/metrics/types.go @@ -44,10 +44,11 @@ type ( // Contracts is a collection of metrics related to contracts. Contracts struct { - Pending uint64 `json:"pending"` Active uint64 `json:"active"` Rejected uint64 `json:"rejected"` Failed uint64 `json:"failed"` + Renewed uint64 `json:"renewed"` + Finalized uint64 `json:"finalized"` Successful uint64 `json:"successful"` LockedCollateral types.Currency `json:"lockedCollateral"` @@ -101,6 +102,12 @@ type ( Earned Revenue `json:"earned"` } + // WalletMetrics is a collection of metrics related to the wallet + WalletMetrics struct { + Balance types.Currency `json:"balance"` + ImmatureBalance types.Currency `json:"immatureBalance"` + } + // Metrics is a collection of metrics for the host. Metrics struct { Accounts Accounts `json:"accounts"` @@ -110,7 +117,7 @@ type ( Storage Storage `json:"storage"` Registry Registry `json:"registry"` Data DataMetrics `json:"data"` - Balance types.Currency `json:"balance"` + Wallet WalletMetrics `json:"wallet"` Timestamp time.Time `json:"timestamp"` } diff --git a/host/settings/announce.go b/host/settings/announce.go index 8d745498..9d7f5041 100644 --- a/host/settings/announce.go +++ b/host/settings/announce.go @@ -1,217 +1,119 @@ package settings import ( - "crypto/ed25519" + "errors" "fmt" - "time" + "net" + "strconv" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" + "go.sia.tech/coreutils/chain" "go.uber.org/zap" - "lukechampine.com/frand" -) - -const ( - announcementDebounce = 18 // blocks ) type ( - // An Announcement contains the host's announced netaddress and public key + // An Announcement contains the host's announced netaddress Announcement struct { - Index types.ChainIndex `json:"index"` - PublicKey types.PublicKey `json:"publicKey"` - Address string `json:"address"` + Index types.ChainIndex `json:"index"` + Address string `json:"address"` } ) -// constant to overwrite announcement alerts instead of registering new ones -var alertAnnouncementID = frand.Entropy256() - // Announce announces the host to the network func (m *ConfigManager) Announce() error { // get the current settings settings := m.Settings() - // if no netaddress is set, override the field with the auto-discovered one - if settings.NetAddress == "" { - settings.NetAddress = m.discoveredRHPAddr - } - - if err := validateNetAddress(settings.NetAddress); err != nil { - return err - } - - // create a transaction with an announcement - minerFee := m.tp.RecommendedFee().Mul64(announcementTxnSize) - txn := types.Transaction{ - ArbitraryData: [][]byte{ - createAnnouncement(m.hostKey, settings.NetAddress), - }, - MinerFees: []types.Currency{minerFee}, - } - - // fund the transaction - toSign, release, err := m.wallet.FundTransaction(&txn, minerFee) - if err != nil { - return fmt.Errorf("failed to fund transaction: %w", err) - } - // sign the transaction - err = m.wallet.SignTransaction(m.cm.TipState(), &txn, toSign, types.CoveredFields{WholeTransaction: true}) - if err != nil { - release() - return fmt.Errorf("failed to sign transaction: %w", err) - } - // broadcast the transaction - err = m.tp.AcceptTransactionSet([]types.Transaction{txn}) - if err != nil { - release() - return fmt.Errorf("failed to broadcast transaction: %w", err) - } - m.log.Debug("broadcast announcement", zap.String("transactionID", txn.ID().String()), zap.String("netaddress", settings.NetAddress), zap.String("cost", minerFee.ExactString())) - return nil -} -// ProcessConsensusChange implements modules.ConsensusSetSubscriber. -func (cm *ConfigManager) ProcessConsensusChange(cc modules.ConsensusChange) { - done, err := cm.tg.Add() - if err != nil { - return + if m.validateNetAddress { + if err := validateNetAddress(settings.NetAddress); err != nil { + return fmt.Errorf("failed to validate net address %q: %w", settings.NetAddress, err) + } } - defer done() - log := cm.log.Named("consensusChange") + minerFee := m.chain.RecommendedFee().Mul64(announcementTxnSize) - lastAnnouncement, err := cm.store.LastAnnouncement() - if err != nil { - log.Fatal("failed to get last announcement", zap.Error(err)) + ha := chain.HostAnnouncement{ + NetAddress: settings.NetAddress, } - hostPub := cm.hostKey.PublicKey() - - // check if the host key changed (should not happen) - if lastAnnouncement.PublicKey != (types.PublicKey{}) && lastAnnouncement.PublicKey != hostPub { - log.Error("resetting announcement due to host key change", zap.Stringer("oldKey", lastAnnouncement.PublicKey), zap.Stringer("newKey", hostPub)) - if err = cm.store.RevertLastAnnouncement(); err != nil { - log.Fatal("failed to reset announcements", zap.Error(err)) + cs := m.chain.TipState() + if cs.Index.Height < cs.Network.HardforkV2.AllowHeight { + // create a transaction with an announcement + txn := types.Transaction{ + ArbitraryData: [][]byte{ + ha.ToArbitraryData(m.hostKey), + }, + MinerFees: []types.Currency{minerFee}, } - } - // check if a block containing the announcement was reverted - blockHeight := uint64(cc.BlockHeight) - for _, block := range cc.RevertedBlocks { - if types.BlockID(block.ID()) == lastAnnouncement.Index.ID { - log.Info("resetting announcement due to block revert", zap.Uint64("height", blockHeight), zap.String("address", lastAnnouncement.Address), zap.Stringer("publicKey", hostPub)) - if err = cm.store.RevertLastAnnouncement(); err != nil { - log.Fatal("failed to revert last announcement", zap.Error(err)) - } + // fund the transaction + toSign, err := m.wallet.FundTransaction(&txn, minerFee, true) + if err != nil { + return fmt.Errorf("failed to fund transaction: %w", err) } - blockHeight-- - } - - // check for new announcements - for _, block := range cc.AppliedBlocks { - blockID := types.BlockID(block.ID()) - for _, txn := range block.Transactions { - for _, arb := range txn.ArbitraryData { - address, pubkey, err := modules.DecodeAnnouncement(arb) - if err != nil || pubkey.Algorithm != stypes.SignatureEd25519 || len(pubkey.Key) != ed25519.PublicKeySize || len(address) == 0 { - continue - } - announcement := Announcement{ - PublicKey: types.PublicKey(pubkey.Key), - Address: string(address), - Index: types.ChainIndex{ - ID: blockID, - Height: blockHeight, - }, - } - - if announcement.PublicKey != hostPub { - continue - } - - // update the announcement - if err := cm.store.UpdateLastAnnouncement(announcement); err != nil { - log.Fatal("failed to update last announcement", zap.Error(err)) - } - cm.a.Register(alerts.Alert{ - ID: alertAnnouncementID, - Severity: alerts.SeverityInfo, - Message: "Announcement confirmed", - Data: map[string]any{ - "address": announcement.Address, - "height": blockHeight, - }, - Timestamp: time.Now(), - }) - log.Info("announcement confirmed", zap.String("address", announcement.Address), zap.Uint64("height", blockHeight)) - } + m.wallet.SignTransaction(&txn, toSign, types.CoveredFields{WholeTransaction: true}) + txnset := append(m.chain.UnconfirmedParents(txn), txn) + if _, err := m.chain.AddPoolTransactions(txnset); err != nil { + m.wallet.ReleaseInputs([]types.Transaction{txn}, nil) + return fmt.Errorf("failed to add transaction to pool: %w", err) } - blockHeight++ + m.syncer.BroadcastTransactionSet(txnset) + m.log.Debug("broadcast announcement", zap.String("transactionID", txn.ID().String()), zap.String("netaddress", settings.NetAddress), zap.String("cost", minerFee.ExactString())) + } else { + // create a v2 transaction with an announcement + txn := types.V2Transaction{ + Attestations: []types.Attestation{ + ha.ToAttestation(cs, m.hostKey), + }, + MinerFee: minerFee, + } + basis, toSign, err := m.wallet.FundV2Transaction(&txn, minerFee, true) + if err != nil { + return fmt.Errorf("failed to fund transaction: %w", err) + } + m.wallet.SignV2Inputs(&txn, toSign) + basis, txnset, err := m.chain.V2TransactionSet(basis, txn) + if err != nil { + m.wallet.ReleaseInputs(nil, []types.V2Transaction{txn}) + return fmt.Errorf("failed to create transaction set: %w", err) + } else if _, err := m.chain.AddV2PoolTransactions(basis, txnset); err != nil { + m.wallet.ReleaseInputs(nil, []types.V2Transaction{txn}) + return fmt.Errorf("failed to add transaction to pool: %w", err) + } + m.syncer.BroadcastV2TransactionSet(cs.Index, txnset) + m.log.Debug("broadcast v2 announcement", zap.String("transactionID", txn.ID().String()), zap.String("netaddress", settings.NetAddress), zap.String("cost", minerFee.ExactString())) } + return nil +} - // get the last announcement again, in case it was updated - lastAnnouncement, err = cm.store.LastAnnouncement() +func validateNetAddress(netaddress string) error { + host, port, err := net.SplitHostPort(netaddress) if err != nil { - log.Fatal("failed to get last announcement", zap.Error(err)) + return fmt.Errorf("failed to split net address: %w", err) } - // get the current net address - cm.mu.Lock() - defer cm.mu.Unlock() - - if err := validateNetAddress(cm.settings.NetAddress); err != nil { - log.Debug("skipping auto announcement for invalid net address", zap.Error(err)) - return - } - - currentNetAddress := cm.settings.NetAddress - cm.scanHeight = uint64(cc.BlockHeight) - timestamp := time.Unix(int64(cc.AppliedBlocks[len(cc.AppliedBlocks)-1].Timestamp), 0) - nextAnnounceHeight := lastAnnouncement.Index.Height + autoAnnounceInterval - - log = log.With(zap.Uint64("currentHeight", cm.scanHeight), zap.Uint64("lastHeight", lastAnnouncement.Index.Height), zap.Uint64("nextHeight", nextAnnounceHeight), zap.String("currentAddress", currentNetAddress), zap.String("oldAddress", lastAnnouncement.Address)) - - // if the address hasn't changed, don't reannounce - if cm.scanHeight < nextAnnounceHeight && currentNetAddress == lastAnnouncement.Address { - log.Debug("skipping announcement for unchanged address") - return + // Check that the host is not empty or localhost. + if host == "" { + return errors.New("empty net address") + } else if host == "localhost" { + return errors.New("net address cannot be localhost") } - // debounce announcements - if cm.scanHeight < cm.lastAnnounceAttempt+announcementDebounce || time.Since(timestamp) > 3*time.Hour { - return + // Check that the port is a valid number. + n, err := strconv.Atoi(port) + if err != nil { + return fmt.Errorf("failed to parse port: %w", err) + } else if n < 1 || n > 65535 { + return errors.New("port must be between 1 and 65535") } - log.Debug("announcing host") - cm.lastAnnounceAttempt = cm.scanHeight - - // in go-routine to prevent deadlock with TPool - go func() { - if err := cm.Announce(); err != nil { - log.Error("failed to announce host", zap.Error(err)) - cm.a.Register(alerts.Alert{ - ID: alertAnnouncementID, - Severity: alerts.SeverityWarning, - Message: "Announcement failed", - Data: map[string]any{ - "error": err.Error(), - }, - Timestamp: time.Now(), - }) - return + // If the host is an IP address, check that it is a public IP address. + ip := net.ParseIP(host) + if ip != nil { + if ip.IsLoopback() || ip.IsPrivate() || !ip.IsGlobalUnicast() { + return errors.New("only public IP addresses allowed") } - log.Info("announced host") - cm.a.Register(alerts.Alert{ - ID: alertAnnouncementID, - Severity: alerts.SeverityInfo, - Message: "Announcement broadcast", - Data: map[string]any{ - "address": currentNetAddress, - "height": lastAnnouncement.Index.Height, - }, - Timestamp: time.Now(), - }) - }() + return nil + } + return nil } diff --git a/host/settings/announce_test.go b/host/settings/announce_test.go index 446a8dbe..6727964a 100644 --- a/host/settings/announce_test.go +++ b/host/settings/announce_test.go @@ -1,121 +1,216 @@ package settings_test import ( - "path/filepath" "testing" - "time" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/settings" - "go.sia.tech/hostd/internal/test" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/webhooks" + "go.sia.tech/hostd/host/storage" + "go.sia.tech/hostd/index" + "go.sia.tech/hostd/internal/testutil" "go.uber.org/zap/zaptest" - "lukechampine.com/frand" ) func TestAutoAnnounce(t *testing.T) { - hostKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) - dir := t.TempDir() log := zaptest.NewLogger(t) - node, err := test.NewWallet(hostKey, dir, log.Named("wallet")) + network, genesisBlock := testutil.V1Network() + hostKey := types.GeneratePrivateKey() + + node := testutil.NewConsensusNode(t, network, genesisBlock, log) + + // TODO: its unfortunate that all these managers need to be created just to + // test the auto-announce feature. + wm, err := wallet.NewSingleAddressWallet(hostKey, node.Chain, node.Store) if err != nil { - t.Fatal(err) + t.Fatal("failed to create wallet:", err) } - defer node.Close() + defer wm.Close() - // fund the wallet - if err := node.MineBlocks(node.Address(), 99); err != nil { - t.Fatal(err) + vm, err := storage.NewVolumeManager(node.Store, storage.WithLogger(log.Named("storage"))) + if err != nil { + t.Fatal("failed to create volume manager:", err) } + defer vm.Close() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("sqlite")) + contracts, err := contracts.NewManager(node.Store, vm, node.Chain, node.Syncer, wm, contracts.WithRejectAfter(10), contracts.WithRevisionSubmissionBuffer(5), contracts.WithLog(log)) if err != nil { - t.Fatal(err) + t.Fatal("failed to create contracts manager:", err) } - defer db.Close() + defer contracts.Close() - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) + sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50), settings.WithValidateNetAddress(false)) if err != nil { t.Fatal(err) } + defer sm.Close() - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - manager, err := settings.NewConfigManager(settings.WithHostKey(hostKey), - settings.WithStore(db), - settings.WithChainManager(node.ChainManager()), - settings.WithTransactionPool(node.TPool()), - settings.WithWallet(node), - settings.WithAlertManager(am), - settings.WithLog(log.Named("settings"))) + idx, err := index.NewManager(node.Store, node.Chain, contracts, wm, sm, vm, index.WithLog(log.Named("index")), index.WithBatchSize(0)) // off-by-one if err != nil { - t.Fatal(err) + t.Fatal("failed to create index manager:", err) } - defer manager.Close() + defer idx.Close() + + // fund the wallet + testutil.MineBlocks(t, node, wm.Address(), 150) + testutil.WaitForSync(t, node.Chain, idx) settings := settings.DefaultSettings settings.NetAddress = "foo.bar:1234" - manager.UpdateSettings(settings) - - // trigger an auto-announce - if err := node.MineBlocks(node.Address(), 1); err != nil { - t.Fatal(err) + sm.UpdateSettings(settings) + + assertAnnouncement := func(t *testing.T, expectedAddr string, height uint64) { + t.Helper() + + index, ok := node.Chain.BestIndex(height) + if !ok { + t.Fatal("failed to get index") + } + + ann, err := sm.LastAnnouncement() + if err != nil { + t.Fatal(err) + } else if ann.Address != expectedAddr { + t.Fatalf("expected address %q, got %q", expectedAddr, ann.Address) + } else if ann.Index != index { + t.Fatalf("expected index %q, got %q", index, ann.Index) + } } - time.Sleep(time.Second) - // confirm the announcement - if err := node.MineBlocks(node.Address(), 5); err != nil { - t.Fatal(err) + // helper that mines blocks and waits for them to be processed before mining + // the next one. This is necessary because test blocks can be extremely fast + // and the host may not have time to process the broadcast before the next + // block is mined. + mineAndSync := func(t *testing.T, numBlocks int) { + t.Helper() + + // waits for each block to be processed before mining the next one + for i := 0; i < numBlocks; i++ { + testutil.MineBlocks(t, node, wm.Address(), 1) + testutil.WaitForSync(t, node.Chain, idx) + } } - time.Sleep(time.Second) - lastAnnouncement, err := manager.LastAnnouncement() + // trigger an auto-announce + mineAndSync(t, 2) + assertAnnouncement(t, "foo.bar:1234", 152) + // mine until the next announcement and confirm it + mineAndSync(t, 51) + assertAnnouncement(t, "foo.bar:1234", 203) // 152 (first confirm) + 50 (interval) + 1 (confirmation) + + // change the address + settings.NetAddress = "baz.qux:5678" + sm.UpdateSettings(settings) + + // trigger and confirm the new announcement + mineAndSync(t, 2) + assertAnnouncement(t, "baz.qux:5678", 205) + + // mine until the v2 hardfork activates. The host should re-announce with a + // v2 attestation. + n := node.Chain.TipState().Network + mineAndSync(t, int(n.HardforkV2.AllowHeight-node.Chain.Tip().Height)+1) + assertAnnouncement(t, "baz.qux:5678", n.HardforkV2.AllowHeight+1) + + // mine a few more blocks to ensure the host doesn't re-announce + mineAndSync(t, 10) + assertAnnouncement(t, "baz.qux:5678", n.HardforkV2.AllowHeight+1) +} + +func TestAutoAnnounceV2(t *testing.T) { + log := zaptest.NewLogger(t) + network, genesisBlock := testutil.V2Network() + network.HardforkV2.AllowHeight = 2 + network.HardforkV2.RequireHeight = 3 + hostKey := types.GeneratePrivateKey() + + node := testutil.NewConsensusNode(t, network, genesisBlock, log) + + // TODO: its unfortunate that all these managers need to be created just to + // test the auto-announce feature. + wm, err := wallet.NewSingleAddressWallet(hostKey, node.Chain, node.Store) if err != nil { - t.Fatal(err) - } else if lastAnnouncement.Index.Height == 0 { - t.Fatalf("expected an announcement, got %v", lastAnnouncement.Index.Height) - } else if lastAnnouncement.Address != "foo.bar:1234" { - t.Fatal("announcement not updated") + t.Fatal("failed to create wallet:", err) } - lastHeight := lastAnnouncement.Index.Height + defer wm.Close() - remainingBlocks := lastHeight + 100 - node.ChainManager().TipState().Index.Height - t.Log("remaining blocks:", remainingBlocks) + vm, err := storage.NewVolumeManager(node.Store, storage.WithLogger(log.Named("storage"))) + if err != nil { + t.Fatal("failed to create volume manager:", err) + } + defer vm.Close() - // mine until right before the next announcement to ensure that the - // announcement is not triggered early - if err := node.MineBlocks(node.Address(), int(remainingBlocks-1)); err != nil { - t.Fatal(err) + contracts, err := contracts.NewManager(node.Store, vm, node.Chain, node.Syncer, wm, contracts.WithRejectAfter(10), contracts.WithRevisionSubmissionBuffer(5), contracts.WithLog(log)) + if err != nil { + t.Fatal("failed to create contracts manager:", err) } - time.Sleep(time.Second) + defer contracts.Close() - lastAnnouncement, err = manager.LastAnnouncement() + sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50)) if err != nil { t.Fatal(err) - } else if lastAnnouncement.Index.Height != lastHeight { - t.Fatal("announcement triggered unexpectedly") } + defer sm.Close() - // trigger an auto-announce - if err := node.MineBlocks(node.Address(), 1); err != nil { - t.Fatal(err) + idx, err := index.NewManager(node.Store, node.Chain, contracts, wm, sm, vm, index.WithLog(log.Named("index")), index.WithBatchSize(0)) // off-by-one + if err != nil { + t.Fatal("failed to create index manager:", err) } - time.Sleep(time.Second) + defer idx.Close() - // confirm the announcement - if err := node.MineBlocks(node.Address(), 2); err != nil { - t.Fatal(err) + // fund the wallet + testutil.MineBlocks(t, node, wm.Address(), 150) + testutil.WaitForSync(t, node.Chain, idx) + + settings := settings.DefaultSettings + settings.NetAddress = "foo.bar:1234" + sm.UpdateSettings(settings) + + assertAnnouncement := func(t *testing.T, expectedAddr string, height uint64) { + t.Helper() + + index, ok := node.Chain.BestIndex(height) + if !ok { + t.Fatal("failed to get index") + } + + ann, err := sm.LastAnnouncement() + if err != nil { + t.Fatal(err) + } else if ann.Address != expectedAddr { + t.Fatalf("expected address %q, got %q", expectedAddr, ann.Address) + } else if ann.Index != index { + t.Fatalf("expected index %q, got %q", index, ann.Index) + } } - time.Sleep(time.Second) - prevHeight := lastAnnouncement.Index.Height - lastAnnouncement, err = manager.LastAnnouncement() - if err != nil { - t.Fatal(err) - } else if lastAnnouncement.Index.Height <= prevHeight { - t.Fatalf("expected a new announcement after %v, got %v", prevHeight, lastAnnouncement.Index.Height) - } else if lastAnnouncement.Address != "foo.bar:1234" { - t.Fatal("announcement not updated") + // helper that mines blocks and waits for them to be processed before mining + // the next one. This is necessary because test blocks can be extremely fast + // and the host may not have time to process the broadcast before the next + // block is mined. + mineAndSync := func(t *testing.T, numBlocks int) { + t.Helper() + + // waits for each block to be processed before mining the next one + for i := 0; i < numBlocks; i++ { + testutil.MineBlocks(t, node, wm.Address(), 1) + testutil.WaitForSync(t, node.Chain, idx) + } } + + // trigger an auto-announce + mineAndSync(t, 2) + assertAnnouncement(t, "foo.bar:1234", 152) + // mine until the next announcement and confirm it + mineAndSync(t, 51) + assertAnnouncement(t, "foo.bar:1234", 203) // 152 (first confirm) + 50 (interval) + 1 (confirmation) + + // change the address + settings.NetAddress = "baz.qux:5678" + sm.UpdateSettings(settings) + + // trigger and confirm the new announcement + mineAndSync(t, 2) + assertAnnouncement(t, "baz.qux:5678", 205) } diff --git a/host/settings/certs.go b/host/settings/certs.go deleted file mode 100644 index 6f8b7d66..00000000 --- a/host/settings/certs.go +++ /dev/null @@ -1,82 +0,0 @@ -package settings - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "fmt" - "math/big" - "net" - "os" - "time" -) - -func (m *ConfigManager) reloadCertificates() error { - if _, err := os.Stat(m.certKeyFilePath); errors.Is(err, os.ErrNotExist) { - // if the certificate files do not exist, create a temporary certificate - addr := m.settings.NetAddress - if addr == "" { - addr = m.discoveredRHPAddr - } - addr, _, err := net.SplitHostPort(addr) - if err != nil { - addr = "localhost" - } - - certificate, err := tempCertificate(addr) - if err != nil { - return fmt.Errorf("failed to create temporary certificate: %w", err) - } - m.rhp3WSTLS.Certificates = []tls.Certificate{certificate} - return nil - } else if err != nil { - return fmt.Errorf("failed to check for certificate: %w", err) - } - - // load the certificate from disk - certificate, err := tls.LoadX509KeyPair(m.certCertFilePath, m.certKeyFilePath) - if err != nil { - return fmt.Errorf("failed to load certificate: %w", err) - } - m.rhp3WSTLS.Certificates = []tls.Certificate{certificate} - return nil -} - -// RHP3TLSConfig returns the TLS config for the rhp3 WebSocket listener -func (m *ConfigManager) RHP3TLSConfig() *tls.Config { - return m.rhp3WSTLS -} - -func tempCertificate(name string) (tls.Certificate, error) { - now := time.Now() - template := &x509.Certificate{ - SerialNumber: big.NewInt(now.Unix()), - Subject: pkix.Name{ - CommonName: name, - Organization: []string{name}, - }, - NotBefore: now, - NotAfter: now.AddDate(1, 0, 0), // valid for one year - BasicConstraintsValid: true, - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - } - - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return tls.Certificate{}, fmt.Errorf("failed to generate private key: %w", err) - } - cert, err := x509.CreateCertificate(rand.Reader, template, template, priv.Public(), priv) - if err != nil { - return tls.Certificate{}, fmt.Errorf("failed to create certificate: %w", err) - } - - var outCert tls.Certificate - outCert.Certificate = append(outCert.Certificate, cert) - outCert.PrivateKey = priv - return outCert, nil -} diff --git a/host/settings/consts_default.go b/host/settings/consts_default.go deleted file mode 100644 index ce17d807..00000000 --- a/host/settings/consts_default.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !testing - -package settings - -const ( - autoAnnounceInterval = (144 * 180) // reannounce every 180 days -) diff --git a/host/settings/consts_testing.go b/host/settings/consts_testing.go deleted file mode 100644 index d694846b..00000000 --- a/host/settings/consts_testing.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build testing - -package settings - -const ( - autoAnnounceInterval = 100 // reannounce every 100 blocks -) diff --git a/host/settings/netaddress_default.go b/host/settings/netaddress_default.go deleted file mode 100644 index 72d6c519..00000000 --- a/host/settings/netaddress_default.go +++ /dev/null @@ -1,42 +0,0 @@ -//go:build !testing - -package settings - -import ( - "errors" - "fmt" - "net" - "strconv" -) - -func validateNetAddress(netaddress string) error { - host, port, err := net.SplitHostPort(netaddress) - if err != nil { - return fmt.Errorf("invalid net address %q: net addresses must contain a host and port: %w", netaddress, err) - } - - // Check that the host is not empty or localhost. - if host == "" { - return errors.New("empty net address") - } else if host == "localhost" { - return errors.New("net address cannot be localhost") - } - - // Check that the port is a valid number. - n, err := strconv.Atoi(port) - if err != nil { - return fmt.Errorf("failed to parse port: %w", err) - } else if n < 1 || n > 65535 { - return errors.New("port must be between 1 and 65535") - } - - // If the host is an IP address, check that it is a public IP address. - ip := net.ParseIP(host) - if ip != nil { - if ip.IsLoopback() || ip.IsPrivate() || !ip.IsGlobalUnicast() { - return fmt.Errorf("invalid net address %q: only public IP addresses allowed", host) - } - return nil - } - return nil -} diff --git a/host/settings/netaddress_testing.go b/host/settings/netaddress_testing.go deleted file mode 100644 index 2024b2f2..00000000 --- a/host/settings/netaddress_testing.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build testing - -package settings - -// enables announcements on localhost -func validateNetAddress(netaddress string) error { - return nil -} diff --git a/host/settings/options.go b/host/settings/options.go index 4f48e033..d6199cde 100644 --- a/host/settings/options.go +++ b/host/settings/options.go @@ -1,7 +1,6 @@ package settings import ( - "go.sia.tech/core/types" "go.uber.org/zap" ) @@ -16,34 +15,6 @@ func WithLog(log *zap.Logger) Option { } } -// WithStore sets the store for the settings manager. -func WithStore(s Store) Option { - return func(cm *ConfigManager) { - cm.store = s - } -} - -// WithChainManager sets the chain manager for the settings manager. -func WithChainManager(cm ChainManager) Option { - return func(c *ConfigManager) { - c.cm = cm - } -} - -// WithTransactionPool sets the transaction pool for the settings manager. -func WithTransactionPool(tp TransactionPool) Option { - return func(c *ConfigManager) { - c.tp = tp - } -} - -// WithWallet sets the wallet for the settings manager. -func WithWallet(w Wallet) Option { - return func(c *ConfigManager) { - c.wallet = w - } -} - // WithAlertManager sets the alerts manager for the settings manager. func WithAlertManager(am Alerts) Option { return func(c *ConfigManager) { @@ -51,24 +22,27 @@ func WithAlertManager(am Alerts) Option { } } -// WithHostKey sets the host key for the settings manager. -func WithHostKey(pk types.PrivateKey) Option { +// WithAnnounceInterval sets the interval at which the host should re-announce +// itself. +func WithAnnounceInterval(interval uint64) Option { return func(c *ConfigManager) { - c.hostKey = pk + c.announceInterval = interval } } -// WithRHP2Addr sets the address of the RHP2 server. -func WithRHP2Addr(addr string) Option { +// WithValidateNetAddress sets whether the settings manager should validate +// the announced net address. +func WithValidateNetAddress(validate bool) Option { return func(c *ConfigManager) { - c.discoveredRHPAddr = addr + c.validateNetAddress = validate } } -// WithCertificateFiles sets the certificate files for the settings manager. -func WithCertificateFiles(certFilePath, keyFilePath string) Option { +// WithInitialSettings sets the host's settings when the config manager is +// initialized. If this option is not provided, the default settings are used. +// If the database already contains settings, they will be used. +func WithInitialSettings(settings Settings) Option { return func(c *ConfigManager) { - c.certCertFilePath = certFilePath - c.certKeyFilePath = keyFilePath + c.initialSettings = settings } } diff --git a/host/settings/pin/options.go b/host/settings/pin/options.go index 3842ea58..7491af86 100644 --- a/host/settings/pin/options.go +++ b/host/settings/pin/options.go @@ -31,27 +31,6 @@ func WithFrequency(frequency time.Duration) Option { } } -// WithSettings sets the settings manager for the manager. -func WithSettings(s SettingsManager) Option { - return func(m *Manager) { - m.sm = s - } -} - -// WithStore sets the store for the manager. -func WithStore(s Store) Option { - return func(m *Manager) { - m.store = s - } -} - -// WithExchangeRateRetriever sets the exchange rate retriever for the manager. -func WithExchangeRateRetriever(e ExchangeRateRetriever) Option { - return func(m *Manager) { - m.explorer = e - } -} - // WithAverageRateWindow sets the window over which the manager calculates the // average exchange rate. func WithAverageRateWindow(window time.Duration) Option { diff --git a/host/settings/pin/pin.go b/host/settings/pin/pin.go index 3a8edca3..b67b7144 100644 --- a/host/settings/pin/pin.go +++ b/host/settings/pin/pin.go @@ -66,20 +66,20 @@ type ( UpdatePinnedSettings(context.Context, PinnedSettings) error } - // An ExchangeRateRetriever retrieves the current exchange rate from + // A Forex retrieves the current exchange rate from // an external source. - ExchangeRateRetriever interface { + Forex interface { SiacoinExchangeRate(ctx context.Context, currency string) (float64, error) } // A Manager manages the host's pinned settings and updates the host's // settings based on the current exchange rate. Manager struct { - log *zap.Logger - store Store - alerts Alerts - explorer ExchangeRateRetriever - sm SettingsManager + log *zap.Logger + store Store + alerts Alerts + forex Forex + sm SettingsManager frequency time.Duration rateWindow time.Duration @@ -142,7 +142,7 @@ func (m *Manager) updatePrices(ctx context.Context, force bool) error { return nil } - rate, err := m.explorer.SiacoinExchangeRate(ctx, currency) + rate, err := m.forex.SiacoinExchangeRate(ctx, currency) if err != nil { return fmt.Errorf("failed to get exchange rate: %w", err) } else if rate <= 0 { @@ -294,9 +294,14 @@ func ConvertCurrencyToSC(target decimal.Decimal, rate decimal.Decimal) (types.Cu } // NewManager creates a new pin manager. -func NewManager(opts ...Option) (*Manager, error) { +func NewManager(store Store, settings SettingsManager, f Forex, opts ...Option) (*Manager, error) { m := &Manager{ - log: zap.NewNop(), + store: store, + sm: settings, + forex: f, + + alerts: alerts.NewNop(), + log: zap.NewNop(), frequency: 5 * time.Minute, rateWindow: 6 * time.Hour, @@ -306,15 +311,7 @@ func NewManager(opts ...Option) (*Manager, error) { opt(m) } - if m.store == nil { - return nil, fmt.Errorf("store is required") - } else if m.explorer == nil { - return nil, fmt.Errorf("exchange rate retriever is required") - } else if m.sm == nil { - return nil, fmt.Errorf("settings manager is required") - } else if m.log == nil { - return nil, fmt.Errorf("logger is required") - } else if m.frequency <= 0 { + if m.frequency <= 0 { return nil, fmt.Errorf("frequency must be positive") } else if m.rateWindow <= 0 { return nil, fmt.Errorf("rate window must be positive") diff --git a/host/settings/pin/pin_test.go b/host/settings/pin/pin_test.go index 6f62086d..554217cc 100644 --- a/host/settings/pin/pin_test.go +++ b/host/settings/pin/pin_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "path/filepath" "strings" "sync" "testing" @@ -14,8 +13,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/settings/pin" - "go.sia.tech/hostd/internal/test" - "go.sia.tech/hostd/persist/sqlite" + "go.sia.tech/hostd/internal/testutil" "go.uber.org/zap/zaptest" ) @@ -113,34 +111,22 @@ func TestConvertConvertCurrencyToSC(t *testing.T) { func TestPinnedFields(t *testing.T) { log := zaptest.NewLogger(t) - db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) - if err != nil { - t.Fatal(err) - } - defer db.Close() + network, genesis := testutil.V1Network() + node := testutil.NewConsensusNode(t, network, genesis, log) - rr := &exchangeRateRetrieverStub{ + fr := &exchangeRateRetrieverStub{ value: 1, currency: "usd", } - node, err := test.NewNode(t.TempDir()) - if err != nil { - t.Fatal(err) - } - defer node.Close() - - sm, err := settings.NewConfigManager(settings.WithHostKey(types.GeneratePrivateKey()), settings.WithStore(db), settings.WithChainManager(node.ChainManager())) + sm, err := settings.NewConfigManager(types.GeneratePrivateKey(), node.Store, node.Chain, node.Syncer, nil) if err != nil { t.Fatal(err) } defer sm.Close() - pm, err := pin.NewManager(pin.WithAverageRateWindow(time.Minute), + pm, err := pin.NewManager(node.Store, sm, fr, pin.WithAverageRateWindow(time.Minute), pin.WithFrequency(100*time.Millisecond), - pin.WithExchangeRateRetriever(rr), - pin.WithSettings(sm), - pin.WithStore(db), pin.WithLogger(log.Named("pin"))) if err != nil { t.Fatal(err) @@ -224,34 +210,22 @@ func TestPinnedFields(t *testing.T) { func TestAutomaticUpdate(t *testing.T) { log := zaptest.NewLogger(t) - db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) - if err != nil { - t.Fatal(err) - } - defer db.Close() + network, genesis := testutil.V1Network() + node := testutil.NewConsensusNode(t, network, genesis, log) - rr := &exchangeRateRetrieverStub{ + fr := &exchangeRateRetrieverStub{ value: 1, currency: "usd", } - node, err := test.NewNode(t.TempDir()) - if err != nil { - t.Fatal(err) - } - defer node.Close() - - sm, err := settings.NewConfigManager(settings.WithHostKey(types.GeneratePrivateKey()), settings.WithStore(db), settings.WithChainManager(node.ChainManager())) + sm, err := settings.NewConfigManager(types.GeneratePrivateKey(), node.Store, node.Chain, node.Syncer, nil) if err != nil { t.Fatal(err) } defer sm.Close() - pm, err := pin.NewManager(pin.WithAverageRateWindow(time.Second/2), + pm, err := pin.NewManager(node.Store, sm, fr, pin.WithAverageRateWindow(time.Second/2), pin.WithFrequency(100*time.Millisecond), - pin.WithExchangeRateRetriever(rr), - pin.WithSettings(sm), - pin.WithStore(db), pin.WithLogger(log.Named("pin"))) if err != nil { t.Fatal(err) @@ -307,14 +281,14 @@ func TestAutomaticUpdate(t *testing.T) { } // update the exchange rate below the threshold - rr.updateRate(1.5) + fr.updateRate(1.5) time.Sleep(time.Second) if err := checkSettings(sm.Settings(), pin, 1); err != nil { t.Fatal(err) } // update the exchange rate to put it over the threshold - rr.updateRate(2) + fr.updateRate(2) time.Sleep(time.Second) if err := checkSettings(sm.Settings(), pin, 2); err != nil { t.Fatal(err) diff --git a/host/settings/settings.go b/host/settings/settings.go index 5579338a..8e957a9c 100644 --- a/host/settings/settings.go +++ b/host/settings/settings.go @@ -1,7 +1,6 @@ package settings import ( - "bytes" "crypto/ed25519" "crypto/tls" "errors" @@ -14,9 +13,7 @@ import ( "go.sia.tech/core/consensus" "go.sia.tech/core/types" "go.sia.tech/hostd/alerts" - "go.sia.tech/hostd/internal/chain" "go.sia.tech/hostd/internal/threadgroup" - "go.sia.tech/siad/modules" "go.uber.org/zap" "golang.org/x/time/rate" ) @@ -42,10 +39,43 @@ type ( UpdateSettings(s Settings) error LastAnnouncement() (Announcement, error) - UpdateLastAnnouncement(Announcement) error - RevertLastAnnouncement() error + } + + // ChainManager defines the interface required by the contract manager to + // interact with the consensus set. + ChainManager interface { + Tip() types.ChainIndex + TipState() consensus.State + BestIndex(height uint64) (types.ChainIndex, bool) + RecommendedFee() types.Currency + + UnconfirmedParents(txn types.Transaction) []types.Transaction + AddPoolTransactions([]types.Transaction) (known bool, err error) + + V2TransactionSet(types.ChainIndex, types.V2Transaction) (types.ChainIndex, []types.V2Transaction, error) + AddV2PoolTransactions(types.ChainIndex, []types.V2Transaction) (known bool, err error) + } + + // A Syncer broadcasts transactions to its peers + Syncer interface { + BroadcastTransactionSet([]types.Transaction) + BroadcastV2TransactionSet(types.ChainIndex, []types.V2Transaction) + } + + // A Wallet manages Siacoins and funds transactions + Wallet interface { + Address() types.Address + ReleaseInputs(txns []types.Transaction, v2txns []types.V2Transaction) + FundTransaction(txn *types.Transaction, amount types.Currency, useUnconfirmed bool) ([]types.Hash256, error) + SignTransaction(txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) - LastSettingsConsensusChange() (modules.ConsensusChangeID, uint64, error) + FundV2Transaction(txn *types.V2Transaction, amount types.Currency, useUnconfirmed bool) (types.ChainIndex, []int, error) + SignV2Inputs(txn *types.V2Transaction, toSign []int) + } + + // Alerts registers global alerts. + Alerts interface { + Register(alerts.Alert) } // Settings contains configuration options for the host. @@ -89,49 +119,24 @@ type ( Revision uint64 `json:"revision"` } - // A TransactionPool broadcasts transactions to the network. - TransactionPool interface { - AcceptTransactionSet([]types.Transaction) error - RecommendedFee() types.Currency - } - - // Alerts registers global alerts. - Alerts interface { - Register(alerts.Alert) - } - - // A ChainManager manages the current consensus state - ChainManager interface { - TipState() consensus.State - Subscribe(s modules.ConsensusSetSubscriber, ccID modules.ConsensusChangeID, cancel <-chan struct{}) error - } - - // A Wallet manages funds and signs transactions - Wallet interface { - FundTransaction(txn *types.Transaction, amount types.Currency) ([]types.Hash256, func(), error) - SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error - } - // A ConfigManager manages the host's current configuration ConfigManager struct { - hostKey types.PrivateKey - discoveredRHPAddr string - - certKeyFilePath string - certCertFilePath string + hostKey types.PrivateKey + announceInterval uint64 + validateNetAddress bool + initialSettings Settings store Store a Alerts log *zap.Logger - cm ChainManager - tp TransactionPool + chain ChainManager + syncer Syncer wallet Wallet - mu sync.Mutex // guards the following fields - settings Settings // in-memory cache of the host's settings - scanHeight uint64 // track the last block height that was scanned for announcements - lastAnnounceAttempt uint64 // debounce announcement transactions + mu sync.Mutex // guards the following fields + settings Settings // in-memory cache of the host's settings + scanHeight uint64 // track the last block height that was scanned for announcements ingressLimit *rate.Limiter egressLimit *rate.Limiter @@ -174,8 +179,6 @@ var ( } // ErrNoSettings must be returned by the store if the host has no settings yet ErrNoSettings = errors.New("no settings found") - - specifierAnnouncement = types.NewSpecifier("HostAnnouncement") ) // setRateLimit sets the bandwidth rate limit for the host @@ -224,7 +227,7 @@ func (m *ConfigManager) UpdateSettings(s Settings) error { } // if a netaddress is set, validate it - if strings.TrimSpace(s.NetAddress) != "" { + if strings.TrimSpace(s.NetAddress) != "" && m.validateNetAddress { if err := validateNetAddress(s.NetAddress); err != nil { return fmt.Errorf("failed to validate net address: %w", err) } @@ -250,37 +253,21 @@ func (m *ConfigManager) BandwidthLimiters() (ingress, egress *rate.Limiter) { return m.ingressLimit, m.egressLimit } -// DiscoveredRHP2Address returns the rhp2 address that was discovered by the gateway -func (m *ConfigManager) DiscoveredRHP2Address() string { - return m.discoveredRHPAddr -} - -func createAnnouncement(priv types.PrivateKey, netaddress string) []byte { - // encode the announcement - var buf bytes.Buffer - pub := priv.PublicKey() - enc := types.NewEncoder(&buf) - specifierAnnouncement.EncodeTo(enc) - enc.WriteString(netaddress) - pub.UnlockKey().EncodeTo(enc) - if err := enc.Flush(); err != nil { - panic(err) - } - // hash without the signature - sigHash := types.HashBytes(buf.Bytes()) - // sign - sig := priv.SignHash(sigHash) - sig.EncodeTo(enc) - if err := enc.Flush(); err != nil { - panic(err) - } - return buf.Bytes() -} - // NewConfigManager initializes a new config manager -func NewConfigManager(opts ...Option) (*ConfigManager, error) { +func NewConfigManager(hostKey types.PrivateKey, store Store, cm ChainManager, s Syncer, wm Wallet, opts ...Option) (*ConfigManager, error) { m := &ConfigManager{ + announceInterval: 144 * 90, // 90 days + validateNetAddress: true, + hostKey: hostKey, + initialSettings: DefaultSettings, + + store: store, + chain: cm, + syncer: s, + wallet: wm, + log: zap.NewNop(), + a: alerts.NewNop(), tg: threadgroup.New(), // initialize the rate limiters @@ -299,42 +286,13 @@ func NewConfigManager(opts ...Option) (*ConfigManager, error) { panic("host key invalid") } - if err := m.reloadCertificates(); err != nil { - return nil, fmt.Errorf("failed to load rhp3 WebSocket certificates: %w", err) - } - settings, err := m.store.Settings() if errors.Is(err, ErrNoSettings) { - if err := m.store.UpdateSettings(DefaultSettings); err != nil { - return nil, fmt.Errorf("failed to initialize settings: %w", err) - } - settings = DefaultSettings // use the default settings + settings = m.initialSettings } else if err != nil { return nil, fmt.Errorf("failed to load settings: %w", err) } - lastChange, height, err := m.store.LastSettingsConsensusChange() - if err != nil { - return nil, fmt.Errorf("failed to load last settings consensus change: %w", err) - } - m.scanHeight = height - - go func() { - // subscribe to consensus changes - err := m.cm.Subscribe(m, lastChange, m.tg.Done()) - if errors.Is(err, chain.ErrInvalidChangeID) { - m.log.Warn("rescanning blockchain due to unknown consensus change ID") - // reset change ID and subscribe again - if err := m.store.RevertLastAnnouncement(); err != nil { - m.log.Fatal("failed to reset wallet", zap.Error(err)) - } else if err = m.cm.Subscribe(m, modules.ConsensusChangeBeginning, m.tg.Done()); err != nil { - m.log.Fatal("failed to reset consensus change subscription", zap.Error(err)) - } - } else if err != nil && !strings.Contains(err.Error(), "ThreadGroup already stopped") { - m.log.Fatal("failed to subscribe to consensus changes", zap.Error(err)) - } - }() - m.settings = settings // update the global rate limiters from settings m.setRateLimit(settings.IngressLimit, settings.EgressLimit) diff --git a/host/settings/settings_test.go b/host/settings/settings_test.go index d5d6141d..5faa789c 100644 --- a/host/settings/settings_test.go +++ b/host/settings/settings_test.go @@ -1,66 +1,70 @@ package settings_test import ( - "path/filepath" "reflect" "testing" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/settings" - "go.sia.tech/hostd/internal/test" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/webhooks" + "go.sia.tech/hostd/host/storage" + "go.sia.tech/hostd/index" + "go.sia.tech/hostd/internal/testutil" "go.uber.org/zap/zaptest" - "lukechampine.com/frand" ) func TestSettings(t *testing.T) { - hostKey := types.NewPrivateKeyFromSeed(frand.Bytes(32)) - dir := t.TempDir() log := zaptest.NewLogger(t) - node, err := test.NewWallet(hostKey, dir, log.Named("wallet")) + network, genesisBlock := testutil.V1Network() + hostKey := types.GeneratePrivateKey() + + node := testutil.NewConsensusNode(t, network, genesisBlock, log) + + // TODO: its unfortunate that all these managers need to be created just to + // test the auto-announce feature. + wm, err := wallet.NewSingleAddressWallet(hostKey, node.Chain, node.Store) if err != nil { - t.Fatal(err) + t.Fatal("failed to create wallet:", err) } - defer node.Close() + defer wm.Close() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("sqlite")) + vm, err := storage.NewVolumeManager(node.Store, storage.WithLogger(log.Named("storage"))) if err != nil { - t.Fatal(err) + t.Fatal("failed to create volume manager:", err) } - defer db.Close() + defer vm.Close() - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) + contracts, err := contracts.NewManager(node.Store, vm, node.Chain, node.Syncer, wm, contracts.WithRejectAfter(10), contracts.WithRevisionSubmissionBuffer(5), contracts.WithLog(log)) if err != nil { - t.Fatal(err) + t.Fatal("failed to create contracts manager:", err) } + defer contracts.Close() - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - manager, err := settings.NewConfigManager(settings.WithHostKey(hostKey), - settings.WithStore(db), - settings.WithChainManager(node.ChainManager()), - settings.WithTransactionPool(node.TPool()), - settings.WithWallet(node), - settings.WithAlertManager(am), - settings.WithLog(log.Named("settings"))) + sm, err := settings.NewConfigManager(hostKey, node.Store, node.Chain, node.Syncer, wm, settings.WithLog(log.Named("settings")), settings.WithAnnounceInterval(50), settings.WithValidateNetAddress(false)) if err != nil { t.Fatal(err) } - defer manager.Close() + defer sm.Close() + + idx, err := index.NewManager(node.Store, node.Chain, contracts, wm, sm, vm, index.WithLog(log.Named("index")), index.WithBatchSize(0)) // off-by-one + if err != nil { + t.Fatal("failed to create index manager:", err) + } + defer idx.Close() - if !reflect.DeepEqual(manager.Settings(), settings.DefaultSettings) { + if !reflect.DeepEqual(sm.Settings(), settings.DefaultSettings) { t.Fatal("settings not equal to default") } - updated := manager.Settings() + updated := sm.Settings() updated.WindowSize = 100 updated.NetAddress = "localhost:10082" updated.BaseRPCPrice = types.Siacoins(1) - if err := manager.UpdateSettings(updated); err != nil { + if err := sm.UpdateSettings(updated); err != nil { t.Fatal(err) - } else if !reflect.DeepEqual(manager.Settings(), updated) { + } else if !reflect.DeepEqual(sm.Settings(), updated) { t.Fatal("settings not equal to updated") } } diff --git a/host/settings/update.go b/host/settings/update.go new file mode 100644 index 00000000..00e8319d --- /dev/null +++ b/host/settings/update.go @@ -0,0 +1,91 @@ +package settings + +import ( + "fmt" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.uber.org/zap" +) + +// An UpdateStateTx is a transaction that can update the host's announcement +// state. +type UpdateStateTx interface { + LastAnnouncement() (Announcement, error) + RevertLastAnnouncement() error + SetLastAnnouncement(Announcement) error +} + +// UpdateChainState updates the host's announcement state based on the given +// chain updates. +func (cm *ConfigManager) UpdateChainState(tx UpdateStateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate) error { + pk := cm.hostKey.PublicKey() + lastAnnouncement, err := tx.LastAnnouncement() + if err != nil { + return fmt.Errorf("failed to get last announcement: %w", err) + } + + for _, cru := range reverted { + if cru.State.Index == lastAnnouncement.Index { + if err := tx.RevertLastAnnouncement(); err != nil { + return fmt.Errorf("failed to revert last announcement: %w", err) + } + } + } + + var nextAnnouncement *Announcement + for _, cau := range applied { + index := cau.State.Index + + chain.ForEachHostAnnouncement(cau.Block, func(hostKey types.PublicKey, announcement chain.HostAnnouncement) { + if hostKey != pk { + return + } + + nextAnnouncement = &Announcement{ + Address: announcement.NetAddress, + Index: index, + } + }) + } + + if nextAnnouncement == nil { + return nil + } + + if err := tx.SetLastAnnouncement(*nextAnnouncement); err != nil { + return fmt.Errorf("failed to set last announcement: %w", err) + } + cm.log.Debug("announcement confirmed", zap.String("netaddress", nextAnnouncement.Address), zap.Stringer("index", nextAnnouncement.Index)) + return nil +} + +// ProcessActions processes announcement actions based on the given chain index. +func (m *ConfigManager) ProcessActions(index types.ChainIndex) error { + announcement, err := m.store.LastAnnouncement() + if err != nil { + return fmt.Errorf("failed to get last announcement: %w", err) + } + + nextHeight := announcement.Index.Height + m.announceInterval + netaddress := m.Settings().NetAddress + if err := validateNetAddress(netaddress); err != nil { + if m.validateNetAddress { + return nil + } + } + + // check if a new announcement is needed + n := m.chain.TipState().Network + // re-announce if the v2 hardfork has activated and the last announcement was before activation + reannounceV2 := index.Height >= n.HardforkV2.AllowHeight && announcement.Index.Height < n.HardforkV2.AllowHeight + if !reannounceV2 && index.Height < nextHeight && announcement.Address == netaddress { + return nil + } + + // re-announce + if err := m.Announce(); err != nil { + m.log.Warn("failed to announce", zap.Error(err)) + } + return nil +} diff --git a/host/storage/consts_default.go b/host/storage/consts_default.go index 7b9e0bc2..685cb47d 100644 --- a/host/storage/consts_default.go +++ b/host/storage/consts_default.go @@ -2,10 +2,6 @@ package storage -import "time" - const ( resizeBatchSize = 64 // 256 MiB - - cleanupInterval = 15 * time.Minute ) diff --git a/host/storage/options.go b/host/storage/options.go new file mode 100644 index 00000000..6273a940 --- /dev/null +++ b/host/storage/options.go @@ -0,0 +1,27 @@ +package storage + +import "go.uber.org/zap" + +// A VolumeManagerOption configures a VolumeManager. +type VolumeManagerOption func(*VolumeManager) + +// WithLogger sets the logger for the manager. +func WithLogger(l *zap.Logger) VolumeManagerOption { + return func(s *VolumeManager) { + s.log = l + } +} + +// WithAlerter sets the alerter for the manager. +func WithAlerter(a Alerts) VolumeManagerOption { + return func(s *VolumeManager) { + s.alerts = a + } +} + +// WithCacheSize sets the sector cache size for the manager. +func WithCacheSize(cacheSize int) VolumeManagerOption { + return func(s *VolumeManager) { + s.cacheSize = cacheSize + } +} diff --git a/host/storage/storage.go b/host/storage/storage.go index 9eeb191d..5f939fbd 100644 --- a/host/storage/storage.go +++ b/host/storage/storage.go @@ -10,12 +10,10 @@ import ( "time" lru "github.com/hashicorp/golang-lru/v2" - "go.sia.tech/core/consensus" rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/internal/threadgroup" - "go.sia.tech/siad/modules" "go.uber.org/zap" "lukechampine.com/frand" ) @@ -42,12 +40,6 @@ type ( Dismiss(...types.Hash256) } - // A ChainManager is used to get the current consensus state. - ChainManager interface { - TipState() consensus.State - Subscribe(s modules.ConsensusSetSubscriber, ccID modules.ConsensusChangeID, cancel <-chan struct{}) error - } - // A SectorLocation is a location of a sector within a volume. SectorLocation struct { ID int64 @@ -75,18 +67,17 @@ type ( VolumeManager struct { cacheHits uint64 // ensure 64-bit alignment on 32-bit systems cacheMisses uint64 + cacheSize int - a Alerts vs VolumeStore - cm ChainManager - log *zap.Logger recorder *sectorAccessRecorder - tg *threadgroup.ThreadGroup + alerts Alerts + tg *threadgroup.ThreadGroup + log *zap.Logger - mu sync.Mutex // protects the following fields - lastCleanup time.Time - volumes map[int64]*volume + mu sync.Mutex // protects the following fields + volumes map[int64]*volume // changedVolumes tracks volumes that need to be fsynced changedVolumes map[int64]bool cache *lru.Cache[types.Hash256, *[rhp2.SectorSize]byte] // Added cache @@ -138,7 +129,7 @@ func (vm *VolumeManager) loadVolumes() error { } // register an alert - vm.a.Register(alerts.Alert{ + vm.alerts.Register(alerts.Alert{ ID: frand.Entropy256(), Severity: alerts.SeverityError, Message: "Failed to open volume", @@ -213,10 +204,10 @@ func (vm *VolumeManager) growVolume(ctx context.Context, id int64, volume *volum }, Timestamp: time.Now(), } - vm.a.Register(alert) + vm.alerts.Register(alert) // dismiss the alert when the function returns. It is the caller's // responsibility to register a completion alert - defer vm.a.Dismiss(alert.ID) + defer vm.alerts.Dismiss(alert.ID) for current := oldMaxSectors; current < newMaxSectors; current += resizeBatchSize { // stop early if the context is cancelled @@ -243,7 +234,7 @@ func (vm *VolumeManager) growVolume(ctx context.Context, id int64, volume *volum // update the alert alert.Data["currentSectors"] = target - vm.a.Register(alert) + vm.alerts.Register(alert) // sleep to allow other operations to run time.Sleep(time.Millisecond) } @@ -271,10 +262,10 @@ func (vm *VolumeManager) shrinkVolume(ctx context.Context, id int64, volume *vol }, Timestamp: time.Now(), } - vm.a.Register(a) + vm.alerts.Register(a) // dismiss the alert when the function returns. It is the caller's // responsibility to register a completion alert - defer vm.a.Dismiss(a.ID) + defer vm.alerts.Dismiss(a.ID) // migrate any sectors outside of the target range. var migrated int @@ -285,7 +276,7 @@ func (vm *VolumeManager) shrinkVolume(ctx context.Context, id int64, volume *vol migrated++ // update the alert a.Data["migratedSectors"] = migrated - vm.a.Register(a) + vm.alerts.Register(a) return nil }) log.Info("migrated sectors", zap.Int("migrated", migrated), zap.Int("failed", failed)) @@ -323,7 +314,7 @@ func (vm *VolumeManager) shrinkVolume(ctx context.Context, id int64, volume *vol current = target // update the alert a.Data["currentSectors"] = current - vm.a.Register(a) + vm.alerts.Register(a) // sleep to allow other operations to run time.Sleep(time.Millisecond) } @@ -486,7 +477,7 @@ func (vm *VolumeManager) AddVolume(ctx context.Context, localPath string, maxSec alert.Message = "Volume initialized" alert.Severity = alerts.SeverityInfo } - vm.a.Register(alert) + vm.alerts.Register(alert) vol.SetStatus(VolumeStatusReady) select { case result <- err: @@ -576,7 +567,7 @@ func (vm *VolumeManager) RemoveVolume(ctx context.Context, id int64, force bool, if err != nil { alert.Data["error"] = err.Error() } - vm.a.Register(alert) + vm.alerts.Register(alert) } doMigration := func() error { @@ -715,7 +706,7 @@ func (vm *VolumeManager) ResizeVolume(ctx context.Context, id int64, maxSectors alert.Message = "Volume resized" alert.Severity = alerts.SeverityInfo } - vm.a.Register(alert) + vm.alerts.Register(alert) if resetReadOnly { // reset the volume to read-write if err := vm.vs.SetReadOnly(id, false); err != nil { @@ -823,7 +814,7 @@ func (vm *VolumeManager) Read(root types.Hash256) (*[rhp2.SectorSize]byte, error sector, err := v.ReadSector(loc.Index) if err != nil { stats := v.Stats() - vm.a.Register(alerts.Alert{ + vm.alerts.Register(alerts.Alert{ ID: v.alertID("read"), Severity: alerts.SeverityError, Message: "Failed to read sector", @@ -902,7 +893,7 @@ func (vm *VolumeManager) Write(root types.Hash256, data *[rhp2.SectorSize]byte) // write the sector to the volume if err := vol.WriteSector(data, loc.Index); err != nil { stats := vol.Stats() - vm.a.Register(alerts.Alert{ + vm.alerts.Register(alerts.Alert{ ID: vol.alertID("write"), Severity: alerts.SeverityError, Message: "Failed to write sector", @@ -953,26 +944,34 @@ func (vm *VolumeManager) ResizeCache(size uint32) { vm.cache.Resize(int(size)) } -// ProcessConsensusChange is called when the consensus set changes. -func (vm *VolumeManager) ProcessConsensusChange(cc modules.ConsensusChange) { - vm.mu.Lock() - defer vm.mu.Unlock() - delta := time.Since(vm.lastCleanup) - if delta < cleanupInterval { - return +// ProcessActions processes the actions for the given chain index. +func (vm *VolumeManager) ProcessActions(index types.ChainIndex) error { + done, err := vm.tg.Add() + if err != nil { + return err } - vm.lastCleanup = time.Now() + defer done() - go func() { - log := vm.log.Named("cleanup").With(zap.Uint64("height", uint64(cc.BlockHeight))) - if err := vm.vs.ExpireTempSectors(uint64(cc.BlockHeight)); err != nil { - log.Error("failed to expire temp sectors", zap.Error(err)) - } - }() + return vm.vs.ExpireTempSectors(index.Height) } // NewVolumeManager creates a new VolumeManager. -func NewVolumeManager(vs VolumeStore, a Alerts, cm ChainManager, log *zap.Logger, sectorCacheSize uint32) (*VolumeManager, error) { +func NewVolumeManager(vs VolumeStore, opts ...VolumeManagerOption) (*VolumeManager, error) { + vm := &VolumeManager{ + vs: vs, + + log: zap.NewNop(), + alerts: alerts.NewNop(), + tg: threadgroup.New(), + + volumes: make(map[int64]*volume), + changedVolumes: make(map[int64]bool), + } + + for _, opt := range opts { + opt(vm) + } + // Initialize cache with LRU eviction and a max capacity of 64 cache, err := lru.New[types.Hash256, *[rhp2.SectorSize]byte](64) if err != nil { @@ -980,27 +979,16 @@ func NewVolumeManager(vs VolumeStore, a Alerts, cm ChainManager, log *zap.Logger } // resize the cache, prevents an error in lru.New when initializing the // cache to 0 - cache.Resize(int(sectorCacheSize)) + cache.Resize(int(vm.cacheSize)) + vm.cache = cache - vm := &VolumeManager{ - vs: vs, - a: a, - cm: cm, - log: log, - recorder: §orAccessRecorder{ - store: vs, - log: log.Named("recorder"), - }, - - volumes: make(map[int64]*volume), - changedVolumes: make(map[int64]bool), - cache: cache, - tg: threadgroup.New(), + vm.recorder = §orAccessRecorder{ + store: vs, + log: vm.log.Named("recorder"), } + if err := vm.loadVolumes(); err != nil { return nil, err - } else if err := vm.cm.Subscribe(vm, modules.ConsensusChangeRecent, vm.tg.Done()); err != nil { - return nil, fmt.Errorf("failed to subscribe to consensus set: %w", err) } go vm.recorder.Run(vm.tg.Done()) return vm, nil diff --git a/host/storage/storage_test.go b/host/storage/storage_test.go index b4717929..d4f69b71 100644 --- a/host/storage/storage_test.go +++ b/host/storage/storage_test.go @@ -12,20 +12,12 @@ import ( rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/chain" "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/webhooks" - "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) -const sectorCacheSize = 64 - func checkFileSize(fp string, expectedSize int64) error { stat, err := os.Stat(fp) if err != nil { @@ -47,40 +39,7 @@ func TestVolumeLoad(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -113,7 +72,7 @@ func TestVolumeLoad(t *testing.T) { } // reopen the volume manager - vm, err = storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err = storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -163,40 +122,7 @@ func TestAddVolume(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -242,41 +168,8 @@ func TestRemoveVolume(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), 0) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -416,41 +309,8 @@ func TestRemoveCorrupt(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), 0) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -622,41 +482,8 @@ func TestRemoveMissing(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), 0) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -715,7 +542,7 @@ func TestRemoveMissing(t *testing.T) { } // reload the volume manager - vm, err = storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err = storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -805,41 +632,8 @@ func TestVolumeConcurrency(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), 0) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -971,41 +765,8 @@ func TestVolumeGrow(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), 0) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -1096,41 +857,8 @@ func TestVolumeShrink(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), 0) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -1265,41 +993,8 @@ func TestVolumeManagerReadWrite(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), 0) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { t.Fatal(err) } @@ -1391,41 +1086,8 @@ func TestSectorCache(t *testing.T) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - t.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - if err != nil { - t.Fatal(err) - } - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - t.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - t.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectors/2) // cache half the sectors + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes")), storage.WithCacheSize(sectors/2)) // cache half the sectors if err != nil { t.Fatal(err) } @@ -1529,39 +1191,8 @@ func BenchmarkVolumeManagerWrite(b *testing.B) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - b.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - b.Fatal(err) - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - b.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - b.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - b.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { b.Fatal(err) } @@ -1611,39 +1242,8 @@ func BenchmarkNewVolume(b *testing.B) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - b.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - b.Fatal(err) - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - b.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - b.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - b.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { b.Fatal(err) } @@ -1676,39 +1276,8 @@ func BenchmarkVolumeManagerRead(b *testing.B) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - b.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - b.Fatal(err) - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - b.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - b.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - b.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { b.Fatal(err) } @@ -1760,39 +1329,8 @@ func BenchmarkVolumeRemove(b *testing.B) { } defer db.Close() - g, err := gateway.New(":0", false, filepath.Join(dir, "gateway")) - if err != nil { - b.Fatal(err) - } - defer g.Close() - - cs, errCh := consensus.New(g, false, filepath.Join(dir, "consensus")) - select { - case err := <-errCh: - b.Fatal(err) - default: - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - b.Fatal(err) - } - defer tp.Close() - - cm, err := chain.NewManager(cs, chain.NewTPool(tp)) - if err != nil { - b.Fatal(err) - } - defer cm.Close() - // initialize the storage manager - webhookReporter, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - b.Fatal(err) - } - - am := alerts.NewManager(webhookReporter, log.Named("alerts")) - vm, err := storage.NewVolumeManager(db, am, cm, log.Named("volumes"), sectorCacheSize) + vm, err := storage.NewVolumeManager(db, storage.WithLogger(log.Named("volumes"))) if err != nil { b.Fatal(err) } diff --git a/index/manager.go b/index/manager.go new file mode 100644 index 00000000..6b294c0f --- /dev/null +++ b/index/manager.go @@ -0,0 +1,160 @@ +package index + +import ( + "context" + "errors" + "fmt" + "sync" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/host/settings" + "go.sia.tech/hostd/internal/threadgroup" + "go.uber.org/zap" +) + +type ( + // A Store is a persistent store for the index manager. + Store interface { + // ResetChainState resets the consensus state of the store. This + // should only occur if the user has reset their consensus database to + // sync from scratch. + ResetChainState() error + UpdateChainState(func(UpdateTx) error) error + Tip() (types.ChainIndex, error) + } + + // A ContractManager manages the lifecycle and state of storage contracts + ContractManager interface { + // UpdateChainState atomically updates the blockchain state of the + // contract manager. + UpdateChainState(tx contracts.UpdateStateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate) error + // ProcessActions is called after the chain state is updated to + // trigger additional actions related to contract lifecycle management. + // This operation should not assume that it will be called for every + // block. During syncing, it will be called at most once per batch. + ProcessActions(index types.ChainIndex) error + } + + // A SettingsManager manages the host's settings and announcements + SettingsManager interface { + // UpdateChainState atomically updates the blockchain state of the + // settings manager. + UpdateChainState(tx settings.UpdateStateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate) error + // ProcessActions is called after the chain state is updated to + // trigger additional actions related to settings management. + // This operation should not assume that it will be called for every + // block. During syncing, it will be called at most once per batch. + ProcessActions(index types.ChainIndex) error + } + + // A WalletManager manages the host's UTXOs and balance + WalletManager interface { + // UpdateChainState atomically updates the blockchain state of the + // wallet manager. + UpdateChainState(tx wallet.UpdateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate) error + } + + // A VolumeManager manages the host's storage volumes + VolumeManager interface { + // ProcessActions is called to trigger additional actions related to + // volume management. This operation should not assume that it will be + // called for every block. During syncing, it will be called at most + // once per batch + ProcessActions(index types.ChainIndex) error + } + + // A Manager manages the state of the blockchain and indexes consensus + // changes. + Manager struct { + updateBatchSize int + + chain *chain.Manager + store Store + + contracts ContractManager + settings SettingsManager + wallet WalletManager + volumes VolumeManager + + tg *threadgroup.ThreadGroup + log *zap.Logger + + mu sync.Mutex // protects the fields below + index types.ChainIndex + } +) + +// Close stops the manager. +func (m *Manager) Close() { + m.tg.Stop() +} + +// Tip returns the current chain index. +func (m *Manager) Tip() types.ChainIndex { + m.mu.Lock() + defer m.mu.Unlock() + return m.index +} + +// NewManager creates a new Manager. +func NewManager(store Store, chain *chain.Manager, contracts ContractManager, wallet WalletManager, settings SettingsManager, volumes VolumeManager, opts ...Option) (*Manager, error) { + index, err := store.Tip() + if err != nil { + return nil, fmt.Errorf("failed to get last indexed tip: %w", err) + } + + m := &Manager{ + updateBatchSize: 100, + + chain: chain, + store: store, + + contracts: contracts, + wallet: wallet, + settings: settings, + volumes: volumes, + + index: index, + + tg: threadgroup.New(), + log: zap.NewNop(), + } + for _, opt := range opts { + opt(m) + } + + reorgCh := make(chan struct{}, 1) + reorgCh <- struct{}{} // trigger initial sync + stop := m.chain.OnReorg(func(index types.ChainIndex) { + select { + case reorgCh <- struct{}{}: + default: + } + }) + + go func() { + defer stop() + + ctx, cancel, err := m.tg.AddContext(context.Background()) + if err != nil { + m.log.Error("failed to add context", zap.Error(err)) + return + } + defer cancel() + + for { + select { + case <-ctx.Done(): + return + case <-reorgCh: + if err := m.syncDB(ctx); err != nil && !errors.Is(err, context.Canceled) { + m.log.Error("failed to sync database", zap.Error(err)) + } + } + } + }() + return m, nil +} diff --git a/index/options.go b/index/options.go new file mode 100644 index 00000000..a98402d8 --- /dev/null +++ b/index/options.go @@ -0,0 +1,20 @@ +package index + +import "go.uber.org/zap" + +// An Option is a functional option for the Manager. +type Option func(*Manager) + +// WithLog sets the logger for the Manager. +func WithLog(l *zap.Logger) Option { + return func(m *Manager) { + m.log = l + } +} + +// WithBatchSize sets the batch size for chain updates. +func WithBatchSize(bs int) Option { + return func(m *Manager) { + m.updateBatchSize = bs + } +} diff --git a/index/update.go b/index/update.go new file mode 100644 index 00000000..32560142 --- /dev/null +++ b/index/update.go @@ -0,0 +1,92 @@ +package index + +import ( + "context" + "fmt" + "strings" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/host/settings" + "go.uber.org/zap" +) + +// An UpdateTx is a transaction that atomically updates the state of the +// index manager. +type UpdateTx interface { + wallet.UpdateTx + contracts.UpdateStateTx + settings.UpdateStateTx + + SetLastIndex(types.ChainIndex) error +} + +func (m *Manager) syncDB(ctx context.Context) error { + log := m.log.Named("sync") + index := m.Tip() + for index != m.chain.Tip() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + reverted, applied, err := m.chain.UpdatesSince(index, m.updateBatchSize) + if err != nil && strings.Contains(err.Error(), "missing block at index") { + log.Warn("missing block at index, resetting chain state") + // reset the consensus state. Should delete all chain related + // state from the store + if err := m.store.ResetChainState(); err != nil { + return fmt.Errorf("failed to reset consensus state: %w", err) + } + // zero out the index to force a full resync + m.mu.Lock() + m.index = types.ChainIndex{} + m.mu.Unlock() + return nil + } else if err != nil { + return fmt.Errorf("failed to get updates since %v: %w", index, err) + } else if len(reverted) == 0 && len(applied) == 0 { + return nil + } + + err = m.store.UpdateChainState(func(tx UpdateTx) error { + if err := m.wallet.UpdateChainState(tx, reverted, applied); err != nil { + return fmt.Errorf("failed to update wallet state: %w", err) + } else if err := m.contracts.UpdateChainState(tx, reverted, applied); err != nil { + return fmt.Errorf("failed to update contract state: %w", err) + } else if err := m.settings.UpdateChainState(tx, reverted, applied); err != nil { + return fmt.Errorf("failed to update settings state: %w", err) + } + + if len(applied) > 0 { + index = applied[len(applied)-1].State.Index + } else { + index = reverted[len(reverted)-1].State.Index + } + + if err := tx.SetLastIndex(index); err != nil { + return fmt.Errorf("failed to set last index: %w", err) + } + return nil + }) + if err != nil { + return fmt.Errorf("failed to update chain state: %w", err) + } + + if err := m.contracts.ProcessActions(index); err != nil { + return fmt.Errorf("failed to process contract actions: %w", err) + } else if err := m.volumes.ProcessActions(index); err != nil { + return fmt.Errorf("failed to process storage actions: %w", err) + } else if err := m.settings.ProcessActions(index); err != nil { + return fmt.Errorf("failed to process settings actions: %w", err) + } + + m.mu.Lock() + m.index = index + m.mu.Unlock() + log.Debug("synced to new chain index", zap.Stringer("index", index)) + } + return nil +} diff --git a/internal/chain/manager.go b/internal/chain/manager.go deleted file mode 100644 index 209b5dfd..00000000 --- a/internal/chain/manager.go +++ /dev/null @@ -1,186 +0,0 @@ -package chain - -import ( - "bytes" - "errors" - "fmt" - "strings" - "sync" - "time" - - "gitlab.com/NebulousLabs/encoding" - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/hostd/build" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" -) - -const maxSyncTime = time.Hour - -var ( - // ErrBlockNotFound is returned when a block is not found. - ErrBlockNotFound = errors.New("block not found") - // ErrInvalidChangeID is returned to a subscriber when the change id is - // invalid. - ErrInvalidChangeID = errors.New("invalid change id") -) - -func convertToSiad(core types.EncoderTo, siad encoding.SiaUnmarshaler) { - var buf bytes.Buffer - e := types.NewEncoder(&buf) - core.EncodeTo(e) - e.Flush() - if err := siad.UnmarshalSia(&buf); err != nil { - panic(err) - } -} - -func convertToCore(siad encoding.SiaMarshaler, core types.DecoderFrom) { - var buf bytes.Buffer - siad.MarshalSia(&buf) - d := types.NewBufDecoder(buf.Bytes()) - core.DecodeFrom(d) - if d.Err() != nil { - panic(d.Err()) - } -} - -// A Manager manages the current state of the blockchain. -type Manager struct { - cs modules.ConsensusSet - tp *TransactionPool - network *consensus.Network - - close chan struct{} - mu sync.Mutex - tip consensus.State - synced bool -} - -// PoolTransactions returns all transactions in the transaction pool -func (m *Manager) PoolTransactions() []types.Transaction { - return m.tp.Transactions() -} - -// ProcessConsensusChange implements the modules.ConsensusSetSubscriber interface. -func (m *Manager) ProcessConsensusChange(cc modules.ConsensusChange) { - m.mu.Lock() - defer m.mu.Unlock() - m.tip = consensus.State{ - Network: m.network, - Index: types.ChainIndex{ - ID: types.BlockID(cc.AppliedBlocks[len(cc.AppliedBlocks)-1].ID()), - Height: uint64(cc.BlockHeight), - }, - } - m.synced = synced(cc.AppliedBlocks[len(cc.AppliedBlocks)-1].Timestamp) -} - -// Network returns the network name. -func (m *Manager) Network() string { - switch m.network.Name { - case "zen": - return "Zen Testnet" - case "mainnet": - return "Mainnet" - default: - return m.network.Name - } -} - -// Close closes the chain manager. -func (m *Manager) Close() error { - select { - case <-m.close: - return nil - default: - } - close(m.close) - return m.cs.Close() -} - -// Synced returns true if the chain manager is synced with the consensus set. -func (m *Manager) Synced() bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.synced -} - -// BlockAtHeight returns the block at the given height. -func (m *Manager) BlockAtHeight(height uint64) (types.Block, bool) { - sb, ok := m.cs.BlockAtHeight(stypes.BlockHeight(height)) - var c types.Block - convertToCore(sb, (*types.V1Block)(&c)) - return types.Block(c), ok -} - -// IndexAtHeight return the chain index at the given height. -func (m *Manager) IndexAtHeight(height uint64) (types.ChainIndex, error) { - block, ok := m.cs.BlockAtHeight(stypes.BlockHeight(height)) - if !ok { - return types.ChainIndex{}, ErrBlockNotFound - } - return types.ChainIndex{ - ID: types.BlockID(block.ID()), - Height: height, - }, nil -} - -// TipState returns the current chain state. -func (m *Manager) TipState() consensus.State { - m.mu.Lock() - defer m.mu.Unlock() - return m.tip -} - -// AcceptBlock adds b to the consensus set. -func (m *Manager) AcceptBlock(b types.Block) error { - var sb stypes.Block - convertToSiad(types.V1Block(b), &sb) - return m.cs.AcceptBlock(sb) -} - -// Subscribe subscribes to the consensus set. -func (m *Manager) Subscribe(s modules.ConsensusSetSubscriber, ccID modules.ConsensusChangeID, cancel <-chan struct{}) error { - if err := m.cs.ConsensusSetSubscribe(s, ccID, cancel); err != nil { - if strings.Contains(err.Error(), "consensus subscription has invalid id") { - return ErrInvalidChangeID - } - return err - } - return nil -} - -func synced(timestamp stypes.Timestamp) bool { - return time.Since(time.Unix(int64(timestamp), 0)) <= maxSyncTime -} - -// NewManager creates a new chain manager. -func NewManager(cs modules.ConsensusSet, tp *TransactionPool) (*Manager, error) { - height := cs.Height() - block, ok := cs.BlockAtHeight(height) - if !ok { - return nil, fmt.Errorf("failed to get block at height %d", height) - } - n, _ := build.Network() - m := &Manager{ - cs: cs, - tp: tp, - network: n, - tip: consensus.State{ - Network: n, - Index: types.ChainIndex{ - ID: types.BlockID(block.ID()), - Height: uint64(height), - }, - }, - synced: synced(block.Timestamp), - close: make(chan struct{}), - } - - if err := cs.ConsensusSetSubscribe(m, modules.ConsensusChangeRecent, m.close); err != nil { - return nil, fmt.Errorf("failed to subscribe to consensus set: %w", err) - } - return m, nil -} diff --git a/internal/chain/tpool.go b/internal/chain/tpool.go deleted file mode 100644 index 698c9d62..00000000 --- a/internal/chain/tpool.go +++ /dev/null @@ -1,76 +0,0 @@ -package chain - -import ( - "go.sia.tech/core/types" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" -) - -// TransactionPool wraps the siad transaction pool with a more convenient API. -type TransactionPool struct { - tp modules.TransactionPool -} - -// RecommendedFee returns the recommended fee per byte. -func (tp *TransactionPool) RecommendedFee() (fee types.Currency) { - _, maxFee := tp.tp.FeeEstimation() - convertToCore(&maxFee, (*types.V1Currency)(&fee)) - return -} - -// Transactions returns the transactions in the transaction pool. -func (tp *TransactionPool) Transactions() []types.Transaction { - stxns := tp.tp.Transactions() - txns := make([]types.Transaction, len(stxns)) - for i := range txns { - convertToCore(&stxns[i], &txns[i]) - } - return txns -} - -// AcceptTransactionSet adds a transaction set to the tpool and broadcasts it to -// the network. -func (tp *TransactionPool) AcceptTransactionSet(txns []types.Transaction) error { - stxns := make([]stypes.Transaction, len(txns)) - for i := range stxns { - convertToSiad(&txns[i], &stxns[i]) - } - return tp.tp.AcceptTransactionSet(stxns) -} - -// UnconfirmedParents returns the unconfirmed parents of a transaction. -func (tp *TransactionPool) UnconfirmedParents(txn types.Transaction) ([]types.Transaction, error) { - pool := tp.Transactions() - outputToParent := make(map[types.SiacoinOutputID]*types.Transaction) - for i, txn := range pool { - for j := range txn.SiacoinOutputs { - outputToParent[txn.SiacoinOutputID(j)] = &pool[i] - } - } - var parents []types.Transaction - seen := make(map[types.TransactionID]bool) - for _, sci := range txn.SiacoinInputs { - if parent, ok := outputToParent[sci.ParentID]; ok { - if txid := parent.ID(); !seen[txid] { - seen[txid] = true - parents = append(parents, *parent) - } - } - } - return parents, nil -} - -// Subscribe subscribes to the transaction pool. -func (tp *TransactionPool) Subscribe(s modules.TransactionPoolSubscriber) { - tp.tp.TransactionPoolSubscribe(s) -} - -// Close closes the transaction pool. -func (tp *TransactionPool) Close() error { - return tp.tp.Close() -} - -// NewTPool wraps a siad transaction pool with a more convenient API. -func NewTPool(tp modules.TransactionPool) *TransactionPool { - return &TransactionPool{tp} -} diff --git a/internal/test/host.go b/internal/test/host.go deleted file mode 100644 index 18f79e52..00000000 --- a/internal/test/host.go +++ /dev/null @@ -1,288 +0,0 @@ -package test - -import ( - "context" - "fmt" - "net" - "net/http" - "path/filepath" - "time" - - crhp2 "go.sia.tech/core/rhp/v2" - crhp3 "go.sia.tech/core/rhp/v3" - "go.sia.tech/core/types" - "go.sia.tech/hostd/alerts" - "go.sia.tech/hostd/host/accounts" - "go.sia.tech/hostd/host/contracts" - "go.sia.tech/hostd/host/registry" - "go.sia.tech/hostd/host/settings" - "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/rhp" - rhp2 "go.sia.tech/hostd/rhp/v2" - rhp3 "go.sia.tech/hostd/rhp/v3" - "go.sia.tech/hostd/wallet" - "go.sia.tech/hostd/webhooks" - "go.uber.org/zap" -) - -const blocksPerMonth = 144 * 30 - -type stubDataMonitor struct{} - -func (stubDataMonitor) ReadBytes(n int) {} -func (stubDataMonitor) WriteBytes(n int) {} - -// A Host is an ephemeral host that can be used for testing. -type Host struct { - *Node - - privKey types.PrivateKey - store *sqlite.Store - log *zap.Logger - wallet *wallet.SingleAddressWallet - settings *settings.ConfigManager - storage *storage.VolumeManager - registry *registry.Manager - accounts *accounts.AccountManager - contracts *contracts.ContractManager - - rhp2 *rhp2.SessionHandler - rhp3 *rhp3.SessionHandler - rhp3WS net.Listener -} - -// DefaultSettings returns the default settings for the test host -var DefaultSettings = settings.Settings{ - AcceptingContracts: true, - MaxContractDuration: blocksPerMonth * 3, - WindowSize: 144, - MaxCollateral: types.Siacoins(5000), - - ContractPrice: types.Siacoins(1).Div64(4), - - BaseRPCPrice: types.NewCurrency64(100), - SectorAccessPrice: types.NewCurrency64(100), - - CollateralMultiplier: 2.0, - StoragePrice: types.Siacoins(100).Div64(1e12).Div64(blocksPerMonth), - EgressPrice: types.Siacoins(100).Div64(1e12), - IngressPrice: types.Siacoins(100).Div64(1e12), - - PriceTableValidity: 2 * time.Minute, - - AccountExpiry: 30 * 24 * time.Hour, // 1 month - MaxAccountBalance: types.Siacoins(10), - SectorCacheSize: 64, -} - -// Close shutsdown the host -func (h *Host) Close() error { - h.rhp3WS.Close() - h.rhp2.Close() - h.rhp3.Close() - h.settings.Close() - h.wallet.Close() - h.contracts.Close() - h.registry.Close() - h.storage.Close() - h.store.Close() - h.Node.Close() - h.log.Sync() - return nil -} - -// RHP2Addr returns the address of the rhp2 listener -func (h *Host) RHP2Addr() string { - return h.rhp2.LocalAddr() -} - -// RHP3Addr returns the address of the rhp3 listener -func (h *Host) RHP3Addr() string { - return h.rhp3.LocalAddr() -} - -// RHP3WSAddr returns the address of the rhp3 WebSocket listener -func (h *Host) RHP3WSAddr() string { - return h.rhp3WS.Addr().String() -} - -// AddVolume adds a new volume to the host -func (h *Host) AddVolume(path string, size uint64) error { - result := make(chan error, 1) - if _, err := h.storage.AddVolume(context.Background(), path, size, result); err != nil { - return err - } - return <-result -} - -// UpdateSettings updates the host's configuration -func (h *Host) UpdateSettings(settings settings.Settings) error { - return h.settings.UpdateSettings(settings) -} - -// RHP2Settings returns the host's current rhp2 settings -func (h *Host) RHP2Settings() (crhp2.HostSettings, error) { - return h.rhp2.Settings() -} - -// RHP3PriceTable returns the host's current rhp3 price table -func (h *Host) RHP3PriceTable() (crhp3.HostPriceTable, error) { - return h.rhp3.PriceTable() -} - -// WalletAddress returns the host's wallet address -func (h *Host) WalletAddress() types.Address { - return h.wallet.Address() -} - -// Contracts returns the host's contract manager -func (h *Host) Contracts() *contracts.ContractManager { - return h.contracts -} - -// Storage returns the host's storage manager -func (h *Host) Storage() *storage.VolumeManager { - return h.storage -} - -// Settings returns the host's settings manager -func (h *Host) Settings() *settings.ConfigManager { - return h.settings -} - -// PublicKey returns the host's public key -func (h *Host) PublicKey() types.PublicKey { - return h.privKey.PublicKey() -} - -// Accounts returns the host's account manager -func (h *Host) Accounts() *accounts.AccountManager { - return h.accounts -} - -// Store returns the host's database -func (h *Host) Store() *sqlite.Store { - return h.store -} - -// NewHost initializes a new test host -func NewHost(privKey types.PrivateKey, dir string, node *Node, log *zap.Logger) (*Host, error) { - host, err := NewEmptyHost(privKey, dir, node, log) - if err != nil { - return nil, err - } - - result := make(chan error, 1) - if _, err := host.Storage().AddVolume(context.Background(), filepath.Join(dir, "storage.dat"), 64, result); err != nil { - return nil, fmt.Errorf("failed to add storage volume: %w", err) - } else if err := <-result; err != nil { - return nil, fmt.Errorf("failed to add storage volume: %w", err) - } - s := DefaultSettings - s.NetAddress = host.RHP2Addr() - if err := host.Settings().UpdateSettings(s); err != nil { - return nil, fmt.Errorf("failed to update host settings: %w", err) - } - return host, nil -} - -// NewEmptyHost initializes a new test host -func NewEmptyHost(privKey types.PrivateKey, dir string, node *Node, log *zap.Logger) (*Host, error) { - db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.db"), log.Named("sqlite")) - if err != nil { - return nil, fmt.Errorf("failed to create sql store: %w", err) - } - - wallet, err := wallet.NewSingleAddressWallet(privKey, node.cm, db, log.Named("wallet")) - if err != nil { - return nil, fmt.Errorf("failed to create wallet: %w", err) - } - - wr, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - return nil, fmt.Errorf("failed to create webhook reporter: %w", err) - } - - am := alerts.NewManager(wr, log.Named("alerts")) - storage, err := storage.NewVolumeManager(db, am, node.cm, log.Named("storage"), DefaultSettings.SectorCacheSize) - if err != nil { - return nil, fmt.Errorf("failed to create storage manager: %w", err) - } - - contracts, err := contracts.NewManager(db, am, storage, node.cm, node.tp, wallet, log.Named("contracts")) - if err != nil { - return nil, fmt.Errorf("failed to create contract manager: %w", err) - } - - rhp2Listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - return nil, fmt.Errorf("failed to create rhp2 listener: %w", err) - } - - rhp3Listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - return nil, fmt.Errorf("failed to create rhp2 listener: %w", err) - } - - settings, err := settings.NewConfigManager(settings.WithHostKey(privKey), - settings.WithStore(db), - settings.WithChainManager(node.ChainManager()), - settings.WithTransactionPool(node.TPool()), - settings.WithWallet(wallet), - settings.WithAlertManager(am), - settings.WithLog(log.Named("settings"))) - if err != nil { - return nil, fmt.Errorf("failed to create settings manager: %w", err) - } - - registry := registry.NewManager(privKey, db, log.Named("registry")) - accounts := accounts.NewManager(db, settings) - - sessions := rhp.NewSessionReporter() - - rhp2, err := rhp2.NewSessionHandler(rhp2Listener, privKey, rhp3Listener.Addr().String(), node.cm, node.tp, wallet, contracts, settings, storage, stubDataMonitor{}, sessions, log.Named("rhp2")) - if err != nil { - return nil, fmt.Errorf("failed to create rhp2 session handler: %w", err) - } - go rhp2.Serve() - - rhp3, err := rhp3.NewSessionHandler(rhp3Listener, privKey, node.cm, node.tp, wallet, accounts, contracts, registry, storage, settings, stubDataMonitor{}, sessions, log.Named("rhp3")) - if err != nil { - return nil, fmt.Errorf("failed to create rhp3 session handler: %w", err) - } - go rhp3.Serve() - - rhp3WSListener, err := net.Listen("tcp", "localhost:0") - if err != nil { - return nil, fmt.Errorf("failed to create rhp3 websocket listener: %w", err) - } - - go func() { - rhp3WS := http.Server{ - Handler: rhp3.WebSocketHandler(), - ReadTimeout: 30 * time.Second, - } - - if err := rhp3WS.Serve(rhp3WSListener); err != nil { - return - } - }() - - return &Host{ - Node: node, - privKey: privKey, - store: db, - log: log, - wallet: wallet, - settings: settings, - storage: storage, - registry: registry, - accounts: accounts, - contracts: contracts, - - rhp2: rhp2, - rhp3: rhp3, - rhp3WS: rhp3WSListener, - }, nil -} diff --git a/internal/test/miner.go b/internal/test/miner.go deleted file mode 100644 index 374f3c0c..00000000 --- a/internal/test/miner.go +++ /dev/null @@ -1,150 +0,0 @@ -package test - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "sync" - "time" - - "go.sia.tech/core/types" - "go.sia.tech/siad/crypto" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" - "lukechampine.com/frand" -) - -const solveAttempts = 1e4 - -type ( - // Consensus defines a minimal interface needed by the miner to interact - // with the consensus set - Consensus interface { - AcceptBlock(types.Block) error - } - - // A Miner is a CPU miner that can mine blocks, sending the reward to a - // specified address. - Miner struct { - consensus Consensus - - mu sync.Mutex - height stypes.BlockHeight - target stypes.Target - currentBlockID stypes.BlockID - txnsets map[modules.TransactionSetID][]stypes.TransactionID - transactions []stypes.Transaction - } -) - -var errFailedToSolve = errors.New("failed to solve block") - -// ProcessConsensusChange implements modules.ConsensusSetSubscriber. -func (m *Miner) ProcessConsensusChange(cc modules.ConsensusChange) { - m.mu.Lock() - defer m.mu.Unlock() - m.target = cc.ChildTarget - m.currentBlockID = cc.AppliedBlocks[len(cc.AppliedBlocks)-1].ID() - m.height = cc.BlockHeight -} - -// ReceiveUpdatedUnconfirmedTransactions implements modules.TransactionPoolSubscriber -func (m *Miner) ReceiveUpdatedUnconfirmedTransactions(diff *modules.TransactionPoolDiff) { - m.mu.Lock() - defer m.mu.Unlock() - - reverted := make(map[stypes.TransactionID]bool) - for _, setID := range diff.RevertedTransactions { - for _, txnID := range m.txnsets[setID] { - reverted[txnID] = true - } - } - - filtered := m.transactions[:0] - for _, txn := range m.transactions { - if reverted[txn.ID()] { - continue - } - filtered = append(filtered, txn) - } - - for _, txnset := range diff.AppliedTransactions { - m.txnsets[txnset.ID] = txnset.IDs - filtered = append(filtered, txnset.Transactions...) - } - m.transactions = filtered -} - -// mineBlock attempts to mine a block and add it to the consensus set. -func (m *Miner) mineBlock(addr stypes.UnlockHash) error { - m.mu.Lock() - block := stypes.Block{ - ParentID: m.currentBlockID, - Timestamp: stypes.CurrentTimestamp(), - } - - randBytes := frand.Bytes(stypes.SpecifierLen) - randTxn := stypes.Transaction{ - ArbitraryData: [][]byte{append(modules.PrefixNonSia[:], randBytes...)}, - } - block.Transactions = append([]stypes.Transaction{randTxn}, m.transactions...) - block.MinerPayouts = append(block.MinerPayouts, stypes.SiacoinOutput{ - Value: block.CalculateSubsidy(m.height + 1), - UnlockHash: addr, - }) - target := m.target - m.mu.Unlock() - - merkleRoot := block.MerkleRoot() - header := make([]byte, 80) - copy(header, block.ParentID[:]) - binary.LittleEndian.PutUint64(header[40:48], uint64(block.Timestamp)) - copy(header[48:], merkleRoot[:]) - - var nonce uint64 - var solved bool - for i := 0; i < solveAttempts; i++ { - id := crypto.HashBytes(header) - if bytes.Compare(target[:], id[:]) >= 0 { - block.Nonce = *(*stypes.BlockNonce)(header[32:40]) - solved = true - break - } - binary.LittleEndian.PutUint64(header[32:], nonce) - nonce += stypes.ASICHardforkFactor - } - if !solved { - return errFailedToSolve - } - - var b types.Block - convertToCore(&block, (*types.V1Block)(&b)) - if err := m.consensus.AcceptBlock(b); err != nil { - return fmt.Errorf("failed to get block accepted: %w", err) - } - return nil -} - -// Mine mines n blocks, sending the reward to addr -func (m *Miner) Mine(addr types.Address, n int) error { - var err error - for mined := 1; mined <= n; { - // return the error only if the miner failed to solve the block, - // ignore any consensus related errors - if err = m.mineBlock(stypes.UnlockHash(addr)); errors.Is(err, errFailedToSolve) { - return fmt.Errorf("failed to mine block %v: %w", mined, errFailedToSolve) - } - mined++ - time.Sleep(time.Millisecond) - } - return nil -} - -// NewMiner initializes a new CPU miner -func NewMiner(consensus Consensus) *Miner { - return &Miner{ - consensus: consensus, - txnsets: make(map[modules.TransactionSetID][]stypes.TransactionID), - } -} diff --git a/internal/test/miner_test.go b/internal/test/miner_test.go deleted file mode 100644 index ea45fa18..00000000 --- a/internal/test/miner_test.go +++ /dev/null @@ -1,273 +0,0 @@ -//go:build ignore - -package test - -import ( - "encoding/json" - "fmt" - "path/filepath" - "sync" - "testing" - - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/hostd/wallet" - "go.sia.tech/siad/modules" - mconsensus "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" - stypes "go.sia.tech/siad/types" - "lukechampine.com/frand" -) - -// A siacoinElement groups a SiacoinOutput and its ID together -type siacoinElement struct { - types.SiacoinOutput - ID types.SiacoinOutputID -} - -// A testWallet provides very basic wallet functionality for testing the miner -type testWallet struct { - priv types.PrivateKey - - mu sync.Mutex - height uint64 - spent map[types.SiacoinOutputID]bool - spendable map[types.SiacoinOutputID]siacoinElement -} - -// UnlockConditions is a helper to return the standard unlock conditions using -// the wallet's private key -func (tw *testWallet) UnlockConditions() types.UnlockConditions { - return wallet.StandardUnlockConditions(tw.priv.PublicKey()) -} - -// Address returns the address of the wallet -func (tw *testWallet) Address() types.Address { - return tw.UnlockConditions().UnlockHash() -} - -// Balance returns the balance of the wallet -func (tw *testWallet) Balance() types.Currency { - tw.mu.Lock() - defer tw.mu.Unlock() - - var balance types.Currency - for _, sco := range tw.spendable { - if tw.spent[sco.ID] { - continue - } - balance = balance.Add(sco.Value) - } - return balance -} - -// FundTransaction adds siacoin inputs worth at least amount to the provided -// transaction. If necessary, a change output will also be added. The inputs -// will not be used again until release is called. -func (tw *testWallet) FundAndSignTransaction(txn *types.Transaction, amount types.Currency) (func(), error) { - tw.mu.Lock() - defer tw.mu.Unlock() - if amount.IsZero() { - return func() {}, nil - } - - var added types.Currency - var spent []siacoinElement - for _, sco := range tw.spendable { - if tw.spent[sco.ID] { - continue - } - - spent = append(spent, sco) - added = added.Add(sco.Value) - if added.Cmp(amount) >= 0 { - break - } - } - // check if the sum of the inputs is greater than the fund amount - if added.Cmp(amount) < 0 { - return nil, fmt.Errorf("not enough funds") - } else if added.Cmp(amount) > 0 { - // add a change output - txn.SiacoinOutputs = append(txn.SiacoinOutputs, types.SiacoinOutput{ - Value: added.Sub(amount), - Address: tw.Address(), - }) - } - - n := len(txn.Signatures) - - // add the spent outputs and signatures to the transaction - for _, sce := range spent { - tw.spent[sce.ID] = true - txn.SiacoinInputs = append(txn.SiacoinInputs, types.SiacoinInput{ - ParentID: sce.ID, - UnlockConditions: tw.UnlockConditions(), - }) - txn.Signatures = append(txn.Signatures, types.TransactionSignature{ - ParentID: types.Hash256(sce.ID), - CoveredFields: types.CoveredFields{WholeTransaction: true}, - }) - } - - // sign all added signatures - for i := n; i < len(txn.Signatures); i++ { - cs := consensus.State{Index: types.ChainIndex{Height: tw.height}} - sig := tw.priv.SignHash(cs.WholeSigHash(*txn, txn.Signatures[i].ParentID, 0, 0, nil)) - txn.Signatures[i].Signature = sig[:] - } - - return func() { - tw.mu.Lock() - defer tw.mu.Unlock() - for _, sce := range spent { - delete(tw.spent, sce.ID) - } - }, nil -} - -// ProcessConsensusChange processes a consensus change - adding new outputs to -// the wallet and removing spent outputs -func (tw *testWallet) ProcessConsensusChange(cc modules.ConsensusChange) { - tw.mu.Lock() - defer tw.mu.Unlock() - - for _, scod := range cc.SiacoinOutputDiffs { - if scod.Direction == modules.DiffApply && types.Address(scod.SiacoinOutput.UnlockHash) == tw.Address() { - var sco types.SiacoinOutput - convertToCore(scod.SiacoinOutput, &sco) - tw.spendable[types.SiacoinOutputID(scod.ID)] = siacoinElement{ - ID: types.SiacoinOutputID(scod.ID), - SiacoinOutput: sco, - } - } else { - delete(tw.spendable, types.SiacoinOutputID(scod.ID)) - } - } - - tw.height = uint64(cc.BlockHeight) -} - -// newTestWallet returns a new test wallet with a random private key. -func newTestWallet() *testWallet { - return &testWallet{ - priv: types.GeneratePrivateKey(), - spent: make(map[types.SiacoinOutputID]bool), - spendable: make(map[types.SiacoinOutputID]siacoinElement), - } -} - -// TestMining tests that cpu mining works as expected and adds spendable outputs -// to the wallet. -func TestMining(t *testing.T) { - dir := t.TempDir() - - g, err := gateway.New("localhost:0", false, filepath.Join(dir, modules.GatewayDir)) - if err != nil { - t.Fatal("could not create gateway:", err) - } - t.Cleanup(func() { g.Close() }) - - cs, errChan := mconsensus.New(g, false, filepath.Join(dir, modules.ConsensusDir)) - if err := <-errChan; err != nil { - t.Fatal("could not create consensus set:", err) - } - go func() { - for err := range errChan { - panic(fmt.Errorf("consensus err: %w", err)) - } - }() - t.Cleanup(func() { cs.Close() }) - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, modules.TransactionPoolDir)) - if err != nil { - t.Fatal("could not create tpool:", err) - } - t.Cleanup(func() { tp.Close() }) - - w := newTestWallet() - if err := cs.ConsensusSetSubscribe(w, modules.ConsensusChangeBeginning, nil); err != nil { - t.Fatal("failed to subscribe to consensus set:", err) - } - - m := NewMiner(cs) - if err := cs.ConsensusSetSubscribe(m, modules.ConsensusChangeBeginning, nil); err != nil { - t.Fatal("failed to subscribe to consensus set:", err) - } - tp.TransactionPoolSubscribe(m) - - // mine a single block - if err := m.Mine(w.Address(), 1); err != nil { - t.Fatal(err) - } - - // make sure the block height is updated - if height := cs.Height(); height != 1 { - t.Fatalf("expected height 1, got %v", height) - } - - // mine until the maturity height of the first payout is reached - if err := m.Mine(w.Address(), int(stypes.MaturityDelay)); err != nil { - t.Fatal(err) - } else if height := cs.Height(); height != stypes.MaturityDelay+1 { - t.Fatalf("expected height %v, got %v", stypes.MaturityDelay+1, height) - } - - // make sure we have the expected balance - siadExpectedBalance := stypes.CalculateCoinbase(1) - var expectedBalance types.Currency - convertToCore(siadExpectedBalance, &expectedBalance) - if balance := w.Balance(); !balance.Equals(expectedBalance) { - t.Fatalf("expected balance to be %v, got %v", expectedBalance, balance) - } - - // mine more blocks until we have lots of outputs - if err := m.Mine(w.Address(), 100); err != nil { - t.Fatal(err) - } - - // add random transactions to the tpool - added := make([]types.TransactionID, 100) - for i := range added { - amount := types.Siacoins(uint32(1 + frand.Intn(1000))) - txn := types.Transaction{ - ArbitraryData: [][]byte{append(modules.PrefixNonSia[:], frand.Bytes(16)...)}, - SiacoinOutputs: []types.SiacoinOutput{ - {Value: amount}, - }, - } - - release, err := w.FundAndSignTransaction(&txn, amount) - if err != nil { - t.Fatal(err) - } - defer release() - - if err := tp.AcceptTransactionSet([]types.Transaction{txn}); err != nil { - buf, _ := json.MarshalIndent(txn, "", " ") - t.Log(string(buf)) - t.Fatalf("failed to accept transaction %v: %v", i, err) - } - - added[i] = txn.ID() - } - - // mine a block to confirm the transactions - if err := m.Mine(w.Address(), 1); err != nil { - t.Fatal(err) - } - - // check that the correct number of transactions are in the block. A random - // transaction is added before all others. - block := cs.CurrentBlock() - if len(block.Transactions) != len(added)+1 { - t.Fatalf("expected %v transactions, got %v", len(added), len(block.Transactions)) - } - // the first transaction in the block should be ignored - for i, txn := range block.Transactions[1:] { - if txn.ID() != added[i] { - t.Fatalf("transaction %v expected ID %v, got %v", i, added[i], txn.ID()) - } - } -} diff --git a/internal/test/node.go b/internal/test/node.go deleted file mode 100644 index fba4551a..00000000 --- a/internal/test/node.go +++ /dev/null @@ -1,159 +0,0 @@ -package test - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "time" - - "gitlab.com/NebulousLabs/encoding" - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/hostd/internal/chain" - "go.sia.tech/siad/modules" - mconsensus "go.sia.tech/siad/modules/consensus" - "go.sia.tech/siad/modules/gateway" - "go.sia.tech/siad/modules/transactionpool" - stypes "go.sia.tech/siad/types" - "go.uber.org/zap" -) - -func convertToCore(siad encoding.SiaMarshaler, core types.DecoderFrom) { - var buf bytes.Buffer - siad.MarshalSia(&buf) - d := types.NewBufDecoder(buf.Bytes()) - core.DecodeFrom(d) - if d.Err() != nil { - panic(d.Err()) - } -} - -type ( - // A Node is a base Sia node that can be extended by a Renter or Host - Node struct { - g modules.Gateway - cs modules.ConsensusSet - cm *chain.Manager - tp *chain.TransactionPool - m *Miner - } -) - -// Close closes the node -func (n *Node) Close() error { - n.tp.Close() - n.cs.Close() - n.g.Close() - return nil -} - -// GatewayAddr returns the address of the gateway -func (n *Node) GatewayAddr() string { - return string(n.g.Address()) -} - -// ConnectPeer connects the host's gateway to a peer -func (n *Node) ConnectPeer(addr string) error { - return n.g.Connect(modules.NetAddress(addr)) -} - -// TipState returns the current consensus state. -func (n *Node) TipState() consensus.State { - return n.cm.TipState() -} - -// MineBlocks mines n blocks sending the reward to address -func (n *Node) MineBlocks(address types.Address, count int) error { - return n.m.Mine(address, count) -} - -// ChainManager returns the chain manager -func (n *Node) ChainManager() *chain.Manager { - return n.cm -} - -// TPool returns the transaction pool -func (n *Node) TPool() *chain.TransactionPool { - return n.tp -} - -// NewNode creates a new Sia node and wallet with the given key -func NewNode(dir string) (*Node, error) { - g, err := gateway.New("localhost:0", false, filepath.Join(dir, "gateway")) - if err != nil { - return nil, fmt.Errorf("failed to create gateway: %w", err) - } - cs, errCh := mconsensus.New(g, false, filepath.Join(dir, "consensus")) - if err := <-errCh; err != nil { - return nil, fmt.Errorf("failed to create consensus set: %w", err) - } - - tp, err := transactionpool.New(cs, g, filepath.Join(dir, "transactionpool")) - if err != nil { - return nil, fmt.Errorf("failed to create transaction pool: %w", err) - } - ctp := chain.NewTPool(tp) - - cm, err := chain.NewManager(cs, ctp) - if err != nil { - return nil, err - } - - m := NewMiner(cm) - if err := cs.ConsensusSetSubscribe(m, modules.ConsensusChangeBeginning, nil); err != nil { - return nil, fmt.Errorf("failed to subscribe miner to consensus set: %w", err) - } - tp.TransactionPoolSubscribe(m) - return &Node{ - g: g, - cs: cs, - cm: cm, - tp: ctp, - m: m, - }, nil -} - -// NewTestingPair creates a new renter and host pair, connects them to each -// other, and funds both wallets. -func NewTestingPair(dir string, log *zap.Logger) (*Renter, *Host, error) { - hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() - - node, err := NewNode(dir) - if err != nil { - return nil, nil, fmt.Errorf("failed to create node: %w", err) - } - - if err := os.MkdirAll(filepath.Join(dir, "host"), 0700); err != nil { - return nil, nil, fmt.Errorf("failed to create host dir: %w", err) - } else if err := os.MkdirAll(filepath.Join(dir, "renter"), 0700); err != nil { - return nil, nil, fmt.Errorf("failed to create renter dir: %w", err) - } - - // initialize the host - host, err := NewHost(hostKey, filepath.Join(dir, "host"), node, log.Named("host")) - if err != nil { - return nil, nil, fmt.Errorf("failed to create host: %w", err) - } - - // initialize the renter - renter, err := NewRenter(renterKey, filepath.Join(dir, "renter"), node, log.Named("renter")) - if err != nil { - return nil, nil, fmt.Errorf("failed to create renter: %w", err) - } - - // mine enough blocks to fund the host's wallet - if err := host.MineBlocks(host.WalletAddress(), int(stypes.MaturityDelay)*2); err != nil { - return nil, nil, fmt.Errorf("failed to mine blocks: %w", err) - } - // small sleep for synchronization - time.Sleep(time.Second) - - // mine enough blocks to fund the renter's wallet - if err := renter.MineBlocks(renter.WalletAddress(), int(stypes.MaturityDelay)*2); err != nil { - return nil, nil, fmt.Errorf("failed to mine blocks: %w", err) - } - // small sleep for synchronization - time.Sleep(time.Second) - return renter, host, nil -} diff --git a/internal/test/renter.go b/internal/test/renter.go deleted file mode 100644 index 076512fe..00000000 --- a/internal/test/renter.go +++ /dev/null @@ -1,220 +0,0 @@ -package test - -import ( - "context" - "fmt" - "net" - "path/filepath" - "time" - - crhp2 "go.sia.tech/core/rhp/v2" - "go.sia.tech/core/types" - rhp2 "go.sia.tech/hostd/internal/test/rhp/v2" - rhp3 "go.sia.tech/hostd/internal/test/rhp/v3" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/wallet" - "go.uber.org/zap" -) - -type ( - // A Renter is an ephemeral renter that can be used for testing - Renter struct { - *Node - - privKey types.PrivateKey - store *sqlite.Store - log *zap.Logger - wallet *wallet.SingleAddressWallet - } -) - -// Close shutsdown the renter -func (r *Renter) Close() error { - r.wallet.Close() - r.store.Close() - r.log.Sync() - r.Node.Close() - return nil -} - -// PrivateKey returns the renter's private key -func (r *Renter) PrivateKey() types.PrivateKey { - return r.privKey -} - -// Wallet returns the renter's wallet -func (r *Renter) Wallet() *wallet.SingleAddressWallet { - return r.wallet -} - -// NewRHP2Session creates a new session, locks a contract, and retrieves the -// host's settings -func (r *Renter) NewRHP2Session(ctx context.Context, hostAddr string, hostKey types.PublicKey, contractID types.FileContractID) (*rhp2.RHP2Session, error) { - t, err := dialTransport(ctx, hostAddr, hostKey) - if err != nil { - return nil, err - } - - session := rhp2.NewRHP2Session(t, r.privKey, crhp2.ContractRevision{}, crhp2.HostSettings{}) - - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - if err := session.Refresh(ctx, 15*time.Second, r.privKey, contractID); err != nil { - return nil, fmt.Errorf("failed to refresh session: %w", err) - } - return session, nil -} - -// NewRHP3Session creates a new session -func (r *Renter) NewRHP3Session(ctx context.Context, hostAddr string, hostKey types.PublicKey) (*rhp3.Session, error) { - return rhp3.NewSession(ctx, hostKey, hostAddr, r.ChainManager(), r.Wallet()) -} - -// Settings returns the host's current settings -func (r *Renter) Settings(ctx context.Context, hostAddr string, hostKey types.PublicKey) (crhp2.HostSettings, error) { - t, err := dialTransport(ctx, hostAddr, hostKey) - if err != nil { - return crhp2.HostSettings{}, fmt.Errorf("failed to create session: %w", err) - } - defer t.Close() - settings, err := rhp2.RPCSettings(ctx, t) - if err != nil { - return crhp2.HostSettings{}, fmt.Errorf("failed to get settings: %w", err) - } - return settings, nil -} - -// FormContract forms a contract with the host -func (r *Renter) FormContract(ctx context.Context, hostAddr string, hostKey types.PublicKey, renterPayout, hostCollateral types.Currency, duration uint64) (crhp2.ContractRevision, error) { - t, err := dialTransport(ctx, hostAddr, hostKey) - if err != nil { - return crhp2.ContractRevision{}, fmt.Errorf("failed to dial transport: %w", err) - } - defer t.Close() - settings, err := rhp2.RPCSettings(ctx, t) - if err != nil { - return crhp2.ContractRevision{}, fmt.Errorf("failed to get host settings: %w", err) - } - cs := r.TipState() - contract := crhp2.PrepareContractFormation(r.privKey.PublicKey(), hostKey, renterPayout, hostCollateral, cs.Index.Height+duration, settings, r.WalletAddress()) - formationCost := crhp2.ContractFormationCost(cs, contract, settings.ContractPrice) - feeEstimate := r.TPool().RecommendedFee().Mul64(2000) - formationTxn := types.Transaction{ - MinerFees: []types.Currency{feeEstimate}, - FileContracts: []types.FileContract{contract}, - } - fundAmount := formationCost.Add(feeEstimate) - - toSign, release, err := r.wallet.FundTransaction(&formationTxn, fundAmount) - if err != nil { - return crhp2.ContractRevision{}, fmt.Errorf("failed to fund transaction: %w", err) - } - - if err := r.wallet.SignTransaction(cs, &formationTxn, toSign, explicitCoveredFields(formationTxn)); err != nil { - release() - return crhp2.ContractRevision{}, fmt.Errorf("failed to sign transaction: %w", err) - } - - revision, _, err := rhp2.RPCFormContract(ctx, t, r.privKey, []types.Transaction{formationTxn}) - if err != nil { - release() - return crhp2.ContractRevision{}, fmt.Errorf("failed to form contract: %w", err) - } - return revision, nil -} - -// WalletAddress returns the renter's wallet address -func (r *Renter) WalletAddress() types.Address { - return r.wallet.Address() -} - -// PublicKey returns the renter's public key -func (r *Renter) PublicKey() types.PublicKey { - return r.privKey.PublicKey() -} - -// dialTransport is a convenience function that connects to the specified -// host -func dialTransport(ctx context.Context, hostIP string, hostKey types.PublicKey) (_ *crhp2.Transport, err error) { - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", hostIP) - if err != nil { - return nil, err - } - done := make(chan struct{}) - go func() { - select { - case <-done: - case <-ctx.Done(): - conn.Close() - } - }() - defer func() { - close(done) - if ctx.Err() != nil { - err = ctx.Err() - } - }() - - t, err := crhp2.NewRenterTransport(conn, hostKey) - if err != nil { - conn.Close() - return nil, err - } - return t, nil -} - -// explicitCoveredFields returns a CoveredFields that covers all elements -// present in txn. -func explicitCoveredFields(txn types.Transaction) (cf types.CoveredFields) { - for i := range txn.SiacoinInputs { - cf.SiacoinInputs = append(cf.SiacoinInputs, uint64(i)) - } - for i := range txn.SiacoinOutputs { - cf.SiacoinOutputs = append(cf.SiacoinOutputs, uint64(i)) - } - for i := range txn.FileContracts { - cf.FileContracts = append(cf.FileContracts, uint64(i)) - } - for i := range txn.FileContractRevisions { - cf.FileContractRevisions = append(cf.FileContractRevisions, uint64(i)) - } - for i := range txn.StorageProofs { - cf.StorageProofs = append(cf.StorageProofs, uint64(i)) - } - for i := range txn.SiafundInputs { - cf.SiafundInputs = append(cf.SiafundInputs, uint64(i)) - } - for i := range txn.SiafundOutputs { - cf.SiafundOutputs = append(cf.SiafundOutputs, uint64(i)) - } - for i := range txn.MinerFees { - cf.MinerFees = append(cf.MinerFees, uint64(i)) - } - for i := range txn.ArbitraryData { - cf.ArbitraryData = append(cf.ArbitraryData, uint64(i)) - } - for i := range txn.Signatures { - cf.Signatures = append(cf.Signatures, uint64(i)) - } - return -} - -// NewRenter creates a new renter for testing -func NewRenter(privKey types.PrivateKey, dir string, node *Node, log *zap.Logger) (*Renter, error) { - db, err := sqlite.OpenDatabase(filepath.Join(dir, "renter.db"), log.Named("sqlite")) - if err != nil { - return nil, fmt.Errorf("failed to create sql store: %w", err) - } - wallet, err := wallet.NewSingleAddressWallet(privKey, node.ChainManager(), db, log.Named("wallet")) - if err != nil { - return nil, fmt.Errorf("failed to create wallet: %w", err) - } - - return &Renter{ - Node: node, - privKey: privKey, - store: db, - log: log, - wallet: wallet, - }, nil -} diff --git a/internal/test/rhp/v2/rhp.go b/internal/test/rhp/v2/rhp.go deleted file mode 100644 index 89a1548b..00000000 --- a/internal/test/rhp/v2/rhp.go +++ /dev/null @@ -1,851 +0,0 @@ -package rhp - -import ( - "bufio" - "context" - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "io" - "math" - "net" - "sort" - "time" - - rhp2 "go.sia.tech/core/rhp/v2" - "go.sia.tech/core/types" -) - -const ( - // minMessageSize is the minimum size of an RPC message - minMessageSize = 4096 -) - -var ( - errContractLocked = errors.New("contract is locked by another party") - errContractFinalized = errors.New("contract cannot be revised further") - errInsufficientCollateral = errors.New("insufficient collateral") - errInsufficientFunds = errors.New("insufficient funds") - errInvalidMerkleProof = errors.New("host supplied invalid Merkle proof") -) - -// RPCSettings calls the Settings RPC, returning the host's reported settings. -func RPCSettings(ctx context.Context, t *rhp2.Transport) (settings rhp2.HostSettings, err error) { - var resp rhp2.RPCSettingsResponse - if err := t.Call(rhp2.RPCSettingsID, nil, &resp); err != nil { - return rhp2.HostSettings{}, err - } else if err := json.Unmarshal(resp.Settings, &settings); err != nil { - return rhp2.HostSettings{}, fmt.Errorf("couldn't unmarshal json: %w", err) - } - - return settings, nil -} - -// RPCFormContract forms a contract with a host. -func RPCFormContract(ctx context.Context, t *rhp2.Transport, renterKey types.PrivateKey, txnSet []types.Transaction) (_ rhp2.ContractRevision, _ []types.Transaction, err error) { - // strip our signatures before sending - parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] - renterContractSignatures := txn.Signatures - txnSet[len(txnSet)-1].Signatures = nil - - // create request - renterPubkey := renterKey.PublicKey() - req := &rhp2.RPCFormContractRequest{ - Transactions: txnSet, - RenterKey: renterPubkey.UnlockKey(), - } - if err := t.WriteRequest(rhp2.RPCFormContractID, req); err != nil { - return rhp2.ContractRevision{}, nil, err - } - - // execute form contract RPC - var resp rhp2.RPCFormContractAdditions - if err := t.ReadResponse(&resp, 65536); err != nil { - return rhp2.ContractRevision{}, nil, err - } - - // merge host additions with txn - txn.SiacoinInputs = append(txn.SiacoinInputs, resp.Inputs...) - txn.SiacoinOutputs = append(txn.SiacoinOutputs, resp.Outputs...) - - // create initial (no-op) revision, transaction, and signature - fc := txn.FileContracts[0] - initRevision := types.FileContractRevision{ - ParentID: txn.FileContractID(0), - UnlockConditions: types.UnlockConditions{ - PublicKeys: []types.UnlockKey{ - renterPubkey.UnlockKey(), - t.HostKey().UnlockKey(), - }, - SignaturesRequired: 2, - }, - FileContract: types.FileContract{ - RevisionNumber: 1, - Filesize: fc.Filesize, - FileMerkleRoot: fc.FileMerkleRoot, - WindowStart: fc.WindowStart, - WindowEnd: fc.WindowEnd, - ValidProofOutputs: fc.ValidProofOutputs, - MissedProofOutputs: fc.MissedProofOutputs, - UnlockHash: fc.UnlockHash, - }, - } - revSig := renterKey.SignHash(hashRevision(initRevision)) - renterRevisionSig := types.TransactionSignature{ - ParentID: types.Hash256(initRevision.ParentID), - CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, - PublicKeyIndex: 0, - Signature: revSig[:], - } - - // write our signatures - renterSigs := &rhp2.RPCFormContractSignatures{ - ContractSignatures: renterContractSignatures, - RevisionSignature: renterRevisionSig, - } - if err := t.WriteResponse(renterSigs); err != nil { - return rhp2.ContractRevision{}, nil, err - } - - // read the host's signatures and merge them with our own - var hostSigs rhp2.RPCFormContractSignatures - if err := t.ReadResponse(&hostSigs, minMessageSize); err != nil { - return rhp2.ContractRevision{}, nil, err - } - - txn.Signatures = append(renterContractSignatures, hostSigs.ContractSignatures...) - signedTxnSet := append(resp.Parents, append(parents, txn)...) - return rhp2.ContractRevision{ - Revision: initRevision, - Signatures: [2]types.TransactionSignature{ - renterRevisionSig, - hostSigs.RevisionSignature, - }, - }, signedTxnSet, nil -} - -// RHP2Session represents a session with a host -type RHP2Session struct { - transport *rhp2.Transport - revision rhp2.ContractRevision - key types.PrivateKey - appendRoots []types.Hash256 - settings rhp2.HostSettings - lastSeen time.Time -} - -// NewRHP2Session returns a new rhp2 session -func NewRHP2Session(t *rhp2.Transport, key types.PrivateKey, rev rhp2.ContractRevision, settings rhp2.HostSettings) *RHP2Session { - return &RHP2Session{ - transport: t, - key: key, - revision: rev, - settings: settings, - } -} - -// Append appends the sector to the contract -func (s *RHP2Session) Append(ctx context.Context, sector *[rhp2.SectorSize]byte, price, collateral types.Currency) (types.Hash256, error) { - err := s.Write(ctx, []rhp2.RPCWriteAction{{ - Type: rhp2.RPCWriteActionAppend, - Data: sector[:], - }}, price, collateral) - if err != nil { - return types.Hash256{}, err - } - return s.appendRoots[0], nil -} - -// Close closes the underlying transport -func (s *RHP2Session) Close() (err error) { - return s.closeTransport() -} - -// Delete deletes the sectors at the given indices from the contract -func (s *RHP2Session) Delete(ctx context.Context, sectorIndices []uint64, price types.Currency) error { - if len(sectorIndices) == 0 { - return nil - } - - // sort in descending order so that we can use 'range' - sort.Slice(sectorIndices, func(i, j int) bool { - return sectorIndices[i] > sectorIndices[j] - }) - - // iterate backwards from the end of the contract, swapping each "good" - // sector with one of the "bad" sectors. - var actions []rhp2.RPCWriteAction - cIndex := s.revision.NumSectors() - 1 - for _, rIndex := range sectorIndices { - if cIndex != rIndex { - // swap a "good" sector for a "bad" sector - actions = append(actions, rhp2.RPCWriteAction{ - Type: rhp2.RPCWriteActionSwap, - A: uint64(cIndex), - B: uint64(rIndex), - }) - } - cIndex-- - } - // trim all "bad" sectors - actions = append(actions, rhp2.RPCWriteAction{ - Type: rhp2.RPCWriteActionTrim, - A: uint64(len(sectorIndices)), - }) - - // request the swap+delete operation - // - // NOTE: siad hosts will accept up to 20 MiB of data in the request, - // which should be sufficient to delete up to 2.5 TiB of sector data - // at a time. - return s.Write(ctx, actions, price, types.ZeroCurrency) -} - -// HostKey returns the host's public key -func (s *RHP2Session) HostKey() types.PublicKey { return s.revision.HostKey() } - -// Read reads the given sections -func (s *RHP2Session) Read(ctx context.Context, w io.Writer, sections []rhp2.RPCReadRequestSection, price types.Currency) (err error) { - empty := true - for _, s := range sections { - empty = empty && s.Length == 0 - } - if empty || len(sections) == 0 { - return nil - } - - if !s.isRevisable() { - return errContractFinalized - } else if !s.sufficientFunds(price) { - return errInsufficientFunds - } - - // construct new revision - rev := s.revision.Revision - rev.RevisionNumber++ - newValid, newMissed := updateRevisionOutputs(&rev, price, types.ZeroCurrency) - revisionHash := hashRevision(rev) - renterSig := s.key.SignHash(revisionHash) - - // construct the request - req := &rhp2.RPCReadRequest{ - Sections: sections, - MerkleProof: true, - - RevisionNumber: rev.RevisionNumber, - ValidProofValues: newValid, - MissedProofValues: newMissed, - Signature: renterSig, - } - - var hostSig *types.Signature - if err := s.withTransport(ctx, func(transport *rhp2.Transport) error { - if err := transport.WriteRequest(rhp2.RPCReadID, req); err != nil { - return err - } - - // ensure we send RPCLoopReadStop before returning - defer transport.WriteResponse(&rhp2.RPCReadStop) - - // read all sections - for _, sec := range sections { - hostSig, err = s.readSection(w, transport, sec) - if err != nil { - return err - } - if hostSig != nil { - break // exit the loop; they won't be sending any more data - } - } - - // the host is required to send a signature; if they haven't sent one - // yet, they should send an empty ReadResponse containing just the - // signature. - if hostSig == nil { - var resp rhp2.RPCReadResponse - if err := transport.ReadResponse(&resp, 4096); err != nil { - return wrapResponseErr(err, "couldn't read signature", "host rejected Read request") - } - hostSig = &resp.Signature - } - return nil - }); err != nil { - return err - } - - // verify the host signature - if !s.HostKey().VerifyHash(revisionHash, *hostSig) { - return errors.New("host's signature is invalid") - } - s.revision.Revision = rev - s.revision.Signatures[0].Signature = renterSig[:] - s.revision.Signatures[1].Signature = hostSig[:] - - return nil -} - -// Reconnect reconnects to the host -func (s *RHP2Session) Reconnect(ctx context.Context, hostIP string, hostKey types.PublicKey, renterKey types.PrivateKey, contractID types.FileContractID) (err error) { - s.closeTransport() - - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", hostIP) - if err != nil { - return err - } - s.transport, err = rhp2.NewRenterTransport(conn, hostKey) - if err != nil { - return err - } - - s.key = renterKey - if err = s.lock(ctx, contractID, renterKey, 10*time.Second); err != nil { - s.closeTransport() - return err - } - - if err := s.updateSettings(ctx); err != nil { - s.closeTransport() - return err - } - - s.lastSeen = time.Now() - return nil -} - -// Refresh refreshes the session -func (s *RHP2Session) Refresh(ctx context.Context, sessionTTL time.Duration, renterKey types.PrivateKey, contractID types.FileContractID) error { - if s.transport == nil { - return errors.New("no transport") - } - - if time.Since(s.lastSeen) >= sessionTTL { - // use RPCSettings as a generic "ping" - if err := s.updateSettings(ctx); err != nil { - return err - } - } - - if s.revision.ID() != contractID { - // connected, but not locking the correct contract - if s.revision.ID() != (types.FileContractID{}) { - if err := s.unlock(ctx); err != nil { - return err - } - } - if err := s.lock(ctx, contractID, renterKey, 10*time.Second); err != nil { - return err - } - - s.key = renterKey - if err := s.updateSettings(ctx); err != nil { - return err - } - } - s.lastSeen = time.Now() - return nil -} - -// RenewContract renews the contract -func (s *RHP2Session) RenewContract(ctx context.Context, txnSet []types.Transaction, finalPayment types.Currency) (_ rhp2.ContractRevision, _ []types.Transaction, err error) { - // strip our signatures before sending - parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] - renterContractSignatures := txn.Signatures - txnSet[len(txnSet)-1].Signatures = nil - - // construct the final revision of the old contract - finalOldRevision := s.revision.Revision - newValid, _ := updateRevisionOutputs(&finalOldRevision, finalPayment, types.ZeroCurrency) - finalOldRevision.MissedProofOutputs = finalOldRevision.ValidProofOutputs - finalOldRevision.Filesize = 0 - finalOldRevision.FileMerkleRoot = types.Hash256{} - finalOldRevision.RevisionNumber = math.MaxUint64 - - // construct the renew request - req := &rhp2.RPCRenewAndClearContractRequest{ - Transactions: txnSet, - RenterKey: s.revision.Revision.UnlockConditions.PublicKeys[0], - FinalValidProofValues: newValid, - FinalMissedProofValues: newValid, - } - - // send the request - var resp rhp2.RPCFormContractAdditions - if err := s.withTransport(ctx, func(transport *rhp2.Transport) error { - if err := transport.WriteRequest(rhp2.RPCRenewClearContractID, req); err != nil { - return err - } - return transport.ReadResponse(&resp, 65536) - }); err != nil { - return rhp2.ContractRevision{}, nil, err - } - - // merge host additions with txn - txn.SiacoinInputs = append(txn.SiacoinInputs, resp.Inputs...) - txn.SiacoinOutputs = append(txn.SiacoinOutputs, resp.Outputs...) - - // create initial (no-op) revision, transaction, and signature - fc := txn.FileContracts[0] - initRevision := types.FileContractRevision{ - ParentID: txn.FileContractID(0), - UnlockConditions: s.revision.Revision.UnlockConditions, - FileContract: types.FileContract{ - RevisionNumber: 1, - Filesize: fc.Filesize, - FileMerkleRoot: fc.FileMerkleRoot, - WindowStart: fc.WindowStart, - WindowEnd: fc.WindowEnd, - ValidProofOutputs: fc.ValidProofOutputs, - MissedProofOutputs: fc.MissedProofOutputs, - UnlockHash: fc.UnlockHash, - }, - } - revSig := s.key.SignHash(hashRevision(initRevision)) - renterRevisionSig := types.TransactionSignature{ - ParentID: types.Hash256(initRevision.ParentID), - CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, - PublicKeyIndex: 0, - Signature: revSig[:], - } - - // create signatures - finalRevSig := s.key.SignHash(hashRevision(finalOldRevision)) - renterSigs := &rhp2.RPCRenewAndClearContractSignatures{ - ContractSignatures: renterContractSignatures, - RevisionSignature: renterRevisionSig, - FinalRevisionSignature: finalRevSig, - } - - // send the signatures and read the host's signatures - var hostSigs rhp2.RPCRenewAndClearContractSignatures - if err := s.withTransport(ctx, func(transport *rhp2.Transport) error { - if err := transport.WriteResponse(renterSigs); err != nil { - return err - } - return transport.ReadResponse(&hostSigs, 4096) - }); err != nil { - return rhp2.ContractRevision{}, nil, err - } - - // merge host signatures with our own - txn.Signatures = append(renterContractSignatures, hostSigs.ContractSignatures...) - signedTxnSet := append(resp.Parents, append(parents, txn)...) - return rhp2.ContractRevision{ - Revision: initRevision, - Signatures: [2]types.TransactionSignature{renterRevisionSig, hostSigs.RevisionSignature}, - }, signedTxnSet, nil -} - -// Revision returns the contract revision -func (s *RHP2Session) Revision() (rev rhp2.ContractRevision) { - b, _ := json.Marshal(s.revision) // deep copy - if err := json.Unmarshal(b, &rev); err != nil { - panic(err) // should never happen - } - return rev -} - -// RPCAppendCost returns the cost of a single append -func (s *RHP2Session) RPCAppendCost(remainingDuration uint64) (types.Currency, types.Currency, error) { - var sector [rhp2.SectorSize]byte - actions := []rhp2.RPCWriteAction{{Type: rhp2.RPCWriteActionAppend, Data: sector[:]}} - cost, err := s.settings.RPCWriteCost(actions, s.revision.Revision.Filesize/rhp2.SectorSize, remainingDuration, true) - if err != nil { - return types.ZeroCurrency, types.ZeroCurrency, err - } - price, collateral := cost.Total() - return price, collateral, nil -} - -// SectorRoots returns n roots at offset. -func (s *RHP2Session) SectorRoots(ctx context.Context, offset, n uint64, price types.Currency) (roots []types.Hash256, err error) { - if !s.isRevisable() { - return nil, errContractFinalized - } else if offset+n > s.revision.NumSectors() { - return nil, errors.New("requested range is out-of-bounds") - } else if n == 0 { - return nil, nil - } else if !s.sufficientFunds(price) { - return nil, errInsufficientFunds - } - - // construct new revision - rev := s.revision.Revision - rev.RevisionNumber++ - newValid, newMissed := updateRevisionOutputs(&rev, price, types.ZeroCurrency) - revisionHash := hashRevision(rev) - - req := &rhp2.RPCSectorRootsRequest{ - RootOffset: uint64(offset), - NumRoots: uint64(n), - - RevisionNumber: rev.RevisionNumber, - ValidProofValues: newValid, - MissedProofValues: newMissed, - Signature: s.key.SignHash(revisionHash), - } - - // execute the sector roots RPC - var resp rhp2.RPCSectorRootsResponse - err = s.withTransport(ctx, func(t *rhp2.Transport) error { - if err := t.WriteRequest(rhp2.RPCSectorRootsID, req); err != nil { - return err - } else if err := t.ReadResponse(&resp, uint64(4096+32*n)); err != nil { - readCtx := fmt.Sprintf("couldn't read %v response", rhp2.RPCSectorRootsID) - rejectCtx := fmt.Sprintf("host rejected %v request", rhp2.RPCSectorRootsID) - return wrapResponseErr(err, readCtx, rejectCtx) - } else { - return nil - } - }) - if err != nil { - return nil, err - } - - // verify the host signature - if !s.HostKey().VerifyHash(revisionHash, resp.Signature) { - return nil, errors.New("host's signature is invalid") - } - s.revision.Revision = rev - s.revision.Signatures[0].Signature = req.Signature[:] - s.revision.Signatures[1].Signature = resp.Signature[:] - - // verify the proof - if !rhp2.VerifySectorRangeProof(resp.MerkleProof, resp.SectorRoots, offset, offset+n, s.revision.NumSectors(), rev.FileMerkleRoot) { - return nil, errInvalidMerkleProof - } - return resp.SectorRoots, nil -} - -// Settings returns the host settings -func (s *RHP2Session) Settings() *rhp2.HostSettings { return &s.settings } - -// Write performs the given write actions -func (s *RHP2Session) Write(ctx context.Context, actions []rhp2.RPCWriteAction, price, collateral types.Currency) (err error) { - if !s.isRevisable() { - return errContractFinalized - } else if len(actions) == 0 { - return nil - } else if !s.sufficientFunds(price) { - return errInsufficientFunds - } else if !s.sufficientCollateral(collateral) { - return errInsufficientCollateral - } - - rev := s.revision.Revision - newFilesize := rev.Filesize - for _, action := range actions { - switch action.Type { - case rhp2.RPCWriteActionAppend: - newFilesize += rhp2.SectorSize - case rhp2.RPCWriteActionTrim: - newFilesize -= rhp2.SectorSize * action.A - } - } - - // calculate new revision outputs - newValid, newMissed := updateRevisionOutputs(&rev, price, collateral) - - // compute appended roots in parallel with I/O - precompChan := make(chan struct{}) - go func() { - s.appendRoots = s.appendRoots[:0] - for _, action := range actions { - if action.Type == rhp2.RPCWriteActionAppend { - s.appendRoots = append(s.appendRoots, rhp2.SectorRoot((*[rhp2.SectorSize]byte)(action.Data))) - } - } - close(precompChan) - }() - // ensure that the goroutine has exited before we return - defer func() { <-precompChan }() - - // create request - req := &rhp2.RPCWriteRequest{ - Actions: actions, - MerkleProof: true, - - RevisionNumber: rev.RevisionNumber + 1, - ValidProofValues: newValid, - MissedProofValues: newMissed, - } - - // send request and read merkle proof - var merkleResp rhp2.RPCWriteMerkleProof - if err := s.withTransport(ctx, func(transport *rhp2.Transport) error { - if err := transport.WriteRequest(rhp2.RPCWriteID, req); err != nil { - return err - } else if err := transport.ReadResponse(&merkleResp, 4096); err != nil { - return wrapResponseErr(err, "couldn't read Merkle proof response", "host rejected Write request") - } else { - return nil - } - }); err != nil { - return err - } - - // verify proof - proofHashes := merkleResp.OldSubtreeHashes - leafHashes := merkleResp.OldLeafHashes - oldRoot, newRoot := types.Hash256(rev.FileMerkleRoot), merkleResp.NewMerkleRoot - <-precompChan - if newFilesize > 0 && !rhp2.VerifyDiffProof(actions, s.revision.NumSectors(), proofHashes, leafHashes, oldRoot, newRoot, s.appendRoots) { - err := errInvalidMerkleProof - s.withTransport(ctx, func(transport *rhp2.Transport) error { return transport.WriteResponseErr(err) }) - return err - } - - // update revision - rev.RevisionNumber++ - rev.Filesize = newFilesize - copy(rev.FileMerkleRoot[:], newRoot[:]) - revisionHash := hashRevision(rev) - renterSig := &rhp2.RPCWriteResponse{ - Signature: s.key.SignHash(revisionHash), - } - - // exchange signatures - var hostSig rhp2.RPCWriteResponse - if err := s.withTransport(ctx, func(transport *rhp2.Transport) error { - if err := transport.WriteResponse(renterSig); err != nil { - return fmt.Errorf("couldn't write signature response: %w", err) - } else if err := transport.ReadResponse(&hostSig, 4096); err != nil { - return wrapResponseErr(err, "couldn't read signature response", "host rejected Write signature") - } else { - return nil - } - }); err != nil { - return err - } - - // verify the host signature - if !s.HostKey().VerifyHash(revisionHash, hostSig.Signature) { - return errors.New("host's signature is invalid") - } - s.revision.Revision = rev - s.revision.Signatures[0].Signature = renterSig.Signature[:] - s.revision.Signatures[1].Signature = hostSig.Signature[:] - return nil -} - -func (s *RHP2Session) closeTransport() error { - if s.transport != nil { - return s.transport.Close() - } - return nil -} - -func (s *RHP2Session) isRevisable() bool { - return s.revision.Revision.RevisionNumber < math.MaxUint64 -} - -func (s *RHP2Session) lock(ctx context.Context, id types.FileContractID, key types.PrivateKey, timeout time.Duration) (err error) { - req := &rhp2.RPCLockRequest{ - ContractID: id, - Signature: s.transport.SignChallenge(key), - Timeout: uint64(timeout.Milliseconds()), - } - - // execute lock RPC - var resp rhp2.RPCLockResponse - if err := s.withTransport(ctx, func(transport *rhp2.Transport) error { - if err := transport.Call(rhp2.RPCLockID, req, &resp); err != nil { - return err - } - transport.SetChallenge(resp.NewChallenge) - return nil - }); err != nil { - return err - } - - // verify claimed revision - if len(resp.Signatures) != 2 { - return fmt.Errorf("host returned wrong number of signatures (expected 2, got %v)", len(resp.Signatures)) - } else if len(resp.Signatures[0].Signature) != 64 || len(resp.Signatures[1].Signature) != 64 { - return errors.New("signatures on claimed revision have wrong length") - } - revHash := hashRevision(resp.Revision) - if !key.PublicKey().VerifyHash(revHash, *(*types.Signature)(resp.Signatures[0].Signature)) { - return errors.New("renter's signature on claimed revision is invalid") - } else if !s.transport.HostKey().VerifyHash(revHash, *(*types.Signature)(resp.Signatures[1].Signature)) { - return errors.New("host's signature on claimed revision is invalid") - } else if !resp.Acquired { - return errContractLocked - } else if resp.Revision.RevisionNumber == math.MaxUint64 { - return errContractFinalized - } - s.revision = rhp2.ContractRevision{ - Revision: resp.Revision, - Signatures: [2]types.TransactionSignature{resp.Signatures[0], resp.Signatures[1]}, - } - return nil -} - -func (s *RHP2Session) readSection(w io.Writer, t *rhp2.Transport, sec rhp2.RPCReadRequestSection) (hostSig *types.Signature, _ error) { - // NOTE: normally, we would call ReadResponse here to read an AEAD RPC - // message, verify the tag and decrypt, and then pass the data to - // VerifyProof. As an optimization, we instead stream the message - // through a Merkle proof verifier before verifying the AEAD tag. - // Security therefore depends on the caller of Read discarding any data - // written to w in the event that verification fails. - msgReader, err := t.RawResponse(4096 + uint64(sec.Length)) - if err != nil { - return nil, wrapResponseErr(err, "couldn't read sector data", "host rejected Read request") - } - // Read the signature, which may or may not be present. - lenbuf := make([]byte, 8) - if _, err := io.ReadFull(msgReader, lenbuf); err != nil { - return nil, fmt.Errorf("couldn't read signature len: %w", err) - } - if n := binary.LittleEndian.Uint64(lenbuf); n > 0 { - hostSig = new(types.Signature) - if _, err := io.ReadFull(msgReader, hostSig[:]); err != nil { - return nil, fmt.Errorf("couldn't read signature: %w", err) - } - } - // stream the sector data into w and the proof verifier - if _, err := io.ReadFull(msgReader, lenbuf); err != nil { - return nil, fmt.Errorf("couldn't read data len: %w", err) - } else if binary.LittleEndian.Uint64(lenbuf) != uint64(sec.Length) { - return nil, errors.New("host sent wrong amount of sector data") - } - proofStart := sec.Offset / rhp2.LeafSize - proofEnd := proofStart + sec.Length/rhp2.LeafSize - rpv := rhp2.NewRangeProofVerifier(proofStart, proofEnd) - tee := io.TeeReader(io.LimitReader(msgReader, int64(sec.Length)), &segWriter{w: w}) - // the proof verifier Reads one segment at a time, so bufio is crucial - // for performance here - if _, err := rpv.ReadFrom(bufio.NewReaderSize(tee, 1<<16)); err != nil { - return nil, fmt.Errorf("couldn't stream sector data: %w", err) - } - // read the Merkle proof - if _, err := io.ReadFull(msgReader, lenbuf); err != nil { - return nil, fmt.Errorf("couldn't read proof len: %w", err) - } - if binary.LittleEndian.Uint64(lenbuf) != uint64(rhp2.RangeProofSize(rhp2.LeavesPerSector, proofStart, proofEnd)) { - return nil, errors.New("invalid proof size") - } - proof := make([]types.Hash256, binary.LittleEndian.Uint64(lenbuf)) - for i := range proof { - if _, err := io.ReadFull(msgReader, proof[i][:]); err != nil { - return nil, fmt.Errorf("couldn't read Merkle proof: %w", err) - } - } - // verify the message tag and the Merkle proof - if err := msgReader.VerifyTag(); err != nil { - return nil, err - } - if !rpv.Verify(proof, sec.MerkleRoot) { - return nil, errInvalidMerkleProof - } - return -} - -func (s *RHP2Session) sufficientFunds(price types.Currency) bool { - return s.revision.RenterFunds().Cmp(price) >= 0 -} - -func (s *RHP2Session) sufficientCollateral(collateral types.Currency) bool { - return s.revision.Revision.MissedProofOutputs[1].Value.Cmp(collateral) >= 0 -} - -func (s *RHP2Session) updateSettings(ctx context.Context) (err error) { - var resp rhp2.RPCSettingsResponse - if err := s.withTransport(ctx, func(transport *rhp2.Transport) error { - return transport.Call(rhp2.RPCSettingsID, nil, &resp) - }); err != nil { - return err - } - - if err := json.Unmarshal(resp.Settings, &s.settings); err != nil { - return fmt.Errorf("couldn't unmarshal json: %w", err) - } - return -} - -func (s *RHP2Session) unlock(ctx context.Context) (err error) { - s.revision = rhp2.ContractRevision{} - s.key = nil - - return s.withTransport(ctx, func(transport *rhp2.Transport) error { - return transport.WriteRequest(rhp2.RPCUnlockID, nil) - }) -} - -func (s *RHP2Session) withTransport(ctx context.Context, fn func(t *rhp2.Transport) error) (err error) { - errChan := make(chan error) - go func() { - defer close(errChan) - errChan <- fn(s.transport) - }() - - select { - case err = <-errChan: - return - case <-ctx.Done(): - _ = s.transport.ForceClose() // ignore error - if err = <-errChan; err == nil { - err = ctx.Err() - } - s.transport = nil - } - return -} - -func hashRevision(rev types.FileContractRevision) types.Hash256 { - h := types.NewHasher() - rev.EncodeTo(h.E) - return h.Sum() -} - -func updateRevisionOutputs(rev *types.FileContractRevision, cost, collateral types.Currency) (valid, missed []types.Currency) { - // allocate new slices; don't want to risk accidentally sharing memory - rev.ValidProofOutputs = append([]types.SiacoinOutput(nil), rev.ValidProofOutputs...) - rev.MissedProofOutputs = append([]types.SiacoinOutput(nil), rev.MissedProofOutputs...) - - // move valid payout from renter to host - rev.ValidProofOutputs[0].Value = rev.ValidProofOutputs[0].Value.Sub(cost) - rev.ValidProofOutputs[1].Value = rev.ValidProofOutputs[1].Value.Add(cost) - - // move missed payout from renter to void - rev.MissedProofOutputs[0].Value = rev.MissedProofOutputs[0].Value.Sub(cost) - rev.MissedProofOutputs[2].Value = rev.MissedProofOutputs[2].Value.Add(cost) - - // move collateral from host to void - rev.MissedProofOutputs[1].Value = rev.MissedProofOutputs[1].Value.Sub(collateral) - rev.MissedProofOutputs[2].Value = rev.MissedProofOutputs[2].Value.Add(collateral) - - return []types.Currency{rev.ValidProofOutputs[0].Value, rev.ValidProofOutputs[1].Value}, - []types.Currency{rev.MissedProofOutputs[0].Value, rev.MissedProofOutputs[1].Value, rev.MissedProofOutputs[2].Value} -} - -func wrapResponseErr(err error, readCtx, rejectCtx string) error { - if errors.As(err, new(*rhp2.RPCError)) { - return fmt.Errorf("%s: %w", rejectCtx, err) - } - if err != nil { - return fmt.Errorf("%s: %w", readCtx, err) - } - return nil -} - -type segWriter struct { - w io.Writer - buf [rhp2.LeafSize * 64]byte - len int -} - -func (sw *segWriter) Write(p []byte) (int, error) { - lenp := len(p) - for len(p) > 0 { - n := copy(sw.buf[sw.len:], p) - sw.len += n - p = p[n:] - segs := sw.buf[:sw.len-(sw.len%rhp2.LeafSize)] - if _, err := sw.w.Write(segs); err != nil { - return 0, err - } - sw.len = copy(sw.buf[:], sw.buf[len(segs):sw.len]) - } - return lenp, nil -} diff --git a/internal/test/wallet.go b/internal/test/wallet.go deleted file mode 100644 index 892afc98..00000000 --- a/internal/test/wallet.go +++ /dev/null @@ -1,77 +0,0 @@ -package test - -import ( - "fmt" - "path/filepath" - - "go.sia.tech/core/types" - "go.sia.tech/hostd/persist/sqlite" - "go.sia.tech/hostd/wallet" - "go.uber.org/zap" -) - -// A Wallet is an ephemeral wallet that can be used for testing. -type Wallet struct { - *Node - *wallet.SingleAddressWallet - store *sqlite.Store - log *zap.Logger -} - -// Close closes the wallet. -func (w *Wallet) Close() error { - w.SingleAddressWallet.Close() - w.store.Close() - w.Node.Close() - w.log.Sync() - return nil -} - -// Store returns the wallet's store. -func (w *Wallet) Store() *sqlite.Store { - return w.store -} - -// SendSiacoins helper func to send siacoins from a wallet. -func (w *Wallet) SendSiacoins(outputs []types.SiacoinOutput) (txn types.Transaction, err error) { - var siacoinOutput types.Currency - for _, o := range outputs { - siacoinOutput = siacoinOutput.Add(o.Value) - } - txn.SiacoinOutputs = outputs - - toSign, release, err := w.FundTransaction(&txn, siacoinOutput) - if err != nil { - return types.Transaction{}, fmt.Errorf("failed to fund transaction: %w", err) - } - if err := w.SignTransaction(w.ChainManager().TipState(), &txn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - return txn, fmt.Errorf("failed to sign transaction: %w", err) - } else if err := w.tp.AcceptTransactionSet([]types.Transaction{txn}); err != nil { - release() - return txn, fmt.Errorf("failed to accept transaction set: %w", err) - } - return txn, nil -} - -// NewWallet initializes a new test wallet. -func NewWallet(privKey types.PrivateKey, dir string, log *zap.Logger) (*Wallet, error) { - node, err := NewNode(dir) - if err != nil { - return nil, fmt.Errorf("failed to create node: %w", err) - } - db, err := sqlite.OpenDatabase(filepath.Join(dir, "wallet.db"), log.Named("sqlite")) - if err != nil { - return nil, fmt.Errorf("failed to create sql store: %w", err) - } - wallet, err := wallet.NewSingleAddressWallet(privKey, node.cm, db, log.Named("wallet")) - if err != nil { - return nil, fmt.Errorf("failed to create wallet: %w", err) - } - return &Wallet{ - Node: node, - SingleAddressWallet: wallet, - log: log, - store: db, - }, nil -} diff --git a/internal/testutil/rhp/v2/rhp.go b/internal/testutil/rhp/v2/rhp.go new file mode 100644 index 00000000..af6ad293 --- /dev/null +++ b/internal/testutil/rhp/v2/rhp.go @@ -0,0 +1,596 @@ +package rhp + +import ( + "bufio" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "time" + + rhp2 "go.sia.tech/core/rhp/v2" + "go.sia.tech/core/types" +) + +const ( + // minMessageSize is the minimum size of an RPC message + minMessageSize = 4096 +) + +var ( + errContractLocked = errors.New("contract is locked by another party") + errInvalidMerkleProof = errors.New("host supplied invalid Merkle proof") +) + +type segWriter struct { + w io.Writer + buf [rhp2.LeafSize * 64]byte + len int +} + +func (sw *segWriter) Write(p []byte) (int, error) { + lenp := len(p) + for len(p) > 0 { + n := copy(sw.buf[sw.len:], p) + sw.len += n + p = p[n:] + segs := sw.buf[:sw.len-(sw.len%rhp2.LeafSize)] + if _, err := sw.w.Write(segs); err != nil { + return 0, err + } + sw.len = copy(sw.buf[:], sw.buf[len(segs):sw.len]) + } + return lenp, nil +} + +func wrapResponseErr(err error, readCtx, rejectCtx string) error { + if errors.As(err, new(*rhp2.RPCError)) { + return fmt.Errorf("%s: %w", rejectCtx, err) + } + if err != nil { + return fmt.Errorf("%s: %w", readCtx, err) + } + return nil +} + +func hashRevision(rev types.FileContractRevision) types.Hash256 { + h := types.NewHasher() + rev.EncodeTo(h.E) + return h.Sum() +} + +func revisionOutputValues(rev types.FileContractRevision) (valid, missed []types.Currency) { + valid = make([]types.Currency, len(rev.ValidProofOutputs)) + missed = make([]types.Currency, len(rev.MissedProofOutputs)) + + for i := range rev.ValidProofOutputs { + valid[i] = rev.ValidProofOutputs[i].Value + } + + for i := range rev.MissedProofOutputs { + missed[i] = rev.MissedProofOutputs[i].Value + } + return +} + +func revisionTransfer(rev types.FileContractRevision, cost, collateral types.Currency) (valid, missed []types.Currency) { + valid, missed = revisionOutputValues(rev) + if len(valid) != 2 || len(missed) != 3 { + panic("unexpected number of outputs") + } + + // move valid payout from renter to host + valid[0] = rev.ValidProofOutputs[0].Value.Sub(cost) + valid[1] = rev.ValidProofOutputs[1].Value.Add(cost) + + // move missed payout from renter to void + missed[0] = rev.MissedProofOutputs[0].Value.Sub(cost) + missed[2] = rev.MissedProofOutputs[2].Value.Add(cost) + + // move collateral from host to void + missed[1] = rev.MissedProofOutputs[1].Value.Sub(collateral) + missed[2] = missed[2].Add(collateral) + return +} + +func readSection(w io.Writer, t *rhp2.Transport, sec rhp2.RPCReadRequestSection) (hostSig *types.Signature, _ error) { + // NOTE: normally, we would call ReadResponse here to read an AEAD RPC + // message, verify the tag and decrypt, and then pass the data to + // VerifyProof. As an optimization, we instead stream the message + // through a Merkle proof verifier before verifying the AEAD tag. + // Security therefore depends on the caller of Read discarding any data + // written to w in the event that verification fails. + msgReader, err := t.RawResponse(4096 + uint64(sec.Length)) + if err != nil { + return nil, wrapResponseErr(err, "couldn't read sector data", "host rejected Read request") + } + // Read the signature, which may or may not be present. + lenbuf := make([]byte, 8) + if _, err := io.ReadFull(msgReader, lenbuf); err != nil { + return nil, fmt.Errorf("couldn't read signature len: %w", err) + } + if n := binary.LittleEndian.Uint64(lenbuf); n > 0 { + hostSig = new(types.Signature) + if _, err := io.ReadFull(msgReader, hostSig[:]); err != nil { + return nil, fmt.Errorf("couldn't read signature: %w", err) + } + } + // stream the sector data into w and the proof verifier + if _, err := io.ReadFull(msgReader, lenbuf); err != nil { + return nil, fmt.Errorf("couldn't read data len: %w", err) + } else if binary.LittleEndian.Uint64(lenbuf) != uint64(sec.Length) { + return nil, errors.New("host sent wrong amount of sector data") + } + proofStart := sec.Offset / rhp2.LeafSize + proofEnd := proofStart + sec.Length/rhp2.LeafSize + rpv := rhp2.NewRangeProofVerifier(proofStart, proofEnd) + tee := io.TeeReader(io.LimitReader(msgReader, int64(sec.Length)), &segWriter{w: w}) + // the proof verifier Reads one segment at a time, so bufio is crucial + // for performance here + if _, err := rpv.ReadFrom(bufio.NewReaderSize(tee, 1<<16)); err != nil { + return nil, fmt.Errorf("couldn't stream sector data: %w", err) + } + // read the Merkle proof + if _, err := io.ReadFull(msgReader, lenbuf); err != nil { + return nil, fmt.Errorf("couldn't read proof len: %w", err) + } + if binary.LittleEndian.Uint64(lenbuf) != uint64(rhp2.RangeProofSize(rhp2.LeavesPerSector, proofStart, proofEnd)) { + return nil, errors.New("invalid proof size") + } + proof := make([]types.Hash256, binary.LittleEndian.Uint64(lenbuf)) + for i := range proof { + if _, err := io.ReadFull(msgReader, proof[i][:]); err != nil { + return nil, fmt.Errorf("couldn't read Merkle proof: %w", err) + } + } + // verify the message tag and the Merkle proof + if err := msgReader.VerifyTag(); err != nil { + return nil, err + } + if !rpv.Verify(proof, sec.MerkleRoot) { + return nil, errInvalidMerkleProof + } + return +} + +// RPCSettings calls the Settings RPC, returning the host's reported settings. +func RPCSettings(t *rhp2.Transport) (settings rhp2.HostSettings, err error) { + var resp rhp2.RPCSettingsResponse + if err := t.Call(rhp2.RPCSettingsID, nil, &resp); err != nil { + return rhp2.HostSettings{}, err + } else if err := json.Unmarshal(resp.Settings, &settings); err != nil { + return rhp2.HostSettings{}, fmt.Errorf("couldn't unmarshal json: %w", err) + } + + return settings, nil +} + +// RPCFormContract forms a contract with a host. +func RPCFormContract(t *rhp2.Transport, renterKey types.PrivateKey, txnSet []types.Transaction) (rhp2.ContractRevision, []types.Transaction, error) { + // strip our signatures before sending + parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] + renterContractSignatures := txn.Signatures + txnSet[len(txnSet)-1].Signatures = nil + + // create request + renterPubkey := renterKey.PublicKey() + req := &rhp2.RPCFormContractRequest{ + Transactions: txnSet, + RenterKey: renterPubkey.UnlockKey(), + } + if err := t.WriteRequest(rhp2.RPCFormContractID, req); err != nil { + return rhp2.ContractRevision{}, nil, err + } + + // execute form contract RPC + var resp rhp2.RPCFormContractAdditions + if err := t.ReadResponse(&resp, 65536); err != nil { + return rhp2.ContractRevision{}, nil, err + } + + // merge host additions with txn + txn.SiacoinInputs = append(txn.SiacoinInputs, resp.Inputs...) + txn.SiacoinOutputs = append(txn.SiacoinOutputs, resp.Outputs...) + + // create initial (no-op) revision, transaction, and signature + fc := txn.FileContracts[0] + initRevision := types.FileContractRevision{ + ParentID: txn.FileContractID(0), + UnlockConditions: types.UnlockConditions{ + PublicKeys: []types.UnlockKey{ + renterPubkey.UnlockKey(), + t.HostKey().UnlockKey(), + }, + SignaturesRequired: 2, + }, + FileContract: types.FileContract{ + RevisionNumber: 1, + Filesize: fc.Filesize, + FileMerkleRoot: fc.FileMerkleRoot, + WindowStart: fc.WindowStart, + WindowEnd: fc.WindowEnd, + ValidProofOutputs: fc.ValidProofOutputs, + MissedProofOutputs: fc.MissedProofOutputs, + UnlockHash: fc.UnlockHash, + }, + } + revSig := renterKey.SignHash(hashRevision(initRevision)) + renterRevisionSig := types.TransactionSignature{ + ParentID: types.Hash256(initRevision.ParentID), + CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, + PublicKeyIndex: 0, + Signature: revSig[:], + } + + // write our signatures + renterSigs := &rhp2.RPCFormContractSignatures{ + ContractSignatures: renterContractSignatures, + RevisionSignature: renterRevisionSig, + } + if err := t.WriteResponse(renterSigs); err != nil { + return rhp2.ContractRevision{}, nil, err + } + + // read the host's signatures and merge them with our own + var hostSigs rhp2.RPCFormContractSignatures + if err := t.ReadResponse(&hostSigs, minMessageSize); err != nil { + return rhp2.ContractRevision{}, nil, err + } + + txn.Signatures = append(renterContractSignatures, hostSigs.ContractSignatures...) + signedTxnSet := append(resp.Parents, append(parents, txn)...) + return rhp2.ContractRevision{ + Revision: initRevision, + Signatures: [2]types.TransactionSignature{ + renterRevisionSig, + hostSigs.RevisionSignature, + }, + }, signedTxnSet, nil +} + +// RPCWrite appends data to a contract. +func RPCWrite(t *rhp2.Transport, renterKey types.PrivateKey, rev *rhp2.ContractRevision, actions []rhp2.RPCWriteAction, price, collateral types.Currency) error { + newValid, newMissed := revisionTransfer(rev.Revision, price, collateral) + + // precompute append roots + var appendRoots []types.Hash256 + appendRoots = appendRoots[:0] + contractSize := rev.Revision.Filesize + for _, action := range actions { + if action.Type == rhp2.RPCWriteActionAppend { + appendRoots = append(appendRoots, rhp2.SectorRoot((*[rhp2.SectorSize]byte)(action.Data))) + } + + switch action.Type { + case rhp2.RPCWriteActionAppend: + contractSize += rhp2.SectorSize + case rhp2.RPCWriteActionTrim: + d := rhp2.SectorSize * action.A + if contractSize < d { + return fmt.Errorf("contract size too small to trim %d sectors", action.A) + } + contractSize -= d + } + } + req := &rhp2.RPCWriteRequest{ + Actions: actions, + MerkleProof: true, + + RevisionNumber: rev.Revision.RevisionNumber + 1, + ValidProofValues: newValid, + MissedProofValues: newMissed, + } + err := t.WriteRequest(rhp2.RPCWriteID, req) + if err != nil { + return fmt.Errorf("failed to write request: %w", err) + } + + var merkleResp rhp2.RPCWriteMerkleProof + if err := t.ReadResponse(&merkleResp, 4096); err != nil { + return fmt.Errorf("failed to read merkle proof response: %w", err) + } + + // verify proof + proofHashes := merkleResp.OldSubtreeHashes + leafHashes := merkleResp.OldLeafHashes + oldRoot, newRoot := rev.Revision.FileMerkleRoot, merkleResp.NewMerkleRoot + if contractSize > 0 && !rhp2.VerifyDiffProof(actions, rev.NumSectors(), proofHashes, leafHashes, oldRoot, newRoot, appendRoots) { + err := errInvalidMerkleProof + t.WriteResponseErr(err) + return err + } + + // create new revision + newRevision := rev.Revision + newRevision.RevisionNumber = req.RevisionNumber + newRevision.Filesize = contractSize + newRevision.ValidProofOutputs = make([]types.SiacoinOutput, len(newValid)) + newRevision.MissedProofOutputs = make([]types.SiacoinOutput, len(newMissed)) + copy(newRevision.FileMerkleRoot[:], newRoot[:]) + for i := range newValid { + newRevision.ValidProofOutputs[i].Address = rev.Revision.ValidProofOutputs[i].Address + newRevision.ValidProofOutputs[i].Value = newValid[i] + } + for i := range newMissed { + newRevision.MissedProofOutputs[i].Address = rev.Revision.MissedProofOutputs[i].Address + newRevision.MissedProofOutputs[i].Value = newMissed[i] + } + revisionHash := hashRevision(newRevision) + renterSig := &rhp2.RPCWriteResponse{ + Signature: renterKey.SignHash(revisionHash), + } + + if err := t.WriteResponse(renterSig); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + + var hostSig rhp2.RPCWriteResponse + if err := t.ReadResponse(&hostSig, 4096); err != nil { + return fmt.Errorf("failed to read host signature: %w", err) + } else if !rev.HostKey().VerifyHash(revisionHash, hostSig.Signature) { + return fmt.Errorf("host's signature is invalid") + } + + rev.Revision = newRevision + rev.Signatures[0].Signature = renterSig.Signature[:] + rev.Signatures[1].Signature = hostSig.Signature[:] + return nil +} + +// RPCRead reads data from a contract. +func RPCRead(t *rhp2.Transport, w io.Writer, renterKey types.PrivateKey, rev *rhp2.ContractRevision, sections []rhp2.RPCReadRequestSection, price types.Currency) error { + empty := true + for _, s := range sections { + empty = empty && s.Length == 0 + } + if empty || len(sections) == 0 { + return nil + } + + // create new revision + newValid, newMissed := revisionTransfer(rev.Revision, price, types.ZeroCurrency) + newRevision := rev.Revision + newRevision.RevisionNumber++ + newRevision.ValidProofOutputs = make([]types.SiacoinOutput, len(newValid)) + newRevision.MissedProofOutputs = make([]types.SiacoinOutput, len(newMissed)) + for i := range newValid { + newRevision.ValidProofOutputs[i].Address = rev.Revision.ValidProofOutputs[i].Address + newRevision.ValidProofOutputs[i].Value = newValid[i] + } + for i := range newMissed { + newRevision.MissedProofOutputs[i].Address = rev.Revision.MissedProofOutputs[i].Address + newRevision.MissedProofOutputs[i].Value = newMissed[i] + } + revisionHash := hashRevision(newRevision) + renterSig := renterKey.SignHash(revisionHash) + + // construct the request + req := &rhp2.RPCReadRequest{ + Sections: sections, + MerkleProof: true, + + RevisionNumber: newRevision.RevisionNumber, + ValidProofValues: newValid, + MissedProofValues: newMissed, + Signature: renterSig, + } + + if err := t.WriteRequest(rhp2.RPCReadID, req); err != nil { + return fmt.Errorf("failed to write request: %w", err) + } + defer t.WriteResponse(&rhp2.RPCReadStop) + + var hostSig *types.Signature + var err error + // read all sections + for i, sec := range sections { + hostSig, err = readSection(w, t, sec) + if err != nil { + return fmt.Errorf("failed to read section %d: %w", i, err) + } else if hostSig != nil { + break // exit the loop; they won't be sending any more data + } + } + + // the host is required to send a signature; if they haven't sent one + // yet, they should send an empty ReadResponse containing just the + // signature. + if hostSig == nil { + var resp rhp2.RPCReadResponse + if err := t.ReadResponse(&resp, 4096); err != nil { + return wrapResponseErr(err, "couldn't read signature", "host rejected Read request") + } + hostSig = &resp.Signature + } + + // verify the host signature + if !rev.HostKey().VerifyHash(revisionHash, *hostSig) { + return errors.New("host's signature is invalid") + } + rev.Revision = newRevision + rev.Signatures[0].Signature = renterSig[:] + rev.Signatures[1].Signature = hostSig[:] + return nil +} + +// RPCRenewContract renews a contract with a host. +func RPCRenewContract(t *rhp2.Transport, renterKey types.PrivateKey, rev *rhp2.ContractRevision, txnSet []types.Transaction, price types.Currency) (rhp2.ContractRevision, []types.Transaction, error) { + // strip our signatures before sending + parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] + renterSignatures := txn.Signatures + txnSet[len(txnSet)-1].Signatures = nil + + finalRevision := rev.Revision + finalRevision.Filesize = 0 + finalRevision.RevisionNumber = types.MaxRevisionNumber + finalRevision.FileMerkleRoot = types.Hash256{} + finalRevision.ValidProofOutputs[0].Value = finalRevision.ValidProofOutputs[0].Value.Sub(price) + finalRevision.ValidProofOutputs[1].Value = finalRevision.ValidProofOutputs[1].Value.Add(price) + finalRevision.MissedProofOutputs = finalRevision.ValidProofOutputs + + newValid, newMissed := revisionOutputValues(rev.Revision) + // construct the renew request + req := &rhp2.RPCRenewAndClearContractRequest{ + Transactions: txnSet, + RenterKey: renterKey.PublicKey().UnlockKey(), + FinalValidProofValues: newValid, + FinalMissedProofValues: newMissed, + } + + if err := t.WriteRequest(rhp2.RPCRenewClearContractID, req); err != nil { + return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to write request: %w", err) + } + + var hostAdditions rhp2.RPCFormContractAdditions + if err := t.ReadResponse(&hostAdditions, 65536); err != nil { + return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to read response: %w", err) + } + + // merge host additions with txn + txn.SiacoinInputs = append(txn.SiacoinInputs, hostAdditions.Inputs...) + txn.SiacoinOutputs = append(txn.SiacoinOutputs, hostAdditions.Outputs...) + + // create initial (no-op) revision, transaction, and signature + fc := txn.FileContracts[0] + initRevision := types.FileContractRevision{ + ParentID: txn.FileContractID(0), + UnlockConditions: finalRevision.UnlockConditions, + FileContract: types.FileContract{ + RevisionNumber: 1, + Filesize: fc.Filesize, + FileMerkleRoot: fc.FileMerkleRoot, + WindowStart: fc.WindowStart, + WindowEnd: fc.WindowEnd, + ValidProofOutputs: fc.ValidProofOutputs, + MissedProofOutputs: fc.MissedProofOutputs, + UnlockHash: fc.UnlockHash, + }, + } + initialRevRenterSig := renterKey.SignHash(hashRevision(initRevision)) + renterRevisionSig := types.TransactionSignature{ + ParentID: types.Hash256(initRevision.ParentID), + CoveredFields: types.CoveredFields{FileContractRevisions: []uint64{0}}, + PublicKeyIndex: 0, + Signature: initialRevRenterSig[:], + } + + // send renter signatures + finalRevSigHash := hashRevision(finalRevision) + finalRevSig := renterKey.SignHash(finalRevSigHash) + err := t.WriteResponse(&rhp2.RPCRenewAndClearContractSignatures{ + ContractSignatures: renterSignatures, + RevisionSignature: renterRevisionSig, + FinalRevisionSignature: finalRevSig, + }) + if err != nil { + return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to write response: %w", err) + } + + var hostSigs rhp2.RPCRenewAndClearContractSignatures + if err := t.ReadResponse(&hostSigs, 4096); err != nil { + return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to read response: %w", err) + } + + // merge host signatures with our own + txn.Signatures = append(renterSignatures, hostSigs.ContractSignatures...) + signedTxnSet := append(hostAdditions.Parents, append(parents, txn)...) + return rhp2.ContractRevision{ + Revision: initRevision, + Signatures: [2]types.TransactionSignature{renterRevisionSig, hostSigs.RevisionSignature}, + }, signedTxnSet, nil +} + +// RPCSectorRoots fetches sector roots from a host. +func RPCSectorRoots(t *rhp2.Transport, renterKey types.PrivateKey, offset, limit uint64, rev *rhp2.ContractRevision, price types.Currency) ([]types.Hash256, error) { + // create new revision + newValid, newMissed := revisionTransfer(rev.Revision, price, types.ZeroCurrency) + newRevision := rev.Revision + newRevision.RevisionNumber++ + newRevision.ValidProofOutputs = make([]types.SiacoinOutput, len(newValid)) + newRevision.MissedProofOutputs = make([]types.SiacoinOutput, len(newMissed)) + for i := range newValid { + newRevision.ValidProofOutputs[i].Address = rev.Revision.ValidProofOutputs[i].Address + newRevision.ValidProofOutputs[i].Value = newValid[i] + } + for i := range newMissed { + newRevision.MissedProofOutputs[i].Address = rev.Revision.MissedProofOutputs[i].Address + newRevision.MissedProofOutputs[i].Value = newMissed[i] + } + revisionHash := hashRevision(newRevision) + renterSig := renterKey.SignHash(revisionHash) + + // create request + req := &rhp2.RPCSectorRootsRequest{ + RootOffset: offset, + NumRoots: limit, + + RevisionNumber: newRevision.RevisionNumber, + ValidProofValues: newValid, + MissedProofValues: newMissed, + Signature: renterSig, + } + + if err := t.WriteRequest(rhp2.RPCSectorRootsID, req); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + var resp rhp2.RPCSectorRootsResponse + if err := t.ReadResponse(&resp, 4096); err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // verify the host signature + if !rev.HostKey().VerifyHash(revisionHash, resp.Signature) { + return nil, errors.New("host's signature is invalid") + } + rev.Revision = newRevision + rev.Signatures[0].Signature = req.Signature[:] + rev.Signatures[1].Signature = resp.Signature[:] + + // verify the proof + if !rhp2.VerifySectorRangeProof(resp.MerkleProof, resp.SectorRoots, offset, offset+limit, rev.NumSectors(), rev.Revision.FileMerkleRoot) { + return nil, errInvalidMerkleProof + } + return resp.SectorRoots, nil +} + +// RPCLock locks a contract with a host. +func RPCLock(t *rhp2.Transport, renterKey types.PrivateKey, id types.FileContractID) (rhp2.ContractRevision, error) { + req := &rhp2.RPCLockRequest{ + ContractID: id, + Signature: t.SignChallenge(renterKey), + Timeout: uint64(30 * time.Second.Milliseconds()), + } + + // execute lock RPC + var resp rhp2.RPCLockResponse + if err := t.Call(rhp2.RPCLockID, req, &resp); err != nil { + return rhp2.ContractRevision{}, err + } + t.SetChallenge(resp.NewChallenge) + + if len(resp.Signatures) != 2 { + return rhp2.ContractRevision{}, fmt.Errorf("host returned wrong number of signatures (expected 2, got %v)", len(resp.Signatures)) + } else if len(resp.Signatures[0].Signature) != 64 || len(resp.Signatures[1].Signature) != 64 { + return rhp2.ContractRevision{}, errors.New("signatures on claimed revision have wrong length") + } + + revHash := hashRevision(resp.Revision) + if !renterKey.PublicKey().VerifyHash(revHash, *(*types.Signature)(resp.Signatures[0].Signature)) { + return rhp2.ContractRevision{}, errors.New("renter's signature on claimed revision is invalid") + } else if !t.HostKey().VerifyHash(revHash, *(*types.Signature)(resp.Signatures[1].Signature)) { + return rhp2.ContractRevision{}, errors.New("host's signature on claimed revision is invalid") + } else if !resp.Acquired { + return rhp2.ContractRevision{}, errContractLocked + } + return rhp2.ContractRevision{ + Revision: resp.Revision, + Signatures: [2]types.TransactionSignature{resp.Signatures[0], resp.Signatures[1]}, + }, nil +} + +// RPCUnlock unlocks a contract with a host. +func RPCUnlock(t *rhp2.Transport) error { + return t.WriteRequest(rhp2.RPCUnlockID, nil) +} diff --git a/internal/test/rhp/v3/rhp.go b/internal/testutil/rhp/v3/rhp.go similarity index 97% rename from internal/test/rhp/v3/rhp.go rename to internal/testutil/rhp/v3/rhp.go index 55f7a845..85b05091 100644 --- a/internal/test/rhp/v3/rhp.go +++ b/internal/testutil/rhp/v3/rhp.go @@ -40,13 +40,15 @@ type ( // A Wallet funds and signs transactions Wallet interface { Address() types.Address - FundTransaction(txn *types.Transaction, amount types.Currency) ([]types.Hash256, func(), error) - SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error + FundTransaction(txn *types.Transaction, amount types.Currency, unconfirmed bool) ([]types.Hash256, error) + SignTransaction(txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) + ReleaseInputs([]types.Transaction, []types.V2Transaction) } // A ChainManager is used to get the current consensus state ChainManager interface { TipState() consensus.State + Tip() types.ChainIndex } ) @@ -459,7 +461,7 @@ func (s *Session) RenewContract(revision *rhp2.ContractRevision, hostAddr types. FileContracts: []types.FileContract{renewal}, } renterCost := rhp2.ContractRenewalCost(state, renewal, pt.ContractPrice, txnFee, baseCost) - toSign, release, err := s.w.FundTransaction(&renewTxn, renterCost) + toSign, err := s.w.FundTransaction(&renewTxn, renterCost, true) if err != nil { return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to fund transaction: %w", err) } @@ -471,16 +473,16 @@ func (s *Session) RenewContract(revision *rhp2.ContractRevision, hostAddr types. FinalRevisionSignature: renterKey.SignHash(clearingSigHash), } if err := stream.WriteResponse(renewReq); err != nil { - release() + s.w.ReleaseInputs([]types.Transaction{renewTxn}, nil) return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to write renew request: %w", err) } var hostAdditions rhp3.RPCRenewContractHostAdditions if err := stream.ReadResponse(&hostAdditions, 4096); err != nil { - release() + s.w.ReleaseInputs([]types.Transaction{renewTxn}, nil) return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to read host additions response: %w", err) } else if !s.hostKey.VerifyHash(clearingSigHash, hostAdditions.FinalRevisionSignature) { - release() + s.w.ReleaseInputs([]types.Transaction{renewTxn}, nil) return rhp2.ContractRevision{}, nil, fmt.Errorf("host final revision signature invalid") } // add the host's additions to the transaction set @@ -489,11 +491,7 @@ func (s *Session) RenewContract(revision *rhp2.ContractRevision, hostAddr types. renewTxn.SiacoinOutputs = append(renewTxn.SiacoinOutputs, hostAdditions.SiacoinOutputs...) // sign the transaction - if err := s.w.SignTransaction(state, &renewTxn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to sign transaction: %w", err) - } - + s.w.SignTransaction(&renewTxn, toSign, types.CoveredFields{WholeTransaction: true}) renewRevision := initialRevision(&renewTxn, s.hostKey.UnlockKey(), renterKey.PublicKey().UnlockKey()) renewSigHash := hashRevision(renewRevision) renterSig := renterKey.SignHash(renewSigHash) @@ -509,16 +507,16 @@ func (s *Session) RenewContract(revision *rhp2.ContractRevision, hostAddr types. }, } if err := stream.WriteResponse(renterSigsResp); err != nil { - release() + s.w.ReleaseInputs([]types.Transaction{renewTxn}, nil) return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to write renter signatures: %w", err) } var hostSigsResp rhp3.RPCRenewSignatures if err := stream.ReadResponse(&hostSigsResp, 4096); err != nil { - release() + s.w.ReleaseInputs([]types.Transaction{renewTxn}, nil) return rhp2.ContractRevision{}, nil, fmt.Errorf("failed to read host signatures: %w", err) } else if err := validateHostRevisionSignature(hostSigsResp.RevisionSignature, renewRevision.ParentID, renewSigHash, s.hostKey); err != nil { - release() + s.w.ReleaseInputs([]types.Transaction{renewTxn}, nil) return rhp2.ContractRevision{}, nil, fmt.Errorf("invalid host revision signature: %w", err) } return rhp2.ContractRevision{ diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 00000000..ee7439aa --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,239 @@ +package testutil + +import ( + "context" + "net" + "path/filepath" + "testing" + "time" + + "go.sia.tech/core/consensus" + "go.sia.tech/core/gateway" + "go.sia.tech/core/types" + "go.sia.tech/coreutils" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/syncer" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/accounts" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/host/registry" + "go.sia.tech/hostd/host/settings" + "go.sia.tech/hostd/host/storage" + "go.sia.tech/hostd/index" + "go.sia.tech/hostd/persist/sqlite" + "go.uber.org/zap" +) + +type ( + // A ConsensusNode is a node with the core consensus components + ConsensusNode struct { + Store *sqlite.Store + Chain *chain.Manager + Syncer *syncer.Syncer + } + + // A HostNode is a node with the core wallet components and the host + // components + HostNode struct { + ConsensusNode + + Settings *settings.ConfigManager + Wallet *wallet.SingleAddressWallet + Contracts *contracts.Manager + Volumes *storage.VolumeManager + Indexer *index.Manager + + Accounts *accounts.AccountManager + Registry *registry.Manager + } +) + +// V1Network is a test helper that returns a consensus.Network and genesis block +// suited for testing the v1 network +func V1Network() (*consensus.Network, types.Block) { + // use a modified version of Zen + n, genesisBlock := chain.TestnetZen() + n.InitialTarget = types.BlockID{0xFF} + n.HardforkDevAddr.Height = 1 + n.HardforkTax.Height = 1 + n.HardforkStorageProof.Height = 1 + n.HardforkOak.Height = 1 + n.HardforkASIC.Height = 1 + n.HardforkFoundation.Height = 1 + n.HardforkV2.AllowHeight = 500 // comfortably above MaturityHeight + n.HardforkV2.RequireHeight = 600 + return n, genesisBlock +} + +// V2Network is a test helper that returns a consensus.Network and genesis block +// suited for testing after the v2 hardfork +func V2Network() (*consensus.Network, types.Block) { + // use a modified version of Zen + n, genesisBlock := chain.TestnetZen() + n.InitialTarget = types.BlockID{0xFF} + n.HardforkDevAddr.Height = 1 + n.HardforkTax.Height = 1 + n.HardforkStorageProof.Height = 1 + n.HardforkOak.Height = 1 + n.HardforkASIC.Height = 1 + n.HardforkFoundation.Height = 1 + n.HardforkV2.AllowHeight = 145 // just above the maturity height + n.HardforkV2.RequireHeight = 180 + return n, genesisBlock +} + +// WaitForSync is a helper to wait for the chain and indexer to sync +func WaitForSync(t *testing.T, cm *chain.Manager, idx *index.Manager) { + t.Helper() + + for { + if cm.Tip() == idx.Tip() { + break + } + time.Sleep(time.Millisecond) + } +} + +// MineBlocks is a helper to mine blocks and broadcast the headers +func MineBlocks(t *testing.T, cn *ConsensusNode, addr types.Address, n int) { + t.Helper() + + for i := 0; i < n; i++ { + b, ok := coreutils.MineBlock(cn.Chain, addr, 5*time.Second) + if !ok { + t.Fatal("failed to mine block") + } else if err := cn.Chain.AddBlocks([]types.Block{b}); err != nil { + t.Fatal(err) + } + + if b.V2 == nil { + cn.Syncer.BroadcastHeader(gateway.BlockHeader{ + ParentID: b.ParentID, + Nonce: b.Nonce, + Timestamp: b.Timestamp, + MerkleRoot: b.MerkleRoot(), + }) + } else { + cn.Syncer.BroadcastV2BlockOutline(gateway.OutlineBlock(b, cn.Chain.PoolTransactions(), cn.Chain.V2PoolTransactions())) + } + } +} + +// MineAndSync is a helper to mine blocks and wait for the index to catch up +// between each block +func MineAndSync(t *testing.T, hn *HostNode, addr types.Address, n int) { + t.Helper() + + for i := 0; i < n; i++ { + MineBlocks(t, &hn.ConsensusNode, addr, 1) + WaitForSync(t, hn.Chain, hn.Indexer) + } +} + +// NewConsensusNode initializes all of the consensus components and returns them. +// The function will clean up all resources when the test is done. +func NewConsensusNode(t *testing.T, network *consensus.Network, genesis types.Block, log *zap.Logger) *ConsensusNode { + t.Helper() + + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "hostd.sqlite3"), log.Named("sqlite")) + if err != nil { + t.Fatal("failed to open sqlite store:", err) + } + t.Cleanup(func() { db.Close() }) + + chainDB, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal("failed to open chain db:", err) + } + t.Cleanup(func() { chainDB.Close() }) + + cs, tipState, err := chain.NewDBStore(chainDB, network, genesis) + if err != nil { + t.Fatal("failed to create chain store:", err) + } + cm := chain.NewManager(cs, tipState) + + syncerListener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("failed to listen:", err) + } + t.Cleanup(func() { syncerListener.Close() }) + + ps, err := sqlite.NewPeerStore(db) + if err != nil { + t.Fatal("failed to create peer store:", err) + } + + syncer := syncer.New(syncerListener, cm, ps, gateway.Header{ + GenesisID: genesis.ID(), + UniqueID: gateway.GenerateUniqueID(), + NetAddress: syncerListener.Addr().String(), + }) + go syncer.Run(context.Background()) + t.Cleanup(func() { syncer.Close() }) + + return &ConsensusNode{ + Store: db, + Chain: cm, + Syncer: syncer, + } +} + +// NewHostNode initializes all of the hostd components and returns them. The function +// will clean up all resources when the test is done. +func NewHostNode(t *testing.T, pk types.PrivateKey, network *consensus.Network, genesis types.Block, log *zap.Logger) *HostNode { + t.Helper() + + cn := NewConsensusNode(t, network, genesis, log) + + wm, err := wallet.NewSingleAddressWallet(pk, cn.Chain, cn.Store) + if err != nil { + t.Fatal("failed to create wallet:", err) + } + t.Cleanup(func() { wm.Close() }) + + vm, err := storage.NewVolumeManager(cn.Store, storage.WithLogger(log.Named("storage"))) + if err != nil { + t.Fatal("failed to create volume manager:", err) + } + t.Cleanup(func() { vm.Close() }) + + contracts, err := contracts.NewManager(cn.Store, vm, cn.Chain, cn.Syncer, wm, contracts.WithRejectAfter(10), contracts.WithRevisionSubmissionBuffer(5), contracts.WithLog(log)) + if err != nil { + t.Fatal("failed to create contracts manager:", err) + } + t.Cleanup(func() { contracts.Close() }) + + initialSettings := settings.DefaultSettings + initialSettings.AcceptingContracts = true + initialSettings.NetAddress = "127.0.0.1:9981" + initialSettings.WindowSize = 10 + sm, err := settings.NewConfigManager(pk, cn.Store, cn.Chain, cn.Syncer, wm, settings.WithAnnounceInterval(10), settings.WithValidateNetAddress(false), settings.WithInitialSettings(initialSettings)) + if err != nil { + t.Fatal(err) + } + + idx, err := index.NewManager(cn.Store, cn.Chain, contracts, wm, sm, vm, index.WithLog(log.Named("index")), index.WithBatchSize(0)) // off-by-one + if err != nil { + t.Fatal("failed to create index manager:", err) + } + t.Cleanup(func() { idx.Close() }) + + am := accounts.NewManager(cn.Store, sm) + rm := registry.NewManager(pk, cn.Store, log.Named("registry")) + t.Cleanup(func() { rm.Close() }) + + return &HostNode{ + ConsensusNode: *cn, + + Settings: sm, + Wallet: wm, + Contracts: contracts, + Volumes: vm, + Indexer: idx, + + Accounts: am, + Registry: rm, + } +} diff --git a/persist/sqlite/accounts.go b/persist/sqlite/accounts.go index fddc891a..0a3e354c 100644 --- a/persist/sqlite/accounts.go +++ b/persist/sqlite/accounts.go @@ -15,21 +15,25 @@ import ( // AccountBalance returns the balance of the account with the given ID. func (s *Store) AccountBalance(accountID rhp3.Account) (balance types.Currency, err error) { - _, balance, err = accountBalance(&dbTxn{s}, accountID) - if errors.Is(err, sql.ErrNoRows) { - return types.ZeroCurrency, nil - } + err = s.transaction(func(tx *txn) error { + _, balance, err = accountBalance(tx, accountID) + if errors.Is(err, sql.ErrNoRows) { + err = nil + return nil + } + return err + }) return } -func incrementContractAccountFunding(tx txn, accountID, contractID int64, amount types.Currency) error { +func incrementContractAccountFunding(tx *txn, accountID, contractID int64, amount types.Currency) error { var fundingValue types.Currency - err := tx.QueryRow(`SELECT amount FROM contract_account_funding WHERE contract_id=$1 AND account_id=$2`, contractID, accountID).Scan((*sqlCurrency)(&fundingValue)) + err := tx.QueryRow(`SELECT amount FROM contract_account_funding WHERE contract_id=$1 AND account_id=$2`, contractID, accountID).Scan(decode(&fundingValue)) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to get fund amount: %w", err) } fundingValue = fundingValue.Add(amount) - _, err = tx.Exec(`INSERT INTO contract_account_funding (contract_id, account_id, amount) VALUES ($1, $2, $3) ON CONFLICT (contract_id, account_id) DO UPDATE SET amount=EXCLUDED.amount`, contractID, accountID, sqlCurrency(fundingValue)) + _, err = tx.Exec(`INSERT INTO contract_account_funding (contract_id, account_id, amount) VALUES ($1, $2, $3) ON CONFLICT (contract_id, account_id) DO UPDATE SET amount=EXCLUDED.amount`, contractID, accountID, encode(fundingValue)) if err != nil { return fmt.Errorf("failed to update funding source: %w", err) } @@ -38,7 +42,7 @@ func incrementContractAccountFunding(tx txn, accountID, contractID int64, amount // CreditAccountWithContract adds the specified amount to the account with the given ID. func (s *Store) CreditAccountWithContract(fund accounts.FundAccountWithContract) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { // get current balance accountID, balance, err := accountBalance(tx, fund.Account) exists := err == nil @@ -48,7 +52,7 @@ func (s *Store) CreditAccountWithContract(fund accounts.FundAccountWithContract) // update balance balance = balance.Add(fund.Amount) const query = `INSERT INTO accounts (account_id, balance, expiration_timestamp) VALUES ($1, $2, $3) ON CONFLICT (account_id) DO UPDATE SET balance=EXCLUDED.balance, expiration_timestamp=EXCLUDED.expiration_timestamp RETURNING id` - err = tx.QueryRow(query, sqlHash256(fund.Account), sqlCurrency(balance), sqlTime(fund.Expiration)).Scan(&accountID) + err = tx.QueryRow(query, encode(fund.Account), encode(balance), encode(fund.Expiration)).Scan(&accountID) if err != nil { return fmt.Errorf("failed to update balance: %w", err) } @@ -94,7 +98,7 @@ func (s *Store) CreditAccountWithContract(fund accounts.FundAccountWithContract) // ID. Returns the remaining balance of the account. func (s *Store) DebitAccount(accountID rhp3.Account, usage accounts.Usage) error { amount := usage.Total() - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { dbID, balance, err := accountBalance(tx, accountID) if err != nil { return fmt.Errorf("failed to query balance: %w", err) @@ -105,7 +109,7 @@ func (s *Store) DebitAccount(accountID rhp3.Account, usage accounts.Usage) error // update balance balance = balance.Sub(amount) const query = `UPDATE accounts SET balance=$1 WHERE id=$8 RETURNING id` - err = tx.QueryRow(query, sqlCurrency(balance), dbID).Scan(&dbID) + err = tx.QueryRow(query, encode(balance), dbID).Scan(&dbID) if err != nil { return fmt.Errorf("failed to update balance: %w", err) } else if err := updateContractUsage(tx, dbID, usage, s.log); err != nil { @@ -123,19 +127,22 @@ func (s *Store) DebitAccount(accountID rhp3.Account, usage accounts.Usage) error // Accounts returns all accounts in the database paginated. func (s *Store) Accounts(limit, offset int) (acc []accounts.Account, err error) { - rows, err := s.query(`SELECT account_id, balance, expiration_timestamp FROM accounts LIMIT $1 OFFSET $2`, limit, offset) - if err != nil { - return nil, err - } - defer rows.Close() + err = s.transaction(func(tx *txn) error { + rows, err := tx.Query(`SELECT account_id, balance, expiration_timestamp FROM accounts LIMIT $1 OFFSET $2`, limit, offset) + if err != nil { + return err + } + defer rows.Close() - for rows.Next() { - var a accounts.Account - if err := rows.Scan((*sqlHash256)(&a.ID), (*sqlCurrency)(&a.Balance), (*sqlTime)(&a.Expiration)); err != nil { - return nil, fmt.Errorf("failed to scan row: %w", err) + for rows.Next() { + var a accounts.Account + if err := rows.Scan(decode(&a.ID), decode(&a.Balance), decode(&a.Expiration)); err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + acc = append(acc, a) } - acc = append(acc, a) - } + return rows.Err() + }) return } @@ -147,30 +154,35 @@ INNER JOIN accounts a ON a.id=caf.account_id INNER JOIN contracts c ON c.id=caf.contract_id WHERE a.account_id=$1` - rows, err := s.query(query, sqlHash256(account)) - if err != nil { - return nil, err - } - defer rows.Close() + err = s.transaction(func(tx *txn) error { + rows, err := tx.Query(query, encode(account)) + if err != nil { + return err + } + defer rows.Close() - for rows.Next() { - var src accounts.FundingSource - if err := rows.Scan((*sqlHash256)(&src.AccountID), (*sqlHash256)(&src.ContractID), (*sqlCurrency)(&src.Amount)); err != nil { - return nil, fmt.Errorf("failed to scan row: %w", err) + for rows.Next() { + var src accounts.FundingSource + if err := rows.Scan(decode((*types.PublicKey)(&src.AccountID)), decode(&src.ContractID), decode(&src.Amount)); err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + srcs = append(srcs, src) } - srcs = append(srcs, src) - } + return rows.Err() + }) return } // PruneAccounts removes all accounts that have expired func (s *Store) PruneAccounts(height uint64) error { - _, err := s.exec(`DELETE FROM accounts WHERE expiration_height<$1`, height) - return err + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(`DELETE FROM accounts WHERE expiration_height<$1`, height) + return err + }) } -func accountBalance(tx txn, accountID rhp3.Account) (dbID int64, balance types.Currency, err error) { - err = tx.QueryRow(`SELECT id, balance FROM accounts WHERE account_id=$1`, sqlHash256(accountID)).Scan(&dbID, (*sqlCurrency)(&balance)) +func accountBalance(tx *txn, accountID rhp3.Account) (dbID int64, balance types.Currency, err error) { + err = tx.QueryRow(`SELECT id, balance FROM accounts WHERE account_id=$1`, encode(accountID)).Scan(&dbID, decode(&balance)) return } @@ -181,7 +193,7 @@ type fundAmount struct { } // contractFunding returns all contracts that were used to fund the account. -func contractFunding(tx txn, accountID int64) (fund []fundAmount, err error) { +func contractFunding(tx *txn, accountID int64) (fund []fundAmount, err error) { rows, err := tx.Query(`SELECT id, contract_id, amount FROM contract_account_funding WHERE account_id=$1`, accountID) if err != nil { return nil, err @@ -190,7 +202,7 @@ func contractFunding(tx txn, accountID int64) (fund []fundAmount, err error) { for rows.Next() { var f fundAmount - if err := rows.Scan(&f.ID, &f.ContractID, (*sqlCurrency)(&f.Amount)); err != nil { + if err := rows.Scan(&f.ID, &f.ContractID, decode(&f.Amount)); err != nil { return nil, fmt.Errorf("failed to scan row: %w", err) } else if f.Amount.IsZero() { continue @@ -202,7 +214,7 @@ func contractFunding(tx txn, accountID int64) (fund []fundAmount, err error) { // updateContractUsage distributes account usage to the contracts that funded // the account. -func updateContractUsage(tx txn, accountID int64, usage accounts.Usage, log *zap.Logger) error { +func updateContractUsage(tx *txn, accountID int64, usage accounts.Usage, log *zap.Logger) error { funding, err := contractFunding(tx, accountID) if err != nil { return fmt.Errorf("failed to get contract funding: %w", err) @@ -274,16 +286,16 @@ func updateContractUsage(tx txn, accountID int64, usage accounts.Usage, log *zap return nil } -func setContractRemainingFunds(tx txn, contractID int64, amount types.Currency) error { - return tx.QueryRow(`UPDATE contracts SET account_funding=$1 WHERE id=$2 RETURNING id`, sqlCurrency(amount), contractID).Scan(&contractID) +func setContractRemainingFunds(tx *txn, contractID int64, amount types.Currency) error { + return tx.QueryRow(`UPDATE contracts SET account_funding=$1 WHERE id=$2 RETURNING id`, encode(amount), contractID).Scan(&contractID) } -func setContractAccountFunding(tx txn, fundingID int64, amount types.Currency) error { +func setContractAccountFunding(tx *txn, fundingID int64, amount types.Currency) error { if amount.IsZero() { _, err := tx.Exec(`DELETE FROM contract_account_funding WHERE id=$1`, fundingID) return err } - _, err := tx.Exec(`UPDATE contract_account_funding SET amount=$1 WHERE id=$2`, sqlCurrency(amount), fundingID) + _, err := tx.Exec(`UPDATE contract_account_funding SET amount=$1 WHERE id=$2`, encode(amount), fundingID) return err } diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go new file mode 100644 index 00000000..01b69373 --- /dev/null +++ b/persist/sqlite/consensus.go @@ -0,0 +1,1790 @@ +package sqlite + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/host/settings" + "go.sia.tech/hostd/index" + "go.uber.org/zap" +) + +type ( + updateTx struct { + tx *txn + + relevant map[types.Hash256]bool // map to prevent rare duplicate selects + } + + contractState struct { + ID int64 + LockedCollateral types.Currency + Usage contracts.Usage + Status contracts.ContractStatus + } + + v2ContractState struct { + ID int64 + LockedCollateral types.Currency + Usage contracts.V2Usage + Status contracts.V2ContractStatus + } +) + +var _ index.UpdateTx = (*updateTx)(nil) + +// ResetChainState resets the consensus state of the store. This +// should only occur if the user has reset their consensus database to +// sync from scratch. +func (s *Store) ResetChainState() error { + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(` +-- v2 contracts +DELETE FROM contracts_v2_chain_index_elements; +DELETE FROM contract_v2_state_elements; +-- wallet +DELETE FROM wallet_siacoin_elements; +DELETE FROM wallet_events; +DELETE FROM host_stats WHERE stat IN (?,?); -- reset wallet stats since they are derived from the chain +-- settings +UPDATE global_settings SET last_scanned_index=NULL, last_announce_index=NULL, last_announce_address=NULL`, metricWalletBalance, metricWalletImmatureBalance) + return err + }) +} + +// WalletStateElements returns all state elements related to the wallet. It is used +// to update the proofs of all state elements affected by the update. +func (ux *updateTx) WalletStateElements() (elements []types.StateElement, err error) { + const query = `SELECT id, merkle_proof, leaf_index FROM wallet_siacoin_elements` + rows, err := ux.tx.Query(query) + if err != nil { + return nil, fmt.Errorf("failed to query wallet state elements: %w", err) + } + defer rows.Close() + + for rows.Next() { + var se types.StateElement + if err := rows.Scan(decode(&se.ID), decode(&se.MerkleProof), decode(&se.LeafIndex)); err != nil { + return nil, fmt.Errorf("failed to scan element: %w", err) + } + elements = append(elements, se) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan error: %w", err) + } + return elements, nil +} + +// UpdateWalletStateElements updates the proofs of all state elements affected by the +// update. +func (ux *updateTx) UpdateWalletStateElements(elements []types.StateElement) error { + if len(elements) == 0 { + return nil + } + stmt, err := ux.tx.Prepare(`UPDATE wallet_siacoin_elements SET merkle_proof=?, leaf_index=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer stmt.Close() + + for _, se := range elements { + if _, err := stmt.Exec(encode(se.MerkleProof), encode(se.LeafIndex), encode(se.ID)); err != nil { + return fmt.Errorf("failed to update wallet state element: %w", err) + } + } + return nil +} + +// WalletApplyIndex is called with the chain index that is being applied. +// Any transactions and siacoin elements that were created by the index +// should be added and any siacoin elements that were spent should be +// removed. +func (ux *updateTx) WalletApplyIndex(index types.ChainIndex, created, spent []types.SiacoinElement, events []wallet.Event, timestamp time.Time) error { + matureOutflow, immatureOutflow, err := deleteSiacoinElements(ux.tx, index, spent) + if err != nil { + return fmt.Errorf("failed to delete siacoin elements: %w", err) + } + matureInflow, immatureInflow, err := createSiacoinElements(ux.tx, index, created) + if err != nil { + return fmt.Errorf("failed to create siacoin elements: %w", err) + } else if err := createWalletEvents(ux.tx, events); err != nil { + return fmt.Errorf("failed to create wallet events: %w", err) + } + + // get the matured balance + matured, err := maturedSiacoinBalance(ux.tx, index) + if err != nil { + return fmt.Errorf("failed to query matured siacoin balance: %w", err) + } + // apply the maturation by adding the matured balance to the matured inflow + // and the immature outflow + matureInflow = matureInflow.Add(matured) + immatureOutflow = immatureOutflow.Add(matured) + // update the balance metrics + if err := updateBalanceMetric(ux.tx, matureInflow, matureOutflow, immatureInflow, immatureOutflow, timestamp); err != nil { + return fmt.Errorf("failed to update wallet balance: %w", err) + } + return nil +} + +// WalletRevertIndex is called with the chain index that is being reverted. +// Any transactions that were added by the index should be removed +// +// removed contains the siacoin elements that were created by the index +// and should be deleted. +// +// unspent contains the siacoin elements that were spent and should be +// recreated. They are not necessarily created by the index and should +// not be associated with it. +func (ux *updateTx) WalletRevertIndex(index types.ChainIndex, removed, unspent []types.SiacoinElement, timestamp time.Time) error { + matureOutflow, immatureOutflow, err := deleteSiacoinElements(ux.tx, index, removed) + if err != nil { + return fmt.Errorf("failed to delete siacoin elements: %w", err) + } + matureInflow, immatureInflow, err := createSiacoinElements(ux.tx, index, unspent) + if err != nil { + return fmt.Errorf("failed to create siacoin elements: %w", err) + } else if _, err := ux.tx.Exec(`DELETE FROM wallet_events WHERE chain_index=?`, encode(index)); err != nil { + return fmt.Errorf("failed to delete wallet events: %w", err) + } + + // get the matured balance + matured, err := maturedSiacoinBalance(ux.tx, index) + if err != nil { + return fmt.Errorf("failed to query matured siacoin balance: %w", err) + } + // revert the maturation by adding the matured balance to the matured outflow + // and the immature inflow + matureOutflow = matureOutflow.Add(matured) + immatureInflow = immatureInflow.Add(matured) + // update the balance metrics + if err := updateBalanceMetric(ux.tx, matureInflow, matureOutflow, immatureInflow, immatureOutflow, timestamp); err != nil { + return fmt.Errorf("failed to increment wallet balance: %w", err) + } + return nil +} + +// RevertContractChainIndexElement removes a reverted chain index from the store +func (ux *updateTx) RevertContractChainIndexElement(index types.ChainIndex) error { + _, err := ux.tx.Exec(`DELETE FROM contracts_v2_chain_index_elements WHERE height=? AND id=?`, index.Height, encode(index.ID)) + return err +} + +// ContractChainIndexElements returns chain index state elements that +// need to be updated. The elements must be ordered by height. +func (ux *updateTx) ContractChainIndexElements() (elements []types.ChainIndexElement, err error) { + rows, err := ux.tx.Query(`SELECT id, height, merkle_proof, leaf_index FROM contracts_v2_chain_index_elements ORDER BY height ASC`) + if err != nil { + return nil, fmt.Errorf("failed to query contract chain index state elements: %w", err) + } + defer rows.Close() + + for rows.Next() { + var ele types.ChainIndexElement + if err := rows.Scan(decode(&ele.ChainIndex.ID), &ele.ChainIndex.Height, decode(&ele.MerkleProof), decode(&ele.LeafIndex)); err != nil { + return nil, fmt.Errorf("failed to scan state element: %w", err) + } + ele.ID = types.Hash256(ele.ChainIndex.ID) + elements = append(elements, ele) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to scan contract chain index state elements: %w", err) + } + return elements, nil +} + +// ApplyContractChainIndexElements adds or updates the merkle proof of +// chain index state elements +func (ux *updateTx) ApplyContractChainIndexElements(elements []types.ChainIndexElement) error { + if len(elements) == 0 { + return nil + } + + stmt, err := ux.tx.Prepare(`INSERT INTO contracts_v2_chain_index_elements (id, height, merkle_proof, leaf_index) VALUES (?, ?, ?, ?) ON CONFLICT (id) DO UPDATE SET merkle_proof=EXCLUDED.merkle_proof, leaf_index=EXCLUDED.leaf_index, height=EXCLUDED.height`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer stmt.Close() + + for _, se := range elements { + if _, err := stmt.Exec(encode(se.ChainIndex.ID), se.ChainIndex.Height, encode(se.MerkleProof), encode(se.LeafIndex)); err != nil { + return fmt.Errorf("failed to update contract chain index state element: %w", err) + } + } + return nil +} + +// DeleteExpiredContractChainIndexElements deletes chain index state +// elements that are no long necessary +func (ux *updateTx) DeleteExpiredContractChainIndexElements(height uint64) error { + _, err := ux.tx.Exec(`DELETE FROM contracts_v2_chain_index_elements WHERE height <= ?`, height) + return err +} + +// ContractStateElements returns all state elements from the contract +// store +func (ux *updateTx) ContractStateElements() (elements []types.StateElement, err error) { + rows, err := ux.tx.Query(`SELECT c.contract_id, cs.merkle_proof, cs.leaf_index FROM contract_v2_state_elements cs +INNER JOIN contracts_v2 c ON (c.id=cs.contract_id)`) + if err != nil { + return nil, fmt.Errorf("failed to query contract state elements: %w", err) + } + defer rows.Close() + + for rows.Next() { + var se types.StateElement + if err := rows.Scan(decode(&se.ID), decode(&se.MerkleProof), decode(&se.LeafIndex)); err != nil { + return nil, fmt.Errorf("failed to scan contract state element: %w", err) + } + elements = append(elements, se) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to scan contract state elements: %w", err) + } + return elements, nil +} + +// UpdateContractStateElements updates the state elements in the host +// contract store +func (ux *updateTx) UpdateContractStateElements(elements []types.StateElement) error { + if len(elements) == 0 { + return nil + } + + stmt, err := ux.tx.Prepare(`UPDATE contract_v2_state_elements SET merkle_proof=?, leaf_index=? WHERE contract_id=(SELECT id FROM contracts_v2 WHERE contract_id=?)`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer stmt.Close() + + for _, se := range elements { + if _, err := stmt.Exec(encode(se.MerkleProof), encode(se.LeafIndex), encode(se.ID)); err != nil { + return fmt.Errorf("failed to update contract state element %q: %w", se.ID, err) + } + } + return nil +} + +// ApplyContracts applies relevant contract changes to the contract +// store +func (ux *updateTx) ApplyContracts(index types.ChainIndex, state contracts.StateChanges) error { + if err := applyContractFormation(ux.tx, state.Confirmed, ux.tx.log.Named("applyContractFormation")); err != nil { + return fmt.Errorf("failed to apply contract formation: %w", err) + } else if err := applyContractRevision(ux.tx, state.Revised); err != nil { + return fmt.Errorf("failed to apply contract revisions: %w", err) + } else if err := applySuccessfulContracts(ux.tx, index, state.Successful); err != nil { + return fmt.Errorf("failed to apply contract resolution: %w", err) + } else if err := applyFailedContracts(ux.tx, state.Failed); err != nil { + return fmt.Errorf("failed to apply contract failures: %w", err) + } + + // v2 + if err := applyV2ContractFormation(ux.tx, index, state.ConfirmedV2, ux.tx.log.Named("applyV2ContractFormation")); err != nil { + return fmt.Errorf("failed to apply v2 contract formation: %w", err) + } else if err := applyV2ContractRevision(ux.tx, state.RevisedV2); err != nil { + return fmt.Errorf("failed to apply v2 contract revisions: %w", err) + } else if err := applySuccessfulV2Contracts(ux.tx, index, contracts.V2ContractStatusSuccessful, state.SuccessfulV2); err != nil { + return fmt.Errorf("failed to apply successful v2 resolution: %w", err) + } else if err := applySuccessfulV2Contracts(ux.tx, index, contracts.V2ContractStatusFinalized, state.FinalizedV2); err != nil { + return fmt.Errorf("failed to apply v2 finalized v2 resolution: %w", err) + } else if err := applySuccessfulV2Contracts(ux.tx, index, contracts.V2ContractStatusRenewed, state.RenewedV2); err != nil { + return fmt.Errorf("failed to apply v2 renewed v2 resolution: %w", err) + } else if err := applyFailedV2Contracts(ux.tx, index, state.FailedV2); err != nil { + return fmt.Errorf("failed to apply v2 failure resolution: %w", err) + } + return nil +} + +// RevertContracts reverts relevant contract changes from the contract +// store +func (ux *updateTx) RevertContracts(index types.ChainIndex, state contracts.StateChanges) error { + if err := revertContractFormation(ux.tx, state.Confirmed); err != nil { + return fmt.Errorf("failed to revert contract formation: %w", err) + } else if err := applyContractRevision(ux.tx, state.Revised); err != nil { // note: this is correct. The previous revision is being applied + return fmt.Errorf("failed to revert contract revisions: %w", err) + } else if err := revertSuccessfulContracts(ux.tx, state.Successful); err != nil { + return fmt.Errorf("failed to revert contract resolution: %w", err) + } else if err := revertFailedContracts(ux.tx, state.Failed); err != nil { + return fmt.Errorf("failed to revert contract failures: %w", err) + } + + // v2 + if err := revertV2ContractFormation(ux.tx, state.ConfirmedV2); err != nil { + return fmt.Errorf("failed to revert v2 contract formation: %w", err) + } else if err := applyV2ContractRevision(ux.tx, state.RevisedV2); err != nil { // note: this is correct. The previous revision is being applied + return fmt.Errorf("failed to revert v2 contract revisions: %w", err) + } else if err := revertSuccessfulV2Contracts(ux.tx, contracts.V2ContractStatusSuccessful, state.SuccessfulV2); err != nil { + return fmt.Errorf("failed to revert v2 successful resolution: %w", err) + } else if err := revertSuccessfulV2Contracts(ux.tx, contracts.V2ContractStatusFinalized, state.FinalizedV2); err != nil { + return fmt.Errorf("failed to revert v2 finalized resolution: %w", err) + } else if err := revertSuccessfulV2Contracts(ux.tx, contracts.V2ContractStatusRenewed, state.RenewedV2); err != nil { + return fmt.Errorf("failed to revert v2 renewed resolution: %w", err) + } else if err := revertFailedV2Contracts(ux.tx, state.FailedV2); err != nil { + return fmt.Errorf("failed to revert v2 failure resolution: %w", err) + } + return nil +} + +// RejectContracts returns any contracts with a negotiation height +// before the provided height that have not been confirmed. +func (ux *updateTx) RejectContracts(height uint64) ([]types.FileContractID, []types.FileContractID, error) { + rejected, err := rejectContracts(ux.tx, height) + if err != nil { + return nil, nil, fmt.Errorf("failed to get rejected contracts: %w", err) + } + + rejectedV2, err := rejectV2Contracts(ux.tx, height) + if err != nil { + return nil, nil, fmt.Errorf("failed to get rejected v2 contracts: %w", err) + } + + if len(rejected) == 0 && len(rejectedV2) == 0 { + return nil, nil, nil + } + + contractState, stateDone, err := getContractStateStmt(ux.tx) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare select statement: %w", err) + } + defer stateDone() + + contractStateV2, stateDoneV2, err := getV2ContractStateStmt(ux.tx) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare select statement: %w", err) + } + defer stateDoneV2() + + incrementNumericStat, numericStatDone, err := incrementNumericStatStmt(ux.tx) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer numericStatDone() + + updateV1Status, err := ux.tx.Prepare(`UPDATE contracts SET contract_status=? WHERE id=?`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateV1Status.Close() + + updateV2Status, err := ux.tx.Prepare(`UPDATE contracts_v2 SET contract_status=? WHERE id=?`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare v2 update statement: %w", err) + } + defer updateV2Status.Close() + + for _, id := range rejected { + state, err := contractState(id) + if err != nil { + return nil, nil, fmt.Errorf("failed to get contract state %q: %w", id, err) + } + + if state.Status != contracts.ContractStatusPending { + // orderly applies and reverts should prevent this from happening + panic(fmt.Sprintf("unexpected contract status %v", state.Status)) + } + + // update metrics + if _, err := updateV1Status.Exec(contracts.ContractStatusRejected, state.ID); err != nil { + return nil, nil, fmt.Errorf("failed to update contract status: %w", err) + } else if err := updateStatusMetrics(state.Status, contracts.ContractStatusRejected, incrementNumericStat); err != nil { + return nil, nil, fmt.Errorf("failed to update contract metrics: %w", err) + } + } + for _, id := range rejectedV2 { + state, err := contractStateV2(id) + if err != nil { + return nil, nil, fmt.Errorf("failed to get contract state %q: %w", id, err) + } + + if state.Status != contracts.V2ContractStatusPending { + // orderly applies and reverts should prevent this from happening + panic(fmt.Sprintf("unexpected contract status %v", state.Status)) + } + + // update metrics + if _, err := updateV2Status.Exec(contracts.V2ContractStatusRejected, state.ID); err != nil { + return nil, nil, fmt.Errorf("failed to update contract status: %w", err) + } else if err := updateV2StatusMetrics(state.Status, contracts.V2ContractStatusRejected, incrementNumericStat); err != nil { + return nil, nil, fmt.Errorf("failed to update contract metrics: %w", err) + } + } + return rejected, rejectedV2, nil +} + +// ContractRelevant returns true if a contract is relevant to the host. Otherwise, +// it returns false. +func (ux *updateTx) ContractRelevant(id types.FileContractID) (relevant bool, err error) { + if ux.relevant == nil { + ux.relevant = make(map[types.Hash256]bool) + } + + if relevant, ok := ux.relevant[types.Hash256(id)]; ok { + return relevant, nil + } + + err = ux.tx.QueryRow(`SELECT 1 FROM contracts WHERE contract_id=?`, encode(id)).Scan(&relevant) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil { + return false, err + } + ux.relevant[types.Hash256(id)] = relevant + return +} + +// V2ContractRelevant returns true if the v2 contract is relevant to the host. +// Otherwise, it returns false. +func (ux *updateTx) V2ContractRelevant(id types.FileContractID) (relevant bool, err error) { + if ux.relevant == nil { + ux.relevant = make(map[types.Hash256]bool) + } + + if relevant, ok := ux.relevant[types.Hash256(id)]; ok { + return relevant, nil + } + + err = ux.tx.QueryRow(`SELECT 1 FROM contracts_v2 WHERE contract_id=?`, encode(id)).Scan(&relevant) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil { + return false, err + } + ux.relevant[types.Hash256(id)] = relevant + return +} + +func (ux *updateTx) LastAnnouncement() (announcement settings.Announcement, err error) { + var addr sql.NullString + err = ux.tx.QueryRow(`SELECT last_announce_index, last_announce_address FROM global_settings`).Scan(decodeNullable(&announcement.Index), &addr) + if addr.Valid { + announcement.Address = addr.String + } + return +} + +func (ux *updateTx) RevertLastAnnouncement() error { + _, err := ux.tx.Exec(`UPDATE global_settings SET last_announce_address=NULL, last_announce_index=NULL`) + return err +} + +func (ux *updateTx) SetLastAnnouncement(announcement settings.Announcement) error { + _, err := ux.tx.Exec(`UPDATE global_settings SET last_announce_index=?, last_announce_address=?`, encode(announcement.Index), announcement.Address) + return err +} + +func (ux *updateTx) SetLastIndex(index types.ChainIndex) error { + _, err := ux.tx.Exec(`UPDATE global_settings SET last_scanned_index=?`, encode(index)) + return err +} + +// Tip returns the last scanned chain index. +func (s *Store) Tip() (index types.ChainIndex, err error) { + err = s.transaction(func(tx *txn) error { + err = tx.QueryRow(`SELECT last_scanned_index FROM global_settings`).Scan(decodeNullable(&index)) + return err + }) + return +} + +// UpdateChainState updates the chain state with the given updates. +func (s *Store) UpdateChainState(fn func(index.UpdateTx) error) error { + return s.transaction(func(tx *txn) error { + return fn(&updateTx{tx: tx}) + }) +} + +// maturedSiacoinBalance helper to query the sum of matured siacoin elements +// for a given height. +func maturedSiacoinBalance(tx *txn, index types.ChainIndex) (inflow types.Currency, err error) { + rows, err := tx.Query(`SELECT siacoin_value FROM wallet_siacoin_elements WHERE maturity_height=?`, index.Height) + if err != nil { + return types.ZeroCurrency, fmt.Errorf("failed to query matured siacoin elements: %w", err) + } + defer rows.Close() + + for rows.Next() { + var value types.Currency + if err := rows.Scan(decode(&value)); err != nil { + return types.ZeroCurrency, fmt.Errorf("failed to scan siacoin value: %w", err) + } + inflow = inflow.Add(value) + } + if err := rows.Err(); err != nil { + return types.ZeroCurrency, fmt.Errorf("failed to iterate siacoin elements: %w", err) + } + return +} + +// createSiacoinElements helper to insert siacoin elements into the database. +func createSiacoinElements(tx *txn, index types.ChainIndex, created []types.SiacoinElement) (matureInflow, immatureInflow types.Currency, _ error) { + if len(created) == 0 { + return types.ZeroCurrency, types.ZeroCurrency, nil + } + + stmt, err := tx.Prepare(`INSERT INTO wallet_siacoin_elements (id, siacoin_value, sia_address, merkle_proof, leaf_index, maturity_height) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT (id) DO NOTHING;`) + if err != nil { + return types.ZeroCurrency, types.ZeroCurrency, fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer stmt.Close() + + for _, elem := range created { + if _, err := stmt.Exec(encode(elem.ID), encode(elem.SiacoinOutput.Value), encode(elem.SiacoinOutput.Address), encode(elem.MerkleProof), encode(elem.LeafIndex), elem.MaturityHeight); err != nil { + return types.ZeroCurrency, types.ZeroCurrency, fmt.Errorf("failed to insert siacoin element %q: %w", elem.ID, err) + } + + if elem.MaturityHeight <= index.Height { + matureInflow = matureInflow.Add(elem.SiacoinOutput.Value) + } else { + immatureInflow = immatureInflow.Add(elem.SiacoinOutput.Value) + } + } + return +} + +// deleteSiacoinElements helper to delete siacoin elements from the database. +func deleteSiacoinElements(tx *txn, index types.ChainIndex, removed []types.SiacoinElement) (matureOutflow types.Currency, immatureOutflow types.Currency, _ error) { + if len(removed) == 0 { + return types.ZeroCurrency, types.ZeroCurrency, nil + } + + stmt, err := tx.Prepare(`DELETE FROM wallet_siacoin_elements WHERE id=?`) + if err != nil { + return types.ZeroCurrency, types.ZeroCurrency, fmt.Errorf("failed to prepare delete statement: %w", err) + } + defer stmt.Close() + + for _, elem := range removed { + if res, err := stmt.Exec(encode(elem.ID)); err != nil { + return types.ZeroCurrency, types.ZeroCurrency, fmt.Errorf("failed to delete siacoin element: %w", err) + } else if n, err := res.RowsAffected(); err != nil { + return types.ZeroCurrency, types.ZeroCurrency, fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return types.ZeroCurrency, types.ZeroCurrency, fmt.Errorf("failed to delete siacoin element %q: not found", elem.ID) + } + + if elem.MaturityHeight <= index.Height { + matureOutflow = matureOutflow.Add(elem.SiacoinOutput.Value) + } else { + immatureOutflow = immatureOutflow.Add(elem.SiacoinOutput.Value) + } + } + return +} + +// createWalletEvents helper to insert wallet events into the database. +func createWalletEvents(tx *txn, events []wallet.Event) error { + if len(events) == 0 { + return nil + } + + stmt, err := tx.Prepare(`INSERT INTO wallet_events (id, chain_index, maturity_height, event_type, raw_data) VALUES (?, ?, ?, ?, ?) ON CONFLICT (id) DO NOTHING`) + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer stmt.Close() + + for _, event := range events { + buf, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("failed to marshal wallet event: %w", err) + } + if _, err := stmt.Exec(encode(event.ID), encode(event.Index), event.MaturityHeight, event.Type, buf); err != nil { + return fmt.Errorf("failed to insert wallet event: %w", err) + } + } + return nil +} + +// updateBalanceMetric updates the wallet balance metric. +func updateBalanceMetric(tx *txn, matureInflow, matureOutflow, immatureInflow, immatureOutflow types.Currency, timestamp time.Time) error { + // calculate the delta for the balance and immature balance + var matureDelta types.Currency + var matureNegative bool + if n := matureInflow.Cmp(matureOutflow); n > 0 { + matureDelta = matureInflow.Sub(matureOutflow) + } else if n < 0 { + matureDelta = matureOutflow.Sub(matureInflow) + matureNegative = true + } + + var immatureDelta types.Currency + var immatureNegative bool + if n := immatureInflow.Cmp(immatureOutflow); n > 0 { + immatureDelta = immatureInflow.Sub(immatureOutflow) + } else if n < 0 { + immatureDelta = immatureOutflow.Sub(immatureInflow) + immatureNegative = true + } + + // if no change, return + if matureDelta.IsZero() && immatureDelta.IsZero() { + return nil + } + + // prepare the increment statement + increment, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + if !matureDelta.IsZero() { + // increment the balance + if err := increment(metricWalletBalance, matureDelta, matureNegative, timestamp); err != nil { + return fmt.Errorf("failed to increment balance: %w", err) + } + } + + if !immatureDelta.IsZero() { + // increment the immature balance + if err := increment(metricWalletImmatureBalance, immatureDelta, immatureNegative, timestamp); err != nil { + return fmt.Errorf("failed to increment immature balance: %w", err) + } + } + return nil +} + +// getContractStateStmt helper to get the current state of a contract. +func getContractStateStmt(tx *txn) (func(contractID types.FileContractID) (contractState, error), func() error, error) { + stmt, err := tx.Prepare(`SELECT id, locked_collateral, risked_collateral, rpc_revenue, storage_revenue, +ingress_revenue, egress_revenue, registry_read, registry_write, contract_status +FROM contracts +WHERE contract_id=?`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare select statement: %w", err) + } + + return func(contractID types.FileContractID) (state contractState, err error) { + err = stmt.QueryRow(encode(contractID)).Scan(&state.ID, + decode(&state.LockedCollateral), decode(&state.Usage.RiskedCollateral), decode(&state.Usage.RPCRevenue), + decode(&state.Usage.StorageRevenue), decode(&state.Usage.IngressRevenue), decode(&state.Usage.EgressRevenue), + decode(&state.Usage.RegistryRead), decode(&state.Usage.RegistryWrite), &state.Status) + return + }, stmt.Close, nil +} + +// getV2ContractStateStmt helper to get the current state of a v2 contract. +func getV2ContractStateStmt(tx *txn) (func(contractID types.FileContractID) (v2ContractState, error), func() error, error) { + stmt, err := tx.Prepare(`SELECT id, locked_collateral, risked_collateral, rpc_revenue, storage_revenue, +ingress_revenue, egress_revenue, contract_status +FROM contracts_v2 +WHERE contract_id=?`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare select statement: %w", err) + } + + return func(contractID types.FileContractID) (state v2ContractState, err error) { + err = stmt.QueryRow(encode(contractID)).Scan(&state.ID, + decode(&state.LockedCollateral), decode(&state.Usage.RiskedCollateral), decode(&state.Usage.RPCRevenue), + decode(&state.Usage.StorageRevenue), decode(&state.Usage.IngressRevenue), decode(&state.Usage.EgressRevenue), + &state.Status) + return + }, stmt.Close, nil +} + +// updateEarnedRevenueMetrics helper to update the earned revenue metrics. +func updateEarnedRevenueMetrics(usage contracts.Usage, negative bool, fn func(stat string, delta types.Currency, negative bool, timestamp time.Time) error) error { + if err := fn(metricEarnedRPCRevenue, usage.RPCRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedRPCRevenue, err) + } else if err := fn(metricEarnedStorageRevenue, usage.StorageRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedStorageRevenue, err) + } else if err := fn(metricEarnedIngressRevenue, usage.IngressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedIngressRevenue, err) + } else if err := fn(metricEarnedEgressRevenue, usage.EgressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedEgressRevenue, err) + } else if err := fn(metricEarnedRegistryReadRevenue, usage.RegistryRead, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedRegistryReadRevenue, err) + } else if err := fn(metricEarnedRegistryWriteRevenue, usage.RegistryWrite, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedRegistryWriteRevenue, err) + } + return nil +} + +// updatePotentialRevenueMetrics helper to update the potential revenue metrics. +func updatePotentialRevenueMetrics(usage contracts.Usage, negative bool, fn func(stat string, delta types.Currency, negative bool, timestamp time.Time) error) error { + if err := fn(metricPotentialRPCRevenue, usage.RPCRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialRPCRevenue, err) + } else if err := fn(metricPotentialStorageRevenue, usage.StorageRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialStorageRevenue, err) + } else if err := fn(metricPotentialIngressRevenue, usage.IngressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialIngressRevenue, err) + } else if err := fn(metricPotentialEgressRevenue, usage.EgressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialEgressRevenue, err) + } else if err := fn(metricPotentialRegistryReadRevenue, usage.RegistryRead, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialRegistryReadRevenue, err) + } else if err := fn(metricPotentialRegistryWriteRevenue, usage.RegistryWrite, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialRegistryWriteRevenue, err) + } + return nil +} + +// updateV2EarnedRevenueMetrics helper to update the earned revenue metrics. +func updateV2EarnedRevenueMetrics(usage contracts.V2Usage, negative bool, fn func(stat string, delta types.Currency, negative bool, timestamp time.Time) error) error { + if err := fn(metricEarnedRPCRevenue, usage.RPCRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedRPCRevenue, err) + } else if err := fn(metricEarnedStorageRevenue, usage.StorageRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedStorageRevenue, err) + } else if err := fn(metricEarnedIngressRevenue, usage.IngressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedIngressRevenue, err) + } else if err := fn(metricEarnedEgressRevenue, usage.EgressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricEarnedEgressRevenue, err) + } + return nil +} + +// updateV2PotentialRevenueMetrics helper to update the potential revenue metrics. +func updateV2PotentialRevenueMetrics(usage contracts.V2Usage, negative bool, fn func(stat string, delta types.Currency, negative bool, timestamp time.Time) error) error { + if err := fn(metricPotentialRPCRevenue, usage.RPCRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialRPCRevenue, err) + } else if err := fn(metricPotentialStorageRevenue, usage.StorageRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialStorageRevenue, err) + } else if err := fn(metricPotentialIngressRevenue, usage.IngressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialIngressRevenue, err) + } else if err := fn(metricPotentialEgressRevenue, usage.EgressRevenue, negative, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", metricPotentialEgressRevenue, err) + } + return nil +} + +// updateCollateralMetrics helper to update the collateral metrics. +func updateCollateralMetrics(locked, risked types.Currency, negative bool, fn func(stat string, delta types.Currency, negative bool, timestamp time.Time) error) error { + if err := fn(metricLockedCollateral, locked, negative, time.Now()); err != nil { + return fmt.Errorf("failed to set metric %q: %w", metricLockedCollateral, err) + } else if err := fn(metricRiskedCollateral, risked, negative, time.Now()); err != nil { + return fmt.Errorf("failed to set metric %q: %w", metricRiskedCollateral, err) + } + return nil +} + +// contractStatusMetric returns the metric name for a contract status. +func contractStatusMetric(status contracts.ContractStatus) string { + switch status { + case contracts.ContractStatusActive: + return metricActiveContracts + case contracts.ContractStatusRejected: + return metricRejectedContracts + case contracts.ContractStatusSuccessful: + return metricSuccessfulContracts + case contracts.ContractStatusFailed: + return metricFailedContracts + default: + panic(fmt.Sprintf("unexpected contract status: %v", status)) + } +} + +// updateStatusMetrics helper to update the contract status metrics. +func updateStatusMetrics(oldStatus, newStatus contracts.ContractStatus, fn func(stat string, delta int64, timestamp time.Time) error) error { + if oldStatus == newStatus { + return nil + } + + // pending contracts do not have a metric. + if oldStatus != contracts.ContractStatusPending { + if err := fn(contractStatusMetric(oldStatus), -1, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", contractStatusMetric(oldStatus), err) + } + } + + if err := fn(contractStatusMetric(newStatus), 1, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", contractStatusMetric(newStatus), err) + } + return nil +} + +// v2ContractStatusMetric returns the metric name for a v2 contract status. +func v2ContractStatusMetric(status contracts.V2ContractStatus) string { + switch status { + case contracts.V2ContractStatusActive: + return metricActiveContracts + case contracts.V2ContractStatusRejected: + return metricRejectedContracts + case contracts.V2ContractStatusSuccessful: + return metricSuccessfulContracts + case contracts.V2ContractStatusFailed: + return metricFailedContracts + case contracts.V2ContractStatusFinalized: + return metricFinalizedContracts + case contracts.V2ContractStatusRenewed: + return metricRenewedContracts + default: + panic(fmt.Sprintf("unexpected contract status: %v", status)) + } +} + +// updateV2StatusMetrics helper to update the v2 contract status metrics. +func updateV2StatusMetrics(oldStatus, newStatus contracts.V2ContractStatus, fn func(stat string, delta int64, timestamp time.Time) error) error { + if oldStatus == newStatus { + return nil + } + + if oldStatus != contracts.V2ContractStatusPending { + if err := fn(v2ContractStatusMetric(oldStatus), -1, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", v2ContractStatusMetric(oldStatus), err) + } + } + + if err := fn(v2ContractStatusMetric(newStatus), 1, time.Now()); err != nil { + return fmt.Errorf("failed to update metric %q: %w", v2ContractStatusMetric(newStatus), err) + } + return nil +} + +// applyContractRevision updates the confirmed revision number of a contract. +func applyContractRevision(tx *txn, revisions []types.FileContractElement) error { + if len(revisions) == 0 { + return nil + } + + stmt, err := tx.Prepare(`UPDATE contracts SET confirmed_revision_number=? WHERE contract_id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer stmt.Close() + + for _, fce := range revisions { + if res, err := stmt.Exec(encode(fce.FileContract.RevisionNumber), encode(fce.ID)); err != nil { + return fmt.Errorf("failed to update contract revision %q: %w", fce.ID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated: %q", fce.ID) + } + } + return nil +} + +// applyContractFormation updates the contract table with the confirmation index and new status. +func applyContractFormation(tx *txn, confirmed []types.FileContractElement, log *zap.Logger) error { + if len(confirmed) == 0 { + return nil + } + + getContractState, done, err := getContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts SET formation_confirmed=true, contract_status=$1 WHERE id=$2`) + if err != nil { + return fmt.Errorf("failed to prepare confirmation statement: %w", err) + } + defer updateStmt.Close() + + for _, fce := range confirmed { + state, err := getContractState(types.FileContractID(fce.ID)) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", fce.ID, err) + } + + // skip the update if the contract is not active, rejected, or pending. + // This should only happen during a rescan. + if state.Status != contracts.ContractStatusPending && state.Status != contracts.ContractStatusRejected && state.Status != contracts.ContractStatusActive { + log.Debug("skipping rescan state transition", zap.Stringer("contractID", fce.ID), zap.Stringer("current", state.Status)) + continue + } + + // update the contract table with the confirmation index and new status. + res, err := updateStmt.Exec(contracts.ContractStatusActive, state.ID) + if err != nil { + return fmt.Errorf("failed to update state %q: %w", fce.ID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to update contract %q: %w", fce.ID, err) + } + + // skip the metric update if the contract is already active. + // This should only happen during a rescan + if state.Status == contracts.ContractStatusActive { + continue + } + + if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } else if err := updatePotentialRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateStatusMetrics(state.Status, contracts.ContractStatusActive, incrementNumericStat); err != nil { + return fmt.Errorf("failed to update contract metrics: %w", err) + } + } + return nil +} + +// applySuccessfulContracts updates the contract table with the resolution index +// sets the contract status to successful, and updates the revenue metrics. +func applySuccessfulContracts(tx *txn, index types.ChainIndex, successful []types.FileContractID) error { + if len(successful) == 0 { + return nil + } + + getState, done, err := getContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts SET resolution_height=?, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range successful { + state, err := getState(contractID) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + if state.Status == contracts.ContractStatusSuccessful { + // skip update if the contract was already successful + continue + } else if state.Status != contracts.ContractStatusActive { + // panic if the contract is not active. Proper reverts should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.ContractStatusSuccessful)) + } + + // update the contract's resolution index and status + if res, err := updateStmt.Exec(index.Height, contracts.ContractStatusSuccessful, state.ID); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } + + // update the contract status metrics + if err := updateStatusMetrics(state.Status, contracts.ContractStatusSuccessful, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // subtract the usage from the potential revenue metrics and add it to the + // earned revenue metrics + if err := updatePotentialRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateEarnedRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update earned revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return nil +} + +// applyFailedContracts sets the contract status to failed and subtracts the +// potential revenue metrics. +func applyFailedContracts(tx *txn, failed []types.FileContractID) error { + if len(failed) == 0 { + return nil + } + + getState, done, err := getContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts SET resolution_height=NULL, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range failed { + state, err := getState(types.FileContractID(contractID)) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + if state.Status == contracts.ContractStatusFailed { + // skip update if the contract was already failed + continue + } else if state.Status != contracts.ContractStatusActive { + // panic if the contract is not active. Proper reverts should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.ContractStatusFailed)) + } + + // update the contract's resolution index and status + if res, err := updateStmt.Exec(contracts.ContractStatusFailed, state.ID); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } + + // update the contract status metrics + if err := updateStatusMetrics(state.Status, contracts.ContractStatusFailed, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // subtract the usage from the potential revenue metrics + if err := updatePotentialRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return nil +} + +// revertContractFormation reverts the contract formation by setting the +// confirmation index to null and the status to pending. +func revertContractFormation(tx *txn, reverted []types.FileContractElement) error { + if len(reverted) == 0 { + return nil + } + + getState, done, err := getContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts SET formation_confirmed=false, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, fce := range reverted { + // get the current state of the contract + state, err := getState(types.FileContractID(fce.ID)) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", fce.ID, err) + } + + if state.Status != contracts.ContractStatusActive { + // if the contract is not active, panic. Applies should have ensured + // that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.ContractStatusPending)) + } + + // set the contract status to pending + if res, err := updateStmt.Exec(contracts.ContractStatusPending, state.ID); err != nil { + return fmt.Errorf("failed to revert contract formation %q: %w", fce.ID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated %q", fce.ID) + } + + // subtract the metrics + if err := updateStatusMetrics(state.Status, contracts.ContractStatusPending, incrementNumericStat); err != nil { + return fmt.Errorf("failed to update contract metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } else if err := updatePotentialRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } + } + return err +} + +// revertSuccessfulContracts reverts the contract resolution by setting the +// resolution index to null, the status to active, and updating the revenue +// metrics. +func revertSuccessfulContracts(tx *txn, successful []types.FileContractID) error { + if len(successful) == 0 { + return nil + } + + getState, done, err := getContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts SET resolution_height=NULL, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range successful { + // get the current state of the contract + state, err := getState(contractID) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + if state.Status != contracts.ContractStatusSuccessful { + // if the contract is not successful, panic. Applies should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.ContractStatusActive)) + } + + if res, err := updateStmt.Exec(encode(contractID)); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated: %q", contractID) + } + + // update the contract status metrics + if err := updateStatusMetrics(state.Status, contracts.ContractStatusActive, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // subtract the usage from the earned revenue metrics and add it to the + // potential revenue metrics + if err := updatePotentialRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateEarnedRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update earned revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return err +} + +// revertFailedContracts sets the contract status to active and adds the +// potential revenue and collateral metrics. +func revertFailedContracts(tx *txn, failed []types.FileContractID) error { + if len(failed) == 0 { + return nil + } + + getState, done, err := getContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts SET resolution_height=NULL, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range failed { + state, err := getState(contractID) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + if state.Status != contracts.ContractStatusFailed { + // panic if the contract is not failed. Proper reverts should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.ContractStatusFailed)) + } else if state.Status == contracts.ContractStatusFailed { + // skip update, most likely rescanning + continue + } + + // update the contract's resolution index and status + if res, err := updateStmt.Exec(contracts.ContractStatusActive, state.ID); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } + + // update the contract status metrics + if err := updateStatusMetrics(state.Status, contracts.ContractStatusActive, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // add the usage back to the potential revenue metrics + if err := updatePotentialRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return nil +} + +// applyV2ContractFormation updates the contract table with the confirmation index +// and new status. +func applyV2ContractFormation(tx *txn, index types.ChainIndex, confirmed []types.V2FileContractElement, log *zap.Logger) error { + if len(confirmed) == 0 { + return nil + } + + getV2ContractState, done, err := getV2ContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare get state statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment numeric statement: %w", err) + } + defer done() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment currency statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts_v2 SET confirmation_index=$1, contract_status=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare update status statement: %w", err) + } + defer updateStmt.Close() + + insertElementStmt, err := tx.Prepare(`INSERT INTO contract_v2_state_elements (contract_id, leaf_index, merkle_proof, raw_contract, revision_number) VALUES (?, ?, ?, ?, ?) ON CONFLICT (contract_id) DO UPDATE SET leaf_index=EXCLUDED.leaf_index, merkle_proof=EXCLUDED.merkle_proof, raw_contract=EXCLUDED.raw_contract, revision_number=EXCLUDED.revision_number`) + if err != nil { + return fmt.Errorf("failed to prepare insert state element statement: %w", err) + } + defer insertElementStmt.Close() + + for _, fce := range confirmed { + state, err := getV2ContractState(types.FileContractID(fce.ID)) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", fce.ID, err) + } + + if _, err := insertElementStmt.Exec(state.ID, fce.LeafIndex, encode(fce.MerkleProof), encode(fce.V2FileContract), encode(fce.V2FileContract.RevisionNumber)); err != nil { + return fmt.Errorf("failed to insert contract state element %q: %w", fce.ID, err) + } + + // skip the update if the contract is not active, rejected, or pending. + // This should only happen during a rescan. + if state.Status != contracts.V2ContractStatusPending && state.Status != contracts.V2ContractStatusRejected && state.Status != contracts.V2ContractStatusActive { + log.Debug("skipping rescan state transition", zap.Stringer("contractID", fce.ID), zap.String("current", string(state.Status))) + continue + } + + // update the contract table with the confirmation index and new status. + res, err := updateStmt.Exec(encode(index), contracts.V2ContractStatusActive, state.ID) + if err != nil { + return fmt.Errorf("failed to update state %q: %w", fce.ID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("failed to update contract %q: %w", fce.ID, err) + } + + // skip the metric update if the contract is already active. + // This should only happen during a rescan + if state.Status == contracts.V2ContractStatusActive { + continue + } + + if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } else if err := updateV2PotentialRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateV2StatusMetrics(state.Status, contracts.V2ContractStatusActive, incrementNumericStat); err != nil { + return fmt.Errorf("failed to update contract metrics: %w", err) + } + } + return nil +} + +// revertV2ContractFormation reverts the contract formation by setting the +// confirmation index to null and the status to pending. +func revertV2ContractFormation(tx *txn, reverted []types.V2FileContractElement) error { + if len(reverted) == 0 { + return nil + } + + getV2ContractState, done, err := getV2ContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts_v2 SET confirmation_index=NULL, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + deleteStmt, err := tx.Prepare(`DELETE FROM contract_v2_state_elements WHERE contract_id=?`) + if err != nil { + return fmt.Errorf("failed to prepare delete statement: %w", err) + } + defer deleteStmt.Close() + + for _, fce := range reverted { + // get the current contract state + state, err := getV2ContractState(types.FileContractID(fce.ID)) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", fce.ID, err) + } + + // delete the state element + if _, err := deleteStmt.Exec(state.ID); err != nil { + return fmt.Errorf("failed to delete contract state element %q: %w", fce.ID, err) + } + + if state.Status != contracts.V2ContractStatusActive { + // if the contract is not active, panic. Applies should have ensured + // that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.ContractStatusPending)) + } + + // set the contract status to pending + if res, err := updateStmt.Exec(contracts.V2ContractStatusPending, state.ID); err != nil { + return fmt.Errorf("failed to revert contract formation %q: %w", fce.ID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated %q", fce.ID) + } + + // subtract the metrics + if err := updateV2StatusMetrics(state.Status, contracts.V2ContractStatusPending, incrementNumericStat); err != nil { + return fmt.Errorf("failed to update contract metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } else if err := updateV2PotentialRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } + } + return nil +} + +// applyV2ContractRevision updates the confirmed revision number of a contract. +func applyV2ContractRevision(tx *txn, revised []types.V2FileContractElement) error { + selectIDStmt, err := tx.Prepare(`SELECT id FROM contracts_v2 WHERE contract_id=?`) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer selectIDStmt.Close() + + updateElementStmt, err := tx.Prepare(`UPDATE contract_v2_state_elements SET leaf_index=?, merkle_proof=?, raw_contract=?, revision_number=? WHERE contract_id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update element statement: %w", err) + } + defer updateElementStmt.Close() + + for _, fce := range revised { + var contractID int64 + if err := selectIDStmt.QueryRow(encode(fce.ID)).Scan(&contractID); err != nil { + return fmt.Errorf("failed to update contract revision %q: %w", fce.ID, err) + } else if _, err := updateElementStmt.Exec(fce.LeafIndex, encode(fce.MerkleProof), encode(fce.V2FileContract), encode(fce.V2FileContract.RevisionNumber), contractID); err != nil { + return fmt.Errorf("failed to update contract state element %q: %w", fce.ID, err) + } + } + return nil +} + +// applySuccessfulV2Contracts updates the contract table with the resolution index +// sets the contract status to successful, and updates the revenue metrics. +func applySuccessfulV2Contracts(tx *txn, index types.ChainIndex, status contracts.V2ContractStatus, successful []types.FileContractID) error { + if len(successful) == 0 { + return nil + } + + getState, done, err := getV2ContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts_v2 SET resolution_index=?, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range successful { + state, err := getState(contractID) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + if state.Status == status { + // skip update if the contract was already successful. + // This should only happen during a rescan + continue + } else if state.Status != contracts.V2ContractStatusActive { + // panic if the contract is not active. Proper reverts should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.ContractStatusSuccessful)) + } + + // update the contract's resolution index and status + if res, err := updateStmt.Exec(encode(index), status, state.ID); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated: %q", contractID) + } + + // update the contract status metrics + if err := updateV2StatusMetrics(state.Status, status, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // subtract the usage from the potential revenue metrics and add it to the + // earned revenue metrics + if err := updateV2PotentialRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateV2EarnedRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update earned revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return nil +} + +// applyFailedV2Contracts sets the contract status to active and adds the +// potential revenue metrics. +func applyFailedV2Contracts(tx *txn, index types.ChainIndex, failed []types.FileContractID) error { + if len(failed) == 0 { + return nil + } + + getState, done, err := getV2ContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts_v2 SET resolution_index=?, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range failed { + state, err := getState(types.FileContractID(contractID)) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + // skip update if the contract is already failed. + // This should only happen during a rescan + if state.Status == contracts.V2ContractStatusFailed { + continue + } else if state.Status != contracts.V2ContractStatusActive { + // panic if the contract is not active. Proper reverts should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.V2ContractStatusFailed)) + } + + // update the contract's resolution index and status + if res, err := updateStmt.Exec(encode(index), contracts.V2ContractStatusFailed, state.ID); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated: %q", contractID) + } + + // update the contract status metrics + if err := updateV2StatusMetrics(state.Status, contracts.V2ContractStatusFailed, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // add the usage to the potential revenue metrics + if err := updateV2PotentialRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return nil +} + +// revertSuccessfulV2Contracts clears the resolution index, sets the status to +// active and updates the revenue metrics. +func revertSuccessfulV2Contracts(tx *txn, status contracts.V2ContractStatus, successful []types.FileContractID) error { + if len(successful) == 0 { + return nil + } + + getState, done, err := getV2ContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts_v2 SET resolution_index=NULL, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range successful { + state, err := getState(contractID) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + if state.Status != status { + // panic if the contract is not active. Proper reverts should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.V2ContractStatusActive)) + } + + // update the contract's resolution index and status + if res, err := updateStmt.Exec(status, state.ID); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated: %q", contractID) + } + + // update the contract status metrics + if err := updateV2StatusMetrics(state.Status, contracts.V2ContractStatusActive, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // add the usage to the potential revenue metrics and subtract it from the + // earned revenue metrics + if err := updateV2PotentialRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateV2EarnedRevenueMetrics(state.Usage, true, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update earned revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return nil +} + +// revertFailedV2Contracts sets the contract status to active and adds the +// potential revenue and collateral metrics. +func revertFailedV2Contracts(tx *txn, failed []types.FileContractID) error { + if len(failed) == 0 { + return nil + } + + getState, done, err := getV2ContractStateStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare select statement: %w", err) + } + defer done() + + updateStmt, err := tx.Prepare(`UPDATE contracts_v2 SET resolution_index=NULL, contract_status=? WHERE id=?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + incrementNumericStat, done, err := incrementNumericStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment statement: %w", err) + } + defer done() + + for _, contractID := range failed { + state, err := getState(contractID) + if err != nil { + return fmt.Errorf("failed to get contract state %q: %w", contractID, err) + } + + if state.Status != contracts.V2ContractStatusFailed { + // panic if the contract is not failed. Proper reverts should have + // ensured that this never happens. + panic(fmt.Errorf("unexpected contract state transition %q -> %q", state.Status, contracts.V2ContractStatusFailed)) + } else if state.Status == contracts.V2ContractStatusFailed { + // skip update, most likely rescanning + continue + } + + // update the contract's resolution index and status + if res, err := updateStmt.Exec(contracts.V2ContractStatusActive, state.ID); err != nil { + return fmt.Errorf("failed to update contract %q: %w", contractID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("no rows updated: %q", contractID) + } + + // update the contract status metrics + if err := updateV2StatusMetrics(state.Status, contracts.V2ContractStatusActive, incrementNumericStat); err != nil { + return fmt.Errorf("failed to set contract %q status: %w", contractID, err) + } + + // add the usage back to the potential revenue metrics + if err := updateV2PotentialRevenueMetrics(state.Usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue metrics: %w", err) + } else if err := updateCollateralMetrics(state.LockedCollateral, state.Usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + return nil +} + +// rejectContracts returns the ID of any contracts that are not confirmed and have +// a negotiation height less than the given height. +func rejectContracts(tx *txn, height uint64) (rejected []types.FileContractID, err error) { + rows, err := tx.Query(`SELECT contract_id FROM contracts WHERE contract_status <> $1 AND formation_confirmed=false AND negotiation_height < $2`, contracts.ContractStatusRejected, height) + if err != nil { + return nil, fmt.Errorf("failed to query contracts: %w", err) + } + defer rows.Close() + + for rows.Next() { + var id types.FileContractID + if err := rows.Scan(decode(&id)); err != nil { + return nil, fmt.Errorf("failed to scan contract: %w", err) + } + rejected = append(rejected, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to scan contracts: %w", err) + } + return +} + +// rejectV2Contracts returns the ID of any v2 contracts that are not confirmed and have +// a negotiation height less than the given height. +func rejectV2Contracts(tx *txn, height uint64) (rejected []types.FileContractID, err error) { + rows, err := tx.Query(`SELECT contract_id FROM contracts_v2 WHERE contract_status <> $1 AND confirmation_index IS NULL AND negotiation_height < $2`, contracts.V2ContractStatusRejected, height) + if err != nil { + return nil, fmt.Errorf("failed to query contracts: %w", err) + } + defer rows.Close() + + for rows.Next() { + var id types.FileContractID + if err := rows.Scan(decode(&id)); err != nil { + return nil, fmt.Errorf("failed to scan contract: %w", err) + } + rejected = append(rejected, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to scan contracts: %w", err) + } + return +} diff --git a/persist/sqlite/contracts.go b/persist/sqlite/contracts.go index 322597e0..a4044cad 100644 --- a/persist/sqlite/contracts.go +++ b/persist/sqlite/contracts.go @@ -10,22 +10,10 @@ import ( "go.sia.tech/core/types" "go.sia.tech/hostd/host/contracts" - "go.sia.tech/siad/modules" "go.uber.org/zap" ) type ( - // An updateContractsTxn atomically updates the contract manager's state - updateContractsTxn struct { - tx txn - } - - // A contractAction pairs a contract's ID with a lifecycle action. - contractAction struct { - ID types.FileContractID - Action string - } - contractSectorRootRef struct { dbID int64 sectorID int64 @@ -33,129 +21,30 @@ type ( } ) -// setLastChangeID sets the last processed consensus change ID. -func (u *updateContractsTxn) setLastChangeID(ccID modules.ConsensusChangeID, height uint64) error { - var dbID int64 // unused, but required by QueryRow to ensure exactly one row is updated - err := u.tx.QueryRow(`UPDATE global_settings SET contracts_last_processed_change=$1, contracts_height=$2 RETURNING id`, sqlHash256(ccID), sqlUint64(height)).Scan(&dbID) - return err -} - -// ConfirmFormation sets the formation_confirmed flag to true. -func (u *updateContractsTxn) ConfirmFormation(id types.FileContractID) error { - const query = `UPDATE contracts SET formation_confirmed=true WHERE contract_id=$1 RETURNING id;` - var dbID int64 - err := u.tx.QueryRow(query, sqlHash256(id)).Scan(&dbID) - if err != nil { - return fmt.Errorf("failed to confirm formation: %w", err) - } - - // get the contract's status - contract, err := getContract(u.tx, dbID) - if err != nil { - return fmt.Errorf("failed to get contract: %w", err) - } - - // only update the status if the contract is pending or rejected - if contract.Status != contracts.ContractStatusPending && contract.Status != contracts.ContractStatusRejected { - return nil - } +var _ contracts.ContractStore = (*Store)(nil) - if err := setContractStatus(u.tx, id, contracts.ContractStatusActive); err != nil { - return fmt.Errorf("failed to set contract status to active: %w", err) - } - // rejected contracts have already had their collateral and revenue removed, - // need to re-add it if the contract is now confirmed - if contract.Status == contracts.ContractStatusRejected { - if err := incrementCurrencyStat(u.tx, metricLockedCollateral, contract.LockedCollateral, false, time.Now()); err != nil { - return fmt.Errorf("failed to increment locked collateral stat: %w", err) - } else if err := incrementCurrencyStat(u.tx, metricRiskedCollateral, contract.Usage.RiskedCollateral, false, time.Now()); err != nil { - return fmt.Errorf("failed to increment risked collateral stat: %w", err) +func (s *Store) batchExpireContractSectors(height uint64) (expired int, removed []types.Hash256, err error) { + err = s.transaction(func(tx *txn) (err error) { + sectorIDs, err := deleteExpiredContractSectors(tx, height) + if err != nil { + return fmt.Errorf("failed to delete contract sectors: %w", err) } - } - return nil -} - -// ConfirmRevision sets the confirmed revision number. -func (u *updateContractsTxn) ConfirmRevision(revision types.FileContractRevision) error { - const query = `UPDATE contracts SET confirmed_revision_number=$1 WHERE contract_id=$2 RETURNING id;` - var dbID int64 - err := u.tx.QueryRow(query, sqlUint64(revision.RevisionNumber), sqlHash256(revision.ParentID)).Scan(&dbID) - if err != nil { - return fmt.Errorf("failed to confirm revision: %w", err) - } - return nil -} - -// ConfirmResolution sets the resolution height. -func (u *updateContractsTxn) ConfirmResolution(id types.FileContractID, height uint64) error { - const query = `UPDATE contracts SET resolution_height=$1 WHERE contract_id=$2 RETURNING id;` - var dbID int64 - if err := u.tx.QueryRow(query, height, sqlHash256(id)).Scan(&dbID); err != nil { - return fmt.Errorf("failed to confirm resolution: %w", err) - } - return nil -} - -// RevertFormation sets the formation_confirmed flag to false. -func (u *updateContractsTxn) RevertFormation(id types.FileContractID) error { - const query = `UPDATE contracts SET formation_confirmed=false WHERE contract_id=$1 RETURNING id;` - var dbID int64 - return u.tx.QueryRow(query, sqlHash256(id)).Scan(&dbID) -} - -// RevertRevision sets the confirmed revision number to 0. -func (u *updateContractsTxn) RevertRevision(id types.FileContractID) error { - const query = `UPDATE contracts SET confirmed_revision_number=$1 WHERE contract_id=$2 RETURNING id;` - var dbID int64 - return u.tx.QueryRow(query, sqlUint64(0), sqlHash256(id)).Scan(&dbID) // TODO: revert to the previous revision number -} - -// RevertResolution sets the resolution height to null -func (u *updateContractsTxn) RevertResolution(id types.FileContractID) error { - const query = `UPDATE contracts SET resolution_height=NULL WHERE contract_id=$1 RETURNING id;` - var dbID int64 - if err := u.tx.QueryRow(query, sqlHash256(id)).Scan(&dbID); err != nil { - return fmt.Errorf("failed to revert resolution: %w", err) - } - return nil -} - -// ContractRevelant returns true if the contract is relevant to the host. -func (u *updateContractsTxn) ContractRelevant(id types.FileContractID) (bool, error) { - const query = `SELECT id FROM contracts WHERE contract_id=$1` - var dbID int64 - err := u.tx.QueryRow(query, sqlHash256(id)).Scan(&dbID) - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } - return err == nil, err -} + expired = len(sectorIDs) -func deleteExpiredContractSectors(tx txn, height uint64) (sectorIDs []int64, err error) { - const query = `DELETE FROM contract_sector_roots -WHERE id IN (SELECT csr.id FROM contract_sector_roots csr -INNER JOIN contracts c ON (csr.contract_id=c.id) --- past proof window or not confirmed and past the rebroadcast height -WHERE c.window_end < $1 OR c.contract_status=$2 LIMIT $3) -RETURNING sector_id;` - rows, err := tx.Query(query, height, contracts.ContractStatusRejected, sqlSectorBatchSize) - if err != nil { - return nil, err - } - defer rows.Close() - for rows.Next() { - var id int64 - if err := rows.Scan(&id); err != nil { - return nil, err + // decrement the contract metrics + if err := incrementNumericStat(tx, metricContractSectors, -len(sectorIDs), time.Now()); err != nil { + return fmt.Errorf("failed to decrement contract sectors: %w", err) } - sectorIDs = append(sectorIDs, id) - } - return sectorIDs, nil + + removed, err = pruneSectors(tx, sectorIDs) + return err + }) + return } -func (s *Store) batchExpireContractSectors(height uint64) (expired int, removed []types.Hash256, err error) { - err = s.transaction(func(tx txn) (err error) { - sectorIDs, err := deleteExpiredContractSectors(tx, height) +func (s *Store) batchExpireV2ContractSectors(height uint64) (expired int, removed []types.Hash256, err error) { + err = s.transaction(func(tx *txn) (err error) { + sectorIDs, err := deleteExpiredV2ContractSectors(tx, height) if err != nil { return fmt.Errorf("failed to delete contract sectors: %w", err) } @@ -184,7 +73,7 @@ func (s *Store) Contracts(filter contracts.ContractFilter) (contracts []contract } contractQuery := fmt.Sprintf(`SELECT c.contract_id, rt.contract_id AS renewed_to, rf.contract_id AS renewed_from, c.contract_status, c.negotiation_height, c.formation_confirmed, - c.revision_number=c.confirmed_revision_number AS revision_confirmed, c.resolution_height, c.locked_collateral, c.rpc_revenue, + COALESCE(c.revision_number=c.confirmed_revision_number, false) AS revision_confirmed, c.resolution_height, c.locked_collateral, c.rpc_revenue, c.storage_revenue, c.ingress_revenue, c.egress_revenue, c.account_funding, c.risked_collateral, c.raw_revision, c.host_sig, c.renter_sig FROM contracts c INNER JOIN contract_renters r ON (c.renter_id=r.id) @@ -196,32 +85,35 @@ INNER JOIN contract_renters r ON (c.renter_id=r.id) LEFT JOIN contracts rt ON (c.renewed_to=rt.id) LEFT JOIN contracts rf ON (c.renewed_from=rf.id) %s`, whereClause) - if err := s.queryRow(countQuery, whereParams...).Scan(&count); err != nil { - return nil, 0, fmt.Errorf("failed to query contract count: %w", err) - } - - rows, err := s.query(contractQuery, append(whereParams, filter.Limit, filter.Offset)...) - if err != nil { - return nil, 0, fmt.Errorf("failed to query contracts: %w", err) - } - defer rows.Close() + err = s.transaction(func(tx *txn) error { + if err := tx.QueryRow(countQuery, whereParams...).Scan(&count); err != nil { + return fmt.Errorf("failed to query contract count: %w", err) + } - for rows.Next() { - contract, err := scanContract(rows) + rows, err := tx.Query(contractQuery, append(whereParams, filter.Limit, filter.Offset)...) if err != nil { - return nil, 0, fmt.Errorf("failed to scan contract: %w", err) + return fmt.Errorf("failed to query contracts: %w", err) } - contracts = append(contracts, contract) - } + defer rows.Close() + + for rows.Next() { + contract, err := scanContract(rows) + if err != nil { + return fmt.Errorf("failed to scan contract: %w", err) + } + contracts = append(contracts, contract) + } + return rows.Err() + }) return } // Contract returns the contract with the given ID. func (s *Store) Contract(id types.FileContractID) (contract contracts.Contract, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT id FROM contracts WHERE contract_id=$1;` var dbID int64 - err := tx.QueryRow(query, sqlHash256(id)).Scan(&dbID) + err := tx.QueryRow(query, encode(id)).Scan(&dbID) if errors.Is(err, sql.ErrNoRows) { return contracts.ErrNotFound } else if err != nil { @@ -233,9 +125,84 @@ func (s *Store) Contract(id types.FileContractID) (contract contracts.Contract, return } +// V2ContractElement returns the latest v2 state element with the given ID. +func (s *Store) V2ContractElement(contractID types.FileContractID) (ele types.V2FileContractElement, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT cs.raw_contract, cs.leaf_index, cs.merkle_proof +FROM contracts_v2 c +INNER JOIN contract_v2_state_elements cs ON (c.id = cs.contract_id) +WHERE c.contract_id=?` + + err := tx.QueryRow(query, encode(contractID)).Scan(decode(&ele.V2FileContract), decode(&ele.LeafIndex), decode(&ele.MerkleProof)) + if errors.Is(err, sql.ErrNoRows) { + return contracts.ErrNotFound + } + ele.ID = types.Hash256(contractID) + return err + }) + return +} + +// V2Contract returns the contract with the given ID. +func (s *Store) V2Contract(id types.FileContractID) (contract contracts.V2Contract, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT c.contract_id, rt.contract_id AS renewed_to, rf.contract_id AS renewed_from, c.contract_status, c.negotiation_height, c.confirmation_index, +COALESCE(c.revision_number=cs.revision_number, false) AS revision_confirmed, c.resolution_index, c.rpc_revenue, +c.storage_revenue, c.ingress_revenue, c.egress_revenue, c.account_funding, c.risked_collateral, c.raw_revision +FROM contracts_v2 c +LEFT JOIN contract_v2_state_elements cs ON (c.id = cs.contract_id) +LEFT JOIN contracts_v2 rt ON (c.renewed_to = rt.id) +LEFT JOIN contracts_v2 rf ON (c.renewed_from = rf.id) +WHERE c.contract_id=$1;` + contract, err = scanV2Contract(tx.QueryRow(query, encode(id))) + return err + }) + return +} + +// AddV2Contract adds a new contract to the database. +func (s *Store) AddV2Contract(contract contracts.V2Contract, formationSet contracts.V2FormationTransactionSet) error { + return s.transaction(func(tx *txn) error { + _, err := insertV2Contract(tx, contract, formationSet) + return err + }) +} + +// RenewV2Contract adds a new v2 contract to the database and sets the old +// contract's renewed_from field. The old contract's sector roots are +// copied to the new contract. The status of the old contract should continue +// to be active until the renewal is confirmed +func (s *Store) RenewV2Contract(renewal contracts.V2Contract, renewalSet contracts.V2FormationTransactionSet, renewedID types.FileContractID, clearing types.V2FileContract) error { + return s.transaction(func(tx *txn) error { + // add the new contract + renewedDBID, err := insertV2Contract(tx, renewal, renewalSet) + if err != nil { + return fmt.Errorf("failed to insert renewed contract: %w", err) + } + + clearedDBID, err := updateResolvedV2Contract(tx, renewedID, clearing, renewedDBID) + if err != nil { + return fmt.Errorf("failed to resolve existing contract: %w", err) + } + + // update the renewed_from field + err = tx.QueryRow(`UPDATE contracts_v2 SET renewed_from=$1 WHERE id=$2 RETURNING id;`, clearedDBID, renewedDBID).Scan(&renewedDBID) + if err != nil { + return fmt.Errorf("failed to update renewed contract: %w", err) + } + + // move the sector roots from the old contract to the new contract + _, err = tx.Exec(`UPDATE contract_v2_sector_roots SET contract_id=$1 WHERE contract_id=$2`, renewedDBID, clearedDBID) + if err != nil { + return fmt.Errorf("failed to copy sector roots: %w", err) + } + return nil + }) +} + // AddContract adds a new contract to the database. func (s *Store) AddContract(revision contracts.SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, initialUsage contracts.Usage, negotationHeight uint64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { _, err := insertContract(tx, revision, formationSet, lockedCollateral, initialUsage, negotationHeight) return err }) @@ -245,7 +212,7 @@ func (s *Store) AddContract(revision contracts.SignedRevision, formationSet []ty // contract's renewed_from field. The old contract's sector roots are // copied to the new contract. func (s *Store) RenewContract(renewal contracts.SignedRevision, clearing contracts.SignedRevision, renewalTxnSet []types.Transaction, lockedCollateral types.Currency, clearingUsage, renewalUsage contracts.Usage, negotationHeight uint64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { // add the new contract renewedDBID, err := insertContract(tx, renewal, renewalTxnSet, lockedCollateral, renewalUsage, negotationHeight) if err != nil { @@ -254,7 +221,7 @@ func (s *Store) RenewContract(renewal contracts.SignedRevision, clearing contrac clearedDBID, err := clearContract(tx, clearing, renewedDBID, clearingUsage) if err != nil { - return fmt.Errorf("faile to clear contract: %w", err) + return fmt.Errorf("failed to clear contract: %w", err) } err = tx.QueryRow(`UPDATE contracts SET renewed_from=$1 WHERE id=$2 RETURNING id;`, clearedDBID, renewedDBID).Scan(&renewedDBID) @@ -271,9 +238,175 @@ func (s *Store) RenewContract(renewal contracts.SignedRevision, clearing contrac }) } +func incrementV2ContractUsage(tx *txn, dbID int64, usage contracts.Usage) error { + const query = `SELECT rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral FROM contracts_v2 WHERE id=$1;` + var existing contracts.Usage + err := tx.QueryRow(query, dbID).Scan( + decode(&existing.RPCRevenue), + decode(&existing.StorageRevenue), + decode(&existing.IngressRevenue), + decode(&existing.EgressRevenue), + decode(&existing.AccountFunding), + decode(&existing.RiskedCollateral)) + if err != nil { + return fmt.Errorf("failed to get existing revenue: %w", err) + } + + total := existing.Add(usage) + if total == existing { + return nil + } + + var updatedID int64 + err = tx.QueryRow(`UPDATE contracts_v2 SET (rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral) = ($1, $2, $3, $4, $5, $6) WHERE id=$7 RETURNING id;`, + encode(total.RPCRevenue), + encode(total.StorageRevenue), + encode(total.IngressRevenue), + encode(total.EgressRevenue), + encode(total.AccountFunding), + encode(total.RiskedCollateral), + dbID).Scan(&updatedID) + if err != nil { + return fmt.Errorf("failed to update contract revenue: %w", err) + } + return nil +} + +func cleanupDanglingRoots(tx *txn, contractID int64, length int64) (deleted []int64, err error) { + rows, err := tx.Query(`DELETE FROM contract_sector_roots WHERE contract_id=$1 AND root_index >= $2 RETURNING sector_id`, contractID, length) + if err != nil { + return nil, fmt.Errorf("failed to cleanup dangling roots: %w", err) + } + defer rows.Close() + + used := make(map[int64]bool) + for rows.Next() { + var sectorID int64 + if err := rows.Scan(§orID); err != nil { + return nil, fmt.Errorf("failed to scan sector ID: %w", err) + } + + if used[sectorID] { + continue + } + deleted = append(deleted, sectorID) + used[sectorID] = true + } + return deleted, nil +} + +// ReviseV2Contract atomically updates a contract's revision and sectors +func (s *Store) ReviseV2Contract(id types.FileContractID, revision types.V2FileContract, roots []types.Hash256, usage contracts.Usage) error { + return s.transaction(func(tx *txn) error { + incrementCurrencyStat, done, err := incrementCurrencyStatStmt(tx) + if err != nil { + return fmt.Errorf("failed to prepare increment currency stat statement: %w", err) + } + defer done() + + const updateQuery = `UPDATE contracts_v2 SET raw_revision=?, revision_number=? WHERE contract_id=? RETURNING id, contract_status` + + var contractDBID int64 + var status contracts.V2ContractStatus + err = tx.QueryRow(updateQuery, encode(revision), encode(revision.RevisionNumber), encode(id)).Scan(&contractDBID, &status) + if err != nil { + return fmt.Errorf("failed to update contract: %w", err) + } else if err := incrementV2ContractUsage(tx, contractDBID, usage); err != nil { + return fmt.Errorf("failed to update contract usage: %w", err) + } + + // only increment metrics if the contract is active. + // If the contract is pending or some variant of successful, the metrics + // will already be handled. + if status == contracts.V2ContractStatusActive { + if err := updatePotentialRevenueMetrics(usage, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update potential revenue: %w", err) + } else if err := updateCollateralMetrics(types.ZeroCurrency, usage.RiskedCollateral, false, incrementCurrencyStat); err != nil { + return fmt.Errorf("failed to update collateral metrics: %w", err) + } + } + + selectOldSectorStmt, err := tx.Prepare(`SELECT sector_id FROM contract_v2_sector_roots WHERE contract_id=? AND root_index=?`) + if err != nil { + return fmt.Errorf("failed to prepare select old sector statement: %w", err) + } + defer selectOldSectorStmt.Close() + + selectRootIDStmt, err := tx.Prepare(`SELECT id FROM stored_sectors WHERE sector_root=?`) + if err != nil { + return fmt.Errorf("failed to prepare select root ID statement: %w", err) + } + defer selectRootIDStmt.Close() + + updateRootStmt, err := tx.Prepare(`INSERT INTO contract_v2_sector_roots (contract_id, sector_id, root_index) VALUES (?, ?, ?) ON CONFLICT (contract_id, root_index) DO UPDATE SET sector_id=excluded.sector_id`) + if err != nil { + return fmt.Errorf("failed to prepare update root statement: %w", err) + } + defer updateRootStmt.Close() + + var appended int + var deleted []int64 + seen := make(map[int64]bool) + for i, root := range roots { + // TODO: benchmark this against an exceptionally large contract. + // This is less efficient than the v1 implementation, but it leaves + // less room for update edge-cases now that all sectors are loaded + // into memory. + var newSectorID int64 + if err := selectRootIDStmt.QueryRow(encode(root)).Scan(&newSectorID); err != nil { + return fmt.Errorf("failed to get sector ID: %w", err) + } + + var oldSectorID int64 + err := selectOldSectorStmt.QueryRow(contractDBID, i).Scan(&oldSectorID) + if errors.Is(err, sql.ErrNoRows) { + // new sector + appended++ + } else if err != nil { + // db error + return fmt.Errorf("failed to get sector ID: %w", err) + } else if newSectorID == oldSectorID { + // no change + continue + } else if !seen[oldSectorID] { + // updated root + deleted = append(deleted, oldSectorID) // mark for pruning + seen[oldSectorID] = true + } + + if _, err := updateRootStmt.Exec(contractDBID, newSectorID, i); err != nil { + return fmt.Errorf("failed to update sector root: %w", err) + } + } + + cleaned, err := cleanupDanglingRoots(tx, contractDBID, int64(len(roots))) + if err != nil { + return fmt.Errorf("failed to cleanup dangling roots: %w", err) + } + for _, sectorID := range cleaned { + if seen[sectorID] { + continue + } + deleted = append(deleted, sectorID) + } + + delta := appended - len(deleted) + if err := incrementNumericStat(tx, metricContractSectors, delta, time.Now()); err != nil { + return fmt.Errorf("failed to update contract sectors: %w", err) + } + + if pruned, err := pruneSectors(tx, deleted); err != nil { + return fmt.Errorf("failed to prune sectors: %w", err) + } else if len(pruned) > 0 { + s.log.Debug("pruned sectors", zap.Int("count", len(pruned)), zap.Stringers("sectors", pruned)) + } + return nil + }) +} + // ReviseContract atomically updates a contract's revision and sectors func (s *Store) ReviseContract(revision contracts.SignedRevision, roots []types.Hash256, usage contracts.Usage, sectorChanges []contracts.SectorChange) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { // revise the contract contractID, err := reviseContract(tx, revision) if err != nil { @@ -347,183 +480,83 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, roots []types. // SectorRoots returns the sector roots for a contract. The contract must be // locked before calling. -func (s *Store) SectorRoots(contractID types.FileContractID) (roots []types.Hash256, err error) { - var dbID int64 - err = s.queryRow(`SELECT id FROM contracts WHERE contract_id=$1;`, sqlHash256(contractID)).Scan(&dbID) - if err != nil { - return nil, fmt.Errorf("failed to get contract id: %w", err) - } - - // note: OFFSET is significantly slower than using the last root_index - const query = `SELECT s.sector_root, root_index FROM contract_sector_roots c -INNER JOIN stored_sectors s ON (c.sector_id = s.id) -WHERE c.contract_id=$1 AND root_index > $2 -ORDER BY root_index ASC -LIMIT 5000` - - stmt, err := s.prepare(query) - if err != nil { - return nil, fmt.Errorf("failed to prepare query: %w", err) - } - defer stmt.Close() - - lastIndex := int64(-1) // root_index can be 0 - for { - start := time.Now() - n, err := func() (n int, err error) { - rows, err := stmt.Query(dbID, lastIndex) - if err != nil { - return 0, err - } - defer rows.Close() - - for rows.Next() { - var root types.Hash256 - - if err := rows.Scan((*sqlHash256)(&root), &lastIndex); err != nil { - return 0, fmt.Errorf("failed to scan sector root: %w", err) - } - roots = append(roots, root) - n++ - } - return n, nil - }() +func (s *Store) SectorRoots() (roots map[types.FileContractID][]types.Hash256, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT s.sector_root, c.contract_id FROM contract_sector_roots cr +INNER JOIN stored_sectors s ON (cr.sector_id = s.id) +INNER JOIN contracts c ON (cr.contract_id = c.id) +ORDER BY cr.contract_id, cr.root_index ASC;` + + rows, err := tx.Query(query) if err != nil { - return nil, err - } else if n < 5000 { - return roots, nil + return err } - s.log.Debug("loaded sectors", zap.Int("count", n), zap.Stringer("contractID", contractID), zap.Duration("elapsed", time.Since(start))) - } -} + defer rows.Close() -// ContractAction calls contractFn on every contract in the store that -// needs a lifecycle action performed. -func (s *Store) ContractAction(height uint64, contractFn func(types.FileContractID, uint64, string)) error { - tx := &dbTxn{s} - actions, err := rebroadcastContractActions(tx, height) - if err != nil { - return fmt.Errorf("failed to get rebroadcast actions: %w", err) - } - for _, action := range actions { - contractFn(action.ID, height, action.Action) - } - actions, err = rejectContractActions(tx, height) - if err != nil { - return fmt.Errorf("failed to get reject actions: %w", err) - } - for _, action := range actions { - contractFn(action.ID, height, action.Action) - } - actions, err = revisionContractActions(tx, height) - if err != nil { - return fmt.Errorf("failed to get revision actions: %w", err) - } - for _, action := range actions { - contractFn(action.ID, height, action.Action) - } - actions, err = resolveContractActions(tx, height) - if err != nil { - return fmt.Errorf("failed to get resolve actions: %w", err) - } - for _, action := range actions { - contractFn(action.ID, height, action.Action) - } - actions, err = expireContractActions(tx, height) - if err != nil { - return fmt.Errorf("failed to get expire actions: %w", err) - } - for _, action := range actions { - contractFn(action.ID, height, action.Action) - } - return nil -} + roots = make(map[types.FileContractID][]types.Hash256) + for rows.Next() { + var contractID types.FileContractID + var root types.Hash256 -// ContractFormationSet returns the set of transactions that were created during -// contract formation. -func (s *Store) ContractFormationSet(id types.FileContractID) ([]types.Transaction, error) { - var buf []byte - err := s.queryRow(`SELECT formation_txn_set FROM contracts WHERE contract_id=$1;`, sqlHash256(id)).Scan(&buf) - if err != nil { - return nil, fmt.Errorf("failed to query formation txn set: %w", err) - } - var txnSet []types.Transaction - if err := decodeTxnSet(buf, &txnSet); err != nil { - return nil, fmt.Errorf("failed to decode formation txn set: %w", err) - } - return txnSet, nil + if err := rows.Scan(decode(&root), decode(&contractID)); err != nil { + return fmt.Errorf("failed to scan sector root: %w", err) + } + roots[contractID] = append(roots[contractID], root) + } + return rows.Err() + }) + return } -// ExpireContract expires a contract and updates its status. Should only be used -// if the contract is active or pending. -func (s *Store) ExpireContract(id types.FileContractID, status contracts.ContractStatus) error { - return s.transaction(func(tx txn) error { - var contractID int64 - err := tx.QueryRow(`SELECT id FROM contracts WHERE contract_id=$1;`, sqlHash256(id)).Scan(&contractID) +// ContractActions returns the contract lifecycle actions for the given index. +func (s *Store) ContractActions(index types.ChainIndex, revisionBroadcastHeight uint64) (actions contracts.LifecycleActions, err error) { + err = s.transaction(func(tx *txn) error { + actions.RebroadcastFormation, err = rebroadcastContracts(tx) if err != nil { - return fmt.Errorf("failed to get contract id: %w", err) + return fmt.Errorf("failed to get formation broadcast actions: %w", err) } - // get the contract and check if the status is already set - contract, err := getContract(tx, contractID) + actions.BroadcastRevision, err = broadcastRevision(tx, index, revisionBroadcastHeight) if err != nil { - return fmt.Errorf("failed to get contract: %w", err) - } else if contract.Status == status { - return nil + return fmt.Errorf("failed to get revision broadcast actions: %w", err) + } + actions.BroadcastProof, err = proofContracts(tx, index) + if err != nil { + return fmt.Errorf("failed to get proof broadcast actions: %w", err) } - if contract.Status == contracts.ContractStatusActive || contract.Status == contracts.ContractStatusPending { - // successful, failed and rejected contracts should have already had - // their collateral removed from the metrics - if err := incrementCurrencyStat(tx, metricLockedCollateral, contract.LockedCollateral, true, time.Now()); err != nil { - return fmt.Errorf("failed to increment locked collateral stat: %w", err) - } else if err := incrementCurrencyStat(tx, metricRiskedCollateral, contract.Usage.RiskedCollateral, true, time.Now()); err != nil { - return fmt.Errorf("failed to increment risked collateral stat: %w", err) - } else if err := incrementPotentialRevenueMetrics(tx, contract.Usage, true); err != nil { - return fmt.Errorf("failed to decrement potential revenue: %w", err) - } + // v2 + actions.RebroadcastV2Formation, err = rebroadcastV2Contracts(tx) + if err != nil { + return fmt.Errorf("failed to get v2 formation broadcast actions: %w", err) } - // if the contract is successful and the final revision is confirmed, - // increment the earned revenue metrics - // - // note: if the final revision is not confirmed, the earned revenue - // may be incorrect. - if status == contracts.ContractStatusSuccessful && contract.RevisionConfirmed { - if err := incrementEarnedRevenueMetrics(tx, contract.Usage, false); err != nil { - return fmt.Errorf("failed to increment earned revenue: %w", err) - } + actions.BroadcastV2Revision, err = broadcastV2Revision(tx, index, revisionBroadcastHeight) + if err != nil { + return fmt.Errorf("failed to get v2 revision broadcast actions: %w", err) } - // update the contract status - if err := setContractStatus(tx, id, status); err != nil { - return fmt.Errorf("failed to set contract status: %w", err) + + actions.BroadcastV2Proof, err = proofV2Contracts(tx, index) + if err != nil { + return fmt.Errorf("failed to get v2 proof broadcast actions: %w", err) + } + actions.BroadcastV2Expiration, err = expireV2Contracts(tx, index) + if err != nil { + return fmt.Errorf("failed to get v2 expiration broadcast actions: %w", err) } return nil }) -} - -// LastContractChange gets the last consensus change processed by the -// contractor. -func (s *Store) LastContractChange() (id modules.ConsensusChangeID, err error) { - err = s.queryRow(`SELECT contracts_last_processed_change FROM global_settings`).Scan(nullable((*sqlHash256)(&id))) - if errors.Is(err, sql.ErrNoRows) { - return modules.ConsensusChangeBeginning, nil - } else if err != nil { - return modules.ConsensusChangeBeginning, fmt.Errorf("failed to query last contract change: %w", err) - } return } -// UpdateContractState atomically updates the contractor's state. -func (s *Store) UpdateContractState(ccID modules.ConsensusChangeID, height uint64, fn func(contracts.UpdateStateTransaction) error) error { - return s.transaction(func(tx txn) error { - utx := &updateContractsTxn{tx: tx} - if err := fn(utx); err != nil { - return err - } else if err := utx.setLastChangeID(ccID, height); err != nil { - return fmt.Errorf("failed to update last change id: %w", err) - } - return nil +// ContractChainIndexElement returns the chain index element for the given height. +func (s *Store) ContractChainIndexElement(index types.ChainIndex) (element types.ChainIndexElement, err error) { + err = s.transaction(func(tx *txn) error { + err := tx.QueryRow(`SELECT leaf_index, merkle_proof FROM contracts_v2_chain_index_elements WHERE id=? AND height=?`, encode(index.ID), index.Height).Scan(decode(&element.LeafIndex), decode(&element.MerkleProof)) + element.ChainIndex = index + element.ID = types.Hash256(index.ID) + return err }) + return } // ExpireContractSectors expires all sectors that are no longer covered by an @@ -543,9 +576,26 @@ func (s *Store) ExpireContractSectors(height uint64) error { } } -func getContract(tx txn, contractID int64) (contracts.Contract, error) { +// ExpireV2ContractSectors expires all sectors that are no longer covered by an +// active contract. +func (s *Store) ExpireV2ContractSectors(height uint64) error { + log := s.log.Named("ExpireV2ContractSectors").With(zap.Uint64("height", height)) + // delete in batches to avoid holding a lock on the database for too long + for i := 0; ; i++ { + expired, removed, err := s.batchExpireV2ContractSectors(height) + if err != nil { + return fmt.Errorf("failed to prune sectors: %w", err) + } else if expired == 0 { + return nil + } + log.Debug("removed sectors", zap.Int("expired", expired), zap.Stringers("removed", removed), zap.Int("batch", i)) + jitterSleep(time.Millisecond) // allow other transactions to run + } +} + +func getContract(tx *txn, contractID int64) (contracts.Contract, error) { const query = `SELECT c.contract_id, rt.contract_id AS renewed_to, rf.contract_id AS renewed_from, c.contract_status, c.negotiation_height, c.formation_confirmed, - c.revision_number=c.confirmed_revision_number AS revision_confirmed, c.resolution_height, c.locked_collateral, c.rpc_revenue, + COALESCE(c.revision_number=c.confirmed_revision_number, false) AS revision_confirmed, c.resolution_height, c.locked_collateral, c.rpc_revenue, c.storage_revenue, c.ingress_revenue, c.egress_revenue, c.account_funding, c.risked_collateral, c.raw_revision, c.host_sig, c.renter_sig FROM contracts c LEFT JOIN contracts rt ON (c.renewed_to = rt.id) @@ -560,9 +610,9 @@ func getContract(tx txn, contractID int64) (contracts.Contract, error) { } // appendSector appends a new sector root to a contract. -func appendSector(tx txn, contractID int64, root types.Hash256, index uint64) error { +func appendSector(tx *txn, contractID int64, root types.Hash256, index uint64) error { var sectorID int64 - err := tx.QueryRow(`INSERT INTO contract_sector_roots (contract_id, sector_id, root_index) SELECT $1, id, $2 FROM stored_sectors WHERE sector_root=$3 RETURNING sector_id`, contractID, index, sqlHash256(root)).Scan(§orID) + err := tx.QueryRow(`INSERT INTO contract_sector_roots (contract_id, sector_id, root_index) SELECT $1, id, $2 FROM stored_sectors WHERE sector_root=$3 RETURNING sector_id`, contractID, index, encode(root)).Scan(§orID) if err != nil { return err } else if err := incrementNumericStat(tx, metricContractSectors, 1, time.Now()); err != nil { @@ -572,7 +622,7 @@ func appendSector(tx txn, contractID int64, root types.Hash256, index uint64) er } // updateSector updates a contract sector root in place and returns the old sector root -func updateSector(tx txn, contractID int64, root types.Hash256, index uint64) (types.Hash256, error) { +func updateSector(tx *txn, contractID int64, root types.Hash256, index uint64) (types.Hash256, error) { row := tx.QueryRow(`SELECT csr.id, csr.sector_id, ss.sector_root FROM contract_sector_roots csr INNER JOIN stored_sectors ss ON (csr.sector_id = ss.id) @@ -583,7 +633,7 @@ WHERE contract_id=$1 AND root_index=$2`, contractID, index) } var newSectorID int64 - err = tx.QueryRow(`SELECT id FROM stored_sectors WHERE sector_root=$1`, sqlHash256(root)).Scan(&newSectorID) + err = tx.QueryRow(`SELECT id FROM stored_sectors WHERE sector_root=$1`, encode(root)).Scan(&newSectorID) if err != nil { return types.Hash256{}, fmt.Errorf("failed to get new sector id: %w", err) } @@ -604,7 +654,7 @@ WHERE contract_id=$1 AND root_index=$2`, contractID, index) } // swapSectors swaps two sector roots in a contract and returns the sector roots -func swapSectors(tx txn, contractID int64, i, j uint64) (map[types.Hash256]bool, error) { +func swapSectors(tx *txn, contractID int64, i, j uint64) (map[types.Hash256]bool, error) { if i == j { return nil, nil } @@ -660,7 +710,7 @@ ORDER BY root_index ASC;`, contractID, i, j) // trimSectors deletes the last n sector roots for a contract and returns the // deleted sector roots in order. -func trimSectors(tx txn, contractID int64, n uint64, log *zap.Logger) ([]types.Hash256, error) { +func trimSectors(tx *txn, contractID int64, n uint64, log *zap.Logger) ([]types.Hash256, error) { selectStmt, err := tx.Prepare(`SELECT csr.id, csr.sector_id, ss.sector_root FROM contract_sector_roots csr INNER JOIN stored_sectors ss ON (csr.sector_id=ss.id) WHERE csr.contract_id=$1 @@ -683,7 +733,7 @@ LIMIT 1`) var root types.Hash256 var sectorID int64 - if err := selectStmt.QueryRow(contractID).Scan(&contractSectorID, §orID, (*sqlHash256)(&root)); err != nil { + if err := selectStmt.QueryRow(contractID).Scan(&contractSectorID, §orID, decode(&root)); err != nil { return nil, fmt.Errorf("failed to get sector root: %w", err) } else if res, err := deleteStmt.Exec(contractSectorID); err != nil { return nil, fmt.Errorf("failed to delete sector root: %w", err) @@ -710,18 +760,75 @@ LIMIT 1`) return roots, nil } +func deleteExpiredContractSectors(tx *txn, height uint64) (sectorIDs []int64, err error) { + const query = `DELETE FROM contract_sector_roots +WHERE id IN (SELECT csr.id FROM contract_sector_roots csr +INNER JOIN contracts c ON (csr.contract_id=c.id) +-- past proof window or not confirmed and past the rebroadcast height +WHERE c.window_end < $1 OR c.contract_status=$2 LIMIT $3) +RETURNING sector_id;` + rows, err := tx.Query(query, height, contracts.ContractStatusRejected, sqlSectorBatchSize) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + sectorIDs = append(sectorIDs, id) + } + return sectorIDs, nil +} + +func deleteExpiredV2ContractSectors(tx *txn, height uint64) (sectorIDs []int64, err error) { + const query = `DELETE FROM contract_v2_sector_roots +WHERE id IN (SELECT csr.id FROM contract_v2_sector_roots csr +INNER JOIN contracts_v2 c ON (csr.contract_id=c.id) +-- past proof window or not confirmed and past the rebroadcast height +WHERE c.window_end < $1 OR c.contract_status=$2 LIMIT $3) +RETURNING sector_id;` + rows, err := tx.Query(query, height, contracts.ContractStatusRejected, sqlSectorBatchSize) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + sectorIDs = append(sectorIDs, id) + } + return sectorIDs, nil +} + +// updateResolvedV2Contract clears a contract and returns its ID +func updateResolvedV2Contract(tx *txn, contractID types.FileContractID, clearing types.V2FileContract, renewedDBID int64) (dbID int64, err error) { + // add the final usage to the contract revenue + const clearQuery = `UPDATE contracts_v2 SET (renewed_to, revision_number, raw_revision) = ($1, $2, $3) WHERE contract_id=$4 RETURNING id;` + err = tx.QueryRow(clearQuery, + renewedDBID, + encode(clearing.RevisionNumber), + encode(clearing), + encode(contractID), + ).Scan(&dbID) + return +} + // clearContract clears a contract and returns its ID -func clearContract(tx txn, revision contracts.SignedRevision, renewedDBID int64, usage contracts.Usage) (dbID int64, err error) { +func clearContract(tx *txn, revision contracts.SignedRevision, renewedDBID int64, usage contracts.Usage) (dbID int64, err error) { // get the existing contract's current usage var total contracts.Usage - err = tx.QueryRow(`SELECT id, rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral FROM contracts WHERE contract_id=$1`, sqlHash256(revision.Revision.ParentID)).Scan( + err = tx.QueryRow(`SELECT id, rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral FROM contracts WHERE contract_id=$1`, encode(revision.Revision.ParentID)).Scan( &dbID, - (*sqlCurrency)(&total.RPCRevenue), - (*sqlCurrency)(&total.StorageRevenue), - (*sqlCurrency)(&total.IngressRevenue), - (*sqlCurrency)(&total.EgressRevenue), - (*sqlCurrency)(&total.AccountFunding), - (*sqlCurrency)(&total.RiskedCollateral)) + decode(&total.RPCRevenue), + decode(&total.StorageRevenue), + decode(&total.IngressRevenue), + decode(&total.EgressRevenue), + decode(&total.AccountFunding), + decode(&total.RiskedCollateral)) if err != nil { return 0, fmt.Errorf("failed to get existing usage: %w", err) } @@ -731,258 +838,330 @@ func clearContract(tx txn, revision contracts.SignedRevision, renewedDBID int64, const clearQuery = `UPDATE contracts SET (renewed_to, revision_number, host_sig, renter_sig, raw_revision, rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE id=$12 RETURNING id;` err = tx.QueryRow(clearQuery, renewedDBID, - sqlUint64(revision.Revision.RevisionNumber), - sqlHash512(revision.HostSignature), - sqlHash512(revision.RenterSignature), - encodeRevision(revision.Revision), - sqlCurrency(total.RPCRevenue), - sqlCurrency(total.StorageRevenue), - sqlCurrency(total.IngressRevenue), - sqlCurrency(total.EgressRevenue), - sqlCurrency(total.AccountFunding), - sqlCurrency(total.RiskedCollateral), + encode(revision.Revision.RevisionNumber), + encode(revision.HostSignature), + encode(revision.RenterSignature), + encode(revision.Revision), + encode(total.RPCRevenue), + encode(total.StorageRevenue), + encode(total.IngressRevenue), + encode(total.EgressRevenue), + encode(total.AccountFunding), + encode(total.RiskedCollateral), dbID, ).Scan(&dbID) return } // reviseContract revises a contract and returns its ID -func reviseContract(tx txn, revision contracts.SignedRevision) (dbID int64, err error) { +func reviseContract(tx *txn, revision contracts.SignedRevision) (dbID int64, err error) { err = tx.QueryRow(`UPDATE contracts SET (revision_number, window_start, window_end, raw_revision, host_sig, renter_sig) = ($1, $2, $3, $4, $5, $6) WHERE contract_id=$7 RETURNING id;`, - sqlUint64(revision.Revision.RevisionNumber), + encode(revision.Revision.RevisionNumber), revision.Revision.WindowStart, revision.Revision.WindowEnd, - encodeRevision(revision.Revision), - sqlHash512(revision.HostSignature), - sqlHash512(revision.RenterSignature), - sqlHash256(revision.Revision.ParentID), + encode(revision.Revision), + encode(revision.HostSignature), + encode(revision.RenterSignature), + encode(revision.Revision.ParentID), ).Scan(&dbID) return } -func incrementContractUsage(tx txn, dbID int64, usage contracts.Usage) error { - const query = `SELECT rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral FROM contracts WHERE id=$1;` - var total contracts.Usage - err := tx.QueryRow(query, dbID).Scan( - (*sqlCurrency)(&total.RPCRevenue), - (*sqlCurrency)(&total.StorageRevenue), - (*sqlCurrency)(&total.IngressRevenue), - (*sqlCurrency)(&total.EgressRevenue), - (*sqlCurrency)(&total.AccountFunding), - (*sqlCurrency)(&total.RiskedCollateral)) +func rebroadcastContracts(tx *txn) (rebroadcast [][]types.Transaction, err error) { + rows, err := tx.Query(`SELECT formation_txn_set FROM contracts WHERE formation_confirmed=false AND contract_status <> ?`, contracts.ContractStatusRejected) if err != nil { - return fmt.Errorf("failed to get existing revenue: %w", err) + return nil, err } - total = total.Add(usage) - var updatedID int64 - err = tx.QueryRow(`UPDATE contracts SET (rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral) = ($1, $2, $3, $4, $5, $6) WHERE id=$7 RETURNING id;`, - sqlCurrency(total.RPCRevenue), - sqlCurrency(total.StorageRevenue), - sqlCurrency(total.IngressRevenue), - sqlCurrency(total.EgressRevenue), - sqlCurrency(total.AccountFunding), - sqlCurrency(total.RiskedCollateral), - dbID).Scan(&updatedID) - if err != nil { - return fmt.Errorf("failed to update contract revenue: %w", err) + defer rows.Close() + + for rows.Next() { + var buf []byte + if err := rows.Scan(&buf); err != nil { + return nil, fmt.Errorf("failed to scan contract id: %w", err) + } + var formationSet []types.Transaction + if err := decodeTxnSet(buf, &formationSet); err != nil { + return nil, fmt.Errorf("failed to decode formation txn set: %w", err) + } + rebroadcast = append(rebroadcast, formationSet) } - return nil + if err := rows.Err(); err != nil { + return nil, err + } + return } -func rebroadcastContractActions(tx txn, height uint64) (actions []contractAction, _ error) { - // formation not confirmed, within rebroadcast window - const query = `SELECT contract_id FROM contracts WHERE formation_confirmed=false AND negotiation_height BETWEEN $1 AND $2` +func broadcastRevision(tx *txn, index types.ChainIndex, revisionBroadcastHeight uint64) (revisions []contracts.SignedRevision, err error) { + const query = `SELECT raw_revision, host_sig, renter_sig + FROM contracts + WHERE formation_confirmed=true AND confirmed_revision_number != revision_number AND window_start BETWEEN ? AND ?` - var minNegotiationHeight uint64 - if height >= contracts.RebroadcastBuffer { - minNegotiationHeight = height - contracts.RebroadcastBuffer - } - - rows, err := tx.Query(query, minNegotiationHeight, height) + rows, err := tx.Query(query, index.Height, revisionBroadcastHeight) if err != nil { - return nil, fmt.Errorf("failed to query contracts: %w", err) + return nil, err } defer rows.Close() for rows.Next() { - action := contractAction{ - Action: contracts.ActionBroadcastFormation, - } - if err := rows.Scan((*sqlHash256)(&action.ID)); err != nil { + var rev contracts.SignedRevision + err = rows.Scan( + decode(&rev.Revision), + decode(&rev.HostSignature), + decode(&rev.RenterSignature)) + if err != nil { return nil, fmt.Errorf("failed to scan contract: %w", err) } - actions = append(actions, action) + revisions = append(revisions, rev) + } + if err := rows.Err(); err != nil { + return nil, err } return } -func rejectContractActions(tx txn, height uint64) (actions []contractAction, _ error) { - // formation not confirmed, not rejected, outside rebroadcast window - const query = `SELECT contract_id FROM contracts WHERE formation_confirmed=false AND negotiation_height < $1 AND contract_status != $2` +func proofContracts(tx *txn, index types.ChainIndex) (revisions []contracts.SignedRevision, err error) { + const query = `SELECT raw_revision, host_sig, renter_sig + FROM contracts + WHERE formation_confirmed AND resolution_height IS NULL AND window_start <= $1 AND window_end > $1` - var maxRebroadcastHeight uint64 - if height >= contracts.RebroadcastBuffer { - maxRebroadcastHeight = height - contracts.RebroadcastBuffer + rows, err := tx.Query(query, index.Height) + if err != nil { + return nil, err } + defer rows.Close() - rows, err := tx.Query(query, maxRebroadcastHeight, contracts.ContractStatusRejected) + for rows.Next() { + contract, err := scanSignedRevision(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan contract: %w", err) + } + revisions = append(revisions, contract) + } + if err := rows.Err(); err != nil { + return nil, err + } + return +} + +func rebroadcastV2Contracts(tx *txn) (rebroadcast []contracts.V2FormationTransactionSet, err error) { + rows, err := tx.Query(`SELECT formation_txn_set, formation_txn_set_basis FROM contracts_v2 WHERE confirmation_index IS NULL AND contract_status <> ?`, contracts.ContractStatusRejected) if err != nil { - return nil, fmt.Errorf("failed to query contracts: %w", err) + return nil, err } defer rows.Close() for rows.Next() { - action := contractAction{ - Action: contracts.ActionReject, + var formationSet contracts.V2FormationTransactionSet + var buf []byte + if err := rows.Scan(&buf, decode(&formationSet.Basis)); err != nil { + return nil, fmt.Errorf("failed to scan contract id: %w", err) } - if err := rows.Scan((*sqlHash256)(&action.ID)); err != nil { - return nil, fmt.Errorf("failed to scan contract: %w", err) + dec := types.NewBufDecoder(buf) + types.DecodeSlice(dec, &formationSet.TransactionSet) + if err := dec.Err(); err != nil { + return nil, fmt.Errorf("failed to decode formation txn set: %w", err) } - actions = append(actions, action) + rebroadcast = append(rebroadcast, formationSet) + } + if err := rows.Err(); err != nil { + return nil, err } return } -func revisionContractActions(tx txn, height uint64) (actions []contractAction, _ error) { - // formation confirmed, revision not confirmed, just outside proof window - const query = `SELECT contract_id FROM contracts WHERE formation_confirmed=true AND confirmed_revision_number != revision_number AND window_start BETWEEN $1 AND $2` - minRevisionHeight := height + contracts.RevisionSubmissionBuffer - rows, err := tx.Query(query, height, minRevisionHeight) +func broadcastV2Revision(tx *txn, index types.ChainIndex, revisionBroadcastHeight uint64) (revisions []types.V2FileContractRevision, err error) { + const query = `SELECT c.raw_revision, c.contract_id, cs.leaf_index, cs.merkle_proof, cs.raw_contract + FROM contracts_v2 c + INNER JOIN contract_v2_state_elements cs ON (c.id = cs.contract_id) + WHERE c.confirmation_index IS NOT NULL AND c.resolution_index IS NULL AND cs.revision_number != c.revision_number AND c.window_start BETWEEN ? AND ?` + + rows, err := tx.Query(query, index.Height, revisionBroadcastHeight) if err != nil { - return nil, fmt.Errorf("failed to query contracts: %w", err) + return nil, err } defer rows.Close() for rows.Next() { - action := contractAction{ - Action: contracts.ActionBroadcastFinalRevision, - } - if err := rows.Scan((*sqlHash256)(&action.ID)); err != nil { + var rev types.V2FileContractRevision + + err = rows.Scan(decode(&rev.Revision), + decode(&rev.Parent.ID), + decode(&rev.Parent.LeafIndex), + decode(&rev.Parent.MerkleProof), + decode(&rev.Parent.V2FileContract)) + if err != nil { return nil, fmt.Errorf("failed to scan contract: %w", err) } - actions = append(actions, action) + revisions = append(revisions, rev) + } + if err := rows.Err(); err != nil { + return nil, err } return } -func resolveContractActions(tx txn, height uint64) (actions []contractAction, _ error) { - // formation confirmed, resolution not confirmed, status active, in proof window - const query = `SELECT contract_id FROM contracts WHERE formation_confirmed=true AND resolution_height IS NULL AND window_start <= $1 AND window_end > $1` - rows, err := tx.Query(query, height) +func proofV2Contracts(tx *txn, index types.ChainIndex) (elements []types.V2FileContractElement, err error) { + const query = `SELECT c.contract_id, cs.raw_contract, cs.leaf_index, cs.merkle_proof + FROM contracts_v2 c + INNER JOIN contract_v2_state_elements cs ON (c.id = cs.contract_id) + WHERE c.confirmation_index IS NOT NULL AND c.resolution_index IS NULL AND c.window_start <= $1 AND c.window_end > $1` + + rows, err := tx.Query(query, index.Height) if err != nil { - return nil, fmt.Errorf("failed to query contracts: %w", err) + return nil, err } defer rows.Close() for rows.Next() { - action := contractAction{ - Action: contracts.ActionBroadcastResolution, - } - if err := rows.Scan((*sqlHash256)(&action.ID)); err != nil { + var fce types.V2FileContractElement + if err := rows.Scan(decode(&fce.ID), decode(&fce.V2FileContract), decode(&fce.LeafIndex), decode(&fce.MerkleProof)); err != nil { return nil, fmt.Errorf("failed to scan contract: %w", err) } - actions = append(actions, action) + elements = append(elements, fce) + } + if err := rows.Err(); err != nil { + return nil, err } return } -func expireContractActions(tx txn, height uint64) (actions []contractAction, _ error) { - const query = `SELECT contract_id FROM contracts WHERE window_end < $1 AND contract_status = $2;` - rows, err := tx.Query(query, height, contracts.ContractStatusActive) +func expireV2Contracts(tx *txn, index types.ChainIndex) (elements []types.V2FileContractElement, err error) { + const query = `SELECT c.contract_id, cs.raw_contract, cs.leaf_index, cs.merkle_proof + FROM contracts_v2 c + INNER JOIN contract_v2_state_elements cs ON (c.id = cs.contract_id) + WHERE c.resolution_index IS NULL AND c.window_end <= $1` + + rows, err := tx.Query(query, index.Height) if err != nil { - return nil, fmt.Errorf("failed to query contracts: %w", err) + return nil, err } defer rows.Close() for rows.Next() { - action := contractAction{ - Action: contracts.ActionExpire, - } - if err := rows.Scan((*sqlHash256)(&action.ID)); err != nil { + var fce types.V2FileContractElement + if err := rows.Scan(decode(&fce.ID), decode(&fce.V2FileContract), decode(&fce.LeafIndex), decode(&fce.MerkleProof)); err != nil { return nil, fmt.Errorf("failed to scan contract: %w", err) } - actions = append(actions, action) + elements = append(elements, fce) + } + if err := rows.Err(); err != nil { + return nil, err } return } -func renterDBID(tx txn, renterKey types.PublicKey) (int64, error) { +func incrementContractUsage(tx *txn, dbID int64, usage contracts.Usage) error { + const query = `SELECT rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral FROM contracts WHERE id=$1;` + var total contracts.Usage + err := tx.QueryRow(query, dbID).Scan( + decode(&total.RPCRevenue), + decode(&total.StorageRevenue), + decode(&total.IngressRevenue), + decode(&total.EgressRevenue), + decode(&total.AccountFunding), + decode(&total.RiskedCollateral)) + if err != nil { + return fmt.Errorf("failed to get existing revenue: %w", err) + } + total = total.Add(usage) + var updatedID int64 + err = tx.QueryRow(`UPDATE contracts SET (rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, risked_collateral) = ($1, $2, $3, $4, $5, $6) WHERE id=$7 RETURNING id;`, + encode(total.RPCRevenue), + encode(total.StorageRevenue), + encode(total.IngressRevenue), + encode(total.EgressRevenue), + encode(total.AccountFunding), + encode(total.RiskedCollateral), + dbID).Scan(&updatedID) + if err != nil { + return fmt.Errorf("failed to update contract revenue: %w", err) + } + return nil +} + +func renterDBID(tx *txn, renterKey types.PublicKey) (int64, error) { var dbID int64 - err := tx.QueryRow(`SELECT id FROM contract_renters WHERE public_key=$1;`, sqlHash256(renterKey)).Scan(&dbID) + err := tx.QueryRow(`SELECT id FROM contract_renters WHERE public_key=$1;`, encode(renterKey)).Scan(&dbID) if err == nil { return dbID, nil } else if !errors.Is(err, sql.ErrNoRows) { return 0, fmt.Errorf("failed to get renter: %w", err) } - err = tx.QueryRow(`INSERT INTO contract_renters (public_key) VALUES ($1) RETURNING id;`, sqlHash256(renterKey)).Scan(&dbID) + err = tx.QueryRow(`INSERT INTO contract_renters (public_key) VALUES ($1) RETURNING id;`, encode(renterKey)).Scan(&dbID) return dbID, err } -func insertContract(tx txn, revision contracts.SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, initialUsage contracts.Usage, negotationHeight uint64) (dbID int64, err error) { +func insertContract(tx *txn, revision contracts.SignedRevision, formationSet []types.Transaction, lockedCollateral types.Currency, initialUsage contracts.Usage, negotationHeight uint64) (dbID int64, err error) { const query = `INSERT INTO contracts (contract_id, renter_id, locked_collateral, rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, registry_read, registry_write, account_funding, risked_collateral, revision_number, negotiation_height, window_start, window_end, formation_txn_set, -raw_revision, host_sig, renter_sig, confirmed_revision_number, formation_confirmed, contract_status) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22) RETURNING id;` +raw_revision, host_sig, renter_sig, confirmed_revision_number, contract_status, formation_confirmed) VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, false) RETURNING id;` renterID, err := renterDBID(tx, revision.RenterKey()) if err != nil { return 0, fmt.Errorf("failed to get renter id: %w", err) } err = tx.QueryRow(query, - sqlHash256(revision.Revision.ParentID), + encode(revision.Revision.ParentID), renterID, - sqlCurrency(lockedCollateral), - sqlCurrency(initialUsage.RPCRevenue), - sqlCurrency(initialUsage.StorageRevenue), - sqlCurrency(initialUsage.IngressRevenue), - sqlCurrency(initialUsage.EgressRevenue), - sqlCurrency(initialUsage.RegistryRead), - sqlCurrency(initialUsage.RegistryWrite), - sqlCurrency(initialUsage.AccountFunding), - sqlCurrency(initialUsage.RiskedCollateral), - sqlUint64(revision.Revision.RevisionNumber), + encode(lockedCollateral), + encode(initialUsage.RPCRevenue), + encode(initialUsage.StorageRevenue), + encode(initialUsage.IngressRevenue), + encode(initialUsage.EgressRevenue), + encode(initialUsage.RegistryRead), + encode(initialUsage.RegistryWrite), + encode(initialUsage.AccountFunding), + encode(initialUsage.RiskedCollateral), + encode(revision.Revision.RevisionNumber), negotationHeight, // stored as int64 for queries, should never overflow revision.Revision.WindowStart, // stored as int64 for queries, should never overflow revision.Revision.WindowEnd, // stored as int64 for queries, should never overflow encodeTxnSet(formationSet), - encodeRevision(revision.Revision), - sqlHash512(revision.HostSignature), - sqlHash512(revision.RenterSignature), - sqlUint64(0), // confirmed_revision_number - false, // formation_confirmed + encode(revision.Revision), + encode(revision.HostSignature), + encode(revision.RenterSignature), + encode(0), // confirmed_revision_number contracts.ContractStatusPending, ).Scan(&dbID) if err != nil { return 0, fmt.Errorf("failed to insert contract: %w", err) } - // increment the contract count metric - if err := incrementNumericStat(tx, metricPendingContracts, 1, time.Now()); err != nil { - return 0, fmt.Errorf("failed to track pending contracts: %w", err) - } - // increment the collateral metrics - if err := incrementCurrencyStat(tx, metricLockedCollateral, lockedCollateral, false, time.Now()); err != nil { - return 0, fmt.Errorf("failed to track locked collateral: %w", err) - } else if err := incrementCurrencyStat(tx, metricRiskedCollateral, initialUsage.RiskedCollateral, false, time.Now()); err != nil { - return 0, fmt.Errorf("failed to track risked collateral: %w", err) - } - // increment the potential revenue metrics - if err := incrementPotentialRevenueMetrics(tx, initialUsage, false); err != nil { - return 0, fmt.Errorf("failed to increment potential revenue: %w", err) - } return } -func encodeRevision(fcr types.FileContractRevision) []byte { - var buf bytes.Buffer - e := types.NewEncoder(&buf) - fcr.EncodeTo(e) - e.Flush() - return buf.Bytes() -} +func insertV2Contract(tx *txn, contract contracts.V2Contract, formationSet contracts.V2FormationTransactionSet) (dbID int64, err error) { + const query = `INSERT INTO contracts_v2 (contract_id, renter_id, locked_collateral, rpc_revenue, storage_revenue, ingress_revenue, +egress_revenue, account_funding, risked_collateral, revision_number, negotiation_height, window_start, window_end, formation_txn_set, +formation_txn_set_basis, raw_revision, contract_status) VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id;` + renterID, err := renterDBID(tx, contract.RenterPublicKey) + if err != nil { + return 0, fmt.Errorf("failed to get renter id: %w", err) + } -func decodeRevision(b []byte, fcr *types.FileContractRevision) error { - d := types.NewBufDecoder(b) - fcr.DecodeFrom(d) - return d.Err() + err = tx.QueryRow(query, + encode(contract.ID), + renterID, + encode(contract.V2FileContract.TotalCollateral), + encode(contract.Usage.RPCRevenue), + encode(contract.Usage.StorageRevenue), + encode(contract.Usage.IngressRevenue), + encode(contract.Usage.EgressRevenue), + encode(contract.Usage.AccountFunding), + encode(contract.Usage.RiskedCollateral), + encode(contract.RevisionNumber), + contract.NegotiationHeight, // stored as int64 for queries, should never overflow + contract.V2FileContract.ProofHeight, // stored as int64 for queries, should never overflow + contract.ExpirationHeight, // stored as int64 for queries, should never overflow + encodeSlice(formationSet.TransactionSet), + encode(formationSet.Basis), + encode(contract.V2FileContract), + contracts.V2ContractStatusPending, + ).Scan(&dbID) + if err != nil { + return 0, fmt.Errorf("failed to insert contract: %w", err) + } + return } func encodeTxnSet(txns []types.Transaction) []byte { @@ -1011,28 +1190,28 @@ func buildContractFilter(filter contracts.ContractFilter) (string, []any, error) if len(filter.ContractIDs) != 0 { whereClause = append(whereClause, `c.contract_id IN (`+queryPlaceHolders(len(filter.ContractIDs))+`)`) for _, value := range filter.ContractIDs { - queryParams = append(queryParams, sqlHash256(value)) + queryParams = append(queryParams, encode(value)) } } if len(filter.RenewedFrom) != 0 { whereClause = append(whereClause, `rf.contract_id IN (`+queryPlaceHolders(len(filter.RenewedFrom))+`)`) for _, value := range filter.RenewedFrom { - queryParams = append(queryParams, sqlHash256(value)) + queryParams = append(queryParams, encode(value)) } } if len(filter.RenewedTo) != 0 { whereClause = append(whereClause, `rt.contract_id IN (`+queryPlaceHolders(len(filter.RenewedTo))+`)`) for _, value := range filter.RenewedTo { - queryParams = append(queryParams, sqlHash256(value)) + queryParams = append(queryParams, encode(value)) } } if len(filter.RenterKey) != 0 { whereClause = append(whereClause, `r.public_key IN (`+queryPlaceHolders(len(filter.RenterKey))+`)`) for _, value := range filter.RenterKey { - queryParams = append(queryParams, sqlHash256(value)) + queryParams = append(queryParams, encode(value)) } } @@ -1085,104 +1264,71 @@ func buildOrderBy(filter contracts.ContractFilter) string { } func scanContract(row scanner) (c contracts.Contract, err error) { - var revisionBuf []byte var contractID types.FileContractID - var resolutionHeight sql.NullInt64 - err = row.Scan((*sqlHash256)(&contractID), - nullable((*sqlHash256)(&c.RenewedTo)), - nullable((*sqlHash256)(&c.RenewedFrom)), + err = row.Scan(decode(&contractID), + decodeNullable(&c.RenewedTo), + decodeNullable(&c.RenewedFrom), &c.Status, &c.NegotiationHeight, &c.FormationConfirmed, &c.RevisionConfirmed, - &resolutionHeight, - (*sqlCurrency)(&c.LockedCollateral), - (*sqlCurrency)(&c.Usage.RPCRevenue), - (*sqlCurrency)(&c.Usage.StorageRevenue), - (*sqlCurrency)(&c.Usage.IngressRevenue), - (*sqlCurrency)(&c.Usage.EgressRevenue), - (*sqlCurrency)(&c.Usage.AccountFunding), - (*sqlCurrency)(&c.Usage.RiskedCollateral), - &revisionBuf, - (*sqlHash512)(&c.HostSignature), - (*sqlHash512)(&c.RenterSignature), + decodeNullable(&c.ResolutionHeight), + decode(&c.LockedCollateral), + decode(&c.Usage.RPCRevenue), + decode(&c.Usage.StorageRevenue), + decode(&c.Usage.IngressRevenue), + decode(&c.Usage.EgressRevenue), + decode(&c.Usage.AccountFunding), + decode(&c.Usage.RiskedCollateral), + decode(&c.Revision), + decode(&c.HostSignature), + decode(&c.RenterSignature), ) if err != nil { return contracts.Contract{}, fmt.Errorf("failed to scan contract: %w", err) - } else if err := decodeRevision(revisionBuf, &c.Revision); err != nil { - return contracts.Contract{}, fmt.Errorf("failed to decode revision: %w", err) } else if c.Revision.ParentID != contractID { panic("contract id mismatch") - } else if resolutionHeight.Valid { - c.ResolutionHeight = uint64(resolutionHeight.Int64) } return } -func updateContractMetrics(tx txn, current, next contracts.ContractStatus) error { - if current == next { - return nil - } - - var initialMetric, finalMetric string - switch current { - case contracts.ContractStatusPending: - initialMetric = metricPendingContracts - case contracts.ContractStatusRejected: - initialMetric = metricRejectedContracts - case contracts.ContractStatusActive: - initialMetric = metricActiveContracts - case contracts.ContractStatusSuccessful: - initialMetric = metricSuccessfulContracts - case contracts.ContractStatusFailed: - initialMetric = metricFailedContracts - default: - return fmt.Errorf("invalid prev contract status: %v", current) - } - switch next { - case contracts.ContractStatusPending: - finalMetric = metricPendingContracts - case contracts.ContractStatusRejected: - finalMetric = metricRejectedContracts - case contracts.ContractStatusActive: - finalMetric = metricActiveContracts - case contracts.ContractStatusSuccessful: - finalMetric = metricSuccessfulContracts - case contracts.ContractStatusFailed: - finalMetric = metricFailedContracts - default: - return fmt.Errorf("invalid contract status: %v", current) - } - - if err := incrementNumericStat(tx, initialMetric, -1, time.Now()); err != nil { - return fmt.Errorf("failed to decrement initial contract metric: %w", err) - } else if err := incrementNumericStat(tx, finalMetric, 1, time.Now()); err != nil { - return fmt.Errorf("failed to increment final contract metric: %w", err) +func scanV2Contract(row scanner) (c contracts.V2Contract, err error) { + err = row.Scan(decode(&c.ID), + decodeNullable(&c.RenewedTo), + decodeNullable(&c.RenewedFrom), + &c.Status, + &c.NegotiationHeight, + decodeNullable(&c.FormationIndex), + &c.RevisionConfirmed, + decodeNullable(&c.ResolutionIndex), + decode(&c.Usage.RPCRevenue), + decode(&c.Usage.StorageRevenue), + decode(&c.Usage.IngressRevenue), + decode(&c.Usage.EgressRevenue), + decode(&c.Usage.AccountFunding), + decode(&c.Usage.RiskedCollateral), + decode(&c.V2FileContract), + ) + if errors.Is(err, sql.ErrNoRows) { + err = contracts.ErrNotFound } - return nil + return } -func setContractStatus(tx txn, id types.FileContractID, status contracts.ContractStatus) error { - var current contracts.ContractStatus - if err := tx.QueryRow(`SELECT contract_status FROM contracts WHERE contract_id=$1`, sqlHash256(id)).Scan(¤t); err != nil { - return fmt.Errorf("failed to query contract status: %w", err) - } - - var dbID int64 - if err := tx.QueryRow(`UPDATE contracts SET contract_status=$1 WHERE contract_id=$2 RETURNING id;`, status, sqlHash256(id)).Scan(&dbID); err != nil { - return fmt.Errorf("failed to update contract status: %w", err) - } else if err := updateContractMetrics(tx, current, status); err != nil { - return fmt.Errorf("failed to update contract metrics: %w", err) - } - return nil +func scanSignedRevision(row scanner) (rev contracts.SignedRevision, err error) { + err = row.Scan( + decode(&rev.Revision), + decode(&rev.HostSignature), + decode(&rev.RenterSignature)) + return } func scanContractSectorRootRef(s scanner) (ref contractSectorRootRef, err error) { - err = s.Scan(&ref.dbID, &ref.sectorID, (*sqlHash256)(&ref.root)) + err = s.Scan(&ref.dbID, &ref.sectorID, decode(&ref.root)) return } -func incrementPotentialRevenueMetrics(tx txn, usage contracts.Usage, negative bool) error { +func incrementPotentialRevenueMetrics(tx *txn, usage contracts.Usage, negative bool) error { if err := incrementCurrencyStat(tx, metricPotentialRPCRevenue, usage.RPCRevenue, negative, time.Now()); err != nil { return fmt.Errorf("failed to increment rpc revenue stat: %w", err) } else if err := incrementCurrencyStat(tx, metricPotentialStorageRevenue, usage.StorageRevenue, negative, time.Now()); err != nil { @@ -1199,7 +1345,7 @@ func incrementPotentialRevenueMetrics(tx txn, usage contracts.Usage, negative bo return nil } -func incrementEarnedRevenueMetrics(tx txn, usage contracts.Usage, negative bool) error { +func incrementEarnedRevenueMetrics(tx *txn, usage contracts.Usage, negative bool) error { if err := incrementCurrencyStat(tx, metricEarnedRPCRevenue, usage.RPCRevenue, negative, time.Now()); err != nil { return fmt.Errorf("failed to increment rpc revenue stat: %w", err) } else if err := incrementCurrencyStat(tx, metricEarnedStorageRevenue, usage.StorageRevenue, negative, time.Now()); err != nil { diff --git a/persist/sqlite/contracts_test.go b/persist/sqlite/contracts_test.go index 9aac9c89..24a677c2 100644 --- a/persist/sqlite/contracts_test.go +++ b/persist/sqlite/contracts_test.go @@ -15,12 +15,37 @@ import ( ) func (s *Store) rootAtIndex(contractID types.FileContractID, rootIndex int64) (root types.Hash256, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT s.sector_root FROM contract_sector_roots csr INNER JOIN stored_sectors s ON (csr.sector_id = s.id) INNER JOIN contracts c ON (csr.contract_id = c.id) WHERE c.contract_id=$1 AND csr.root_index=$2;` - return tx.QueryRow(query, sqlHash256(contractID), rootIndex).Scan((*sqlHash256)(&root)) + return tx.QueryRow(query, encode(contractID), rootIndex).Scan(decode(&root)) + }) + return +} + +func (s *Store) dbRoots(contractID types.FileContractID) (roots []types.Hash256, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT s.sector_root FROM contract_sector_roots cr +INNER JOIN stored_sectors s ON (cr.sector_id = s.id) +INNER JOIN contracts c ON (cr.contract_id = c.id) +WHERE c.contract_id=$1 ORDER BY cr.root_index ASC;` + + rows, err := tx.Query(query, encode(contractID)) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var root types.Hash256 + if err := rows.Scan(decode(&root)); err != nil { + return err + } + roots = append(roots, root) + } + return rows.Err() }) return } @@ -86,7 +111,7 @@ func TestReviseContract(t *testing.T) { // checkConsistency is a helper function that verifies the expected sector // roots are consistent with the database checkConsistency := func(roots []types.Hash256, expected int) error { - dbRoot, err := db.SectorRoots(contract.Revision.ParentID) + dbRoot, err := db.dbRoots(contract.Revision.ParentID) if err != nil { return fmt.Errorf("failed to get sector roots: %w", err) } else if len(dbRoot) != expected { diff --git a/persist/sqlite/debug.go b/persist/sqlite/debug.go index b4e25f95..152ea8c9 100644 --- a/persist/sqlite/debug.go +++ b/persist/sqlite/debug.go @@ -1,4 +1,4 @@ -//go:build debug +//go:build ignore package sqlite diff --git a/persist/sqlite/encoding.go b/persist/sqlite/encoding.go new file mode 100644 index 00000000..8652679d --- /dev/null +++ b/persist/sqlite/encoding.go @@ -0,0 +1,145 @@ +package sqlite + +import ( + "bytes" + "database/sql" + "encoding/binary" + "errors" + "fmt" + "time" + + rhp3 "go.sia.tech/core/rhp/v3" + "go.sia.tech/core/types" +) + +func encode(obj any) any { + switch obj := obj.(type) { + case types.Currency: + // Currency is encoded as two 64-bit LE integers + // TODO: migrate to big-endian for sorting + buf := make([]byte, 16) + binary.LittleEndian.PutUint64(buf[:8], obj.Lo) + binary.LittleEndian.PutUint64(buf[8:], obj.Hi) + return buf + case rhp3.Account: + // rhp3 accounts are encoded as [32]byte + return obj[:] + case []types.Hash256: + var buf bytes.Buffer + e := types.NewEncoder(&buf) + types.EncodeSlice(e, obj) + e.Flush() + return buf.Bytes() + case types.EncoderTo: + var buf bytes.Buffer + e := types.NewEncoder(&buf) + obj.EncodeTo(e) + e.Flush() + return buf.Bytes() + case int: // special case for encoding contracts + if obj < 0 { + panic(fmt.Sprintf("dbEncode: cannot encode negative int %d", obj)) + } + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(obj)) + return b + case int64: // special case for encoding metrics + if obj < 0 { + panic(fmt.Sprintf("dbEncode: cannot encode negative int64 %d", obj)) + } + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(obj)) + return b + case uint64: + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, obj) + return b + case time.Time: + return obj.Unix() + default: + panic(fmt.Sprintf("dbEncode: unsupported type %T", obj)) + } +} + +type decodable struct { + v any +} + +type nullDecodable[T any] struct { + v *T + valid bool +} + +func (nd *nullDecodable[T]) Scan(src any) error { + nd.valid = false + if src == nil { + return nil + } else if err := decode(nd.v).Scan(src); err != nil { + return err + } + nd.valid = true + return nil +} + +func decodeNullable[T any](v *T) *nullDecodable[T] { + return &nullDecodable[T]{v: v} +} + +// Scan implements the sql.Scanner interface. +func (d *decodable) Scan(src any) error { + if src == nil { + return errors.New("cannot scan nil into decodable") + } + + switch src := src.(type) { + case []byte: + switch v := d.v.(type) { + case *types.Currency: + if len(src) != 16 { + return fmt.Errorf("cannot scan %d bytes into Currency", len(src)) + } + v.Lo = binary.LittleEndian.Uint64(src[:8]) + v.Hi = binary.LittleEndian.Uint64(src[8:]) + case types.DecoderFrom: + dec := types.NewBufDecoder(src) + v.DecodeFrom(dec) + return dec.Err() + case *int64: // special case for decoding metrics + *v = int64(binary.LittleEndian.Uint64(src)) + case *uint64: + *v = binary.LittleEndian.Uint64(src) + case *[]types.Hash256: + dec := types.NewBufDecoder(src) + types.DecodeSlice(dec, v) + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } + return nil + case int64: + switch v := d.v.(type) { + case *uint64: + *v = uint64(src) + case *time.Time: + *v = time.Unix(src, 0).UTC() + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } + return nil + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } +} + +func encodeSlice[T types.EncoderTo](v []T) []byte { + buf := bytes.NewBuffer(nil) + e := types.NewEncoder(buf) + types.EncodeSlice(e, v) + if err := e.Flush(); err != nil { + panic(err) + } + return buf.Bytes() +} + +func decode(obj any) sql.Scanner { + return &decodable{obj} +} diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go index 75af4a92..7616f359 100644 --- a/persist/sqlite/init.go +++ b/persist/sqlite/init.go @@ -19,9 +19,9 @@ import ( var initDatabase string func (s *Store) initNewDatabase(target int64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if _, err := tx.Exec(initDatabase); err != nil { - return fmt.Errorf("failed to initialize database: %w", err) + return err } else if err := setDBVersion(tx, target); err != nil { return fmt.Errorf("failed to set initial database version: %w", err) } else if err = generateHostKey(tx); err != nil { @@ -35,7 +35,7 @@ func (s *Store) upgradeDatabase(current, target int64) error { log := s.log.Named("migrations") log.Info("migrating database", zap.Int64("current", current), zap.Int64("target", target)) - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if _, err := tx.Exec("PRAGMA defer_foreign_keys=ON"); err != nil { return fmt.Errorf("failed to enable foreign key deferral: %w", err) } @@ -46,8 +46,10 @@ func (s *Store) upgradeDatabase(current, target int64) error { return fmt.Errorf("failed to migrate database to version %v: %w", current, err) } // check that no foreign key constraints were violated - if err := tx.QueryRow("PRAGMA foreign_key_check").Scan(); !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("foreign key constraints are not satisfied") + if err := tx.QueryRow("PRAGMA foreign_key_check").Scan(); err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to check foreign key constraints after migration to version %v: %w", current, err) + } else if err == nil { + return fmt.Errorf("foreign key constraint violated after migration to version %v: %w", current, err) } log.Debug("migration complete", zap.Int64("current", current), zap.Int64("target", target), zap.Duration("elapsed", time.Since(start))) } @@ -63,9 +65,13 @@ func (s *Store) init() error { version := getDBVersion(s.db) switch { case version == 0: - return s.initNewDatabase(target) + if err := s.initNewDatabase(target); err != nil { + return fmt.Errorf("failed to initialize database: %w", err) + } case version < target: - return s.upgradeDatabase(version, target) + if err := s.upgradeDatabase(version, target); err != nil { + return fmt.Errorf("failed to upgrade database: %w", err) + } case version > target: return fmt.Errorf("database version %v is newer than expected %v. database downgrades are not supported", version, target) } @@ -73,7 +79,7 @@ func (s *Store) init() error { return nil } -func generateHostKey(tx txn) (err error) { +func generateHostKey(tx *txn) (err error) { key := types.NewPrivateKeyFromSeed(frand.Bytes(32)) var dbID int64 err = tx.QueryRow(`UPDATE global_settings SET host_key=? RETURNING id`, key).Scan(&dbID) diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index f52e52d4..0d59d98d 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -3,27 +3,24 @@ migrations.go */ -CREATE TABLE wallet_utxos ( +CREATE TABLE wallet_siacoin_elements ( id BLOB PRIMARY KEY, - amount BLOB NOT NULL, - unlock_hash BLOB NOT NULL + siacoin_value BLOB NOT NULL, + sia_address BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index BLOB NOT NULL, + maturity_height INTEGER NOT NULL ); -CREATE TABLE wallet_transactions ( - id INTEGER PRIMARY KEY, - transaction_id BLOB NOT NULL, - block_id BLOB NOT NULL, - inflow BLOB NOT NULL, - outflow BLOB NOT NULL, - raw_transaction BLOB NOT NULL, -- binary serialized transaction - source TEXT NOT NULL, - block_height INTEGER NOT NULL, - date_created INTEGER NOT NULL -); -CREATE INDEX wallet_transactions_date_created_index ON wallet_transactions(date_created); -CREATE INDEX wallet_transactions_block_id ON wallet_transactions(block_id); -CREATE INDEX wallet_transactions_date_created ON wallet_transactions(date_created); -CREATE INDEX wallet_transactions_block_height_id ON wallet_transactions(block_height DESC, id); +CREATE TABLE wallet_events ( + id BLOB PRIMARY KEY, + chain_index BLOB NOT NULL, + maturity_height INTEGER NOT NULL, + event_type TEXT NOT NULL, + raw_data BLOB NOT NULL +); +CREATE INDEX wallet_events_chain_index ON wallet_events(chain_index); +CREATE INDEX wallet_events_maturity_height ON wallet_events(maturity_height DESC); CREATE TABLE stored_sectors ( id INTEGER PRIMARY KEY, @@ -126,6 +123,69 @@ CREATE TABLE contract_sector_roots ( CREATE INDEX contract_sector_roots_sector_id ON contract_sector_roots(sector_id); CREATE INDEX contract_sector_roots_contract_id_root_index ON contract_sector_roots(contract_id, root_index); +CREATE TABLE contract_v2_state_elements ( + contract_id INTEGER PRIMARY KEY REFERENCES contracts_v2(id), + leaf_index BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + raw_contract BLOB NOT NULL, -- binary serialized contract + revision_number BLOB NOT NULL -- for comparison +); + +CREATE TABLE contracts_v2_chain_index_elements ( + id BLOB PRIMARY KEY, + height INTEGER NOT NULL, + leaf_index BLOB NOT NULL, + merkle_proof BLOB NOT NULL +); +CREATE INDEX contracts_v2_chain_index_elements_height ON contracts_v2_chain_index_elements(height); + +CREATE TABLE contracts_v2 ( + id INTEGER PRIMARY KEY, + renter_id INTEGER NOT NULL REFERENCES contract_renters(id), + renewed_to INTEGER REFERENCES contracts_v2(id) ON DELETE SET NULL, + renewed_from INTEGER REFERENCES contracts_v2(id) ON DELETE SET NULL, + contract_id BLOB UNIQUE NOT NULL, + revision_number BLOB NOT NULL, -- stored as BLOB to support uint64_max on clearing revisions + formation_txn_set BLOB NOT NULL, -- binary serialized transaction set + formation_txn_set_basis BLOB NOT NULL, + locked_collateral BLOB NOT NULL, + rpc_revenue BLOB NOT NULL, + storage_revenue BLOB NOT NULL, + ingress_revenue BLOB NOT NULL, + egress_revenue BLOB NOT NULL, + account_funding BLOB NOT NULL, + risked_collateral BLOB NOT NULL, + raw_revision BLOB NOT NULL, -- binary serialized contract revision + confirmation_index BLOB, -- null if the contract has not been confirmed on the blockchain, otherwise the chain index of the block containing the confirmation transaction + resolution_index BLOB, -- null if the storage proof/resolution has not been confirmed on the blockchain, otherwise the chain index of the block containing the resolution transaction + negotiation_height INTEGER NOT NULL, -- determines if the formation txn should be rebroadcast or if the contract should be deleted + window_start INTEGER NOT NULL, + window_end INTEGER NOT NULL, + contract_status TEXT NOT NULL +); +CREATE INDEX contracts_v2_contract_id ON contracts_v2(contract_id); +CREATE INDEX contracts_v2_renter_id ON contracts_v2(renter_id); +CREATE INDEX contracts_v2_renewed_to ON contracts_v2(renewed_to); +CREATE INDEX contracts_v2_renewed_from ON contracts_v2(renewed_from); +CREATE INDEX contracts_v2_negotiation_height ON contracts_v2(negotiation_height); +CREATE INDEX contracts_v2_window_start ON contracts_v2(window_start); +CREATE INDEX contracts_v2_window_end ON contracts_v2(window_end); +CREATE INDEX contracts_v2_contract_status ON contracts_v2(contract_status); +CREATE INDEX contracts_v2_confirmation_index_resolution_index_window_start ON contracts_v2(confirmation_index, resolution_index, window_start); +CREATE INDEX contracts_v2_confirmation_index_resolution_index_window_end ON contracts_v2(confirmation_index, resolution_index, window_end); +CREATE INDEX contracts_v2_confirmation_index_window_start ON contracts_v2(confirmation_index, window_start); +CREATE INDEX contracts_v2_confirmation_index_negotiation_height ON contracts_v2(confirmation_index, negotiation_height); + +CREATE TABLE contract_v2_sector_roots ( + id INTEGER PRIMARY KEY, + contract_id INTEGER NOT NULl REFERENCES contracts_v2(id), + sector_id INTEGER NOT NULL REFERENCES stored_sectors(id), + root_index INTEGER NOT NULL, + UNIQUE(contract_id, root_index) +); +CREATE INDEX contract_v2_sector_roots_sector_id ON contract_v2_sector_roots(sector_id); +CREATE INDEX contract_v2_sector_roots_contract_id_root_index ON contract_v2_sector_roots(contract_id, root_index); + CREATE TABLE temp_storage_sector_roots ( id INTEGER PRIMARY KEY, sector_id INTEGER NOT NULL REFERENCES stored_sectors(id), @@ -217,20 +277,25 @@ CREATE TABLE webhooks ( secret_key TEXT UNIQUE NOT NULL ); +CREATE TABLE syncer_peers ( + peer_address TEXT PRIMARY KEY NOT NULL, + first_seen INTEGER NOT NULL +); + +CREATE TABLE syncer_bans ( + net_cidr TEXT PRIMARY KEY NOT NULL, + expiration INTEGER NOT NULL, + reason TEXT NOT NULL +); +CREATE INDEX syncer_bans_expiration_index_idx ON syncer_bans (expiration); + CREATE TABLE global_settings ( id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row db_version INTEGER NOT NULL, -- used for migrations host_key BLOB, - last_announce_key BLOB, -- public key of the last host announcement wallet_hash BLOB, -- used to prevent wallet seed changes - wallet_last_processed_change BLOB, -- last processed consensus change for the wallet - contracts_last_processed_change BLOB, -- last processed consensus change for the contract manager - settings_last_processed_change BLOG, -- last processed consensus change for the settings manager - last_announce_id BLOB, -- chain index of the last host announcement - last_announce_height INTEGER, -- height of the last host announcement - wallet_height INTEGER, -- height of the wallet as of the last processed change - contracts_height INTEGER, -- height of the contract manager as of the last processed change - settings_height INTEGER, -- height of the settings manager as of the last processed change + last_scanned_index BLOB, -- chain index of the last scanned block + last_announce_index BLOB, -- chain index of the last host announcement last_announce_address TEXT -- address of the last host announcement ); diff --git a/persist/sqlite/metrics.go b/persist/sqlite/metrics.go index bfb16f4c..b4404d3f 100644 --- a/persist/sqlite/metrics.go +++ b/persist/sqlite/metrics.go @@ -2,6 +2,7 @@ package sqlite import ( "database/sql" + "encoding/binary" "errors" "fmt" "math" @@ -13,13 +14,17 @@ import ( const ( // contracts - metricPendingContracts = "pendingContracts" metricActiveContracts = "activeContracts" metricRejectedContracts = "rejectedContracts" metricSuccessfulContracts = "successfulContracts" metricFailedContracts = "failedContracts" - metricLockedCollateral = "lockedCollateral" - metricRiskedCollateral = "riskedCollateral" + + // v2 + metricFinalizedContracts = "finalizedContracts" + metricRenewedContracts = "renewedContracts" + + metricLockedCollateral = "lockedCollateral" + metricRiskedCollateral = "riskedCollateral" // accounts metricActiveAccounts = "activeAccounts" @@ -69,7 +74,8 @@ const ( metricCollateralMultiplier = "collateralMultiplier" // wallet - metricWalletBalance = "walletBalance" + metricWalletBalance = "walletBalance" + metricWalletImmatureBalance = "walletImmatureBalance" // potential revenue metricPotentialRPCRevenue = "potentialRPCRevenue" @@ -123,7 +129,7 @@ func (s *Store) PeriodMetrics(start time.Time, n int, interval metrics.Interval) } const query = `SELECT stat, stat_value, date_created FROM host_stats WHERE date_created BETWEEN $1 AND $2 ORDER BY date_created ASC` - rows, err := s.db.Query(query, sqlTime(start), sqlTime(end)) + rows, err := s.db.Query(query, encode(start), encode(end)) if err != nil { return nil, fmt.Errorf("failed to query metrics: %w", err) } @@ -138,7 +144,7 @@ func (s *Store) PeriodMetrics(start time.Time, n int, interval metrics.Interval) var value []byte var timestamp time.Time - if err := rows.Scan(&stat, &value, (*sqlTime)(×tamp)); err != nil { + if err := rows.Scan(&stat, &value, decode(×tamp)); err != nil { return nil, fmt.Errorf("failed to scan row: %w", err) } @@ -212,7 +218,8 @@ func (s *Store) PeriodMetrics(start time.Time, n int, interval metrics.Interval) // Metrics returns aggregate metrics for the host as of the timestamp. func (s *Store) Metrics(timestamp time.Time) (m metrics.Metrics, err error) { - const query = `SELECT s.stat, s.stat_value + err = s.transaction(func(tx *txn) error { + const query = `SELECT s.stat, s.stat_value FROM host_stats s JOIN ( SELECT stat, MAX(date_created) AS most_recent @@ -220,28 +227,31 @@ JOIN ( WHERE date_created <= $1 GROUP BY stat ) AS sub ON s.stat = sub.stat AND s.date_created = sub.most_recent;` - rows, err := s.query(query, sqlTime(timestamp)) - if err != nil { - return metrics.Metrics{}, fmt.Errorf("failed to query metrics: %w", err) - } - defer rows.Close() + rows, err := tx.Query(query, encode(timestamp)) + if err != nil { + return fmt.Errorf("failed to query metrics: %w", err) + } + defer rows.Close() - for rows.Next() { - var stat string - var value []byte + for rows.Next() { + var stat string + var value []byte - if err := rows.Scan(&stat, &value); err != nil { - return metrics.Metrics{}, fmt.Errorf("failed to scan row: %w", err) + if err := rows.Scan(&stat, &value); err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + mustParseMetricValue(stat, value, &m) } - mustParseMetricValue(stat, value, &m) - } - m.Timestamp = timestamp + m.Timestamp = timestamp + return nil + }) + return } // IncrementRHPDataUsage increments the RHP3 ingress and egress metrics. func (s *Store) IncrementRHPDataUsage(ingress, egress uint64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if ingress > 0 { if err := incrementNumericStat(tx, metricDataRHPIngress, int(ingress), time.Now()); err != nil { return fmt.Errorf("failed to track ingress: %w", err) @@ -258,7 +268,7 @@ func (s *Store) IncrementRHPDataUsage(ingress, egress uint64) error { // IncrementSectorStats increments the sector read, write and cache metrics. func (s *Store) IncrementSectorStats(reads, writes, cacheHit, cacheMiss uint64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if reads > 0 { if err := incrementNumericStat(tx, metricSectorReads, int(reads), time.Now()); err != nil { return fmt.Errorf("failed to track reads: %w", err) @@ -287,7 +297,7 @@ func (s *Store) IncrementSectorStats(reads, writes, cacheHit, cacheMiss uint64) // IncrementRegistryAccess increments the registry read and write metrics. func (s *Store) IncrementRegistryAccess(read, write uint64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if read > 0 { if err := incrementNumericStat(tx, metricRegistryReads, int(read), time.Now()); err != nil { return fmt.Errorf("failed to track reads: %w", err) @@ -302,20 +312,17 @@ func (s *Store) IncrementRegistryAccess(read, write uint64) error { }) } -func mustScanCurrency(b []byte) types.Currency { - var c sqlCurrency - if err := c.Scan(b); err != nil { - panic(err) +func mustScanCurrency(src []byte) (c types.Currency) { + if len(src) != 16 { + panic(fmt.Sprintf("cannot scan %d bytes into Currency", len(src))) } - return types.Currency(c) + c.Lo = binary.LittleEndian.Uint64(src[:8]) + c.Hi = binary.LittleEndian.Uint64(src[8:]) + return } func mustScanUint64(b []byte) uint64 { - var u sqlUint64 - if err := u.Scan(b); err != nil { - panic(err) - } - return uint64(u) + return binary.LittleEndian.Uint64(b) } // mustParseMetricValue parses the value of a metric from the database. @@ -339,8 +346,6 @@ func mustParseMetricValue(stat string, buf []byte, m *metrics.Metrics) { value := mustScanUint64(buf) m.Pricing.CollateralMultiplier = math.Float64frombits(value) // contracts - case metricPendingContracts: - m.Contracts.Pending = mustScanUint64(buf) case metricActiveContracts: m.Contracts.Active = mustScanUint64(buf) case metricRejectedContracts: @@ -349,6 +354,10 @@ func mustParseMetricValue(stat string, buf []byte, m *metrics.Metrics) { m.Contracts.Successful = mustScanUint64(buf) case metricFailedContracts: m.Contracts.Failed = mustScanUint64(buf) + case metricFinalizedContracts: + m.Contracts.Finalized = mustScanUint64(buf) + case metricRenewedContracts: + m.Contracts.Renewed = mustScanUint64(buf) case metricLockedCollateral: m.Contracts.LockedCollateral = mustScanCurrency(buf) case metricRiskedCollateral: @@ -419,21 +428,104 @@ func mustParseMetricValue(stat string, buf []byte, m *metrics.Metrics) { m.Revenue.Earned.RegistryWrite = mustScanCurrency(buf) // wallet case metricWalletBalance: - m.Balance = mustScanCurrency(buf) + m.Wallet.Balance = mustScanCurrency(buf) + case metricWalletImmatureBalance: + m.Wallet.ImmatureBalance = mustScanCurrency(buf) default: panic(fmt.Sprintf("unknown metric: %v", stat)) } } +// incrementNumericStatStmt tracks a numeric stat, incrementing the current value by +// delta. If the resulting value is negative, the function panics. This function +// should be used when lots of stats need to be batched together. +func incrementNumericStatStmt(tx *txn) (func(stat string, delta int64, timestamp time.Time) error, func() error, error) { + getStatStmt, err := tx.Prepare(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare get stat statement: %w", err) + } + + insertStatStmt, err := tx.Prepare(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare insert stat statement: %w", err) + } + + return func(stat string, delta int64, timestamp time.Time) error { + if delta == 0 { + return nil + } + timestamp = timestamp.Truncate(statInterval) + var current int64 + if err := getStatStmt.QueryRow(stat, encode(timestamp)).Scan(decode(¤t)); err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to query existing value: %w", err) + } + + if current+delta < 0 { + panic(fmt.Errorf("negative stat value: %v %v%v", stat, current, delta)) + } + value := current + delta + _, err = insertStatStmt.Exec(stat, encode(value), encode(timestamp)) + return err + }, func() error { + getStatStmt.Close() + insertStatStmt.Close() + return nil + }, nil +} + +// incrementCurrencyStatStmt increments a currency stat. If negative is false, the current +// value is incremented by delta. Otherwise, the value is decremented. If the +// resulting value would be negative, the function panics. This function should +// be used when lots of stats need to be batched together. +func incrementCurrencyStatStmt(tx *txn) (func(stat string, delta types.Currency, negative bool, timestamp time.Time) error, func() error, error) { + getStatStmt, err := tx.Prepare(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare get stat statement: %w", err) + } + + insertStatStmt, err := tx.Prepare(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare insert stat statement: %w", err) + } + + return func(stat string, delta types.Currency, negative bool, timestamp time.Time) error { + if delta.IsZero() { + return nil + } + timestamp = timestamp.Truncate(statInterval) + var current types.Currency + if err := getStatStmt.QueryRow(stat, encode(timestamp)).Scan(decode(¤t)); err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to query existing value: %w", err) + } + + var value types.Currency + if negative { + if current.Cmp(delta) < 0 { + panic(fmt.Errorf("negative stat value: %v %v-%v", stat, current, delta)) + } + value = current.Sub(delta) + } else { + value = current.Add(delta) + } + + _, err = insertStatStmt.Exec(stat, encode(value), encode(timestamp)) + return err + }, func() error { + getStatStmt.Close() + insertStatStmt.Close() + return nil + }, nil +} + // incrementNumericStat tracks a numeric stat, incrementing the current value by // delta. If the resulting value is negative, the function panics. -func incrementNumericStat(tx txn, stat string, delta int, timestamp time.Time) error { +func incrementNumericStat(tx *txn, stat string, delta int, timestamp time.Time) error { if delta == 0 { return nil } timestamp = timestamp.Truncate(statInterval) var current uint64 - err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, sqlTime(timestamp)).Scan((*sqlUint64)(¤t)) + err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, encode(timestamp)).Scan(decode(¤t)) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to query existing value: %w", err) } @@ -446,7 +538,7 @@ func incrementNumericStat(tx txn, stat string, delta int, timestamp time.Time) e } else { value = current + uint64(delta) } - _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, sqlUint64(value), sqlTime(timestamp)) + _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, encode(value), encode(timestamp)) if err != nil { return fmt.Errorf("failed to insert stat: %w", err) } @@ -456,13 +548,13 @@ func incrementNumericStat(tx txn, stat string, delta int, timestamp time.Time) e // incrementCurrencyStat tracks a currency stat. If negative is false, the current // value is incremented by delta. Otherwise, the value is decremented. If the // resulting value would be negative, the function panics. -func incrementCurrencyStat(tx txn, stat string, delta types.Currency, negative bool, timestamp time.Time) error { +func incrementCurrencyStat(tx *txn, stat string, delta types.Currency, negative bool, timestamp time.Time) error { if delta.IsZero() { return nil } timestamp = timestamp.Truncate(statInterval) var current types.Currency - err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, sqlTime(timestamp)).Scan((*sqlCurrency)(¤t)) + err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, encode(timestamp)).Scan(decode(¤t)) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to query existing value: %w", err) } @@ -472,101 +564,58 @@ func incrementCurrencyStat(tx txn, stat string, delta types.Currency, negative b } else { value = value.Add(delta) } - _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, sqlCurrency(value), sqlTime(timestamp)) + _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, encode(value), encode(timestamp)) if err != nil { return fmt.Errorf("failed to insert stat: %w", err) } return nil } -func setCurrencyStat(tx txn, stat string, value types.Currency, timestamp time.Time) error { +func setCurrencyStat(tx *txn, stat string, value types.Currency, timestamp time.Time) error { timestamp = timestamp.Truncate(statInterval) var current types.Currency - err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, sqlTime(timestamp)).Scan((*sqlCurrency)(¤t)) + err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, encode(timestamp)).Scan(decode(¤t)) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to query existing value: %w", err) } else if value.Equals(current) { return nil } - _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, sqlCurrency(value), sqlTime(timestamp)) + _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, encode(value), encode(timestamp)) if err != nil { return fmt.Errorf("failed to insert stat: %w", err) } return nil } -func setNumericStat(tx txn, stat string, value uint64, timestamp time.Time) error { +func setNumericStat(tx *txn, stat string, value uint64, timestamp time.Time) error { timestamp = timestamp.Truncate(statInterval) var current uint64 - err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, sqlTime(timestamp)).Scan((*sqlUint64)(¤t)) + err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, encode(timestamp)).Scan(decode(¤t)) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to query existing value: %w", err) } else if value == current { return nil } - _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, sqlUint64(value), sqlTime(timestamp)) + _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, encode(value), encode(timestamp)) if err != nil { return fmt.Errorf("failed to insert stat: %w", err) } return nil } -func setFloat64Stat(tx txn, stat string, f float64, timestamp time.Time) error { +func setFloat64Stat(tx *txn, stat string, f float64, timestamp time.Time) error { timestamp = timestamp.Truncate(statInterval) value := math.Float64bits(f) var current uint64 - err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, sqlTime(timestamp)).Scan((*sqlUint64)(¤t)) + err := tx.QueryRow(`SELECT stat_value FROM host_stats WHERE stat=$1 AND date_created<=$2 ORDER BY date_created DESC LIMIT 1`, stat, encode(timestamp)).Scan(decode(¤t)) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to query existing value: %w", err) } else if value == current { return nil } - _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, sqlUint64(value), sqlTime(timestamp)) + _, err = tx.Exec(`INSERT INTO host_stats (stat, stat_value, date_created) VALUES ($1, $2, $3) ON CONFLICT (stat, date_created) DO UPDATE SET stat_value=EXCLUDED.stat_value`, stat, encode(value), encode(timestamp)) if err != nil { return fmt.Errorf("failed to insert stat: %w", err) } return nil } - -// reflowCurrencyStat updates all currency stats after the given timestamp. If -// negative is false, the current value is incremented by delta. Otherwise, the -// value is decremented. If the resulting value would be negative, the function -// panics. -func reflowCurrencyStat(tx txn, stat string, startTime time.Time, value types.Currency, negative bool) error { - startTime = startTime.Truncate(statInterval) - rows, err := tx.Query(`SELECT stat_value, date_created FROM host_stats WHERE stat=$1 AND date_created > $2 ORDER BY date_created ASC`, stat, sqlTime(startTime)) - if err != nil { - return fmt.Errorf("failed to query existing value: %w", err) - } - defer rows.Close() - var values []types.Currency - var timestamps []time.Time - for rows.Next() { - var v types.Currency - var timestamp time.Time - if err := rows.Scan((*sqlCurrency)(&v), (*sqlTime)(×tamp)); err != nil { - return fmt.Errorf("failed to scan row: %w", err) - } - if negative { - v = v.Sub(value) - } else { - v = v.Add(value) - } - values = append(values, v) - timestamps = append(timestamps, timestamp) - } - - stmt, err := tx.Prepare(`UPDATE host_stats SET stat_value=$1 WHERE stat=$2 AND date_created=$3 RETURNING date_created`) - if err != nil { - return fmt.Errorf("failed to prepare update statement: %w", err) - } - defer stmt.Close() - for i := range values { - var dbTime time.Time - err := stmt.QueryRow(sqlCurrency(values[i]), stat, sqlTime(timestamps[i])).Scan((*sqlTime)(&dbTime)) - if err != nil { - return fmt.Errorf("failed to update stat: %w", err) - } - } - return nil -} diff --git a/persist/sqlite/migrations.go b/persist/sqlite/migrations.go index 69149e9e..31cac350 100644 --- a/persist/sqlite/migrations.go +++ b/persist/sqlite/migrations.go @@ -10,9 +10,158 @@ import ( "go.uber.org/zap" ) +// migrateVersion29 resets the chain state to trigger a full rescan of the +// wallet to calculate the new immature balance metric. +func migrateVersion29(tx *txn, _ *zap.Logger) error { + _, err := tx.Exec(` +-- v2 contracts +DELETE FROM contracts_v2_chain_index_elements; +DELETE FROM contract_v2_state_elements; +-- wallet +DELETE FROM wallet_siacoin_elements; +DELETE FROM wallet_events; +DELETE FROM host_stats WHERE stat IN (?,?); -- reset wallet stats since they are derived from the chain +-- settings +UPDATE global_settings SET last_scanned_index=NULL, last_announce_index=NULL, last_announce_address=NULL`, metricWalletBalance, metricWalletImmatureBalance) + return err +} + +// migrateVersion28 prepares the database for version 2 +func migrateVersion28(tx *txn, log *zap.Logger) error { + _, err := tx.Exec(` +-- Drop v1 tables +DROP TABLE wallet_utxos; +DROP TABLE wallet_transactions; + +-- Create v2 tables +CREATE TABLE wallet_siacoin_elements ( + id BLOB PRIMARY KEY, + siacoin_value BLOB NOT NULL, + sia_address BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index BLOB NOT NULL, + maturity_height INTEGER NOT NULL +); + +CREATE TABLE wallet_events ( + id BLOB PRIMARY KEY, + chain_index BLOB NOT NULL, + maturity_height INTEGER NOT NULL, + event_type TEXT NOT NULL, + raw_data BLOB NOT NULL +); +CREATE INDEX wallet_events_chain_index ON wallet_events(chain_index); +CREATE INDEX wallet_events_maturity_height ON wallet_events(maturity_height DESC); + +CREATE TABLE contract_v2_state_elements ( + contract_id INTEGER PRIMARY KEY REFERENCES contracts_v2(id), + leaf_index BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + raw_contract BLOB NOT NULL, -- binary serialized contract + revision_number BLOB NOT NULL -- for comparison +); + +CREATE TABLE contracts_v2_chain_index_elements ( + id BLOB PRIMARY KEY, + height INTEGER NOT NULL, + leaf_index BLOB NOT NULL, + merkle_proof BLOB NOT NULL +); + +CREATE TABLE contracts_v2 ( + id INTEGER PRIMARY KEY, + renter_id INTEGER NOT NULL REFERENCES contract_renters(id), + renewed_to INTEGER REFERENCES contracts_v2(id) ON DELETE SET NULL, + renewed_from INTEGER REFERENCES contracts_v2(id) ON DELETE SET NULL, + contract_id BLOB UNIQUE NOT NULL, + revision_number BLOB NOT NULL, -- stored as BLOB to support uint64_max on clearing revisions + formation_txn_set BLOB NOT NULL, -- binary serialized transaction set + formation_txn_set_basis BLOB NOT NULL, + locked_collateral BLOB NOT NULL, + rpc_revenue BLOB NOT NULL, + storage_revenue BLOB NOT NULL, + ingress_revenue BLOB NOT NULL, + egress_revenue BLOB NOT NULL, + account_funding BLOB NOT NULL, + registry_read BLOB NOT NULL, + registry_write BLOB NOT NULL, + risked_collateral BLOB NOT NULL, + raw_revision BLOB NOT NULL, -- binary serialized contract revision + confirmation_index BLOB, -- null if the contract has not been confirmed on the blockchain, otherwise the chain index of the block containing the confirmation transaction + resolution_index BLOB, -- null if the storage proof/resolution has not been confirmed on the blockchain, otherwise the chain index of the block containing the resolution transaction + negotiation_height INTEGER NOT NULL, -- determines if the formation txn should be rebroadcast or if the contract should be deleted + window_start INTEGER NOT NULL, + window_end INTEGER NOT NULL, + contract_status TEXT NOT NULL +); +CREATE INDEX contracts_v2_contract_id ON contracts_v2(contract_id); +CREATE INDEX contracts_v2_renter_id ON contracts_v2(renter_id); +CREATE INDEX contracts_v2_renewed_to ON contracts_v2(renewed_to); +CREATE INDEX contracts_v2_renewed_from ON contracts_v2(renewed_from); +CREATE INDEX contracts_v2_negotiation_height ON contracts_v2(negotiation_height); +CREATE INDEX contracts_v2_window_start ON contracts_v2(window_start); +CREATE INDEX contracts_v2_window_end ON contracts_v2(window_end); +CREATE INDEX contracts_v2_contract_status ON contracts_v2(contract_status); +CREATE INDEX contracts_v2_confirmation_index_resolution_index_window_start ON contracts_v2(confirmation_index, resolution_index, window_start); +CREATE INDEX contracts_v2_confirmation_index_resolution_index_window_end ON contracts_v2(confirmation_index, resolution_index, window_end); +CREATE INDEX contracts_v2_confirmation_index_window_start ON contracts_v2(confirmation_index, window_start); +CREATE INDEX contracts_v2_confirmation_index_negotiation_height ON contracts_v2(confirmation_index, negotiation_height); + +CREATE TABLE contract_v2_sector_roots ( + id INTEGER PRIMARY KEY, + contract_id INTEGER NOT NULl REFERENCES contracts_v2(id), + sector_id INTEGER NOT NULL REFERENCES stored_sectors(id), + root_index INTEGER NOT NULL, + UNIQUE(contract_id, root_index) +); +CREATE INDEX contract_v2_sector_roots_sector_id ON contract_v2_sector_roots(sector_id); +CREATE INDEX contract_v2_sector_roots_contract_id_root_index ON contract_v2_sector_roots(contract_id, root_index); + +CREATE TABLE syncer_peers ( + peer_address TEXT PRIMARY KEY NOT NULL, + first_seen INTEGER NOT NULL +); + +CREATE TABLE syncer_bans ( + net_cidr TEXT PRIMARY KEY NOT NULL, + expiration INTEGER NOT NULL, + reason TEXT NOT NULL +); +CREATE INDEX syncer_bans_expiration_index_idx ON syncer_bans (expiration); + +-- Create the new settings table. Will trigger a rescan and reannounce + +CREATE TABLE global_settings_v2 ( + id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row + db_version INTEGER NOT NULL, -- used for migrations + host_key BLOB, + wallet_hash BLOB, -- used to prevent wallet seed changes + last_scanned_index BLOB, -- chain index of the last scanned block + last_announce_index BLOB, -- chain index of the last host announcement + last_announce_address TEXT -- address of the last host announcement +); + +-- Migrate the existing settings table +INSERT INTO global_settings_v2 (id, db_version, host_key, wallet_hash) SELECT 0, db_version, host_key, wallet_hash FROM global_settings; + +DROP TABLE global_settings; + +ALTER TABLE global_settings_v2 RENAME TO global_settings; + +-- drop pending contract metrics +DELETE FROM host_stats WHERE stat='pendingContracts';`) + if err != nil { + return fmt.Errorf("failed to migrate to version 28: %w", err) + } else if err := recalcContractMetrics(tx, log); err != nil { + // recalculate contract revenue to remove pending from the metrics + return fmt.Errorf("failed to recalculate contract metrics: %w", err) + } + return nil +} + // migrateVersion27 adds the sector_writes column to the volume_sectors table to // more evenly distribute sector writes across disks. -func migrateVersion27(tx txn, _ *zap.Logger) error { +func migrateVersion27(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`ALTER TABLE volume_sectors ADD COLUMN sector_writes INTEGER NOT NULL DEFAULT 0; DROP INDEX volume_sectors_volume_id_sector_id_volume_index_compound; DROP INDEX volume_sectors_volume_id_sector_id_volume_index_set_compound; @@ -21,7 +170,7 @@ CREATE INDEX volume_sectors_sector_writes_volume_id_sector_id_volume_index_compo } // migrateVersion26 creates the host_pinned_settings table. -func migrateVersion26(tx txn, _ *zap.Logger) error { +func migrateVersion26(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`CREATE TABLE host_pinned_settings ( id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row currency TEXT NOT NULL, @@ -39,7 +188,7 @@ func migrateVersion26(tx txn, _ *zap.Logger) error { } // migrateVersion25 recalculates the contract and physical sectors metrics -func migrateVersion25(tx txn, log *zap.Logger) error { +func migrateVersion25(tx *txn, log *zap.Logger) error { // recalculate the contract sectors metric var contractSectorCount int64 if err := tx.QueryRow(`SELECT COUNT(*) FROM contract_sector_roots`).Scan(&contractSectorCount); err != nil { @@ -82,7 +231,7 @@ func migrateVersion25(tx txn, log *zap.Logger) error { } // migrateVersion24 combines the rhp2 and rhp3 data metrics -func migrateVersion24(tx txn, log *zap.Logger) error { +func migrateVersion24(tx *txn, log *zap.Logger) error { rows, err := tx.Query(`SELECT date_created, stat, stat_value FROM host_stats WHERE stat IN (?, ?, ?, ?) ORDER BY date_created ASC`, metricRHP2Ingress, metricRHP2Egress, metricRHP3Ingress, metricRHP3Egress) if err != nil { return fmt.Errorf("failed to query host stats: %w", err) @@ -103,7 +252,7 @@ func migrateVersion24(tx txn, log *zap.Logger) error { var timestamp time.Time var stat string var value uint64 - if err := rows.Scan((*sqlTime)(×tamp), &stat, (*sqlUint64)(&value)); err != nil { + if err := rows.Scan(decode(×tamp), &stat, decode(&value)); err != nil { return fmt.Errorf("failed to scan host stat: %w", err) } @@ -165,7 +314,7 @@ func migrateVersion24(tx txn, log *zap.Logger) error { } // migrateVersion23 creates the webhooks table. -func migrateVersion23(tx txn, _ *zap.Logger) error { +func migrateVersion23(tx *txn, _ *zap.Logger) error { const query = `CREATE TABLE webhooks ( id INTEGER PRIMARY KEY, callback_url TEXT UNIQUE NOT NULL, @@ -180,7 +329,7 @@ func migrateVersion23(tx txn, _ *zap.Logger) error { // migrateVersion22 recalculates the locked and risked collateral and the // potential and earned revenue metrics which will be bugged if the host rescans // the blockchain. -func migrateVersion22(tx txn, log *zap.Logger) error { +func migrateVersion22(tx *txn, log *zap.Logger) error { rows, err := tx.Query(`SELECT contract_status, locked_collateral, risked_collateral, rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, registry_read, registry_write FROM contracts WHERE contract_status IN (?, ?, ?);`, contracts.ContractStatusPending, contracts.ContractStatusActive, contracts.ContractStatusSuccessful) if err != nil { return fmt.Errorf("failed to query contracts: %w", err) @@ -194,7 +343,7 @@ func migrateVersion22(tx txn, log *zap.Logger) error { var lockedCollateral types.Currency var usage contracts.Usage - if err := rows.Scan(&status, (*sqlCurrency)(&lockedCollateral), (*sqlCurrency)(&usage.RiskedCollateral), (*sqlCurrency)(&usage.RPCRevenue), (*sqlCurrency)(&usage.StorageRevenue), (*sqlCurrency)(&usage.IngressRevenue), (*sqlCurrency)(&usage.EgressRevenue), (*sqlCurrency)(&usage.AccountFunding), (*sqlCurrency)(&usage.RegistryRead), (*sqlCurrency)(&usage.RegistryWrite)); err != nil { + if err := rows.Scan(&status, decode(&lockedCollateral), decode(&usage.RiskedCollateral), decode(&usage.RPCRevenue), decode(&usage.StorageRevenue), decode(&usage.IngressRevenue), decode(&usage.EgressRevenue), decode(&usage.AccountFunding), decode(&usage.RegistryRead), decode(&usage.RegistryWrite)); err != nil { return fmt.Errorf("failed to scan contract: %w", err) } @@ -241,7 +390,7 @@ func migrateVersion22(tx txn, log *zap.Logger) error { return nil } -func migrateVersion21(tx txn, _ *zap.Logger) error { +func migrateVersion21(tx *txn, _ *zap.Logger) error { const query = ` ALTER TABLE global_settings ADD COLUMN last_announce_key BLOB; ALTER TABLE global_settings ADD COLUMN settings_last_processed_change BLOB; @@ -255,13 +404,13 @@ ALTER TABLE global_settings ADD COLUMN last_announce_address TEXT; } // migrateVersion20 adds a compound index to the volume_sectors table -func migrateVersion20(tx txn, _ *zap.Logger) error { +func migrateVersion20(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`CREATE INDEX volume_sectors_volume_id_sector_id_volume_index_set_compound ON volume_sectors (volume_id, sector_id, volume_index) WHERE sector_id IS NOT NULL;`) return err } // migrateVersion19 adds a compound index to the volume_sectors table -func migrateVersion19(tx txn, _ *zap.Logger) error { +func migrateVersion19(tx *txn, _ *zap.Logger) error { const query = ` DROP INDEX storage_volumes_read_only_available; CREATE INDEX storage_volumes_id_available_read_only ON storage_volumes(id, available, read_only); @@ -273,7 +422,7 @@ CREATE INDEX volume_sectors_volume_id_sector_id_volume_index_compound ON volume_ // migrateVersion18 adds an index to the volume_sectors table to speed up // empty sector selection. -func migrateVersion18(tx txn, _ *zap.Logger) error { +func migrateVersion18(tx *txn, _ *zap.Logger) error { const query = `CREATE INDEX volume_sectors_volume_id_sector_id ON volume_sectors(volume_id, sector_id);` _, err := tx.Exec(query) return err @@ -282,7 +431,7 @@ func migrateVersion18(tx txn, _ *zap.Logger) error { // migrateVersion17 recalculates the indices of all contract sector roots. // Fixes a bug where the indices were not being properly updated if more than // one root was trimmed. -func migrateVersion17(tx txn, _ *zap.Logger) error { +func migrateVersion17(tx *txn, _ *zap.Logger) error { const query = ` -- create a temp table that contains the new indices CREATE TEMP TABLE temp_contract_sector_roots AS @@ -300,7 +449,7 @@ DROP TABLE temp_contract_sector_roots;` } // migrateVersion16 recalculates the contract and physical sector metrics. -func migrateVersion16(tx txn, _ *zap.Logger) error { +func migrateVersion16(tx *txn, _ *zap.Logger) error { // recalculate the contract sectors metric var contractSectorCount int64 if err := tx.QueryRow(`SELECT COUNT(*) FROM contract_sector_roots`).Scan(&contractSectorCount); err != nil { @@ -345,7 +494,7 @@ func migrateVersion16(tx txn, _ *zap.Logger) error { // migrateVersion15 adds the registry usage fields to the contracts table, // removes the usage fields from the accounts table, and refactors the // contract_account_funding table. -func migrateVersion15(tx txn, _ *zap.Logger) error { +func migrateVersion15(tx *txn, _ *zap.Logger) error { const query = ` -- drop the tables that are being removed or refactored DROP TABLE account_financial_records; @@ -413,7 +562,7 @@ CREATE INDEX contracts_formation_confirmed_window_start ON contracts(formation_c CREATE INDEX contracts_formation_confirmed_negotiation_height ON contracts(formation_confirmed, negotiation_height);` // one query parameter to reset the contract's tracked revenue to zero - if _, err := tx.Exec(query, sqlCurrency(types.ZeroCurrency)); err != nil { + if _, err := tx.Exec(query, encode(types.ZeroCurrency)); err != nil { return fmt.Errorf("failed to migrate contracts table: %w", err) } @@ -434,7 +583,7 @@ CREATE INDEX contracts_formation_confirmed_negotiation_height ON contracts(forma for rows.Next() { var balance types.Currency - if err := rows.Scan((*sqlCurrency)(&balance)); err != nil { + if err := rows.Scan(decode(&balance)); err != nil { return fmt.Errorf("failed to scan account balance: %w", err) } accountBalance = accountBalance.Add(balance) @@ -449,7 +598,7 @@ CREATE INDEX contracts_formation_confirmed_negotiation_height ON contracts(forma // migrateVersion14 adds the locked_sectors table, recalculates the contract // sectors metric, and recalculates the physical sectors metric. -func migrateVersion14(tx txn, _ *zap.Logger) error { +func migrateVersion14(tx *txn, _ *zap.Logger) error { // create the new locked sectors table const lockedSectorsTableQuery = `CREATE TABLE locked_sectors ( -- should be cleared at startup. currently persisted for simplicity, but may be moved to memory id INTEGER PRIMARY KEY, @@ -480,19 +629,19 @@ CREATE INDEX locked_sectors_sector_id ON locked_sectors(sector_id);` } // migrateVersion13 adds an index to the storage table to speed up location selection -func migrateVersion13(tx txn, _ *zap.Logger) error { +func migrateVersion13(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`CREATE INDEX storage_volumes_read_only_available_used_sectors ON storage_volumes(available, read_only, used_sectors);`) return err } // migrateVersion12 adds an index to the contracts table to speed up sector pruning -func migrateVersion12(tx txn, _ *zap.Logger) error { +func migrateVersion12(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`CREATE INDEX contracts_window_end ON contracts(window_end);`) return err } // migrateVersion11 recalculates the contract collateral metrics for existing contracts. -func migrateVersion11(tx txn, _ *zap.Logger) error { +func migrateVersion11(tx *txn, _ *zap.Logger) error { rows, err := tx.Query(`SELECT locked_collateral, risked_collateral FROM contracts WHERE contract_status IN (?, ?)`, contracts.ContractStatusPending, contracts.ContractStatusActive) if err != nil { return fmt.Errorf("failed to query contracts: %w", err) @@ -501,7 +650,7 @@ func migrateVersion11(tx txn, _ *zap.Logger) error { var totalLocked, totalRisked types.Currency for rows.Next() { var locked, risked types.Currency - if err := rows.Scan((*sqlCurrency)(&locked), (*sqlCurrency)(&risked)); err != nil { + if err := rows.Scan(decode(&locked), decode(&risked)); err != nil { return fmt.Errorf("failed to scan contract: %w", err) } totalLocked = totalLocked.Add(locked) @@ -517,13 +666,13 @@ func migrateVersion11(tx txn, _ *zap.Logger) error { } // migrateVersion10 drops the log_lines table. -func migrateVersion10(tx txn, _ *zap.Logger) error { +func migrateVersion10(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`DROP TABLE log_lines;`) return err } // migrateVersion9 recalculates the contract metrics for existing contracts. -func migrateVersion9(tx txn, _ *zap.Logger) error { +func migrateVersion9(tx *txn, _ *zap.Logger) error { rows, err := tx.Query(`SELECT contract_status, COUNT(*) FROM contracts GROUP BY contract_status`) if err != nil { return fmt.Errorf("failed to query contracts: %w", err) @@ -541,7 +690,7 @@ func migrateVersion9(tx txn, _ *zap.Logger) error { var metric string switch status { case contracts.ContractStatusPending: - metric = metricPendingContracts + metric = "pendingContracts" case contracts.ContractStatusRejected: metric = metricRejectedContracts case contracts.ContractStatusActive: @@ -561,7 +710,7 @@ func migrateVersion9(tx txn, _ *zap.Logger) error { // migrateVersion8 sets the initial values for the locked and risked collateral // metrics for existing hosts -func migrateVersion8(tx txn, _ *zap.Logger) error { +func migrateVersion8(tx *txn, _ *zap.Logger) error { rows, err := tx.Query(`SELECT locked_collateral, risked_collateral FROM contracts WHERE contract_status IN (?, ?)`, contracts.ContractStatusPending, contracts.ContractStatusActive) if err != nil { return fmt.Errorf("failed to query contracts: %w", err) @@ -570,7 +719,7 @@ func migrateVersion8(tx txn, _ *zap.Logger) error { var totalLocked, totalRisked types.Currency for rows.Next() { var locked, risked types.Currency - if err := rows.Scan((*sqlCurrency)(&locked), (*sqlCurrency)(&risked)); err != nil { + if err := rows.Scan(decode(&locked), decode(&risked)); err != nil { return fmt.Errorf("failed to scan contract: %w", err) } totalLocked = totalLocked.Add(locked) @@ -590,14 +739,14 @@ func migrateVersion8(tx txn, _ *zap.Logger) error { } // migrateVersion7 adds the sector_cache_size column to the host_settings table -func migrateVersion7(tx txn, _ *zap.Logger) error { +func migrateVersion7(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`ALTER TABLE host_settings ADD COLUMN sector_cache_size INTEGER NOT NULL DEFAULT 0;`) return err } // migrateVersion6 fixes a bug where the physical sectors metric was not being // properly decreased when a volume is force removed. -func migrateVersion6(tx txn, _ *zap.Logger) error { +func migrateVersion6(tx *txn, _ *zap.Logger) error { var count int64 if err := tx.QueryRow(`SELECT COUNT(id) FROM volume_sectors WHERE sector_id IS NOT NULL`).Scan(&count); err != nil { return fmt.Errorf("failed to count volume sectors: %w", err) @@ -610,7 +759,7 @@ func migrateVersion6(tx txn, _ *zap.Logger) error { // the contract sectors metric will drastically increase for existing hosts. // This is unavoidable, as we have no way of knowing how many sectors were // previously renewed. -func migrateVersion5(tx txn, _ *zap.Logger) error { +func migrateVersion5(tx *txn, _ *zap.Logger) error { var count int64 if err := tx.QueryRow(`SELECT COUNT(*) FROM contract_sector_roots`).Scan(&count); err != nil { return fmt.Errorf("failed to count contract sector roots: %w", err) @@ -620,7 +769,7 @@ func migrateVersion5(tx txn, _ *zap.Logger) error { } // migrateVersion4 changes the collateral setting to collateral_multiplier -func migrateVersion4(tx txn, _ *zap.Logger) error { +func migrateVersion4(tx *txn, _ *zap.Logger) error { const ( newSettingsSchema = `CREATE TABLE host_settings ( id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row @@ -676,13 +825,13 @@ egress_limit, ddns_provider, ddns_update_v4, ddns_update_v6, ddns_opts, registry // migrateVersion3 adds a wallet hash to the global settings table to detect // when the private key has changed. -func migrateVersion3(tx txn, _ *zap.Logger) error { +func migrateVersion3(tx *txn, _ *zap.Logger) error { _, err := tx.Exec(`ALTER TABLE global_settings ADD COLUMN wallet_hash BLOB;`) return err } // migrateVersion2 removes the min prefix from the price columns in host_settings -func migrateVersion2(tx txn, _ *zap.Logger) error { +func migrateVersion2(tx *txn, _ *zap.Logger) error { const ( newSettingsSchema = `CREATE TABLE host_settings ( id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row @@ -737,7 +886,7 @@ egress_limit, dyn_dns_provider, dns_update_v4, dns_update_v6, dyn_dns_opts, regi // migrations is a list of functions that are run to migrate the database from // one version to the next. Migrations are used to update existing databases to // match the schema in init.sql. -var migrations = []func(tx txn, log *zap.Logger) error{ +var migrations = []func(tx *txn, log *zap.Logger) error{ migrateVersion2, migrateVersion3, migrateVersion4, @@ -764,4 +913,6 @@ var migrations = []func(tx txn, log *zap.Logger) error{ migrateVersion25, migrateVersion26, migrateVersion27, + migrateVersion28, + migrateVersion29, } diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go new file mode 100644 index 00000000..b5b2ee99 --- /dev/null +++ b/persist/sqlite/peers.go @@ -0,0 +1,237 @@ +package sqlite + +import ( + "database/sql" + "errors" + "fmt" + "net" + "strconv" + "strings" + "sync" + "time" + + "go.sia.tech/coreutils/syncer" + "go.uber.org/zap" +) + +// A PeerStore stores information about peers. +type PeerStore struct { + s *Store + + // session-specific peer info is stored in memory to reduce write load + // on the database + mu sync.Mutex + peerInfo map[string]syncer.PeerInfo +} + +var _ syncer.PeerStore = (*PeerStore)(nil) + +// AddPeer adds the given peer to the store. +func (ps *PeerStore) AddPeer(peer string) error { + ps.mu.Lock() + defer ps.mu.Unlock() + ps.peerInfo[peer] = syncer.PeerInfo{ + Address: peer, + FirstSeen: time.Now(), + } + return ps.s.AddPeer(peer) +} + +// Peers returns the addresses of all known peers. +func (ps *PeerStore) Peers() ([]syncer.PeerInfo, error) { + ps.mu.Lock() + defer ps.mu.Unlock() + + // copy the map to a slice + peers := make([]syncer.PeerInfo, 0, len(ps.peerInfo)) + for _, pi := range ps.peerInfo { + peers = append(peers, pi) + } + return peers, nil +} + +// UpdatePeerInfo updates the information for the given peer. +func (ps *PeerStore) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) error { + ps.mu.Lock() + defer ps.mu.Unlock() + if pi, ok := ps.peerInfo[peer]; !ok { + return syncer.ErrPeerNotFound + } else { + fn(&pi) + ps.peerInfo[peer] = pi + } + return nil +} + +// Ban temporarily bans the given peer. +func (ps *PeerStore) Ban(peer string, duration time.Duration, reason string) error { + return ps.s.Ban(peer, duration, reason) +} + +// Banned returns true if the peer is banned. +func (ps *PeerStore) Banned(peer string) (bool, error) { + return ps.s.Banned(peer) +} + +// PeerInfo returns the information for the given peer. +func (ps *PeerStore) PeerInfo(peer string) (syncer.PeerInfo, error) { + ps.mu.Lock() + defer ps.mu.Unlock() + if pi, ok := ps.peerInfo[peer]; ok { + return pi, nil + } + return syncer.PeerInfo{}, syncer.ErrPeerNotFound +} + +// NewPeerStore creates a new peer store using the given store. +func NewPeerStore(s *Store) (syncer.PeerStore, error) { + ps := &PeerStore{s: s, peerInfo: make(map[string]syncer.PeerInfo)} + peers, err := s.Peers() + if err != nil { + return nil, fmt.Errorf("failed to load peers: %w", err) + } + for _, pi := range peers { + ps.peerInfo[pi.Address] = pi + } + return ps, nil +} + +func scanPeerInfo(s scanner) (pi syncer.PeerInfo, err error) { + err = s.Scan(&pi.Address, decode(&pi.FirstSeen)) + return +} + +// AddPeer adds the given peer to the store. +func (s *Store) AddPeer(peer string) error { + return s.transaction(func(tx *txn) error { + const query = `INSERT INTO syncer_peers (peer_address, first_seen) VALUES ($1, $2) ON CONFLICT (peer_address) DO NOTHING` + _, err := tx.Exec(query, peer, encode(time.Now())) + return err + }) +} + +// Peers returns the addresses of all known peers. +func (s *Store) Peers() (peers []syncer.PeerInfo, _ error) { + err := s.transaction(func(tx *txn) error { + const query = `SELECT peer_address, first_seen FROM syncer_peers` + rows, err := tx.Query(query) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + peer, err := scanPeerInfo(rows) + if err != nil { + return fmt.Errorf("failed to scan peer info: %w", err) + } + peers = append(peers, peer) + } + return rows.Err() + }) + return peers, err +} + +// normalizePeer normalizes a peer address to a CIDR subnet. +func normalizePeer(peer string) (string, error) { + host, _, err := net.SplitHostPort(peer) + if err != nil { + host = peer + } + if strings.IndexByte(host, '/') != -1 { + _, subnet, err := net.ParseCIDR(host) + if err != nil { + return "", fmt.Errorf("failed to parse CIDR: %w", err) + } + return subnet.String(), nil + } + + ip := net.ParseIP(host) + if ip == nil { + return "", errors.New("invalid IP address") + } + + var maskLen int + if ip.To4() != nil { + maskLen = 32 + } else { + maskLen = 128 + } + + _, normalized, err := net.ParseCIDR(fmt.Sprintf("%s/%d", ip.String(), maskLen)) + if err != nil { + panic("failed to parse CIDR") + } + return normalized.String(), nil +} + +// Ban temporarily bans one or more IPs. The addr should either be a single +// IP with port (e.g. 1.2.3.4:5678) or a CIDR subnet (e.g. 1.2.3.4/16). +func (s *Store) Ban(peer string, duration time.Duration, reason string) error { + address, err := normalizePeer(peer) + if err != nil { + return err + } + s.log.Debug("banning peer", zap.String("peer", address), zap.Duration("duration", duration), zap.String("reason", reason)) + return s.transaction(func(tx *txn) error { + const query = `INSERT INTO syncer_bans (net_cidr, expiration, reason) VALUES ($1, $2, $3) ON CONFLICT (net_cidr) DO UPDATE SET expiration=EXCLUDED.expiration, reason=EXCLUDED.reason` + _, err := tx.Exec(query, address, encode(time.Now().Add(duration)), reason) + return err + }) +} + +// Banned returns true if the peer is banned. +func (s *Store) Banned(peer string) (banned bool, _ error) { + // normalize the peer into a CIDR subnet + peer, err := normalizePeer(peer) + if err != nil { + return false, fmt.Errorf("failed to normalize peer: %w", err) + } + + _, subnet, err := net.ParseCIDR(peer) + if err != nil { + return false, fmt.Errorf("failed to parse CIDR: %w", err) + } + + // check all subnets from the given subnet to the max subnet length + var maxMaskLen int + if subnet.IP.To4() != nil { + maxMaskLen = 32 + } else { + maxMaskLen = 128 + } + + checkSubnets := make([]string, 0, maxMaskLen) + for i := maxMaskLen; i > 0; i-- { + _, subnet, err := net.ParseCIDR(subnet.IP.String() + "/" + strconv.Itoa(i)) + if err != nil { + panic("failed to parse CIDR") + } + checkSubnets = append(checkSubnets, subnet.String()) + } + + err = s.transaction(func(tx *txn) error { + checkSubnetStmt, err := tx.Prepare(`SELECT expiration FROM syncer_bans WHERE net_cidr = $1 ORDER BY expiration DESC LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer checkSubnetStmt.Close() + + for _, subnet := range checkSubnets { + var expiration time.Time + + err := checkSubnetStmt.QueryRow(subnet).Scan(decode(&expiration)) + banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to check ban status: %w", err) + } else if banned { + s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration)) + return nil + } + } + return nil + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return false, fmt.Errorf("failed to check ban status: %w", err) + } + return banned, nil +} diff --git a/persist/sqlite/peers_test.go b/persist/sqlite/peers_test.go new file mode 100644 index 00000000..0ef91dc8 --- /dev/null +++ b/persist/sqlite/peers_test.go @@ -0,0 +1,122 @@ +package sqlite + +import ( + "net" + "path/filepath" + "testing" + "time" + + "go.sia.tech/coreutils/syncer" + "go.uber.org/zap/zaptest" +) + +func TestAddPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ps, err := NewPeerStore(db) + if err != nil { + t.Fatal(err) + } + + const peer = "1.2.3.4:9981" + + if err := ps.AddPeer(peer); err != nil { + t.Fatal(err) + } + + lastConnect := time.Now().UTC().Truncate(time.Second) // stored as unix milliseconds + syncedBlocks := uint64(15) + syncDuration := 5 * time.Second + + err = ps.UpdatePeerInfo(peer, func(info *syncer.PeerInfo) { + info.LastConnect = lastConnect + info.SyncedBlocks = syncedBlocks + info.SyncDuration = syncDuration + }) + if err != nil { + t.Fatal(err) + } + + info, err := ps.PeerInfo(peer) + if err != nil { + t.Fatal(err) + } + + if !info.LastConnect.Equal(lastConnect) { + t.Errorf("expected LastConnect = %v; got %v", lastConnect, info.LastConnect) + } + if info.SyncedBlocks != syncedBlocks { + t.Errorf("expected SyncedBlocks = %d; got %d", syncedBlocks, info.SyncedBlocks) + } + if info.SyncDuration != 5*time.Second { + t.Errorf("expected SyncDuration = %s; got %s", syncDuration, info.SyncDuration) + } + + peers, err := ps.Peers() + if err != nil { + t.Fatal(err) + } else if len(peers) != 1 { + t.Fatalf("expected 1 peer; got %d", len(peers)) + } else if peerInfo := peers[0]; peerInfo.Address != peer { + t.Errorf("expected peer address = %q; got %q", peer, peerInfo.Address) + } else if peerInfo.LastConnect != lastConnect { + t.Errorf("expected LastConnect = %v; got %v", lastConnect, peerInfo.LastConnect) + } else if peerInfo.SyncedBlocks != syncedBlocks { + t.Errorf("expected SyncedBlocks = %d; got %d", syncedBlocks, peerInfo.SyncedBlocks) + } else if peerInfo.SyncDuration != syncDuration { + t.Errorf("expected SyncDuration = %s; got %s", syncDuration, peerInfo.SyncDuration) + } else if peerInfo.FirstSeen.IsZero() { + t.Errorf("expected FirstSeen to be non-zero; got %v", peerInfo.FirstSeen) + } +} + +func TestBanPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ps, err := NewPeerStore(db) + if err != nil { + t.Fatal(err) + } + + const peer = "1.2.3.4" + + if banned, err := ps.Banned(peer); err != nil || banned { + t.Fatal("expected peer to not be banned", err) + } + + // ban the peer + ps.Ban(peer, 5*time.Second, "test") + + if banned, err := ps.Banned(peer); err != nil || !banned { + t.Fatal("expected peer to be banned", err) + } + + // wait for the ban to expire + time.Sleep(5 * time.Second) + + if banned, err := ps.Banned(peer); err != nil || banned { + t.Fatal("expected peer to not be banned", err) + } + + // ban a subnet + _, subnet, err := net.ParseCIDR(peer + "/24") + if err != nil { + t.Fatal(err) + } + + t.Log("banning", subnet) + ps.Ban(subnet.String(), time.Second, "test") + if banned, err := ps.Banned(peer); err != nil || !banned { + t.Fatal("expected peer to be banned", err) + } +} diff --git a/persist/sqlite/recalc.go b/persist/sqlite/recalc.go index d596810a..91dc7e70 100644 --- a/persist/sqlite/recalc.go +++ b/persist/sqlite/recalc.go @@ -2,12 +2,14 @@ package sqlite import ( "fmt" + "time" "go.sia.tech/core/types" + "go.sia.tech/hostd/host/contracts" "go.uber.org/zap" ) -func checkContractAccountFunding(tx txn, log *zap.Logger) error { +func checkContractAccountFunding(tx *txn, log *zap.Logger) error { rows, err := tx.Query(`SELECT contract_id, amount FROM contract_account_funding`) if err != nil { return fmt.Errorf("failed to query contract account funding: %w", err) @@ -18,7 +20,7 @@ func checkContractAccountFunding(tx txn, log *zap.Logger) error { for rows.Next() { var contractID int64 var amount types.Currency - if err := rows.Scan(&contractID, (*sqlCurrency)(&amount)); err != nil { + if err := rows.Scan(&contractID, decode(&amount)); err != nil { return fmt.Errorf("failed to scan contract account funding: %w", err) } contractFunding[contractID] = contractFunding[contractID].Add(amount) @@ -32,7 +34,7 @@ func checkContractAccountFunding(tx txn, log *zap.Logger) error { for contractID, amount := range contractFunding { var actualAmount types.Currency - err := tx.QueryRow(`SELECT account_funding FROM contracts WHERE id=$1`, contractID).Scan((*sqlCurrency)(&actualAmount)) + err := tx.QueryRow(`SELECT account_funding FROM contracts WHERE id=$1`, contractID).Scan(decode(&actualAmount)) if err != nil { return fmt.Errorf("failed to query contract account funding: %w", err) } @@ -44,7 +46,7 @@ func checkContractAccountFunding(tx txn, log *zap.Logger) error { return nil } -func recalcContractAccountFunding(tx txn, _ *zap.Logger) error { +func recalcContractAccountFunding(tx *txn, _ *zap.Logger) error { rows, err := tx.Query(`SELECT contract_id, amount FROM contract_account_funding`) if err != nil { return fmt.Errorf("failed to query contract account funding: %w", err) @@ -55,7 +57,7 @@ func recalcContractAccountFunding(tx txn, _ *zap.Logger) error { for rows.Next() { var contractID int64 var amount types.Currency - if err := rows.Scan(&contractID, (*sqlCurrency)(&amount)); err != nil { + if err := rows.Scan(&contractID, decode(&amount)); err != nil { return fmt.Errorf("failed to scan contract account funding: %w", err) } contractFunding[contractID] = contractFunding[contractID].Add(amount) @@ -68,7 +70,7 @@ func recalcContractAccountFunding(tx txn, _ *zap.Logger) error { } for contractID, amount := range contractFunding { - res, err := tx.Exec(`UPDATE contracts SET account_funding=$1 WHERE id=$2`, sqlCurrency(amount), contractID) + res, err := tx.Exec(`UPDATE contracts SET account_funding=$1 WHERE id=$2`, encode(amount), contractID) if err != nil { return fmt.Errorf("failed to query contract account funding: %w", err) } else if rowsAffected, err := res.RowsAffected(); err != nil { @@ -80,17 +82,78 @@ func recalcContractAccountFunding(tx txn, _ *zap.Logger) error { return nil } +func recalcContractMetrics(tx *txn, log *zap.Logger) error { + rows, err := tx.Query(`SELECT contract_status, locked_collateral, risked_collateral, rpc_revenue, storage_revenue, ingress_revenue, egress_revenue, account_funding, registry_read, registry_write FROM contracts WHERE contract_status IN (?, ?);`, contracts.ContractStatusActive, contracts.ContractStatusSuccessful) + if err != nil { + return fmt.Errorf("failed to query contracts: %w", err) + } + defer rows.Close() + + var totalLocked types.Currency + var totalPending, totalEarned contracts.Usage + for rows.Next() { + var status contracts.ContractStatus + var lockedCollateral types.Currency + var usage contracts.Usage + + if err := rows.Scan(&status, decode(&lockedCollateral), decode(&usage.RiskedCollateral), decode(&usage.RPCRevenue), decode(&usage.StorageRevenue), decode(&usage.IngressRevenue), decode(&usage.EgressRevenue), decode(&usage.AccountFunding), decode(&usage.RegistryRead), decode(&usage.RegistryWrite)); err != nil { + return fmt.Errorf("failed to scan contract: %w", err) + } + + switch status { + case contracts.ContractStatusActive: + totalLocked = totalLocked.Add(lockedCollateral) + totalPending = totalPending.Add(usage) + case contracts.ContractStatusSuccessful: + totalEarned = totalEarned.Add(usage) + } + } + + log.Debug("resetting metrics", zap.Stringer("lockedCollateral", totalLocked), zap.Stringer("riskedCollateral", totalPending.RiskedCollateral)) + + if err := setCurrencyStat(tx, metricLockedCollateral, totalLocked, time.Now()); err != nil { + return fmt.Errorf("failed to increment locked collateral: %w", err) + } else if err := setCurrencyStat(tx, metricRiskedCollateral, totalPending.RiskedCollateral, time.Now()); err != nil { + return fmt.Errorf("failed to increment risked collateral: %w", err) + } else if err := setCurrencyStat(tx, metricPotentialRPCRevenue, totalPending.RPCRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment rpc revenue: %w", err) + } else if err := setCurrencyStat(tx, metricPotentialStorageRevenue, totalPending.StorageRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment storage revenue: %w", err) + } else if err := setCurrencyStat(tx, metricPotentialIngressRevenue, totalPending.IngressRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment ingress revenue: %w", err) + } else if err := setCurrencyStat(tx, metricPotentialEgressRevenue, totalPending.EgressRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment egress revenue: %w", err) + } else if err := setCurrencyStat(tx, metricPotentialRegistryReadRevenue, totalPending.RegistryRead, time.Now()); err != nil { + return fmt.Errorf("failed to increment read registry revenue: %w", err) + } else if err := setCurrencyStat(tx, metricPotentialRegistryWriteRevenue, totalPending.RegistryWrite, time.Now()); err != nil { + return fmt.Errorf("failed to increment write registry revenue: %w", err) + } else if err := setCurrencyStat(tx, metricEarnedRPCRevenue, totalEarned.RPCRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment rpc revenue: %w", err) + } else if err := setCurrencyStat(tx, metricEarnedStorageRevenue, totalEarned.StorageRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment storage revenue: %w", err) + } else if err := setCurrencyStat(tx, metricEarnedIngressRevenue, totalEarned.IngressRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment ingress revenue: %w", err) + } else if err := setCurrencyStat(tx, metricEarnedEgressRevenue, totalEarned.EgressRevenue, time.Now()); err != nil { + return fmt.Errorf("failed to increment egress revenue: %w", err) + } else if err := setCurrencyStat(tx, metricEarnedRegistryReadRevenue, totalEarned.RegistryRead, time.Now()); err != nil { + return fmt.Errorf("failed to increment read registry revenue: %w", err) + } else if err := setCurrencyStat(tx, metricEarnedRegistryWriteRevenue, totalEarned.RegistryWrite, time.Now()); err != nil { + return fmt.Errorf("failed to increment write registry revenue: %w", err) + } + return nil +} + // CheckContractAccountFunding checks that the contract account funding table // is correct. func (s *Store) CheckContractAccountFunding() error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { return checkContractAccountFunding(tx, s.log) }) } // RecalcContractAccountFunding recalculates the contract account funding table. func (s *Store) RecalcContractAccountFunding() error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { return recalcContractAccountFunding(tx, s.log) }) } diff --git a/persist/sqlite/registry.go b/persist/sqlite/registry.go index af20efa7..215787fb 100644 --- a/persist/sqlite/registry.go +++ b/persist/sqlite/registry.go @@ -12,18 +12,21 @@ import ( // GetRegistryValue returns the registry value for the given key. If the key is not // found should return ErrEntryNotFound. -func (s *Store) GetRegistryValue(key rhp3.RegistryKey) (entry rhp3.RegistryValue, _ error) { - err := s.queryRow(`SELECT revision_number, entry_type, entry_data, entry_signature FROM registry_entries WHERE registry_key=$1`, sqlHash256(key.Hash())).Scan( - (*sqlUint64)(&entry.Revision), - &entry.Type, - &entry.Data, - (*sqlHash512)(&entry.Signature), - ) - if errors.Is(err, sql.ErrNoRows) { - return rhp3.RegistryValue{}, registry.ErrEntryNotFound - } else if err != nil { - return rhp3.RegistryValue{}, fmt.Errorf("failed to get registry entry: %w", err) - } +func (s *Store) GetRegistryValue(key rhp3.RegistryKey) (entry rhp3.RegistryValue, err error) { + err = s.transaction(func(tx *txn) error { + err := tx.QueryRow(`SELECT revision_number, entry_type, entry_data, entry_signature FROM registry_entries WHERE registry_key=$1`, encode(key.Hash())).Scan( + decode(&entry.Revision), + &entry.Type, + &entry.Data, + decode(&entry.Signature), + ) + if errors.Is(err, sql.ErrNoRows) { + return registry.ErrEntryNotFound + } else if err != nil { + return fmt.Errorf("failed to get registry entry: %w", err) + } + return nil + }) return } @@ -36,8 +39,8 @@ func (s *Store) SetRegistryValue(entry rhp3.RegistryEntry, expiration uint64) er ) // note: need to error when the registry is full, so can't use upsert registryKey := entry.RegistryKey.Hash() - return s.transaction(func(tx txn) error { - err := tx.QueryRow(selectQuery, sqlHash256(registryKey)).Scan((*sqlHash256)(®istryKey)) + return s.transaction(func(tx *txn) error { + err := tx.QueryRow(selectQuery, encode(registryKey)).Scan(decode(®istryKey)) if errors.Is(err, sql.ErrNoRows) { // key doesn't exist, insert it count, limit, err := registryLimits(tx) @@ -46,7 +49,7 @@ func (s *Store) SetRegistryValue(entry rhp3.RegistryEntry, expiration uint64) er } else if count >= limit { return registry.ErrNotEnoughSpace } - err = tx.QueryRow(insertQuery, sqlHash256(registryKey), sqlUint64(entry.Revision), entry.Type, sqlHash512(entry.Signature), entry.Data, sqlUint64(expiration)).Scan((*sqlHash256)(®istryKey)) + err = tx.QueryRow(insertQuery, encode(registryKey), encode(entry.Revision), entry.Type, encode(entry.Signature), entry.Data, encode(expiration)).Scan(decode(®istryKey)) if err != nil { return fmt.Errorf("failed to insert registry entry: %w", err) } else if err := incrementNumericStat(tx, metricRegistryEntries, 1, time.Now()); err != nil { @@ -56,17 +59,21 @@ func (s *Store) SetRegistryValue(entry rhp3.RegistryEntry, expiration uint64) er return fmt.Errorf("failed to get registry entry: %w", err) } // key exists, update it - return tx.QueryRow(updateQuery, sqlHash256(registryKey), sqlUint64(entry.Revision), entry.Type, sqlHash512(entry.Signature), entry.Data, sqlUint64(expiration)).Scan((*sqlHash256)(®istryKey)) + return tx.QueryRow(updateQuery, encode(registryKey), encode(entry.Revision), entry.Type, encode(entry.Signature), entry.Data, encode(expiration)).Scan(decode(®istryKey)) }) } // RegistryEntries returns the current number of entries as well as the // maximum number of entries the registry can hold. func (s *Store) RegistryEntries() (count, limit uint64, err error) { - return registryLimits(&dbTxn{s}) + err = s.transaction(func(tx *txn) error { + count, limit, err = registryLimits(tx) + return err + }) + return } -func registryLimits(tx txn) (count, limit uint64, err error) { +func registryLimits(tx *txn) (count, limit uint64, err error) { err = tx.QueryRow(`SELECT COALESCE(COUNT(re.registry_key), 0), COALESCE(hs.registry_limit, 0) FROM host_settings hs LEFT JOIN registry_entries re ON (true);`).Scan(&count, &limit) return } diff --git a/persist/sqlite/sectors.go b/persist/sqlite/sectors.go index 53937965..50c6b6e0 100644 --- a/persist/sqlite/sectors.go +++ b/persist/sqlite/sectors.go @@ -11,7 +11,7 @@ import ( "go.uber.org/zap" ) -func deleteTempSectors(tx txn, height uint64) (sectorIDs []int64, err error) { +func deleteTempSectors(tx *txn, height uint64) (sectorIDs []int64, err error) { const query = `DELETE FROM temp_storage_sector_roots WHERE id IN (SELECT id FROM temp_storage_sector_roots WHERE expiration_height <= $1 LIMIT $2) RETURNING sector_id;` @@ -33,7 +33,7 @@ RETURNING sector_id;` } func (s *Store) batchExpireTempSectors(height uint64) (expired int, pruned []types.Hash256, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { sectorIDs, err := deleteTempSectors(tx, height) if err != nil { return fmt.Errorf("failed to delete sectors: %w", err) @@ -56,7 +56,7 @@ func (s *Store) batchExpireTempSectors(height uint64) (expired int, pruned []typ // RemoveSector removes the metadata of a sector and returns its // location in the volume. func (s *Store) RemoveSector(root types.Hash256) (err error) { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { sectorID, err := sectorDBID(tx, root) if err != nil { return fmt.Errorf("failed to get sector: %w", err) @@ -86,7 +86,7 @@ func (s *Store) RemoveSector(root types.Hash256) (err error) { func (s *Store) SectorLocation(root types.Hash256) (storage.SectorLocation, func() error, error) { var lockID int64 var location storage.SectorLocation - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { sectorID, err := sectorDBID(tx, root) if errors.Is(err, sql.ErrNoRows) { return storage.ErrSectorNotFound @@ -107,7 +107,7 @@ func (s *Store) SectorLocation(root types.Hash256) (storage.SectorLocation, func return storage.SectorLocation{}, nil, err } unlock := func() error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { return unlockSector(tx, s.log.Named("SectorLocation"), lockID) }) } @@ -117,7 +117,7 @@ func (s *Store) SectorLocation(root types.Hash256) (storage.SectorLocation, func // AddTemporarySectors adds the roots of sectors that are temporarily stored // on the host. The sectors will be deleted after the expiration height. func (s *Store) AddTemporarySectors(sectors []storage.TempSector) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { stmt, err := tx.Prepare(`INSERT INTO temp_storage_sector_roots (sector_id, expiration_height) SELECT id, $1 FROM stored_sectors WHERE sector_root=$2 RETURNING id;`) if err != nil { return fmt.Errorf("failed to prepare query: %w", err) @@ -125,7 +125,7 @@ func (s *Store) AddTemporarySectors(sectors []storage.TempSector) error { defer stmt.Close() for _, sector := range sectors { var dbID int64 - err := stmt.QueryRow(sector.Expiration, sqlHash256(sector.Root)).Scan(&dbID) + err := stmt.QueryRow(sector.Expiration, encode(sector.Root)).Scan(&dbID) if err != nil { return fmt.Errorf("failed to add temp sector root: %w", err) } @@ -154,21 +154,9 @@ func (s *Store) ExpireTempSectors(height uint64) error { } } -// HasSector returns true if the sector is stored on the host. -func (s *Store) HasSector(root types.Hash256) (bool, error) { - var dbID int64 - err := s.queryRow(`SELECT id FROM stored_sectors WHERE sector_root=$1`, sqlHash256(root)).Scan(&dbID) - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } else if err != nil { - return false, err - } - return true, nil -} - // SectorReferences returns the references, if any of a sector root func (s *Store) SectorReferences(root types.Hash256) (refs storage.SectorReference, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { dbID, err := sectorDBID(tx, root) if err != nil { return fmt.Errorf("failed to get sector id: %w", err) @@ -196,7 +184,7 @@ func (s *Store) SectorReferences(root types.Hash256) (refs storage.SectorReferen return } -func contractSectorRefs(tx txn, sectorID int64) (contractIDs []types.FileContractID, err error) { +func contractSectorRefs(tx *txn, sectorID int64) (contractIDs []types.FileContractID, err error) { rows, err := tx.Query(`SELECT DISTINCT contract_id FROM contract_sector_roots WHERE sector_id=$1;`, sectorID) if err != nil { return nil, fmt.Errorf("failed to select contracts: %w", err) @@ -205,7 +193,7 @@ func contractSectorRefs(tx txn, sectorID int64) (contractIDs []types.FileContrac for rows.Next() { var contractID types.FileContractID - if err := rows.Scan((*sqlHash256)(&contractID)); err != nil { + if err := rows.Scan(decode(&contractID)); err != nil { return nil, fmt.Errorf("failed to scan contract id: %w", err) } contractIDs = append(contractIDs, contractID) @@ -213,17 +201,17 @@ func contractSectorRefs(tx txn, sectorID int64) (contractIDs []types.FileContrac return } -func getTempStorageCount(tx txn, sectorID int64) (n int, err error) { +func getTempStorageCount(tx *txn, sectorID int64) (n int, err error) { err = tx.QueryRow(`SELECT COUNT(*) FROM temp_storage_sector_roots WHERE sector_id=$1;`, sectorID).Scan(&n) return } -func getSectorLockCount(tx txn, sectorID int64) (n int, err error) { +func getSectorLockCount(tx *txn, sectorID int64) (n int, err error) { err = tx.QueryRow(`SELECT COUNT(*) FROM locked_sectors WHERE sector_id=$1;`, sectorID).Scan(&n) return } -func incrementVolumeUsage(tx txn, volumeID int64, delta int) error { +func incrementVolumeUsage(tx *txn, volumeID int64, delta int) error { var used int64 err := tx.QueryRow(`UPDATE storage_volumes SET used_sectors=used_sectors+$1 WHERE id=$2 RETURNING used_sectors;`, delta, volumeID).Scan(&used) if err != nil { @@ -236,20 +224,26 @@ func incrementVolumeUsage(tx txn, volumeID int64, delta int) error { return nil } -func pruneSectors(tx txn, ids []int64) (pruned []types.Hash256, err error) { - hasContractRefStmt, err := tx.Prepare(`SELECT id FROM contract_sector_roots WHERE sector_id=$1 LIMIT 1`) +func pruneSectors(tx *txn, ids []int64) (pruned []types.Hash256, err error) { + hasContractRefStmt, err := tx.Prepare(`SELECT EXISTS(SELECT 1 FROM contract_sector_roots WHERE sector_id=$1)`) if err != nil { return nil, fmt.Errorf("failed to prepare contract reference query: %w", err) } defer hasContractRefStmt.Close() - hasTempRefStmt, err := tx.Prepare(`SELECT id FROM temp_storage_sector_roots WHERE sector_id=$1 LIMIT 1`) + hasV2ContractRefStmt, err := tx.Prepare(`SELECT EXISTS(SELECT 1 FROM contract_v2_sector_roots WHERE sector_id=$1)`) + if err != nil { + return nil, fmt.Errorf("failed to prepare v2 contract reference query: %w", err) + } + defer hasV2ContractRefStmt.Close() + + hasTempRefStmt, err := tx.Prepare(`SELECT EXISTS(SELECT 1 FROM temp_storage_sector_roots WHERE sector_id=$1)`) if err != nil { return nil, fmt.Errorf("failed to prepare temp reference query: %w", err) } defer hasTempRefStmt.Close() - hasLockStmt, err := tx.Prepare(`SELECT id FROM locked_sectors WHERE sector_id=$1 LIMIT 1`) + hasLockStmt, err := tx.Prepare(`SELECT EXISTS(SELECT 1 FROM locked_sectors WHERE sector_id=$1)`) if err != nil { return nil, fmt.Errorf("failed to prepare lock reference query: %w", err) } @@ -269,27 +263,32 @@ func pruneSectors(tx txn, ids []int64) (pruned []types.Hash256, err error) { volumeDelta := make(map[int64]int) for _, id := range ids { - var contractDBID int64 - err := hasContractRefStmt.QueryRow(id).Scan(&contractDBID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + var exists bool + err := hasContractRefStmt.QueryRow(id).Scan(&exists) + if err != nil { return nil, fmt.Errorf("failed to check contract references: %w", err) - } else if err == nil { + } else if exists { continue // sector has a contract reference } - var tempDBID int64 - err = hasTempRefStmt.QueryRow(id).Scan(&tempDBID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + err = hasV2ContractRefStmt.QueryRow(id).Scan(&exists) + if err != nil { + return nil, fmt.Errorf("failed to check v2 contract references: %w", err) + } else if exists { + continue // sector has a contract reference + } + + err = hasTempRefStmt.QueryRow(id).Scan(&exists) + if err != nil { return nil, fmt.Errorf("failed to check temp references: %w", err) - } else if err == nil { + } else if exists { continue // sector has a temp storage reference } - var lockDBID int64 - err = hasLockStmt.QueryRow(id).Scan(&lockDBID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + err = hasLockStmt.QueryRow(id).Scan(&exists) + if err != nil { return nil, fmt.Errorf("failed to check lock references: %w", err) - } else if err == nil { + } else if exists { continue // sector is locked } @@ -302,7 +301,7 @@ func pruneSectors(tx txn, ids []int64) (pruned []types.Hash256, err error) { } var root types.Hash256 - err = deleteSectorStmt.QueryRow(id).Scan((*sqlHash256)(&root)) + err = deleteSectorStmt.QueryRow(id).Scan(decode(&root)) if err != nil && !errors.Is(err, sql.ErrNoRows) { // ignore rows not found return nil, fmt.Errorf("failed to delete sector: %w", err) } else if err == nil { @@ -322,14 +321,14 @@ func pruneSectors(tx txn, ids []int64) (pruned []types.Hash256, err error) { // lockSector locks a sector root. The lock must be released by calling // unlockSector. A sector must be locked when it is being read or written // to prevent it from being removed by prune sector. -func lockSector(tx txn, sectorDBID int64) (lockID int64, err error) { +func lockSector(tx *txn, sectorDBID int64) (lockID int64, err error) { err = tx.QueryRow(`INSERT INTO locked_sectors (sector_id) VALUES ($1) RETURNING id;`, sectorDBID).Scan(&lockID) return } // deleteLocks removes the lock records with the given ids and returns the // sector ids of the deleted locks. -func deleteLocks(tx txn, ids []int64) (sectorIDs []int64, err error) { +func deleteLocks(tx *txn, ids []int64) (sectorIDs []int64, err error) { if len(ids) == 0 { return nil, nil } @@ -352,17 +351,17 @@ func deleteLocks(tx txn, ids []int64) (sectorIDs []int64, err error) { } // unlockSector unlocks a sector root. -func unlockSector(txn txn, log *zap.Logger, lockIDs ...int64) error { +func unlockSector(tx *txn, log *zap.Logger, lockIDs ...int64) error { if len(lockIDs) == 0 { return nil } - sectorIDs, err := deleteLocks(txn, lockIDs) + sectorIDs, err := deleteLocks(tx, lockIDs) if err != nil { return fmt.Errorf("failed to delete locks: %w", err) } - pruned, err := pruneSectors(txn, sectorIDs) + pruned, err := pruneSectors(tx, sectorIDs) if err != nil { return fmt.Errorf("failed to prune sectors: %w", err) } @@ -374,7 +373,7 @@ func unlockSector(txn txn, log *zap.Logger, lockIDs ...int64) error { // IDs. The lock ids must be unlocked by unlockLocations. Volume locations // should be locked during writes to prevent the location from being written // to by another goroutine. -func lockLocations(tx txn, locations []storage.SectorLocation) (locks []int64, err error) { +func lockLocations(tx *txn, locations []storage.SectorLocation) (locks []int64, err error) { if len(locations) == 0 { return nil, nil } @@ -396,7 +395,7 @@ func lockLocations(tx txn, locations []storage.SectorLocation) (locks []int64, e // unlockLocations unlocks multiple locked sector locations. It is safe to // call multiple times. -func unlockLocations(tx txn, ids []int64) error { +func unlockLocations(tx *txn, ids []int64) error { if len(ids) == 0 { return nil } diff --git a/persist/sqlite/settings.go b/persist/sqlite/settings.go index 7a96f8bd..14b591d6 100644 --- a/persist/sqlite/settings.go +++ b/persist/sqlite/settings.go @@ -12,7 +12,6 @@ import ( "go.sia.tech/core/types" "go.sia.tech/hostd/host/settings" "go.sia.tech/hostd/host/settings/pin" - "go.sia.tech/siad/modules" "go.uber.org/zap" ) @@ -21,13 +20,17 @@ func (s *Store) PinnedSettings(context.Context) (pinned pin.PinnedSettings, err const query = `SELECT currency, threshold, storage_pinned, storage_price, ingress_pinned, ingress_price, egress_pinned, egress_price, max_collateral_pinned, max_collateral FROM host_pinned_settings;` - err = s.queryRow(query).Scan(&pinned.Currency, &pinned.Threshold, &pinned.Storage.Pinned, &pinned.Storage.Value, &pinned.Ingress.Pinned, &pinned.Ingress.Value, &pinned.Egress.Pinned, &pinned.Egress.Value, &pinned.MaxCollateral.Pinned, &pinned.MaxCollateral.Value) - if errors.Is(err, sql.ErrNoRows) { - return pin.PinnedSettings{ - Currency: "usd", - Threshold: 0.02, - }, nil - } + err = s.transaction(func(tx *txn) error { + err = tx.QueryRow(query).Scan(&pinned.Currency, &pinned.Threshold, &pinned.Storage.Pinned, &pinned.Storage.Value, &pinned.Ingress.Pinned, &pinned.Ingress.Value, &pinned.Egress.Pinned, &pinned.Egress.Value, &pinned.MaxCollateral.Pinned, &pinned.MaxCollateral.Value) + if errors.Is(err, sql.ErrNoRows) { + pinned = pin.PinnedSettings{ + Currency: "usd", + Threshold: 0.02, + } + return nil + } + return err + }) return } @@ -39,8 +42,11 @@ ON CONFLICT (id) DO UPDATE SET currency=EXCLUDED.currency, threshold=EXCLUDED.th storage_pinned=EXCLUDED.storage_pinned, storage_price=EXCLUDED.storage_price, ingress_pinned=EXCLUDED.ingress_pinned, ingress_price=EXCLUDED.ingress_price, egress_pinned=EXCLUDED.egress_pinned, egress_price=EXCLUDED.egress_price, max_collateral_pinned=EXCLUDED.max_collateral_pinned, max_collateral=EXCLUDED.max_collateral;` - _, err := s.exec(query, p.Currency, p.Threshold, p.Storage.Pinned, p.Storage.Value, p.Ingress.Pinned, p.Ingress.Value, p.Egress.Pinned, p.Egress.Value, p.MaxCollateral.Pinned, p.MaxCollateral.Value) - return err + + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(query, p.Currency, p.Threshold, p.Storage.Pinned, p.Storage.Value, p.Ingress.Pinned, p.Ingress.Value, p.Egress.Pinned, p.Egress.Value, p.MaxCollateral.Pinned, p.MaxCollateral.Value) + return err + }) } // Settings returns the current host settings. @@ -52,24 +58,28 @@ func (s *Store) Settings() (config settings.Settings, err error) { max_account_balance, max_account_age, price_table_validity, max_contract_duration, window_size, ingress_limit, egress_limit, registry_limit, ddns_provider, ddns_update_v4, ddns_update_v6, ddns_opts, sector_cache_size FROM host_settings;` - err = s.queryRow(query).Scan(&config.Revision, &config.AcceptingContracts, - &config.NetAddress, (*sqlCurrency)(&config.ContractPrice), - (*sqlCurrency)(&config.BaseRPCPrice), (*sqlCurrency)(&config.SectorAccessPrice), - &config.CollateralMultiplier, (*sqlCurrency)(&config.MaxCollateral), - (*sqlCurrency)(&config.StoragePrice), (*sqlCurrency)(&config.EgressPrice), - (*sqlCurrency)(&config.IngressPrice), (*sqlCurrency)(&config.MaxAccountBalance), - &config.AccountExpiry, &config.PriceTableValidity, &config.MaxContractDuration, &config.WindowSize, - &config.IngressLimit, &config.EgressLimit, &config.MaxRegistryEntries, - &config.DDNS.Provider, &config.DDNS.IPv4, &config.DDNS.IPv6, &dyndnsBuf, &config.SectorCacheSize) - if errors.Is(err, sql.ErrNoRows) { - return settings.Settings{}, settings.ErrNoSettings - } - if dyndnsBuf != nil { - err = json.Unmarshal(dyndnsBuf, &config.DDNS.Options) - if err != nil { - return settings.Settings{}, fmt.Errorf("failed to unmarshal ddns options: %w", err) + + err = s.transaction(func(tx *txn) error { + err = tx.QueryRow(query).Scan(&config.Revision, &config.AcceptingContracts, + &config.NetAddress, decode(&config.ContractPrice), + decode(&config.BaseRPCPrice), decode(&config.SectorAccessPrice), + &config.CollateralMultiplier, decode(&config.MaxCollateral), + decode(&config.StoragePrice), decode(&config.EgressPrice), + decode(&config.IngressPrice), decode(&config.MaxAccountBalance), + &config.AccountExpiry, &config.PriceTableValidity, &config.MaxContractDuration, &config.WindowSize, + &config.IngressLimit, &config.EgressLimit, &config.MaxRegistryEntries, + &config.DDNS.Provider, &config.DDNS.IPv4, &config.DDNS.IPv6, &dyndnsBuf, &config.SectorCacheSize) + if errors.Is(err, sql.ErrNoRows) { + return settings.ErrNoSettings } - } + if dyndnsBuf != nil { + err = json.Unmarshal(dyndnsBuf, &config.DDNS.Options) + if err != nil { + return fmt.Errorf("failed to unmarshal ddns options: %w", err) + } + } + return nil + }) return } @@ -104,13 +114,13 @@ ON CONFLICT (id) DO UPDATE SET (settings_revision, } } - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { _, err := tx.Exec(query, settings.AcceptingContracts, - settings.NetAddress, sqlCurrency(settings.ContractPrice), - sqlCurrency(settings.BaseRPCPrice), sqlCurrency(settings.SectorAccessPrice), - settings.CollateralMultiplier, sqlCurrency(settings.MaxCollateral), - sqlCurrency(settings.StoragePrice), sqlCurrency(settings.EgressPrice), - sqlCurrency(settings.IngressPrice), sqlCurrency(settings.MaxAccountBalance), + settings.NetAddress, encode(settings.ContractPrice), + encode(settings.BaseRPCPrice), encode(settings.SectorAccessPrice), + settings.CollateralMultiplier, encode(settings.MaxCollateral), + encode(settings.StoragePrice), encode(settings.EgressPrice), + encode(settings.IngressPrice), encode(settings.MaxAccountBalance), settings.AccountExpiry, settings.PriceTableValidity, settings.MaxContractDuration, settings.WindowSize, settings.IngressLimit, settings.EgressLimit, settings.MaxRegistryEntries, settings.DDNS.Provider, settings.DDNS.IPv4, settings.DDNS.IPv6, dnsOptsBuf, settings.SectorCacheSize) @@ -143,7 +153,9 @@ ON CONFLICT (id) DO UPDATE SET (settings_revision, // HostKey returns the host's private key. func (s *Store) HostKey() (pk types.PrivateKey) { - err := s.queryRow(`SELECT host_key FROM global_settings WHERE id=0;`).Scan(&pk) + err := s.transaction(func(tx *txn) error { + return tx.QueryRow(`SELECT host_key FROM global_settings WHERE id=0;`).Scan(&pk) + }) if err != nil { s.log.Panic("failed to get host key", zap.Error(err), zap.Stack("stacktrace")) } else if n := len(pk); n != ed25519.PrivateKeySize { @@ -154,18 +166,15 @@ func (s *Store) HostKey() (pk types.PrivateKey) { // LastAnnouncement returns the last announcement. func (s *Store) LastAnnouncement() (ann settings.Announcement, err error) { - var height sql.NullInt64 var address sql.NullString - err = s.queryRow(`SELECT last_announce_id, last_announce_height, last_announce_address, last_announce_key FROM global_settings`). - Scan(nullable((*sqlHash256)(&ann.Index.ID)), &height, &address, nullable((*sqlHash256)(&ann.PublicKey))) + + err = s.transaction(func(tx *txn) error { + return tx.QueryRow(`SELECT last_announce_index, last_announce_address FROM global_settings`). + Scan(decodeNullable(&ann.Index), &address) + }) if errors.Is(err, sql.ErrNoRows) { return settings.Announcement{}, nil - } - - if height.Valid { - ann.Index.Height = uint64(height.Int64) - } - if address.Valid { + } else if address.Valid { ann.Address = address.String } return @@ -173,30 +182,19 @@ func (s *Store) LastAnnouncement() (ann settings.Announcement, err error) { // UpdateLastAnnouncement updates the last announcement. func (s *Store) UpdateLastAnnouncement(ann settings.Announcement) error { - const query = `UPDATE global_settings SET -last_announce_id=$1, last_announce_height=$2, last_announce_address=$3, last_announce_key=$4;` - _, err := s.exec(query, sqlHash256(ann.Index.ID), ann.Index.Height, ann.Address, sqlHash256(ann.PublicKey)) - return err + const query = `UPDATE global_settings SET last_announce_index=$1, last_announce_address=$2;` + + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(query, encode(ann.Index), ann.Address) + return err + }) } // RevertLastAnnouncement reverts the last announcement. func (s *Store) RevertLastAnnouncement() error { - const query = `UPDATE global_settings SET -last_announce_id=NULL, last_announce_height=NULL, last_announce_address=NULL, last_announce_key=NULL;` - _, err := s.exec(query) - return err -} - -// LastSettingsConsensusChange returns the last processed consensus change ID of -// the settings manager -func (s *Store) LastSettingsConsensusChange() (cc modules.ConsensusChangeID, height uint64, err error) { - var nullHeight sql.NullInt64 - n := nullable((*sqlHash256)(&cc)) - err = s.queryRow(`SELECT settings_last_processed_change, settings_height FROM global_settings WHERE id=0;`).Scan(n, &nullHeight) - if errors.Is(err, sql.ErrNoRows) || !n.Valid { - return modules.ConsensusChangeRecent, 0, nil // as a special case don't scan the chain for new announcements - } else if nullHeight.Valid { - height = uint64(nullHeight.Int64) - } - return + const query = `UPDATE global_settings SET last_announce_index=NULL, last_announce_address=NULL, last_announce_key=NULL;` + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(query) + return err + }) } diff --git a/persist/sqlite/sql.go b/persist/sqlite/sql.go index ce34c14f..7799a774 100644 --- a/persist/sqlite/sql.go +++ b/persist/sqlite/sql.go @@ -13,7 +13,7 @@ import ( const ( longQueryDuration = 10 * time.Millisecond - longTxnDuration = 10 * time.Millisecond + longTxnDuration = time.Second // reduce syncing spam ) type ( @@ -22,127 +22,107 @@ type ( Scan(dest ...any) error } - // A txn is an interface for executing queries within a transaction. - txn interface { - // Exec executes a query without returning any rows. The args are for - // any placeholder parameters in the query. - Exec(query string, args ...any) (sql.Result, error) - // Prepare creates a prepared statement for later queries or executions. - // Multiple queries or executions may be run concurrently from the - // returned statement. The caller must call the statement's Close method - // when the statement is no longer needed. - Prepare(query string) (*loggedStmt, error) - // Query executes a query that returns rows, typically a SELECT. The - // args are for any placeholder parameters in the query. - Query(query string, args ...any) (*loggedRows, error) - // QueryRow executes a query that is expected to return at most one row. - // QueryRow always returns a non-nil value. Errors are deferred until - // Row's Scan method is called. If the query selects no rows, the *Row's - // Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the - // first selected row and discards the rest. - QueryRow(query string, args ...any) *loggedRow - } - - // A dbTxn wraps a Store and implements the txn interface. - dbTxn struct { - store *Store - } - - loggedStmt struct { + // A stmt wraps a *sql.Stmt, logging slow queries. + stmt struct { *sql.Stmt query string - log *zap.Logger + + log *zap.Logger } - loggedTxn struct { + // A txn wraps a *sql.Tx, logging slow queries. + txn struct { *sql.Tx log *zap.Logger } - loggedRow struct { + // A row wraps a *sql.Row, logging slow queries. + row struct { *sql.Row log *zap.Logger } - loggedRows struct { + // rows wraps a *sql.Rows, logging slow queries. + rows struct { *sql.Rows + log *zap.Logger } ) -func (lr *loggedRows) Next() bool { +func (r *rows) Next() bool { start := time.Now() - next := lr.Rows.Next() + next := r.Rows.Next() if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) } return next } -func (lr *loggedRows) Scan(dest ...any) error { +func (r *rows) Scan(dest ...any) error { start := time.Now() - err := lr.Rows.Scan(dest...) + err := r.Rows.Scan(dest...) if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) } return err } -func (lr *loggedRow) Scan(dest ...any) error { +func (r *row) Scan(dest ...any) error { start := time.Now() - err := lr.Row.Scan(dest...) + err := r.Row.Scan(dest...) if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) } return err } -func (ls *loggedStmt) Exec(args ...any) (sql.Result, error) { - return ls.ExecContext(context.Background(), args...) +func (s *stmt) Exec(args ...any) (sql.Result, error) { + return s.ExecContext(context.Background(), args...) } -func (ls *loggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { +func (s *stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { start := time.Now() - result, err := ls.Stmt.ExecContext(ctx, args...) + result, err := s.Stmt.ExecContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow exec", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow exec", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return result, err } -func (ls *loggedStmt) Query(args ...any) (*sql.Rows, error) { - return ls.QueryContext(context.Background(), args...) +func (s *stmt) Query(args ...any) (*sql.Rows, error) { + return s.QueryContext(context.Background(), args...) } -func (ls *loggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { +func (s *stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { start := time.Now() - rows, err := ls.Stmt.QueryContext(ctx, args...) + rows, err := s.Stmt.QueryContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow query", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow query", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return rows, err } -func (ls *loggedStmt) QueryRow(args ...any) *loggedRow { - return ls.QueryRowContext(context.Background(), args...) +func (s *stmt) QueryRow(args ...any) *row { + return s.QueryRowContext(context.Background(), args...) } -func (ls *loggedStmt) QueryRowContext(ctx context.Context, args ...any) *loggedRow { +func (s *stmt) QueryRowContext(ctx context.Context, args ...any) *row { start := time.Now() - row := ls.Stmt.QueryRowContext(ctx, args...) + r := s.Stmt.QueryRowContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow query row", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow query row", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRow{row, ls.log.Named("row")} + return &row{r, s.log.Named("row")} } // Exec executes a query without returning any rows. The args are for // any placeholder parameters in the query. -func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) { +func (tx *txn) Exec(query string, args ...any) (sql.Result, error) { start := time.Now() - result, err := lt.Tx.Exec(query, args...) + result, err := tx.Tx.Exec(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return result, err } @@ -151,30 +131,30 @@ func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) { // Multiple queries or executions may be run concurrently from the // returned statement. The caller must call the statement's Close method // when the statement is no longer needed. -func (lt *loggedTxn) Prepare(query string) (*loggedStmt, error) { +func (tx *txn) Prepare(query string) (*stmt, error) { start := time.Now() - stmt, err := lt.Tx.Prepare(query) - if err != nil { + s, err := tx.Tx.Prepare(query) + if dur := time.Since(start); dur > longQueryDuration { + tx.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } else if err != nil { return nil, err - } else if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedStmt{ - Stmt: stmt, + return &stmt{ + Stmt: s, query: query, - log: lt.log.Named("statement"), + log: tx.log.Named("statement"), }, nil } // Query executes a query that returns rows, typically a SELECT. The // args are for any placeholder parameters in the query. -func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) { +func (tx *txn) Query(query string, args ...any) (*rows, error) { start := time.Now() - rows, err := lt.Tx.Query(query, args...) + r, err := tx.Tx.Query(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRows{rows, lt.log.Named("rows")}, err + return &rows{r, tx.log.Named("rows")}, err } // QueryRow executes a query that is expected to return at most one row. @@ -182,42 +162,32 @@ func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) { // Row's Scan method is called. If the query selects no rows, the *Row's // Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the // first selected row and discards the rest. -func (lt *loggedTxn) QueryRow(query string, args ...any) *loggedRow { +func (tx *txn) QueryRow(query string, args ...any) *row { start := time.Now() - row := lt.Tx.QueryRow(query, args...) + r := tx.Tx.QueryRow(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRow{row, lt.log.Named("row")} + return &row{r, tx.log.Named("row")} } -// Exec executes a query without returning any rows. The args are for -// any placeholder parameters in the query. -func (dt *dbTxn) Exec(query string, args ...any) (sql.Result, error) { - return dt.store.exec(query, args...) -} - -// Prepare creates a prepared statement for later queries or executions. -// Multiple queries or executions may be run concurrently from the -// returned statement. The caller must call the statement's Close method -// when the statement is no longer needed. -func (dt *dbTxn) Prepare(query string) (*loggedStmt, error) { - return dt.store.prepare(query) +// getDBVersion returns the current version of the database. +func getDBVersion(db *sql.DB) (version int64) { + // error is ignored -- the database may not have been initialized yet. + db.QueryRow(`SELECT db_version FROM global_settings;`).Scan(&version) + return } -// Query executes a query that returns rows, typically a SELECT. The -// args are for any placeholder parameters in the query. -func (dt *dbTxn) Query(query string, args ...any) (*loggedRows, error) { - return dt.store.query(query, args...) +// setDBVersion sets the current version of the database. +func setDBVersion(tx *txn, version int64) error { + const query = `UPDATE global_settings SET db_version=$1 RETURNING id;` + var dbID int64 + return tx.QueryRow(query, version).Scan(&dbID) } -// QueryRow executes a query that is expected to return at most one row. -// QueryRow always returns a non-nil value. Errors are deferred until -// Row's Scan method is called. If the query selects no rows, the *Row's -// Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the -// first selected row and discards the rest. -func (dt *dbTxn) QueryRow(query string, args ...any) *loggedRow { - return dt.store.queryRow(query, args...) +// jitterSleep sleeps for a random duration between t and t*1.5. +func jitterSleep(t time.Duration) { + time.Sleep(t + time.Duration(rand.Int63n(int64(t/2)))) } func queryPlaceHolders(n int) string { @@ -245,22 +215,3 @@ func queryArgs[T any](args []T) []any { } return out } - -// getDBVersion returns the current version of the database. -func getDBVersion(db *sql.DB) (version int64) { - // error is ignored -- the database may not have been initialized yet. - db.QueryRow(`SELECT db_version FROM global_settings;`).Scan(&version) - return -} - -// setDBVersion sets the current version of the database. -func setDBVersion(tx txn, version int64) error { - const query = `UPDATE global_settings SET db_version=$1 RETURNING id;` - var dbID int64 - return tx.QueryRow(query, version).Scan(&dbID) -} - -// jitterSleep sleeps for a random duration between t and t*1.5. -func jitterSleep(t time.Duration) { - time.Sleep(t + time.Duration(rand.Int63n(int64(t/2)))) -} diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go index 9bbeea1c..39ae18da 100644 --- a/persist/sqlite/store.go +++ b/persist/sqlite/store.go @@ -3,12 +3,17 @@ package sqlite import ( "database/sql" "encoding/hex" + "errors" "fmt" "math" "strings" "time" "github.com/mattn/go-sqlite3" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/host/settings" + "go.sia.tech/hostd/host/storage" "go.uber.org/zap" "lukechampine.com/frand" ) @@ -21,66 +26,16 @@ type ( } ) -// exec executes a query without returning any rows. The args are for -// any placeholder parameters in the query. -func (s *Store) exec(query string, args ...any) (sql.Result, error) { - start := time.Now() - result, err := s.db.Exec(query, args...) - if dur := time.Since(start); dur > longQueryDuration { - s.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) - } - return result, err -} - -// prepare creates a prepared statement for later queries or executions. -// Multiple queries or executions may be run concurrently from the -// returned statement. The caller must call the statement's Close method -// when the statement is no longer needed. -func (s *Store) prepare(query string) (*loggedStmt, error) { - start := time.Now() - stmt, err := s.db.Prepare(query) - if err != nil { - return nil, err - } else if dur := time.Since(start); dur > longQueryDuration { - s.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) - } - return &loggedStmt{ - Stmt: stmt, - query: query, - log: s.log.Named("statement"), - }, nil -} - -// query executes a query that returns rows, typically a SELECT. The -// args are for any placeholder parameters in the query. -func (s *Store) query(query string, args ...any) (*loggedRows, error) { - start := time.Now() - rows, err := s.db.Query(query, args...) - if dur := time.Since(start); dur > longQueryDuration { - s.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) - } - return &loggedRows{rows, s.log.Named("rows")}, err -} - -// queryRow executes a query that is expected to return at most one row. -// QueryRow always returns a non-nil value. Errors are deferred until -// Row's Scan method is called. If the query selects no rows, the *Row's -// Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the -// first selected row and discards the rest. -func (s *Store) queryRow(query string, args ...any) *loggedRow { - start := time.Now() - row := s.db.QueryRow(query, args...) - if dur := time.Since(start); dur > longQueryDuration { - s.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) - } - return &loggedRow{row, s.log.Named("row")} +// Close closes the underlying database. +func (s *Store) Close() error { + return s.db.Close() } // transaction executes a function within a database transaction. If the // function returns an error, the transaction is rolled back. Otherwise, the // transaction is committed. If the transaction fails due to a busy error, it is -// retried up to 15 times before returning. -func (s *Store) transaction(fn func(txn) error) error { +// retried up to 10 times before returning. +func (s *Store) transaction(fn func(*txn) error) error { var err error txnID := hex.EncodeToString(frand.Bytes(4)) log := s.log.Named("transaction").With(zap.String("id", txnID)) @@ -110,11 +65,6 @@ func (s *Store) transaction(fn func(txn) error) error { return fmt.Errorf("transaction failed (attempt %d): %w", attempt, err) } -// Close closes the underlying database. -func (s *Store) Close() error { - return s.db.Close() -} - func sqliteFilepath(fp string) string { params := []string{ fmt.Sprintf("_busy_timeout=%d", busyTimeout), @@ -129,33 +79,35 @@ func sqliteFilepath(fp string) string { // doTransaction is a helper function to execute a function within a transaction. If fn returns // an error, the transaction is rolled back. Otherwise, the transaction is // committed. -func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx txn) error) error { - start := time.Now() - tx, err := db.Begin() +func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx *txn) error) error { + dbtx, err := db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } - defer tx.Rollback() + start := time.Now() defer func() { + if err := dbtx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) { + log.Error("failed to rollback transaction", zap.Error(err)) + } // log the transaction if it took longer than txn duration if time.Since(start) > longTxnDuration { log.Debug("long transaction", zap.Duration("elapsed", time.Since(start)), zap.Stack("stack"), zap.Bool("failed", err != nil)) } }() - ltx := &loggedTxn{ - Tx: tx, + tx := &txn{ + Tx: dbtx, log: log, } - if err := fn(ltx); err != nil { + if err := fn(tx); err != nil { return err - } else if err = tx.Commit(); err != nil { + } else if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } -func clearLockedSectors(tx txn, log *zap.Logger) error { +func clearLockedSectors(tx *txn, log *zap.Logger) error { rows, err := tx.Query(`DELETE FROM locked_sectors RETURNING sector_id`) if err != nil { return err @@ -177,13 +129,13 @@ func clearLockedSectors(tx txn, log *zap.Logger) error { return nil } -func clearLockedLocations(tx txn) error { +func clearLockedLocations(tx *txn) error { _, err := tx.Exec(`DELETE FROM locked_volume_sectors`) return err } func (s *Store) clearLocks() error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if err := clearLockedLocations(tx); err != nil { return fmt.Errorf("failed to clear locked locations: %w", err) } else if err = clearLockedSectors(tx, s.log.Named("clearLockedSectors")); err != nil { @@ -205,7 +157,7 @@ func OpenDatabase(fp string, log *zap.Logger) (*Store, error) { log: log, } if err := store.init(); err != nil { - return nil, fmt.Errorf("failed to initialize database: %w", err) + return nil, err } else if err = store.clearLocks(); err != nil { // clear any locked sectors, metadata not synced to disk is safe to // overwrite. @@ -215,3 +167,10 @@ func OpenDatabase(fp string, log *zap.Logger) (*Store, error) { log.Debug("database initialized", zap.String("sqliteVersion", sqliteVersion), zap.Int("schemaVersion", len(migrations)+1), zap.String("path", fp)) return store, nil } + +var _ interface { + wallet.SingleAddressStore + contracts.ContractStore + storage.VolumeStore + settings.Store +} = (*Store)(nil) diff --git a/persist/sqlite/store_test.go b/persist/sqlite/store_test.go index eed4a474..a7efb57f 100644 --- a/persist/sqlite/store_test.go +++ b/persist/sqlite/store_test.go @@ -24,7 +24,7 @@ func TestTransactionRetry(t *testing.T) { } defer db.Close() - err = db.transaction(func(tx txn) error { return nil }) // start a new empty transaction, should succeed immediately + err = db.transaction(func(tx *txn) error { return nil }) // start a new empty transaction, should succeed immediately if err != nil { t.Fatal(err) } @@ -34,7 +34,7 @@ func TestTransactionRetry(t *testing.T) { // start a transaction in a goroutine and hold it open for 5 seconds // this should allow for the next transaction to be retried a few times go func() { - err := db.transaction(func(tx txn) error { + err := db.transaction(func(tx *txn) error { _, err := tx.Exec(`UPDATE global_settings SET host_key=?`, `foo`) // upgrade the transaction to an exclusive lock; if err != nil { return err @@ -51,7 +51,7 @@ func TestTransactionRetry(t *testing.T) { <-ch // wait for the transaction to start - err = db.transaction(func(tx txn) error { + err = db.transaction(func(tx *txn) error { _, err = tx.Exec(`UPDATE global_settings SET host_key=?`, `bar`) // should fail and be retried if err != nil { return err @@ -73,7 +73,7 @@ func TestTransactionRetry(t *testing.T) { } defer db.Close() - err = db.transaction(func(tx txn) error { return nil }) // start a new empty transaction, should succeed immediately + err = db.transaction(func(tx *txn) error { return nil }) // start a new empty transaction, should succeed immediately if err != nil { t.Fatal(err) } @@ -81,7 +81,7 @@ func TestTransactionRetry(t *testing.T) { ch := make(chan struct{}, 1) // channel to synchronize the transaction goroutine go func() { - err := db.transaction(func(tx txn) error { + err := db.transaction(func(tx *txn) error { _, err := tx.Exec(`UPDATE global_settings SET host_key=?`, `foo`) // upgrade the transaction to an exclusive lock; if err != nil { return err @@ -98,7 +98,7 @@ func TestTransactionRetry(t *testing.T) { <-ch // wait for the transaction to start - err = db.transaction(func(tx txn) error { + err = db.transaction(func(tx *txn) error { _, err := tx.Exec(`UPDATE global_settings SET host_key=?`, `bar`) // should fail and be retried if err != nil { return err @@ -133,31 +133,32 @@ func TestClearLockedSectors(t *testing.T) { t.Fatal(err) } - checkConsistency := func(locked, temp int) error { + assertSectors := func(locked, temp int) { + t.Helper() // check that the sectors are locked - var count int - err = db.queryRow(`SELECT COUNT(*) FROM locked_volume_sectors`).Scan(&count) - if err != nil { - return fmt.Errorf("query locked sectors: %w", err) - } else if locked != count { - return fmt.Errorf("expected %v locked sectors, got %v", locked, count) - } - - // check that the temp sectors are still there - err = db.queryRow(`SELECT COUNT(*) FROM temp_storage_sector_roots`).Scan(&count) + var dbLocked, dbTemp int + err := db.transaction(func(tx *txn) error { + if err := tx.QueryRow(`SELECT COUNT(*) FROM locked_volume_sectors`).Scan(&dbLocked); err != nil { + return fmt.Errorf("query locked sectors: %w", err) + } else if err := tx.QueryRow(`SELECT COUNT(*) FROM temp_storage_sector_roots`).Scan(&dbTemp); err != nil { + return fmt.Errorf("query temp sectors: %w", err) + } + return nil + }) if err != nil { - return fmt.Errorf("query temp sectors: %w", err) - } else if temp != count { - return fmt.Errorf("expected %v temp sectors, got %v", temp, count) + t.Fatal(err) + } else if dbLocked != locked { + t.Fatalf("expected %v locked sectors, got %v", locked, dbLocked) + } else if dbTemp != temp { + t.Fatalf("expected %v temp sectors, got %v", temp, dbTemp) } m, err := db.Metrics(time.Now()) if err != nil { - return fmt.Errorf("metrics: %w", err) + t.Fatal(err) } else if m.Storage.TempSectors != uint64(temp) { - return fmt.Errorf("expected %v temp sector metrics, got %v", temp, m.Storage.TempSectors) + t.Fatalf("expected %v temp sectors, got %v", temp, m.Storage.TempSectors) } - return nil } // write temp sectors to the database @@ -184,9 +185,7 @@ func TestClearLockedSectors(t *testing.T) { } // check that the sectors have been stored and locked - if err = checkConsistency(sectors, sectors/2); err != nil { - t.Fatal(err) - } + assertSectors(sectors, sectors/2) // clear the locked sectors if err = db.clearLocks(); err != nil { @@ -194,7 +193,5 @@ func TestClearLockedSectors(t *testing.T) { } // check that all the locks were removed and half the sectors deleted - if err = checkConsistency(0, sectors/2); err != nil { - t.Fatal(err) - } + assertSectors(0, sectors/2) } diff --git a/persist/sqlite/types.go b/persist/sqlite/types.go deleted file mode 100644 index 865a8ec7..00000000 --- a/persist/sqlite/types.go +++ /dev/null @@ -1,124 +0,0 @@ -package sqlite - -import ( - "database/sql" - "database/sql/driver" - "encoding/binary" - "encoding/hex" - "fmt" - "time" - - "go.sia.tech/core/types" -) - -type ( - sqlUint64 uint64 // sqlite does not support uint64, this will marshal it as a BLOB for when we need to store the high bits - sqlCurrency types.Currency - sqlHash256 [32]byte - sqlHash512 [64]byte - sqlTime time.Time - - sqlNullable[T sql.Scanner] struct { - Value T - Valid bool - } -) - -func (sn *sqlNullable[T]) Scan(src any) error { - if src == nil { - sn.Valid = false - return nil - } else if err := sn.Value.Scan(src); err != nil { - return err - } - sn.Valid = true - return nil -} - -func (sh *sqlHash256) Scan(src any) error { - switch src := src.(type) { - case string: - hex.Decode(sh[:], []byte(src)) - case []byte: - copy(sh[:], src) - default: - return fmt.Errorf("cannot scan %T to Hash256", src) - } - return nil -} - -func (sh sqlHash256) Value() (driver.Value, error) { - return sh[:], nil -} - -func (sh *sqlHash512) Scan(src any) error { - switch src := src.(type) { - case string: - hex.Decode(sh[:], []byte(src)) - case []byte: - copy(sh[:], src) - default: - return fmt.Errorf("cannot scan %T to Hash256", src) - } - return nil -} - -func (sh sqlHash512) Value() (driver.Value, error) { - return sh[:], nil -} - -func (su *sqlUint64) Scan(src any) error { - switch src := src.(type) { - case []byte: - *su = sqlUint64(binary.LittleEndian.Uint64(src)) - default: - return fmt.Errorf("cannot scan %T to uint64", src) - } - return nil -} - -func (su sqlUint64) Value() (driver.Value, error) { - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, uint64(su)) - return buf, nil -} - -// Scan implements the sql.Scanner interface. -func (sc *sqlCurrency) Scan(src any) error { - buf, ok := src.([]byte) - if !ok { - return fmt.Errorf("cannot scan %T to Currency", src) - } else if len(buf) != 16 { - return fmt.Errorf("cannot scan %d bytes to Currency", len(buf)) - } - - sc.Lo = binary.LittleEndian.Uint64(buf[:8]) - sc.Hi = binary.LittleEndian.Uint64(buf[8:]) - return nil -} - -// Value implements the driver.Valuer interface. -func (sc sqlCurrency) Value() (driver.Value, error) { - buf := make([]byte, 16) - binary.LittleEndian.PutUint64(buf[:8], sc.Lo) - binary.LittleEndian.PutUint64(buf[8:], sc.Hi) - return buf, nil -} - -func (st *sqlTime) Scan(src any) error { - switch src := src.(type) { - case int64: - *st = sqlTime(time.Unix(src, 0)) - return nil - default: - return fmt.Errorf("cannot scan %T to Time", src) - } -} - -func (st sqlTime) Value() (driver.Value, error) { - return time.Time(st).Unix(), nil -} - -func nullable[T sql.Scanner](v T) *sqlNullable[T] { - return &sqlNullable[T]{Value: v} -} diff --git a/persist/sqlite/volumes.go b/persist/sqlite/volumes.go index c0fbf63e..e36597fa 100644 --- a/persist/sqlite/volumes.go +++ b/persist/sqlite/volumes.go @@ -19,7 +19,7 @@ func (s *Store) migrateSector(volumeID int64, minIndex uint64, marker int64, mig var locationLocks []int64 var sectorLock int64 var oldLoc, newLoc storage.SectorLocation - err := s.transaction(func(tx txn) (err error) { + err := s.transaction(func(tx *txn) (err error) { oldLoc, err = sectorForMigration(tx, volumeID, marker) if errors.Is(err, sql.ErrNoRows) { marker = math.MaxInt64 @@ -68,8 +68,19 @@ func (s *Store) migrateSector(volumeID int64, minIndex uint64, marker int64, mig return marker, false, nil } // unlock the locations - defer unlockLocations(&dbTxn{s}, locationLocks) - defer unlockSector(&dbTxn{s}, log.Named("unlockSector"), sectorLock) + defer func() { + err = s.transaction(func(tx *txn) error { + if err := unlockLocations(tx, locationLocks); err != nil { + return fmt.Errorf("failed to unlock sector locations: %w", err) + } else if err := unlockSector(tx, log.Named("unlock"), sectorLock); err != nil { + return fmt.Errorf("failed to unlock sector: %w", err) + } + return nil + }) + if err != nil { + log.Error("failed to unlock sectors", zap.Error(err)) + } + }() // call the migrateFn with the new location, data should be copied to the // new location and synced to disk @@ -79,7 +90,7 @@ func (s *Store) migrateSector(volumeID int64, minIndex uint64, marker int64, mig } // update the sector location in a separate transaction - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { // get the sector ID var sectorID int64 err := tx.QueryRow(`SELECT sector_id FROM volume_sectors WHERE id=$1`, oldLoc.ID).Scan(§orID) @@ -119,7 +130,7 @@ func (s *Store) migrateSector(volumeID int64, minIndex uint64, marker int64, mig return marker, true, nil } -func forceDeleteVolumeSectors(tx txn, volumeID int64) (removed, lost int64, err error) { +func forceDeleteVolumeSectors(tx *txn, volumeID int64) (removed, lost int64, err error) { const query = `DELETE FROM volume_sectors WHERE id IN (SELECT id FROM volume_sectors WHERE volume_id=$1 LIMIT $2) RETURNING sector_id IS NULL AS empty` rows, err := tx.Query(query, volumeID, sqlSectorBatchSize) @@ -143,7 +154,7 @@ func forceDeleteVolumeSectors(tx txn, volumeID int64) (removed, lost int64, err return } -func deleteVolumeSectors(tx txn, volumeID int64) (removed int64, err error) { +func deleteVolumeSectors(tx *txn, volumeID int64) (removed int64, err error) { // check that the volume is empty var dummyID int64 err = tx.QueryRow(`SELECT id FROM volume_sectors WHERE volume_id=$1 AND sector_id IS NOT NULL LIMIT 1`, volumeID).Scan(&dummyID) @@ -163,7 +174,7 @@ func deleteVolumeSectors(tx txn, volumeID int64) (removed int64, err error) { } func (s *Store) batchRemoveVolumeSectors(id int64, force bool) (removed, lost int64, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { if force { removed, lost, err = forceDeleteVolumeSectors(tx, id) if err != nil { @@ -205,48 +216,52 @@ func (s *Store) batchRemoveVolumeSectors(id int64, force bool) (removed, lost in // StorageUsage returns the number of sectors stored and the total number of sectors // available in the storage pool. func (s *Store) StorageUsage() (usedSectors, totalSectors uint64, err error) { - // nulls are not included in COUNT() -- counting sector roots is equivalent - // to counting used sectors. const query = `SELECT COALESCE(SUM(total_sectors), 0) AS total_sectors, COALESCE(SUM(used_sectors), 0) AS used_sectors FROM storage_volumes` - err = s.queryRow(query).Scan(&totalSectors, &usedSectors) + err = s.transaction(func(tx *txn) error { + return tx.QueryRow(query).Scan(&totalSectors, &usedSectors) + }) return } // Volumes returns a list of all volumes. -func (s *Store) Volumes() ([]storage.Volume, error) { +func (s *Store) Volumes() (volumes []storage.Volume, err error) { const query = `SELECT v.id, v.disk_path, v.read_only, v.available, v.total_sectors, v.used_sectors FROM storage_volumes v ORDER BY v.id ASC` - rows, err := s.query(query) - if err != nil { - return nil, fmt.Errorf("query failed: %w", err) - } - defer rows.Close() - var volumes []storage.Volume - for rows.Next() { - volume, err := scanVolume(rows) + err = s.transaction(func(tx *txn) error { + rows, err := tx.Query(query) if err != nil { - return nil, fmt.Errorf("failed to scan volume: %w", err) + return fmt.Errorf("query failed: %w", err) } - volumes = append(volumes, volume) - } - return volumes, nil + defer rows.Close() + + for rows.Next() { + volume, err := scanVolume(rows) + if err != nil { + return fmt.Errorf("failed to scan volume: %w", err) + } + volumes = append(volumes, volume) + } + return rows.Err() + }) + return } // Volume returns a volume by its ID. -func (s *Store) Volume(id int64) (storage.Volume, error) { +func (s *Store) Volume(id int64) (vol storage.Volume, err error) { const query = `SELECT v.id, v.disk_path, v.read_only, v.available, v.total_sectors, v.used_sectors FROM storage_volumes v WHERE v.id=$1` - row := s.queryRow(query, id) - vol, err := scanVolume(row) + + err = s.transaction(func(tx *txn) error { + vol, err = scanVolume(tx.QueryRow(query, id)) + return err + }) if errors.Is(err, sql.ErrNoRows) { return storage.Volume{}, storage.ErrVolumeNotFound - } else if err != nil { - return storage.Volume{}, fmt.Errorf("query failed: %w", err) } - return vol, nil + return } // StoreSector calls fn with an empty location in a writable volume. If @@ -265,7 +280,7 @@ func (s *Store) StoreSector(root types.Hash256, fn func(loc storage.SectorLocati var exists bool log := s.log.Named("StoreSector").With(zap.Stringer("root", root)) - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { sectorID, err := insertSectorDBID(tx, root) if err != nil { return fmt.Errorf("failed to get sector id: %w", err) @@ -320,7 +335,7 @@ func (s *Store) StoreSector(root types.Hash256, fn func(loc storage.SectorLocati log = log.With(zap.Int64("volume", location.Volume), zap.Uint64("index", location.Index)) log.Debug("stored sector") unlock := func() error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if err := unlockLocations(tx, locationLocks); err != nil { return fmt.Errorf("failed to unlock sector location: %w", err) } else if err := unlockSector(tx, log.Named("unlock"), sectorLockID); err != nil { @@ -383,7 +398,11 @@ func (s *Store) MigrateSectors(ctx context.Context, volumeID int64, startIndex u // store. GrowVolume must be called afterwards to initialize the volume // to its desired size. func (s *Store) AddVolume(localPath string, readOnly bool) (volumeID int64, err error) { - return addVolume(&dbTxn{s}, localPath, readOnly) + err = s.transaction(func(tx *txn) error { + volumeID, err = addVolume(tx, localPath, readOnly) + return err + }) + return } // RemoveVolume removes a storage volume from the volume store. If there @@ -404,7 +423,7 @@ func (s *Store) RemoveVolume(id int64, force bool) error { jitterSleep(time.Millisecond) } - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { // check that the volume exists var volumeID int64 err := tx.QueryRow(`SELECT id FROM storage_volumes WHERE id=$1`, id).Scan(&volumeID) @@ -435,7 +454,7 @@ func (s *Store) GrowVolume(id int64, maxSectors uint64) error { panic("maxSectors must be greater than 0") // dev error } - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { return growVolume(tx, id, maxSectors) }) } @@ -447,7 +466,7 @@ func (s *Store) ShrinkVolume(id int64, maxSectors uint64) error { panic("maxSectors must be greater than 0") // dev error } - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { // check if there are any used sectors in the shrink range var usedSectors uint64 err := tx.QueryRow(`SELECT COUNT(sector_id) FROM volume_sectors WHERE volume_id=$1 AND volume_index >= $2 AND sector_id IS NOT NULL;`, id, maxSectors).Scan(&usedSectors) @@ -484,20 +503,24 @@ func (s *Store) ShrinkVolume(id int64, maxSectors uint64) error { // SetReadOnly sets the read-only flag on a volume. func (s *Store) SetReadOnly(volumeID int64, readOnly bool) error { const query = `UPDATE storage_volumes SET read_only=$1 WHERE id=$2;` - _, err := s.exec(query, readOnly, volumeID) - return err + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(query, readOnly, volumeID) + return err + }) } // SetAvailable sets the available flag on a volume. func (s *Store) SetAvailable(volumeID int64, available bool) error { const query = `UPDATE storage_volumes SET available=$1 WHERE id=$2;` - _, err := s.exec(query, available, volumeID) - return err + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(query, available, volumeID) + return err + }) } // sectorDBID returns the ID of a sector root in the stored_sectors table. -func sectorDBID(tx txn, root types.Hash256) (id int64, err error) { - err = tx.QueryRow(`SELECT id FROM stored_sectors WHERE sector_root=$1`, sqlHash256(root)).Scan(&id) +func sectorDBID(tx *txn, root types.Hash256) (id int64, err error) { + err = tx.QueryRow(`SELECT id FROM stored_sectors WHERE sector_root=$1`, encode(root)).Scan(&id) if errors.Is(err, sql.ErrNoRows) { err = storage.ErrSectorNotFound } @@ -507,23 +530,23 @@ func sectorDBID(tx txn, root types.Hash256) (id int64, err error) { // insertSectorDBID inserts a sector root into the stored_sectors table if it // does not already exist. If the sector root already exists, the ID is // returned. -func insertSectorDBID(tx txn, root types.Hash256) (id int64, err error) { +func insertSectorDBID(tx *txn, root types.Hash256) (id int64, err error) { id, err = sectorDBID(tx, root) if errors.Is(err, storage.ErrSectorNotFound) { // insert the sector root - err = tx.QueryRow(`INSERT INTO stored_sectors (sector_root, last_access_timestamp) VALUES ($1, $2) RETURNING id`, sqlHash256(root), sqlTime(time.Now())).Scan(&id) + err = tx.QueryRow(`INSERT INTO stored_sectors (sector_root, last_access_timestamp) VALUES ($1, $2) RETURNING id`, encode(root), encode(time.Now())).Scan(&id) return } return } -func addVolume(tx txn, localPath string, readOnly bool) (volumeID int64, err error) { +func addVolume(tx *txn, localPath string, readOnly bool) (volumeID int64, err error) { const query = `INSERT INTO storage_volumes (disk_path, read_only, used_sectors, total_sectors) VALUES (?, ?, 0, 0) RETURNING id;` err = tx.QueryRow(query, localPath, readOnly).Scan(&volumeID) return } -func growVolume(tx txn, id int64, maxSectors uint64) error { +func growVolume(tx *txn, id int64, maxSectors uint64) error { var nextIndex uint64 err := tx.QueryRow(`SELECT total_sectors FROM storage_volumes WHERE id=?;`, id).Scan(&nextIndex) if err != nil { @@ -555,7 +578,7 @@ func growVolume(tx txn, id int64, maxSectors uint64) error { } // sectorLocation returns the location of a sector. -func sectorLocation(tx txn, sectorID int64, root types.Hash256) (loc storage.SectorLocation, err error) { +func sectorLocation(tx *txn, sectorID int64, root types.Hash256) (loc storage.SectorLocation, err error) { const query = `SELECT v.id, v.volume_id, v.volume_index FROM volume_sectors v WHERE v.sector_id=$1` @@ -571,7 +594,7 @@ WHERE v.sector_id=$1` // emptyLocation returns an empty location in a writable volume. If there is no // space available, ErrNotEnoughStorage is returned. -func emptyLocation(tx txn) (loc storage.SectorLocation, err error) { +func emptyLocation(tx *txn) (loc storage.SectorLocation, err error) { const query = `SELECT vs.id, vs.volume_id, vs.volume_index FROM volume_sectors vs INDEXED BY volume_sectors_sector_writes_volume_id_sector_id_volume_index_compound LEFT JOIN locked_volume_sectors lvs ON (lvs.volume_sector_id=vs.id) @@ -592,7 +615,7 @@ func emptyLocation(tx txn) (loc storage.SectorLocation, err error) { // emptyLocationForMigration returns an empty location in a writable volume. If there is no // space available, ErrNotEnoughStorage is returned. -func emptyLocationForMigration(tx txn, volumeID int64) (loc storage.SectorLocation, err error) { +func emptyLocationForMigration(tx *txn, volumeID int64) (loc storage.SectorLocation, err error) { const query = `SELECT vs.id, vs.volume_id, vs.volume_index FROM volume_sectors vs INDEXED BY volume_sectors_sector_writes_volume_id_sector_id_volume_index_compound LEFT JOIN locked_volume_sectors lvs ON (lvs.volume_sector_id=vs.id) @@ -613,7 +636,7 @@ func emptyLocationForMigration(tx txn, volumeID int64) (loc storage.SectorLocati // sectorForMigration returns the location of the first occupied sector in the // volume starting at minIndex and greater than marker. -func sectorForMigration(tx txn, volumeID int64, marker int64) (loc storage.SectorLocation, err error) { +func sectorForMigration(tx *txn, volumeID int64, marker int64) (loc storage.SectorLocation, err error) { const query = `SELECT vs.id, vs.volume_id, vs.volume_index, s.sector_root FROM volume_sectors vs INNER JOIN stored_sectors s ON (s.id=vs.sector_id) @@ -621,14 +644,14 @@ func sectorForMigration(tx txn, volumeID int64, marker int64) (loc storage.Secto ORDER BY vs.volume_index ASC LIMIT 1` - err = tx.QueryRow(query, volumeID, marker).Scan(&loc.ID, &loc.Volume, &loc.Index, (*sqlHash256)(&loc.Root)) + err = tx.QueryRow(query, volumeID, marker).Scan(&loc.ID, &loc.Volume, &loc.Index, decode(&loc.Root)) return } // locationWithinVolume returns an empty location within the same volume as // the given volumeID. If there is no space in the volume, ErrNotEnoughStorage // is returned. -func locationWithinVolume(tx txn, volumeID int64, maxIndex uint64) (loc storage.SectorLocation, err error) { +func locationWithinVolume(tx *txn, volumeID int64, maxIndex uint64) (loc storage.SectorLocation, err error) { const query = `SELECT vs.id, vs.volume_id, vs.volume_index FROM volume_sectors vs WHERE vs.sector_id IS NULL AND vs.id NOT IN (SELECT volume_sector_id FROM locked_volume_sectors) diff --git a/persist/sqlite/volumes_test.go b/persist/sqlite/volumes_test.go index b0538fcd..dc308075 100644 --- a/persist/sqlite/volumes_test.go +++ b/persist/sqlite/volumes_test.go @@ -677,7 +677,7 @@ func TestPrune(t *testing.T) { // lock the remaining sectors var locks []int64 for _, root := range lockedSectors { - err := db.transaction(func(tx txn) error { + err := db.transaction(func(tx *txn) error { sectorID, err := sectorDBID(tx, root) if err != nil { return err @@ -760,7 +760,10 @@ func TestPrune(t *testing.T) { } // unlock locked sectors - if err := unlockSector(&dbTxn{db}, log.Named("unlockSector"), locks...); err != nil { + err = db.transaction(func(tx *txn) error { + return unlockSector(tx, log.Named("unlockSector"), locks...) + }) + if err != nil { t.Fatal(err) } diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index a7a6dc3b..1fc87e94 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -1,167 +1,78 @@ package sqlite import ( - "bytes" "database/sql" + "encoding/json" "errors" "fmt" - "time" "go.sia.tech/core/types" - "go.sia.tech/hostd/wallet" - "go.sia.tech/siad/modules" + "go.sia.tech/coreutils/wallet" ) -func encodeTransaction(txn wallet.Transaction) []byte { - var buf bytes.Buffer - e := types.NewEncoder(&buf) - txn.EncodeTo(e) - e.Flush() - return buf.Bytes() -} - -func decodeTransaction(b []byte, txn *wallet.Transaction) error { - d := types.NewBufDecoder(b) - txn.DecodeFrom(d) - return d.Err() -} - -// An updateWalletTxn atomically updates the wallet -type updateWalletTxn struct { - tx txn -} - -// setLastChange sets the last processed consensus change. -func (tx *updateWalletTxn) setLastChange(id modules.ConsensusChangeID, height uint64) error { - var dbID int64 // unused, but required by QueryRow to ensure exactly one row is updated - err := tx.tx.QueryRow(`UPDATE global_settings SET wallet_last_processed_change=$1, wallet_height=$2 RETURNING id`, sqlHash256(id), height).Scan(&dbID) - return err -} - -// AddSiacoinElement adds a spendable siacoin output to the wallet. -func (tx *updateWalletTxn) AddSiacoinElement(utxo wallet.SiacoinElement) error { - _, err := tx.tx.Exec(`INSERT INTO wallet_utxos (id, amount, unlock_hash) VALUES (?, ?, ?)`, sqlHash256(utxo.ID), sqlCurrency(utxo.Value), sqlHash256(utxo.Address)) - return err -} - -// RemoveSiacoinElement removes a spendable siacoin output from the wallet -// either due to a spend or a reorg. -func (tx *updateWalletTxn) RemoveSiacoinElement(id types.SiacoinOutputID) error { - err := tx.tx.QueryRow(`DELETE FROM wallet_utxos WHERE id=? RETURNING id`, sqlHash256(id)).Scan((*sqlHash256)(&id)) - return err -} - -// AddWalletDelta adds the delta to the wallet balance metric. -func (tx updateWalletTxn) AddWalletDelta(value types.Currency, timestamp time.Time) error { - if err := incrementCurrencyStat(tx.tx, metricWalletBalance, value, false, timestamp); err != nil { - return fmt.Errorf("failed to increment wallet balance: %w", err) - } else if err := reflowCurrencyStat(tx.tx, metricWalletBalance, timestamp, value, false); err != nil { - return fmt.Errorf("failed to reflow wallet balance: %w", err) - } - return nil -} - -// SubWalletDelta subtracts the delta from the wallet balance metric. -func (tx updateWalletTxn) SubWalletDelta(value types.Currency, timestamp time.Time) error { - if err := incrementCurrencyStat(tx.tx, metricWalletBalance, value, true, timestamp); err != nil { - return fmt.Errorf("failed to increment wallet balance: %w", err) - } else if err := reflowCurrencyStat(tx.tx, metricWalletBalance, timestamp, value, true); err != nil { - return fmt.Errorf("failed to reflow wallet balance: %w", err) - } - return nil -} - -// AddTransaction adds a transaction to the wallet. -func (tx *updateWalletTxn) AddTransaction(txn wallet.Transaction) error { - const query = `INSERT INTO wallet_transactions (transaction_id, block_id, block_height, source, inflow, outflow, raw_transaction, date_created) VALUES (?, ?, ?, ?, ?, ?, ?, ?)` - _, err := tx.tx.Exec(query, - sqlHash256(txn.ID), - sqlHash256(txn.Index.ID), - txn.Index.Height, - txn.Source, - sqlCurrency(txn.Inflow), - sqlCurrency(txn.Outflow), - encodeTransaction(txn), - sqlTime(txn.Timestamp), - ) - return err -} - -// RevertBlock removes all transactions that occurred within the block from the -// wallet. -func (tx *updateWalletTxn) RevertBlock(blockID types.BlockID) error { - _, err := tx.tx.Exec(`DELETE FROM wallet_transactions WHERE block_id=?`, sqlHash256(blockID)) - return err -} - -// LastWalletChange gets the last consensus change processed by the wallet. -func (s *Store) LastWalletChange() (id modules.ConsensusChangeID, height uint64, err error) { - var nullHeight sql.NullInt64 - err = s.queryRow(`SELECT wallet_last_processed_change, wallet_height FROM global_settings`).Scan(nullable((*sqlHash256)(&id)), &nullHeight) - if errors.Is(err, sql.ErrNoRows) { - return modules.ConsensusChangeBeginning, 0, nil - } else if err != nil { - return modules.ConsensusChangeBeginning, 0, fmt.Errorf("failed to query last wallet change: %w", err) - } - height = uint64(nullHeight.Int64) - return -} +var _ wallet.SingleAddressStore = (*Store)(nil) // UnspentSiacoinElements returns the spendable siacoin outputs in the wallet. -func (s *Store) UnspentSiacoinElements() (utxos []wallet.SiacoinElement, err error) { - rows, err := s.query(`SELECT id, amount, unlock_hash FROM wallet_utxos`) - if err != nil { - return nil, fmt.Errorf("failed to query unspent siacoin elements: %w", err) - } - defer rows.Close() - for rows.Next() { - var utxo wallet.SiacoinElement - if err := rows.Scan((*sqlHash256)(&utxo.ID), (*sqlCurrency)(&utxo.Value), (*sqlHash256)(&utxo.Address)); err != nil { - return nil, fmt.Errorf("failed to scan unspent siacoin element: %w", err) +func (s *Store) UnspentSiacoinElements() (utxos []types.SiacoinElement, err error) { + err = s.transaction(func(tx *txn) error { + rows, err := tx.Query(`SELECT id, siacoin_value, sia_address, leaf_index, merkle_proof, maturity_height FROM wallet_siacoin_elements`) + if err != nil { + return fmt.Errorf("failed to query unspent siacoin elements: %w", err) } - utxos = append(utxos, utxo) - } - return utxos, nil -} - -// Transactions returns a paginated list of transactions ordered by block height -// descending. If no transactions are found, (nil, nil) is returned. -func (s *Store) Transactions(limit, offset int) (txns []wallet.Transaction, err error) { - rows, err := s.query(`SELECT transaction_id, block_id, block_height, source, inflow, outflow, raw_transaction, date_created FROM wallet_transactions ORDER BY block_height DESC, id ASC LIMIT ? OFFSET ?`, limit, offset) - if err != nil { - return nil, fmt.Errorf("failed to query transactions: %w", err) - } - defer rows.Close() - for rows.Next() { - var txn wallet.Transaction - var buf []byte - if err := rows.Scan((*sqlHash256)(&txn.ID), (*sqlHash256)(&txn.Index.ID), &txn.Index.Height, &txn.Source, (*sqlCurrency)(&txn.Inflow), (*sqlCurrency)(&txn.Outflow), &buf, (*sqlTime)(&txn.Timestamp)); err != nil { - return nil, fmt.Errorf("failed to scan transaction: %w", err) - } else if err := decodeTransaction(buf, &txn); err != nil { - return nil, fmt.Errorf("failed to unmarshal transaction data: %w", err) + defer rows.Close() + for rows.Next() { + var se types.SiacoinElement + if err := rows.Scan(decode(&se.ID), decode(&se.SiacoinOutput.Value), decode(&se.SiacoinOutput.Address), decode(&se.LeafIndex), decode(&se.MerkleProof), &se.MaturityHeight); err != nil { + return fmt.Errorf("failed to scan unspent siacoin element: %w", err) + } + utxos = append(utxos, se) } - txns = append(txns, txn) - } + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to iterate unspent siacoin elements: %w", err) + } + return nil + }) return } -// TransactionCount returns the total number of transactions in the wallet. -func (s *Store) TransactionCount() (count uint64, err error) { - err = s.queryRow(`SELECT COUNT(*) FROM wallet_transactions`).Scan(&count) +// WalletEventCount returns the total number of events relevant to the +// wallet. +func (s *Store) WalletEventCount() (count uint64, err error) { + err = s.transaction(func(tx *txn) error { + err := tx.QueryRow(`SELECT COUNT(*) FROM wallet_events`).Scan(&count) + return err + }) return } -// UpdateWallet begins an update transaction on the wallet store. -func (s *Store) UpdateWallet(ccID modules.ConsensusChangeID, height uint64, fn func(wallet.UpdateTransaction) error) error { - return s.transaction(func(tx txn) error { - utx := &updateWalletTxn{tx} - if err := fn(utx); err != nil { - return err - } else if err := utx.setLastChange(ccID, height); err != nil { - return fmt.Errorf("failed to set last wallet change: %w", err) +// WalletEvents returns a paginated list of transactions ordered by +// maturity height, descending. If no more transactions are available, +// (nil, nil) should be returned. +func (s *Store) WalletEvents(offset, limit int) (events []wallet.Event, err error) { + err = s.transaction(func(tx *txn) error { + rows, err := tx.Query(`SELECT raw_data FROM wallet_events ORDER BY maturity_height DESC LIMIT ? OFFSET ?`, limit, offset) + if err != nil { + return fmt.Errorf("failed to query wallet events: %w", err) + } + defer rows.Close() + + for rows.Next() { + var buf []byte + if err := rows.Scan(&buf); err != nil { + return fmt.Errorf("failed to scan wallet event: %w", err) + } + var event wallet.Event + if err := json.Unmarshal(buf, &event); err != nil { + return fmt.Errorf("failed to unmarshal wallet event: %w", err) + } + events = append(events, event) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to iterate wallet events: %w", err) } return nil }) + return } // VerifyWalletKey checks that the wallet seed matches the seed hash. @@ -169,30 +80,15 @@ func (s *Store) UpdateWallet(ccID modules.ConsensusChangeID, height uint64, fn f // to rescan. func (s *Store) VerifyWalletKey(seedHash types.Hash256) error { var buf []byte - err := s.queryRow(`SELECT wallet_hash FROM global_settings`).Scan(&buf) - if err == nil && buf == nil { - _, err := s.exec(`UPDATE global_settings SET wallet_hash=?`, sqlHash256(seedHash)) // wallet not initialized, set seed hash - return err - } else if err != nil { - return fmt.Errorf("failed to query wallet seed hash: %w", err) - } else if seedHash != *(*types.Hash256)(buf) { - return wallet.ErrDifferentSeed - } - return nil -} - -// ResetWallet resets the wallet to its initial state. This is used when a -// consensus subscription error occurs. -func (s *Store) ResetWallet(seedHash types.Hash256) error { - return s.transaction(func(tx txn) error { - if _, err := tx.Exec(`DELETE FROM wallet_utxos`); err != nil { - return fmt.Errorf("failed to delete wallet utxos: %w", err) - } else if _, err := tx.Exec(`DELETE FROM wallet_transactions`); err != nil { - return fmt.Errorf("failed to delete wallet transactions: %w", err) - } else if _, err := tx.Exec(`DELETE FROM host_stats WHERE stat=$1`, metricWalletBalance); err != nil { - return fmt.Errorf("failed to delete wallet metrics: %w", err) - } else if _, err := tx.Exec(`UPDATE global_settings SET wallet_last_processed_change=NULL, wallet_height=NULL, wallet_hash=?`, sqlHash256(seedHash)); err != nil { - return fmt.Errorf("failed to reset wallet settings: %w", err) + return s.transaction(func(tx *txn) error { + err := tx.QueryRow(`SELECT wallet_hash FROM global_settings`).Scan(&buf) + if errors.Is(err, sql.ErrNoRows) { + _, err := tx.Exec(`UPDATE global_settings SET wallet_hash=?`, encode(seedHash)) // wallet not initialized, set seed hash + return err + } else if err != nil { + return fmt.Errorf("failed to query wallet seed hash: %w", err) + } else if seedHash != *(*types.Hash256)(buf) { + return wallet.ErrDifferentSeed } return nil }) diff --git a/persist/sqlite/wallet_test.go b/persist/sqlite/wallet_test.go new file mode 100644 index 00000000..07f50060 --- /dev/null +++ b/persist/sqlite/wallet_test.go @@ -0,0 +1,70 @@ +package sqlite_test + +import ( + "context" + "testing" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/hostd/internal/testutil" + "go.sia.tech/hostd/persist/sqlite" + "go.uber.org/zap/zaptest" +) + +func TestWalletMetrics(t *testing.T) { + log := zaptest.NewLogger(t) + network, genesis := testutil.V2Network() + n1 := testutil.NewConsensusNode(t, network, genesis, log.Named("node1")) + + h1 := testutil.NewHostNode(t, types.GeneratePrivateKey(), network, genesis, log.Named("host")) + + if _, err := h1.Syncer.Connect(context.Background(), n1.Syncer.Addr()); err != nil { + t.Fatal(err) + } + + mineAndSync := func(t *testing.T, cn *testutil.ConsensusNode, addr types.Address, n int) { + t.Helper() + + for i := 0; i < n; i++ { + testutil.MineBlocks(t, cn, addr, 1) + testutil.WaitForSync(t, cn.Chain, h1.Indexer) + } + } + + assertWalletMetrics := func(t *testing.T, db *sqlite.Store, mature types.Currency, immature types.Currency) { + t.Helper() + + m, err := db.Metrics(time.Now()) + if err != nil { + t.Fatal(err) + } else if !m.Wallet.Balance.Equals(mature) { + t.Fatalf("expected mature balance %v, got %v", mature, m.Wallet.Balance) + } else if !m.Wallet.ImmatureBalance.Equals(immature) { + t.Fatalf("expected immature balance %v, got %v", immature, m.Wallet.ImmatureBalance) + } + } + + var expectedMature types.Currency + expectedImmature := n1.Chain.TipState().BlockReward() + + // mine a single block to get the first block reward + mineAndSync(t, n1, h1.Wallet.Address(), 1) + assertWalletMetrics(t, h1.Store, expectedMature, expectedImmature) + + // mine until the first block reward matures + mineAndSync(t, n1, types.VoidAddress, 144) + expectedMature = expectedImmature + expectedImmature = types.ZeroCurrency + assertWalletMetrics(t, h1.Store, expectedMature, expectedImmature) + + // mine a secondary chain to reorg the first chain + n2 := testutil.NewConsensusNode(t, network, genesis, log.Named("node2")) + testutil.MineBlocks(t, n2, types.VoidAddress, 250) + + t.Log("connecting peer 2") + if _, err := h1.Syncer.Connect(context.Background(), n2.Syncer.Addr()); err != nil { + t.Fatal(err) + } + testutil.WaitForSync(t, n2.Chain, h1.Indexer) + assertWalletMetrics(t, h1.Store, types.ZeroCurrency, types.ZeroCurrency) +} diff --git a/persist/sqlite/webhooks.go b/persist/sqlite/webhooks.go index cca63b74..fdb55272 100644 --- a/persist/sqlite/webhooks.go +++ b/persist/sqlite/webhooks.go @@ -6,40 +6,49 @@ import ( "go.sia.tech/hostd/webhooks" ) -// RegisterWebHook registers a new webhook. -func (s *Store) RegisterWebHook(url, secret string, scopes []string) (id int64, err error) { - err = s.queryRow("INSERT INTO webhooks (callback_url, secret_key, scopes) VALUES (?, ?, ?) RETURNING id", url, secret, strings.Join(scopes, ",")).Scan(&id) +// RegisterWebhook registers a new webhook. +func (s *Store) RegisterWebhook(url, secret string, scopes []string) (id int64, err error) { + err = s.transaction(func(tx *txn) error { + return tx.QueryRow("INSERT INTO webhooks (callback_url, secret_key, scopes) VALUES (?, ?, ?) RETURNING id", url, secret, strings.Join(scopes, ",")).Scan(&id) + }) return } -// UpdateWebHook updates a webhook. -func (s *Store) UpdateWebHook(id int64, url string, scopes []string) error { - var dbID int64 - return s.queryRow(`UPDATE webhooks SET callback_url = ?, scopes = ? WHERE id = ? RETURNING id`, url, strings.Join(scopes, ","), id).Scan(&dbID) +// UpdateWebhook updates a webhook. +func (s *Store) UpdateWebhook(id int64, url string, scopes []string) error { + return s.transaction(func(tx *txn) error { + var dbID int64 + return tx.QueryRow(`UPDATE webhooks SET callback_url = ?, scopes = ? WHERE id = ? RETURNING id`, url, strings.Join(scopes, ","), id).Scan(&dbID) + }) } -// RemoveWebHook removes a webhook. -func (s *Store) RemoveWebHook(id int64) error { - _, err := s.exec("DELETE FROM webhooks WHERE id = ?", id) - return err +// RemoveWebhook removes a webhook. +func (s *Store) RemoveWebhook(id int64) error { + return s.transaction(func(tx *txn) error { + _, err := tx.Exec("DELETE FROM webhooks WHERE id = ?", id) + return err + }) } -// WebHooks returns all webhooks. -func (s *Store) WebHooks() ([]webhooks.WebHook, error) { - rows, err := s.query("SELECT id, callback_url, secret_key, scopes FROM webhooks") - if err != nil { - return nil, err - } - defer rows.Close() - var hooks []webhooks.WebHook - for rows.Next() { - var hook webhooks.WebHook - var scopes string - if err := rows.Scan(&hook.ID, &hook.CallbackURL, &hook.SecretKey, &scopes); err != nil { - return nil, err +// Webhooks returns all webhooks. +func (s *Store) Webhooks() (hooks []webhooks.Webhook, err error) { + err = s.transaction(func(tx *txn) error { + rows, err := tx.Query("SELECT id, callback_url, secret_key, scopes FROM webhooks") + if err != nil { + return err } - hook.Scopes = strings.Split(scopes, ",") - hooks = append(hooks, hook) - } - return hooks, nil + defer rows.Close() + + for rows.Next() { + var hook webhooks.Webhook + var scopes string + if err := rows.Scan(&hook.ID, &hook.CallbackURL, &hook.SecretKey, &scopes); err != nil { + return err + } + hook.Scopes = strings.Split(scopes, ",") + hooks = append(hooks, hook) + } + return rows.Err() + }) + return } diff --git a/rhp/conn.go b/rhp/conn.go index d1d8e555..a4385c3a 100644 --- a/rhp/conn.go +++ b/rhp/conn.go @@ -15,7 +15,7 @@ type ( WriteBytes(n int) } - // An Conn wraps a net.Conn to track the amount of data read and written and + // A Conn wraps a net.Conn to track the amount of data read and written and // limit bandwidth usage. Conn struct { net.Conn @@ -23,8 +23,17 @@ type ( monitor DataMonitor rl, wl *rate.Limiter } + + // A noOpMonitor is a DataMonitor that does nothing. + noOpMonitor struct{} ) +// ReadBytes implements DataMonitor +func (noOpMonitor) ReadBytes(n int) {} + +// WriteBytes implements DataMonitor +func (noOpMonitor) WriteBytes(n int) {} + // Usage returns the amount of data read and written by the connection. func (c *Conn) Usage() (read, written uint64) { read = atomic.LoadUint64(&c.r) @@ -66,3 +75,8 @@ func NewConn(c net.Conn, m DataMonitor, rl, wl *rate.Limiter) *Conn { wl: wl, } } + +// NewNoOpMonitor initializes a new NoOpMonitor. +func NewNoOpMonitor() DataMonitor { + return noOpMonitor{} +} diff --git a/rhp/contracts.go b/rhp/contracts.go index 7c64a3b0..18dd0166 100644 --- a/rhp/contracts.go +++ b/rhp/contracts.go @@ -74,7 +74,7 @@ func HashRevision(rev types.FileContractRevision) types.Hash256 { // InitialRevision returns the first revision of a file contract formation // transaction. -func InitialRevision(formationTxn *types.Transaction, hostPubKey, renterPubKey types.UnlockKey) types.FileContractRevision { +func InitialRevision(formationTxn types.Transaction, hostPubKey, renterPubKey types.UnlockKey) types.FileContractRevision { fc := formationTxn.FileContracts[0] return types.FileContractRevision{ ParentID: formationTxn.FileContractID(0), diff --git a/rhp/v2/options.go b/rhp/v2/options.go new file mode 100644 index 00000000..c7f9a6de --- /dev/null +++ b/rhp/v2/options.go @@ -0,0 +1,41 @@ +package rhp + +import ( + "go.sia.tech/core/types" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/rhp" + "go.uber.org/zap" +) + +// A SessionHandlerOption is a functional option for session handlers. +type SessionHandlerOption func(*SessionHandler) + +// WithLog sets the logger for the session handler. +func WithLog(l *zap.Logger) SessionHandlerOption { + return func(s *SessionHandler) { + s.log = l + } +} + +// WithSessionReporter sets the session reporter for the session handler. +func WithSessionReporter(r SessionReporter) SessionHandlerOption { + return func(s *SessionHandler) { + s.sessions = r + } +} + +// WithDataMonitor sets the data monitor for the session handler. +func WithDataMonitor(m rhp.DataMonitor) SessionHandlerOption { + return func(s *SessionHandler) { + s.monitor = m + } +} + +type noopSessionReporter struct{} + +func (noopSessionReporter) StartSession(conn *rhp.Conn, proto string, version int) (sessionID rhp.UID, end func()) { + return rhp.UID{}, func() {} +} +func (noopSessionReporter) StartRPC(sessionID rhp.UID, rpc types.Specifier) (rpcID rhp.UID, end func(contracts.Usage, error)) { + return rhp.UID{}, func(contracts.Usage, error) {} +} diff --git a/rhp/v2/rhp.go b/rhp/v2/rhp.go index 0268413e..3db917e3 100644 --- a/rhp/v2/rhp.go +++ b/rhp/v2/rhp.go @@ -46,7 +46,7 @@ type ( ReviseContract(contractID types.FileContractID) (*contracts.ContractUpdater, error) // SectorRoots returns the sector roots of the contract with the given ID. - SectorRoots(id types.FileContractID) ([]types.Hash256, error) + SectorRoots(id types.FileContractID) []types.Hash256 } // A StorageManager manages the storage of sectors on disk. @@ -65,25 +65,29 @@ type ( // A ChainManager provides access to the current state of the blockchain. ChainManager interface { + Tip() types.ChainIndex TipState() consensus.State + UnconfirmedParents(txn types.Transaction) []types.Transaction + AddPoolTransactions([]types.Transaction) (known bool, err error) + AddV2PoolTransactions(types.ChainIndex, []types.V2Transaction) (known bool, err error) + } + + // A Syncer broadcasts transactions to the network + Syncer interface { + BroadcastTransactionSet([]types.Transaction) + BroadcastV2TransactionSet(types.ChainIndex, []types.V2Transaction) } // A Wallet manages funds and signs transactions Wallet interface { Address() types.Address - FundTransaction(txn *types.Transaction, amount types.Currency) ([]types.Hash256, func(), error) - SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error - } - - // A TransactionPool broadcasts transactions to the network. - TransactionPool interface { - AcceptTransactionSet([]types.Transaction) error - RecommendedFee() types.Currency + FundTransaction(txn *types.Transaction, amount types.Currency, unconfirmed bool) ([]types.Hash256, error) + SignTransaction(txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) + ReleaseInputs(txn []types.Transaction, v2txn []types.V2Transaction) } // A SettingsReporter reports the host's current configuration. SettingsReporter interface { - DiscoveredRHP2Address() string Settings() settings.Settings BandwidthLimiters() (ingress, egress *rate.Limiter) } @@ -104,8 +108,8 @@ type ( monitor rhp.DataMonitor tg *threadgroup.ThreadGroup - cm ChainManager - tpool TransactionPool + chain ChainManager + syncer Syncer wallet Wallet contracts ContractManager @@ -128,6 +132,13 @@ func (sh *SessionHandler) rpcLoop(sess *session, log *zap.Logger) error { return fmt.Errorf("failed to read RPC ID: %w", err) } + cs := sh.chain.TipState() + // disable rhp2 after v2 require height + if cs.Index.Height >= cs.Network.HardforkV2.RequireHeight { + sess.t.WriteResponseErr(ErrV2Hardfork) + return ErrV2Hardfork + } + rpcFn, ok := map[types.Specifier]func(*session, *zap.Logger) (contracts.Usage, error){ rhp2.RPCFormContractID: sh.rpcFormContract, rhp2.RPCRenewClearContractID: sh.rpcRenewAndClearContract, @@ -207,15 +218,6 @@ func (sh *SessionHandler) Settings() (rhp2.HostSettings, error) { return rhp2.HostSettings{}, fmt.Errorf("failed to get storage usage: %w", err) } - netaddr := settings.NetAddress - if netaddr == "" { - netaddr = sh.settings.DiscoveredRHP2Address() - } - // if the net address is still empty, return an error - if netaddr == "" { - return rhp2.HostSettings{}, errors.New("no net address found") - } - return rhp2.HostSettings{ // build info Release: "hostd " + build.Version(), @@ -225,7 +227,7 @@ func (sh *SessionHandler) Settings() (rhp2.HostSettings, error) { // host info Address: sh.wallet.Address(), SiaMuxPort: sh.rhp3Port, - NetAddress: netaddr, + NetAddress: settings.NetAddress, TotalStorage: totalSectors * rhp2.SectorSize, RemainingStorage: (totalSectors - usedSectors) * rhp2.SectorSize, @@ -285,7 +287,7 @@ func (sh *SessionHandler) LocalAddr() string { } // NewSessionHandler creates a new RHP2 SessionHandler -func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, rhp3Addr string, cm ChainManager, tpool TransactionPool, wallet Wallet, contracts ContractManager, settings SettingsReporter, storage StorageManager, monitor rhp.DataMonitor, sessions SessionReporter, log *zap.Logger) (*SessionHandler, error) { +func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, rhp3Addr string, cm ChainManager, s Syncer, wallet Wallet, contracts ContractManager, settings SettingsReporter, storage StorageManager, opts ...SessionHandlerOption) (*SessionHandler, error) { _, rhp3Port, err := net.SplitHostPort(rhp3Addr) if err != nil { return nil, fmt.Errorf("failed to parse rhp3 addr: %w", err) @@ -293,20 +295,26 @@ func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, rhp3Addr string sh := &SessionHandler{ privateKey: hostKey, - tg: threadgroup.New(), rhp3Port: rhp3Port, listener: l, - monitor: monitor, - cm: cm, - tpool: tpool, - wallet: wallet, + + chain: cm, + syncer: s, + wallet: wallet, contracts: contracts, - sessions: sessions, settings: settings, storage: storage, - log: log, + + log: zap.NewNop(), + monitor: rhp.NewNoOpMonitor(), + sessions: noopSessionReporter{}, + + tg: threadgroup.New(), + } + for _, opt := range opts { + opt(sh) } return sh, nil } diff --git a/rhp/v2/rpc.go b/rhp/v2/rpc.go index d46bd35c..bb057acd 100644 --- a/rhp/v2/rpc.go +++ b/rhp/v2/rpc.go @@ -11,9 +11,9 @@ import ( rhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/rhp" - "go.sia.tech/hostd/wallet" "go.uber.org/zap" ) @@ -35,6 +35,14 @@ var ( // ErrNotAcceptingContracts is returned when the host is not accepting // contracts. ErrNotAcceptingContracts = errors.New("host is not accepting contracts") + + // ErrV2Hardfork is returned when a renter tries to form or renew a contract + // after the v2 hardfork has been activated. + ErrV2Hardfork = errors.New("hardfork v2 is active") + + // ErrAfterV2Hardfork is returned when a renter tries to form or renew a + // contract that ends after the v2 hardfork has been activated. + ErrAfterV2Hardfork = errors.New("proof window after hardfork v2 activation") ) func (sh *SessionHandler) rpcSettings(s *session, log *zap.Logger) (contracts.Usage, error) { @@ -115,6 +123,13 @@ func (sh *SessionHandler) rpcUnlock(s *session, log *zap.Logger) (contracts.Usag // rpcFormContract is an RPC that forms a contract between a renter and the // host. func (sh *SessionHandler) rpcFormContract(s *session, log *zap.Logger) (contracts.Usage, error) { + cs := sh.chain.TipState() + // prevent forming v1 contracts after the allow height + if cs.Index.Height >= cs.Network.HardforkV2.AllowHeight { + s.t.WriteResponseErr(ErrV2Hardfork) + return contracts.Usage{}, ErrV2Hardfork + } + if !sh.settings.Settings().AcceptingContracts { s.t.WriteResponseErr(ErrNotAcceptingContracts) return contracts.Usage{}, ErrNotAcceptingContracts @@ -143,13 +158,21 @@ func (sh *SessionHandler) rpcFormContract(s *session, log *zap.Logger) (contract s.t.WriteResponseErr(ErrHostInternalError) return contracts.Usage{}, fmt.Errorf("failed to get host settings: %w", err) } - currentHeight := sh.cm.TipState().Index.Height + // get the contract from the transaction set - formationTxn := &formationTxnSet[len(formationTxnSet)-1] + formationTxn, formationTxnSet := formationTxnSet[len(formationTxnSet)-1], formationTxnSet[:len(formationTxnSet)-1] + fc := formationTxn.FileContracts[0] + + // prevent forming contracts that end after the v2 hardfork + if fc.WindowStart >= cs.Network.HardforkV2.RequireHeight { + err := ErrAfterV2Hardfork + s.t.WriteResponseErr(err) + return contracts.Usage{}, err + } // validate the contract formation fields. note: the v1 contract type // does not contain the public keys or signatures. - hostCollateral, err := validateContractFormation(formationTxn.FileContracts[0], hostPub.UnlockKey(), renterPub.UnlockKey(), currentHeight, settings) + hostCollateral, err := validateContractFormation(formationTxn.FileContracts[0], hostPub.UnlockKey(), renterPub.UnlockKey(), sh.chain.Tip().Height, settings) if err != nil { err := fmt.Errorf("contract rejected: validation failed: %w", err) s.t.WriteResponseErr(err) @@ -158,7 +181,7 @@ func (sh *SessionHandler) rpcFormContract(s *session, log *zap.Logger) (contract // calculate the host's collateral and add the inputs to the transaction renterInputs, renterOutputs := len(formationTxn.SiacoinInputs), len(formationTxn.SiacoinOutputs) - toSign, release, err := sh.wallet.FundTransaction(formationTxn, hostCollateral) + toSign, err := sh.wallet.FundTransaction(&formationTxn, hostCollateral, false) if err != nil { remoteErr := ErrHostInternalError if errors.Is(err, wallet.ErrNotEnoughFunds) { @@ -179,17 +202,17 @@ func (sh *SessionHandler) rpcFormContract(s *session, log *zap.Logger) (contract Outputs: formationTxn.SiacoinOutputs[renterOutputs:], } if err := s.writeResponse(hostAdditionsResp, 30*time.Second); err != nil { - release() + sh.wallet.ReleaseInputs(append(formationTxnSet, formationTxn), nil) return contracts.Usage{}, fmt.Errorf("failed to write host additions: %w", err) } // read and validate the renter's signatures var renterSignaturesResp rhp2.RPCFormContractSignatures if err := s.readResponse(&renterSignaturesResp, 10*minMessageSize, 30*time.Second); err != nil { - release() + sh.wallet.ReleaseInputs(append(formationTxnSet, formationTxn), nil) return contracts.Usage{}, fmt.Errorf("failed to read renter signatures: %w", err) } else if err := validateRenterRevisionSignature(renterSignaturesResp.RevisionSignature, initialRevision.ParentID, sigHash, renterPub); err != nil { - release() + sh.wallet.ReleaseInputs(append(formationTxnSet, formationTxn), nil) err := fmt.Errorf("contract rejected: validation failed: %w", err) s.t.WriteResponseErr(err) return contracts.Usage{}, err @@ -199,18 +222,16 @@ func (sh *SessionHandler) rpcFormContract(s *session, log *zap.Logger) (contract formationTxn.Signatures = renterSignaturesResp.ContractSignatures // sign and broadcast the formation transaction - if err = sh.wallet.SignTransaction(sh.cm.TipState(), formationTxn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - s.t.WriteResponseErr(ErrHostInternalError) - return contracts.Usage{}, fmt.Errorf("failed to sign formation transaction: %w", err) - } else if err = sh.tpool.AcceptTransactionSet(formationTxnSet); err != nil { - release() + sh.wallet.SignTransaction(&formationTxn, toSign, types.CoveredFields{WholeTransaction: true}) + + formationTxnSet = append(formationTxnSet, formationTxn) + if _, err := sh.chain.AddPoolTransactions(formationTxnSet); err != nil { + sh.wallet.ReleaseInputs(formationTxnSet, nil) err = fmt.Errorf("failed to broadcast formation transaction: %w", err) - buf, _ := json.Marshal(formationTxnSet) - log.Error("failed to broadcast formation transaction", zap.Error(err), zap.String("txnset", string(buf))) s.t.WriteResponseErr(err) return contracts.Usage{}, err } + sh.syncer.BroadcastTransactionSet(formationTxnSet) signedRevision := contracts.SignedRevision{ Revision: initialRevision, @@ -240,7 +261,13 @@ func (sh *SessionHandler) rpcFormContract(s *session, log *zap.Logger) (contract // rpcRenewAndClearContract is an RPC that renews a contract and clears the // existing contract func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) (contracts.Usage, error) { - state := sh.cm.TipState() + cs := sh.chain.TipState() + // prevent renewing v1 contracts after the allow height + if cs.Index.Height >= cs.Network.HardforkV2.AllowHeight { + s.t.WriteResponseErr(ErrV2Hardfork) + return contracts.Usage{}, ErrV2Hardfork + } + settings, err := sh.Settings() if err != nil { s.t.WriteResponseErr(ErrHostInternalError) @@ -253,7 +280,7 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) hostUnlockKey := sh.privateKey.PublicKey().UnlockKey() // make sure the current contract is revisable - if err := s.ContractRevisable(state.Index.Height); err != nil { + if err := s.ContractRevisable(); err != nil { err := fmt.Errorf("contract not revisable: %w", err) s.t.WriteResponseErr(err) return contracts.Usage{}, err @@ -277,10 +304,16 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) s.t.WriteResponseErr(err) return contracts.Usage{}, err } - renewalParents := renewalTxnSet[:len(renewalTxnSet)-1] - renewalTxn := renewalTxnSet[len(renewalTxnSet)-1] + renewalTxn, renewalParents := renewalTxnSet[len(renewalTxnSet)-1], renewalTxnSet[:len(renewalTxnSet)-1] renewedContract := renewalTxn.FileContracts[0] + // prevent forming contracts that end after the v2 hardfork + if renewedContract.WindowStart >= cs.Network.HardforkV2.RequireHeight { + err := ErrAfterV2Hardfork + s.t.WriteResponseErr(err) + return contracts.Usage{}, err + } + existingRevision := s.contract.Revision clearingRevision, err := rhp.ClearingRevision(existingRevision, req.FinalValidProofValues) if err != nil { @@ -316,7 +349,7 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) } // validate the renewal - baseRevenue, riskedCollateral, lockedCollateral, err := validateContractRenewal(existingRevision, renewedContract, hostUnlockKey, req.RenterKey, baseRevenue, baseCollateral, state.Index.Height, settings) + baseRevenue, riskedCollateral, lockedCollateral, err := validateContractRenewal(existingRevision, renewedContract, hostUnlockKey, req.RenterKey, baseRevenue, baseCollateral, sh.chain.Tip().Height, settings) if err != nil { err = fmt.Errorf("invalid contract renewal: %w", err) s.t.WriteResponseErr(err) @@ -329,7 +362,7 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) } renterInputs, renterOutputs := len(renewalTxn.SiacoinInputs), len(renewalTxn.SiacoinOutputs) - toSign, release, err := sh.wallet.FundTransaction(&renewalTxn, lockedCollateral) + toSign, err := sh.wallet.FundTransaction(&renewalTxn, lockedCollateral, false) if err != nil { remoteErr := ErrHostInternalError if errors.Is(err, wallet.ErrNotEnoughFunds) { @@ -345,37 +378,32 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) Outputs: renewalTxn.SiacoinOutputs[renterOutputs:], } if err = s.writeResponse(hostAdditionsResp, 30*time.Second); err != nil { - release() + sh.wallet.ReleaseInputs(append(renewalParents, renewalTxn), nil) return contracts.Usage{}, fmt.Errorf("failed to write host additions: %w", err) } // read the renter's signatures for the renewal var renterSigsResp rhp2.RPCRenewAndClearContractSignatures if err = s.readResponse(&renterSigsResp, minMessageSize, 30*time.Second); err != nil { - release() + sh.wallet.ReleaseInputs(append(renewalParents, renewalTxn), nil) return contracts.Usage{}, fmt.Errorf("failed to read renter signatures: %w", err) } else if len(renterSigsResp.RevisionSignature.Signature) != 64 { - release() + sh.wallet.ReleaseInputs(append(renewalParents, renewalTxn), nil) return contracts.Usage{}, fmt.Errorf("invalid renter signature length: %w", ErrInvalidRenterSignature) } // add the renter's signatures to the formation transaction renewalTxn.Signatures = append(renewalTxn.Signatures, renterSigsResp.ContractSignatures...) // sign the transaction - if err = sh.wallet.SignTransaction(state, &renewalTxn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - s.t.WriteResponseErr(ErrHostInternalError) - return contracts.Usage{}, fmt.Errorf("failed to sign renewal transaction: %w", err) - } - + sh.wallet.SignTransaction(&renewalTxn, toSign, types.CoveredFields{WholeTransaction: true}) // create the initial revision - initialRevision := rhp.InitialRevision(&renewalTxn, hostUnlockKey, req.RenterKey) + initialRevision := rhp.InitialRevision(renewalTxn, hostUnlockKey, req.RenterKey) // verify the clearing revision signature clearingRevSigHash := rhp.HashRevision(clearingRevision) // important: verify using the existing contract's renter key if !s.contract.RenterKey().VerifyHash(clearingRevSigHash, renterSigsResp.FinalRevisionSignature) { - release() + sh.wallet.ReleaseInputs(append(renewalParents, renewalTxn), nil) err := fmt.Errorf("failed to verify clearing revision signature: %w", ErrInvalidRenterSignature) s.t.WriteResponseErr(err) return contracts.Usage{}, err @@ -385,7 +413,7 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) renewalSigHash := rhp.HashRevision(initialRevision) renterRenewalSig := *(*types.Signature)(renterSigsResp.RevisionSignature.Signature) if !renterKey.VerifyHash(renewalSigHash, renterRenewalSig) { - release() + sh.wallet.ReleaseInputs(append(renewalParents, renewalTxn), nil) err := fmt.Errorf("failed to verify renewal revision signature: %w", ErrInvalidRenterSignature) s.t.WriteResponseErr(err) return contracts.Usage{}, err @@ -402,14 +430,16 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) HostSignature: sh.privateKey.SignHash(renewalSigHash), } - // broadcast the transaction + // validate & broadcast the transaction renewalTxnSet = append(renewalParents, renewalTxn) - if err = sh.tpool.AcceptTransactionSet(renewalTxnSet); err != nil { - release() + if _, err = sh.chain.AddPoolTransactions(renewalTxnSet); err != nil { + sh.wallet.ReleaseInputs(renewalTxnSet, nil) err = fmt.Errorf("failed to broadcast renewal transaction: %w", err) s.t.WriteResponseErr(err) return contracts.Usage{}, err } + sh.syncer.BroadcastTransactionSet(renewalTxnSet) + // update the existing contract and add the renewed contract to the store if err := sh.contracts.RenewContract(signedRenewal, signedClearing, renewalTxnSet, lockedCollateral, clearingUsage, renewalUsage); err != nil { s.t.WriteResponseErr(ErrHostInternalError) @@ -419,7 +449,7 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) // send the host signatures to the renter hostSigsResp := &rhp2.RPCRenewAndClearContractSignatures{ ContractSignatures: renewalTxn.Signatures[len(renterSigsResp.ContractSignatures):], - RevisionSignature: signedRenewal.Signatures()[0], + RevisionSignature: signedRenewal.Signatures()[1], FinalRevisionSignature: signedClearing.HostSignature, } return clearingUsage.Add(renewalUsage), s.writeResponse(hostSigsResp, 30*time.Second) @@ -427,8 +457,7 @@ func (sh *SessionHandler) rpcRenewAndClearContract(s *session, log *zap.Logger) // rpcSectorRoots returns the Merkle roots of the sectors in a contract func (sh *SessionHandler) rpcSectorRoots(s *session, log *zap.Logger) (contracts.Usage, error) { - currentHeight := sh.cm.TipState().Index.Height - if err := s.ContractRevisable(currentHeight); err != nil { + if err := s.ContractRevisable(); err != nil { err := fmt.Errorf("contract not revisable: %w", err) s.t.WriteResponseErr(err) return contracts.Usage{}, err @@ -493,11 +522,8 @@ func (sh *SessionHandler) rpcSectorRoots(s *session, log *zap.Logger) (contracts return contracts.Usage{}, err } - roots, err := sh.contracts.SectorRoots(s.contract.Revision.ParentID) - if err != nil { - s.t.WriteResponseErr(ErrHostInternalError) - return contracts.Usage{}, fmt.Errorf("failed to get sector roots: %w", err) - } else if uint64(len(roots)) != contractSectors { + roots := sh.contracts.SectorRoots(s.contract.Revision.ParentID) + if uint64(len(roots)) != contractSectors { s.t.WriteResponseErr(ErrHostInternalError) return contracts.Usage{}, fmt.Errorf("inconsistent sector roots: expected %v, got %v", contractSectors, len(roots)) } @@ -536,9 +562,9 @@ func (sh *SessionHandler) rpcSectorRoots(s *session, log *zap.Logger) (contracts } func (sh *SessionHandler) rpcWrite(s *session, log *zap.Logger) (contracts.Usage, error) { - currentHeight := sh.cm.TipState().Index.Height + currentHeight := sh.chain.Tip().Height // get the locked contract and check that it is revisable - if err := s.ContractRevisable(currentHeight); err != nil { + if err := s.ContractRevisable(); err != nil { err := fmt.Errorf("contract not revisable: %w", err) s.t.WriteResponseErr(err) return contracts.Usage{}, err @@ -736,9 +762,8 @@ func (sh *SessionHandler) rpcWrite(s *session, log *zap.Logger) (contracts.Usage } func (sh *SessionHandler) rpcRead(s *session, log *zap.Logger) (contracts.Usage, error) { - currentHeight := sh.cm.TipState().Index.Height // get the locked contract and check that it is revisable - if err := s.ContractRevisable(currentHeight); err != nil { + if err := s.ContractRevisable(); err != nil { err := fmt.Errorf("contract not revisable: %w", err) s.t.WriteResponseErr(err) return contracts.Usage{}, err diff --git a/rhp/v2/rpc_test.go b/rhp/v2/rpc_test.go index 4e421de7..7d780416 100644 --- a/rhp/v2/rpc_test.go +++ b/rhp/v2/rpc_test.go @@ -3,661 +3,511 @@ package rhp_test import ( "bytes" "context" - "io" + "errors" + "net" "path/filepath" "reflect" + "runtime" "testing" - "time" - rhp2 "go.sia.tech/core/rhp/v2" + crhp2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/coreutils/wallet" - "go.sia.tech/hostd/internal/test" + "go.sia.tech/hostd/internal/testutil" + rpc2 "go.sia.tech/hostd/internal/testutil/rhp/v2" + "go.sia.tech/hostd/rhp" + rhp2 "go.sia.tech/hostd/rhp/v2" "go.uber.org/goleak" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) +func dialHost(t *testing.T, hostKey types.PublicKey, netaddr string) *crhp2.Transport { + t.Helper() + + conn, err := net.Dial("tcp", netaddr) + if err != nil { + t.Fatal("failed to dial host", err) + } + t.Cleanup(func() { conn.Close() }) + transport, err := crhp2.NewRenterTransport(conn, hostKey) + if err != nil { + t.Fatal("failed to create transport", err) + } + return transport +} + func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } func TestSettings(t *testing.T) { log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) + hostKey := types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + s := node.Settings.Settings() + s.NetAddress = "localhost:9983" + if err := node.Settings.UpdateSettings(s); err != nil { + t.Fatal(err) + } + + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) if err != nil { t.Fatal(err) } - defer renter.Close() - defer host.Close() + defer sh.Close() + go sh.Serve() - hostSettings, err := host.RHP2Settings() + transport := dialHost(t, hostKey.PublicKey(), l.Addr().String()) + defer transport.Close() + + settings, err := rpc2.RPCSettings(transport) if err != nil { t.Fatal(err) } - renterSettings, err := renter.Settings(context.Background(), host.RHP2Addr(), host.PublicKey()) + expected, err := sh.Settings() if err != nil { t.Fatal(err) - } else if !reflect.DeepEqual(hostSettings, renterSettings) { - t.Errorf("host settings mismatch") + } + + if !reflect.DeepEqual(settings, expected) { + t.Fatal("settings mismatch") } } func TestUploadDownload(t *testing.T) { log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) - if err != nil { + renterKey, hostKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // set the host to accept contracts + s := node.Settings.Settings() + s.AcceptingContracts = true + s.NetAddress = "localhost:9983" + if err := node.Settings.UpdateSettings(s); err != nil { t.Fatal(err) } - defer renter.Close() - defer host.Close() - // form a contract - contract, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), 200) - if err != nil { + // initialize a storage volume + res := make(chan error) + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(t.TempDir(), "storage.dat"), 10, res); err != nil { + t.Fatal(err) + } else if err := <-res; err != nil { t.Fatal(err) } - session, err := renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), contract.ID()) + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) + + l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) } - defer session.Close() + defer l.Close() - // generate a sector - var sector [rhp2.SectorSize]byte - frand.Read(sector[:256]) - sectorRoot := rhp2.SectorRoot(§or) - - // calculate the remaining duration of the contract - var remainingDuration uint64 - contractExpiration := uint64(session.Revision().Revision.WindowEnd) - currentHeight := renter.TipState().Index.Height - if contractExpiration < currentHeight { - t.Fatal("contract expired") - } - // upload the sector - remainingDuration = contractExpiration - currentHeight - price, collateral, err := session.RPCAppendCost(remainingDuration) + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) if err != nil { t.Fatal(err) } + defer sh.Close() + go sh.Serve() - writtenRoot, err := session.Append(context.Background(), §or, price, collateral) + transport := dialHost(t, hostKey.PublicKey(), l.Addr().String()) + defer transport.Close() + + settings, err := rpc2.RPCSettings(transport) if err != nil { t.Fatal(err) - } else if writtenRoot != sectorRoot { - t.Fatal("sector root mismatch") } - // check the host's sector roots matches the sector we just uploaded - price, _ = session.Settings().RPCSectorRootsCost(0, 1).Total() - roots, err := session.SectorRoots(context.Background(), 0, 1, price) + fc := crhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey.PublicKey(), types.Siacoins(10), types.Siacoins(20), node.Chain.Tip().Height+200, settings, node.Wallet.Address()) + formationCost := crhp2.ContractFormationCost(node.Chain.TipState(), fc, settings.ContractPrice) + txn := types.Transaction{ + FileContracts: []types.FileContract{fc}, + } + toSign, err := node.Wallet.FundTransaction(&txn, formationCost, true) if err != nil { t.Fatal(err) - } else if roots[0] != sectorRoot { - t.Fatal("sector root mismatch") } + node.Wallet.SignTransaction(&txn, toSign, wallet.ExplicitCoveredFields(txn)) + formationSet := append(node.Chain.UnconfirmedParents(txn), txn) - // check that the revision fields are correct - revision := session.Revision().Revision - switch { - case revision.Filesize != rhp2.SectorSize: - t.Fatal("wrong filesize") - case revision.FileMerkleRoot != sectorRoot: - t.Fatal("wrong merkle root") + revision, _, err := rpc2.RPCFormContract(transport, renterKey, formationSet) + if err != nil { + t.Fatal(err) + } else if _, err := rpc2.RPCLock(transport, renterKey, revision.ID()); err != nil { + t.Fatal(err) } + defer rpc2.RPCUnlock(transport) - sections := []rhp2.RPCReadRequestSection{ + var sector [crhp2.SectorSize]byte + frand.Read(sector[:256]) + appendAction := []crhp2.RPCWriteAction{ { - MerkleRoot: writtenRoot, - Offset: 0, - Length: rhp2.SectorSize, + Type: crhp2.RPCWriteActionAppend, + Data: sector[:], }, } + root := crhp2.SectorRoot(§or) - // calculate the price - cost, err := session.Settings().RPCReadCost(sections, true) + err = rpc2.RPCWrite(transport, renterKey, &revision, appendAction, types.Siacoins(1), types.ZeroCurrency) // just overpay if err != nil { t.Fatal(err) + } else if revision.Revision.FileMerkleRoot != root { + t.Fatal("root mismatch") } - price, _ = cost.Total() - var buf bytes.Buffer - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := session.Read(ctx, &buf, sections, price); err != nil { + roots, err := rpc2.RPCSectorRoots(transport, renterKey, 0, 1, &revision, types.Siacoins(1)) // just overpay + if err != nil { t.Fatal(err) + } else if len(roots) != 1 || roots[0] != root { + t.Fatal("root mismatch") } - if !bytes.Equal(buf.Bytes(), sector[:]) { + readSection := []crhp2.RPCReadRequestSection{ + { + MerkleRoot: root, + Offset: 0, + Length: crhp2.SectorSize, + }, + } + buf := bytes.NewBuffer(make([]byte, 0, crhp2.SectorSize)) + err = rpc2.RPCRead(transport, buf, renterKey, &revision, readSection, types.Siacoins(1)) // just overpay + if err != nil { + t.Fatal(err) + } else if !bytes.Equal(buf.Bytes(), sector[:]) { t.Fatal("sector mismatch") } } func TestRenew(t *testing.T) { log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) - if err != nil { + renterKey, hostKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // set the host to accept contracts + s := node.Settings.Settings() + s.AcceptingContracts = true + s.NetAddress = "localhost:9983" + if err := node.Settings.UpdateSettings(s); err != nil { t.Fatal(err) } - defer renter.Close() - defer host.Close() - t.Run("empty contract", func(t *testing.T) { - state := renter.TipState() - // form a contract - origin, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), state.Index.Height+200) - if err != nil { - t.Fatal(err) - } - - session, err := renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), origin.ID()) - if err != nil { - t.Fatal(err) - } - defer session.Close() - - // mine a few blocks into the contract - if err := host.MineBlocks(host.WalletAddress(), 10); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) - - renewHeight := origin.Revision.WindowEnd + 10 - settings := *session.Settings() - current := session.Revision().Revision - additionalCollateral := rhp2.ContractRenewalCollateral(current.FileContract, 1<<22, settings, renter.TipState().Index.Height, renewHeight) - renewed, basePrice := rhp2.PrepareContractRenewal(current, renter.WalletAddress(), types.Siacoins(10), additionalCollateral, settings, renewHeight) - renewalTxn := types.Transaction{ - FileContracts: []types.FileContract{renewed}, - } + // initialize a storage volume + res := make(chan error) + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(t.TempDir(), "storage.dat"), 10, res); err != nil { + t.Fatal(err) + } else if err := <-res; err != nil { + t.Fatal(err) + } - cost := rhp2.ContractRenewalCost(state, renewed, settings.ContractPrice, types.ZeroCurrency, basePrice) - toSign, release, err := renter.Wallet().FundTransaction(&renewalTxn, cost) - if err != nil { - t.Fatal(err) - } + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) - if err := renter.Wallet().SignTransaction(host.TipState(), &renewalTxn, toSign, wallet.ExplicitCoveredFields(renewalTxn)); err != nil { - release() - t.Fatal(err) - } + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() - renewal, _, err := session.RenewContract(context.Background(), []types.Transaction{renewalTxn}, settings.BaseRPCPrice) - if err != nil { - release() - t.Fatal(err) - } + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) + if err != nil { + t.Fatal(err) + } + defer sh.Close() + go sh.Serve() - // mine a block to confirm the revision - if err := host.MineBlocks(host.WalletAddress(), 1); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) + transport := dialHost(t, hostKey.PublicKey(), l.Addr().String()) + defer transport.ForceClose() - old, err := host.Contracts().Contract(origin.ID()) - if err != nil { - t.Fatal(err) - } else if old.Revision.Filesize != 0 { - t.Fatal("filesize mismatch") - } else if old.Revision.FileMerkleRoot != (types.Hash256{}) { - t.Fatal("merkle root mismatch") - } else if old.RenewedTo != renewal.ID() { - t.Fatal("renewed to mismatch") - } else if !old.Usage.RPCRevenue.Equals(settings.ContractPrice.Add(settings.BaseRPCPrice)) { - t.Fatalf("expected rpc revenue to equal contract price + base rpc price %d, got %d", settings.ContractPrice.Add(settings.BaseRPCPrice), old.Usage.RPCRevenue) - } + settings, err := rpc2.RPCSettings(transport) + if err != nil { + t.Fatal(err) + } - contract, err := host.Contracts().Contract(renewal.ID()) - if err != nil { - t.Fatal(err) - } else if contract.Revision.Filesize != origin.Revision.Filesize { - t.Fatal("filesize mismatch") - } else if contract.Revision.FileMerkleRoot != origin.Revision.FileMerkleRoot { - t.Fatal("merkle root mismatch") - } else if !contract.LockedCollateral.Equals(additionalCollateral) { - t.Fatalf("locked collateral mismatch: expected %d, got %d", additionalCollateral, contract.LockedCollateral) - } else if !contract.Usage.RiskedCollateral.IsZero() { - t.Fatalf("expected zero risked collateral, got %d", contract.Usage.RiskedCollateral) - } else if !contract.Usage.RPCRevenue.Equals(settings.ContractPrice) { - t.Fatalf("expected %d RPC revenue, got %d", settings.ContractPrice, contract.Usage.RPCRevenue) - } else if contract.RenewedFrom != origin.ID() { - t.Fatalf("expected renewed from %s, got %s", origin.ID(), contract.RenewedFrom) - } - }) + formContract := func(t *testing.T, duration uint64) crhp2.ContractRevision { + t.Helper() - t.Run("drained contract", func(t *testing.T) { - // form a contract - state := renter.TipState() - origin, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), state.Index.Height+200) - if err != nil { - t.Fatal(err) + fc := crhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey.PublicKey(), types.Siacoins(10), types.Siacoins(20), node.Chain.Tip().Height+duration, settings, node.Wallet.Address()) + formationCost := crhp2.ContractFormationCost(node.Chain.TipState(), fc, settings.ContractPrice) + txn := types.Transaction{ + FileContracts: []types.FileContract{fc}, } - - session, err := renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), origin.ID()) + toSign, err := node.Wallet.FundTransaction(&txn, formationCost, true) if err != nil { - t.Fatal(err) + t.Fatal("failed to fund formation txn:", err) } - defer session.Close() + node.Wallet.SignTransaction(&txn, toSign, wallet.ExplicitCoveredFields(txn)) + formationSet := append(node.Chain.UnconfirmedParents(txn), txn) - // generate a sector - var sector [rhp2.SectorSize]byte - frand.Read(sector[:256]) - sectorRoot := rhp2.SectorRoot(§or) - - // calculate the remaining duration of the contract - var remainingDuration uint64 - contractExpiration := uint64(session.Revision().Revision.WindowEnd) - currentHeight := renter.TipState().Index.Height - if contractExpiration < currentHeight { - t.Fatal("contract expired") - } - // upload the sector - remainingDuration = contractExpiration - currentHeight - _, collateral, err := session.RPCAppendCost(remainingDuration) - if err != nil { - t.Fatal(err) - } - // overpay for the sector, leaving a few hastings for the renewal - remainingValue := types.NewCurrency64(25) - price := origin.Revision.ValidRenterPayout().Sub(remainingValue) - writtenRoot, err := session.Append(context.Background(), §or, price, collateral) + revision, _, err := rpc2.RPCFormContract(transport, renterKey, formationSet) if err != nil { - t.Fatal(err) - } else if writtenRoot != sectorRoot { - t.Fatal("sector root mismatch") + t.Fatal("failed to form contract:", err) } + return revision + } - // mine a few blocks into the contract - if err := host.MineBlocks(host.WalletAddress(), 10); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) + renewContract := func(t *testing.T, revision crhp2.ContractRevision, windowEnd uint64) crhp2.ContractRevision { + current := revision.Revision - settings := *session.Settings() - renewHeight := origin.Revision.WindowEnd + 10 - current := session.Revision().Revision - additionalCollateral := rhp2.ContractRenewalCollateral(current.FileContract, 1<<22, settings, renter.TipState().Index.Height, renewHeight) - renewed, basePrice := rhp2.PrepareContractRenewal(session.Revision().Revision, renter.WalletAddress(), types.Siacoins(10), additionalCollateral, settings, renewHeight) + additionalCollateral := crhp2.ContractRenewalCollateral(current.FileContract, 1<<22, settings, node.Chain.Tip().Height, windowEnd) + renewed, basePrice := crhp2.PrepareContractRenewal(current, node.Wallet.Address(), types.Siacoins(10), additionalCollateral, settings, windowEnd) renewalTxn := types.Transaction{ FileContracts: []types.FileContract{renewed}, } - cost := rhp2.ContractRenewalCost(state, renewed, settings.ContractPrice, types.ZeroCurrency, basePrice) - toSign, release, err := renter.Wallet().FundTransaction(&renewalTxn, cost) + cost := crhp2.ContractRenewalCost(node.Chain.TipState(), renewed, settings.ContractPrice, types.ZeroCurrency, basePrice) + toSign, err := node.Wallet.FundTransaction(&renewalTxn, cost, true) if err != nil { - t.Fatal(err) + t.Fatal("failed to fund formation txn:", err) } + node.Wallet.SignTransaction(&renewalTxn, toSign, wallet.ExplicitCoveredFields(renewalTxn)) + renewalSet := append(node.Chain.UnconfirmedParents(renewalTxn), renewalTxn) - if err := renter.Wallet().SignTransaction(host.TipState(), &renewalTxn, toSign, wallet.ExplicitCoveredFields(renewalTxn)); err != nil { - release() - t.Fatal(err) + renewal, _, err := rpc2.RPCRenewContract(transport, renterKey, &revision, renewalSet, settings.BaseRPCPrice) + if err != nil { + t.Fatal("failed to renew contract:", err) } + return renewal + } - // try to renew the contract without paying the remaining value, should fail - if _, _, err := session.RenewContract(context.Background(), []types.Transaction{renewalTxn}, types.ZeroCurrency); err == nil { - release() - t.Fatal("expected renewal to fail") - } else if err := session.Close(); err != nil { - t.Fatal(err) - } + assertContract := func(t *testing.T, rev crhp2.ContractRevision, expectedRevisionNumber uint64, expectedRoot types.Hash256, expectedFilesize uint64) { + t.Helper() - // previous session was closed by the RPC failure, create a new one - session, err = renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), origin.ID()) - if err != nil { - t.Fatal(err) + if rev.Revision.FileMerkleRoot != expectedRoot { + t.Fatalf("expected root %v, got %v", expectedRoot, rev.Revision.FileMerkleRoot) + } else if rev.Revision.Filesize != expectedFilesize { + t.Fatalf("expected filesize %d, got %d", expectedFilesize, rev.Revision.Filesize) + } else if rev.Revision.RevisionNumber != expectedRevisionNumber { + t.Fatalf("expected revision number %d, got %d", expectedRevisionNumber, rev.Revision.RevisionNumber) } - defer session.Close() - renewal, _, err := session.RenewContract(context.Background(), []types.Transaction{renewalTxn}, remainingValue) - if err != nil { - t.Fatal(err) + sigHash := rhp.HashRevision(rev.Revision) + if !hostKey.PublicKey().VerifyHash(sigHash, types.Signature(rev.Signatures[1].Signature)) { + t.Fatal("host signature invalid") + } else if !renterKey.PublicKey().VerifyHash(sigHash, types.Signature(rev.Signatures[0].Signature)) { + t.Fatal("renter signature invalid") } + } - expectedExchange := settings.ContractPrice.Add(settings.BaseRPCPrice).Add(remainingValue) // contract price + upload sector base RPC price + remaining value in contract for renewal - old, err := host.Contracts().Contract(origin.ID()) - if err != nil { + t.Run("empty contract", func(t *testing.T) { + fc := formContract(t, 145) + // mine to confirm the contract + testutil.MineAndSync(t, node, node.Wallet.Address(), 5) + if _, err := rpc2.RPCLock(transport, renterKey, fc.ID()); err != nil { t.Fatal(err) - } else if old.Revision.Filesize != 0 { - t.Fatal("filesize mismatch") - } else if old.Revision.FileMerkleRoot != (types.Hash256{}) { - t.Fatal("merkle root mismatch") - } else if old.RenewedTo != renewal.ID() { - t.Fatal("renewed to mismatch") - } else if !old.Usage.RPCRevenue.Equals(expectedExchange) { // only 25 hastings should remain in the contract - t.Fatalf("expected rpc revenue to equal contract price + base rpc price %d, got %d", expectedExchange, old.Usage.RPCRevenue) } + defer rpc2.RPCUnlock(transport) - contract, err := host.Contracts().Contract(renewal.ID()) - if err != nil { - t.Fatal(err) - } else if contract.Revision.Filesize != current.Filesize { - t.Fatal("filesize mismatch") - } else if contract.Revision.FileMerkleRoot != current.FileMerkleRoot { - t.Fatal("merkle root mismatch") - } else if contract.LockedCollateral.Cmp(additionalCollateral) <= 0 { - t.Fatalf("locked collateral mismatch: expected at least %d, got %d", additionalCollateral, contract.LockedCollateral) - } else if !contract.Usage.RPCRevenue.Equals(settings.ContractPrice) { - t.Fatalf("expected %d RPC revenue, got %d", settings.ContractPrice, contract.Usage.RPCRevenue) - } else if contract.RenewedFrom != origin.ID() { - t.Fatalf("expected renewed from %s, got %s", origin.ID(), contract.RenewedFrom) - } + assertContract(t, fc, 1, types.Hash256{}, 0) + renewal := renewContract(t, fc, fc.Revision.WindowEnd+10) + assertContract(t, renewal, 1, types.Hash256{}, 0) }) - t.Run("non-empty contract", func(t *testing.T) { - // form a contract - state := renter.TipState() - origin, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), state.Index.Height+200) - if err != nil { - t.Fatal(err) - } - - session, err := renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), origin.ID()) - if err != nil { + // note: rhp2 contracts could not be renewed if they did not have any funds. + t.Run("refresh contract", func(t *testing.T) { + fc := formContract(t, 145) + // mine to confirm the contract + testutil.MineAndSync(t, node, node.Wallet.Address(), 5) + if _, err := rpc2.RPCLock(transport, renterKey, fc.ID()); err != nil { t.Fatal(err) } - defer session.Close() + defer rpc2.RPCUnlock(transport) - // generate a sector - var sector [rhp2.SectorSize]byte + // upload a sector and pay the entire contract value + // minus the cost of renewal + var sector [crhp2.SectorSize]byte frand.Read(sector[:256]) - sectorRoot := rhp2.SectorRoot(§or) - - // calculate the remaining duration of the contract - var remainingDuration uint64 - contractExpiration := uint64(session.Revision().Revision.WindowEnd) - currentHeight := renter.TipState().Index.Height - if contractExpiration < currentHeight { - t.Fatal("contract expired") + appendAction := []crhp2.RPCWriteAction{ + { + Type: crhp2.RPCWriteActionAppend, + Data: sector[:], + }, } - // upload the sector - remainingDuration = contractExpiration - currentHeight - price, collateral, err := session.RPCAppendCost(remainingDuration) - if err != nil { - t.Fatal(err) - } - writtenRoot, err := session.Append(context.Background(), §or, price, collateral) + root := crhp2.SectorRoot(§or) + + err = rpc2.RPCWrite(transport, renterKey, &fc, appendAction, fc.RenterFunds().Sub(settings.BaseRPCPrice), types.ZeroCurrency) if err != nil { t.Fatal(err) - } else if writtenRoot != sectorRoot { - t.Fatal("sector root mismatch") } + assertContract(t, fc, 2, root, crhp2.SectorSize) - // mine a few blocks into the contract - if err := host.MineBlocks(host.WalletAddress(), 10); err != nil { + renewal := renewContract(t, fc, fc.Revision.WindowEnd) + assertContract(t, renewal, 1, root, crhp2.SectorSize) + }) + + t.Run("renew contract", func(t *testing.T) { + fc := formContract(t, 145) + // mine to confirm the contract + testutil.MineAndSync(t, node, node.Wallet.Address(), 5) + if _, err := rpc2.RPCLock(transport, renterKey, fc.ID()); err != nil { t.Fatal(err) } - time.Sleep(100 * time.Millisecond) + defer rpc2.RPCUnlock(transport) - settings := *session.Settings() - renewHeight := origin.Revision.WindowEnd + 10 - current := session.Revision().Revision - additionalCollateral := rhp2.ContractRenewalCollateral(current.FileContract, 1<<22, settings, renter.TipState().Index.Height, renewHeight) - renewed, basePrice := rhp2.PrepareContractRenewal(session.Revision().Revision, renter.WalletAddress(), types.Siacoins(10), additionalCollateral, settings, renewHeight) - renewalTxn := types.Transaction{ - FileContracts: []types.FileContract{renewed}, + // upload a sector and pay the entire contract value + // minus the cost of renewal + var roots []types.Hash256 + var sector [crhp2.SectorSize]byte + frand.Read(sector[:256]) + appendAction := []crhp2.RPCWriteAction{ + { + Type: crhp2.RPCWriteActionAppend, + Data: sector[:], + }, } + root := crhp2.SectorRoot(§or) + roots = append(roots, root) - cost := rhp2.ContractRenewalCost(state, renewed, settings.ContractPrice, types.ZeroCurrency, basePrice) - toSign, discard, err := renter.Wallet().FundTransaction(&renewalTxn, cost) + err = rpc2.RPCWrite(transport, renterKey, &fc, appendAction, fc.RenterFunds().Sub(settings.BaseRPCPrice), types.ZeroCurrency) if err != nil { t.Fatal(err) } - defer discard() + assertContract(t, fc, 2, crhp2.MetaRoot(roots), uint64(len(roots))*crhp2.SectorSize) - if err := renter.Wallet().SignTransaction(host.TipState(), &renewalTxn, toSign, wallet.ExplicitCoveredFields(renewalTxn)); err != nil { - t.Fatal(err) - } + renewal := renewContract(t, fc, fc.Revision.WindowEnd+10) + assertContract(t, renewal, 1, crhp2.MetaRoot(roots), uint64(len(roots))*crhp2.SectorSize) - renewal, _, err := session.RenewContract(context.Background(), []types.Transaction{renewalTxn}, settings.BaseRPCPrice) - if err != nil { + if err := rpc2.RPCUnlock(transport); err != nil { + t.Fatal(err) + } else if _, err := rpc2.RPCLock(transport, renterKey, renewal.ID()); err != nil { t.Fatal(err) } - expectedExchange := settings.ContractPrice.Add(settings.BaseRPCPrice).Add(settings.BaseRPCPrice) // contract price + upload sector base RPC price + renewal base RPC price - old, err := host.Contracts().Contract(origin.ID()) - if err != nil { - t.Fatal(err) - } else if old.Revision.Filesize != 0 { - t.Fatal("filesize mismatch") - } else if old.Revision.FileMerkleRoot != (types.Hash256{}) { - t.Fatal("merkle root mismatch") - } else if old.RenewedTo != renewal.ID() { - t.Fatal("renewed to mismatch") - } else if !old.Usage.RPCRevenue.Equals(expectedExchange) { - t.Fatalf("expected rpc revenue to equal contract price + base rpc price %d, got %d", expectedExchange, old.Usage.RPCRevenue) + // upload a new sector + frand.Read(sector[:256]) + appendAction = []crhp2.RPCWriteAction{ + { + Type: crhp2.RPCWriteActionAppend, + Data: sector[:], + }, } + root = crhp2.SectorRoot(§or) + roots = append(roots, root) - contract, err := host.Contracts().Contract(renewal.ID()) + err = rpc2.RPCWrite(transport, renterKey, &renewal, appendAction, types.Siacoins(1), types.ZeroCurrency) if err != nil { t.Fatal(err) - } else if contract.Revision.Filesize != current.Filesize { - t.Fatal("filesize mismatch") - } else if contract.Revision.FileMerkleRoot != current.FileMerkleRoot { - t.Fatal("merkle root mismatch") - } else if contract.LockedCollateral.Cmp(additionalCollateral) <= 0 { - t.Fatalf("locked collateral mismatch: expected at least %d, got %d", additionalCollateral, contract.LockedCollateral) - } else if !contract.Usage.RPCRevenue.Equals(settings.ContractPrice) { - t.Fatalf("expected %d RPC revenue, got %d", settings.ContractPrice, contract.Usage.RPCRevenue) - } else if contract.RenewedFrom != origin.ID() { - t.Fatalf("expected renewed from %s, got %s", origin.ID(), contract.RenewedFrom) } + assertContract(t, renewal, 2, crhp2.MetaRoot(roots), uint64(len(roots))*crhp2.SectorSize) }) } -func BenchmarkUpload(b *testing.B) { - log := zaptest.NewLogger(b) - renter, host, err := test.NewTestingPair(b.TempDir(), log) - if err != nil { - b.Fatal(err) +func TestRPCV2(t *testing.T) { + log := zaptest.NewLogger(t) + renterKey, hostKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V2Network() + network.HardforkV2.AllowHeight = 180 + network.HardforkV2.RequireHeight = 200 + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // initialize a storage volume + res := make(chan error) + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(t.TempDir(), "storage.dat"), 10, res); err != nil { + t.Fatal(err) + } else if err := <-res; err != nil { + t.Fatal(err) } - defer renter.Close() - defer host.Close() - // form a contract - contract, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), 200) - if err != nil { - b.Fatal(err) - } + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) - session, err := renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), contract.ID()) + l, err := net.Listen("tcp", "localhost:0") if err != nil { - b.Fatal(err) + t.Fatal(err) } - defer session.Close() + defer l.Close() - // calculate the remaining duration of the contract - var remainingDuration uint64 - contractExpiration := uint64(session.Revision().Revision.WindowEnd) - currentHeight := renter.TipState().Index.Height - if contractExpiration < currentHeight { - b.Fatal("contract expired") - } - // calculate the cost of uploading a sector - remainingDuration = contractExpiration - currentHeight - price, collateral, err := session.RPCAppendCost(remainingDuration) + sh, err := rhp2.NewSessionHandler(l, hostKey, "localhost:9983", node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log)) if err != nil { - b.Fatal(err) - } - - // generate b.N sectors - sectors := make([][rhp2.SectorSize]byte, b.N) - for i := range sectors { - frand.Read(sectors[i][:256]) + t.Fatal(err) } + defer sh.Close() + go sh.Serve() - b.ResetTimer() - b.ReportAllocs() - b.SetBytes(rhp2.SectorSize) - - // upload b.N sectors - for i := 0; i < b.N; i++ { - sector := sectors[i] + t.Run("ends after require height", func(t *testing.T) { + transport := dialHost(t, hostKey.PublicKey(), l.Addr().String()) + defer transport.Close() - // upload the sector - if _, err := session.Append(context.Background(), §or, price, collateral); err != nil { - b.Fatal(err) + settings, err := rpc2.RPCSettings(transport) + if err != nil { + t.Fatal(err) } - } -} -func BenchmarkDownload(b *testing.B) { - log := zaptest.NewLogger(b) - renter, host, err := test.NewTestingPair(b.TempDir(), log) - if err != nil { - b.Fatal(err) - } - defer renter.Close() - defer host.Close() - - if err := host.AddVolume(filepath.Join(b.TempDir(), "storage.dat"), uint64(b.N)); err != nil { - b.Fatal(err) - } - - // form a contract - contract, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), 200) - if err != nil { - b.Fatal(err) - } - - // mine a block to confirm the contract - if err := host.MineBlocks(host.WalletAddress(), 1); err != nil { - b.Fatal(err) - } - - session, err := renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), contract.ID()) - if err != nil { - b.Fatal(err) - } - defer session.Close() - - // calculate the remaining duration of the contract - var remainingDuration uint64 - contractExpiration := uint64(session.Revision().Revision.WindowEnd) - currentHeight := renter.TipState().Index.Height - if contractExpiration < currentHeight { - b.Fatal("contract expired") - } - remainingDuration = contractExpiration - currentHeight - - var uploaded []types.Hash256 - // upload b.N sectors - for i := 0; i < b.N; i++ { - // generate a sector - var sector [rhp2.SectorSize]byte - frand.Read(sector[:256]) - - // upload the sector - session.Settings().RPCWriteCost([]rhp2.RPCWriteAction{{Type: rhp2.RPCWriteActionAppend}}, uint64(b.N), remainingDuration, true) - price, collateral, err := session.RPCAppendCost(remainingDuration) - if err != nil { - b.Fatal(err) + // try to form a v1 contract that ends after the require height + fc := crhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey.PublicKey(), types.Siacoins(10), types.Siacoins(20), network.HardforkV2.RequireHeight, settings, node.Wallet.Address()) + formationCost := crhp2.ContractFormationCost(node.Chain.TipState(), fc, settings.ContractPrice) + txn := types.Transaction{ + FileContracts: []types.FileContract{fc}, } - root, err := session.Append(context.Background(), §or, price, collateral) + toSign, err := node.Wallet.FundTransaction(&txn, formationCost, true) if err != nil { - b.Fatal(err) + t.Fatal(err) } - uploaded = append(uploaded, root) - } + node.Wallet.SignTransaction(&txn, toSign, wallet.ExplicitCoveredFields(txn)) + formationSet := append(node.Chain.UnconfirmedParents(txn), txn) - b.ReportAllocs() - b.ResetTimer() - b.SetBytes(rhp2.SectorSize) - - for _, root := range uploaded { - // download the sector - sections := []rhp2.RPCReadRequestSection{{ - MerkleRoot: root, - Offset: 0, - Length: rhp2.SectorSize, - }} - - cost, err := session.Settings().RPCReadCost(sections, true) - if err != nil { - b.Fatal(err) - } - price, _ := cost.Total() - if err := session.Read(context.Background(), io.Discard, sections, price); err != nil { - b.Fatal(err) + if _, _, err := rpc2.RPCFormContract(transport, renterKey, formationSet); !errors.Is(err, rhp2.ErrAfterV2Hardfork) { + t.Fatalf("expected ErrV2Hardfork, got %v", err) } - } -} - -func TestSectorRoots(t *testing.T) { - log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) - if err != nil { - t.Fatal(err) - } - defer renter.Close() - defer host.Close() - - // form a contract - contract, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), 200) - if err != nil { - t.Fatal(err) - } - - session, err := renter.NewRHP2Session(context.Background(), host.RHP2Addr(), host.PublicKey(), contract.ID()) - if err != nil { - t.Fatal(err) - } - defer session.Close() - - // calculate the remaining duration of the contract - var remainingDuration uint64 - contractExpiration := uint64(session.Revision().Revision.WindowEnd) - currentHeight := renter.TipState().Index.Height - if contractExpiration < currentHeight { - t.Fatal("contract expired") - } - // calculate the cost of uploading a sector - remainingDuration = contractExpiration - currentHeight + }) - // upload a few sectors - sectors := make([][rhp2.SectorSize]byte, 5) - for i := range sectors { - frand.Read(sectors[i][:256]) - } + t.Run("form after allow height", func(t *testing.T) { + // mine until the allow height + testutil.MineAndSync(t, node, node.Wallet.Address(), int(network.HardforkV2.AllowHeight-node.Chain.Tip().Height)) - for i := 0; i < len(sectors); i++ { - sector := sectors[i] + transport := dialHost(t, hostKey.PublicKey(), l.Addr().String()) + defer transport.Close() - price, collateral, err := session.RPCAppendCost(remainingDuration) + settings, err := rpc2.RPCSettings(transport) if err != nil { t.Fatal(err) } - // upload the sector - if _, err := session.Append(context.Background(), §or, price, collateral); err != nil { + // try to form a v1 contract after the allow height + fc := crhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey.PublicKey(), types.Siacoins(10), types.Siacoins(20), node.Chain.Tip().Height+10, settings, node.Wallet.Address()) + formationCost := crhp2.ContractFormationCost(node.Chain.TipState(), fc, settings.ContractPrice) + txn := types.Transaction{ + FileContracts: []types.FileContract{fc}, + } + toSign, err := node.Wallet.FundTransaction(&txn, formationCost, true) + if err != nil { t.Fatal(err) } - } + node.Wallet.SignTransaction(&txn, toSign, wallet.ExplicitCoveredFields(txn)) + formationSet := append(node.Chain.UnconfirmedParents(txn), txn) - // fetch sectors one-by-one and compare - for i := 0; i < len(sectors); i++ { - price, _ := session.Settings().RPCSectorRootsCost(uint64(i), 1).Total() - root, err := session.SectorRoots(context.Background(), uint64(i), 1, price) - if err != nil { - t.Fatalf("root %d error: %s", i, err) - } else if len(root) != 1 { - t.Fatal("expected 1 sector root") - } else if root[0] != rhp2.SectorRoot(§ors[i]) { - t.Fatal("sector root mismatch") + _, _, err = rpc2.RPCFormContract(transport, renterKey, formationSet) + if runtime.GOOS != "windows" && !errors.Is(err, rhp2.ErrV2Hardfork) { // windows responds with wsarecv rather than the error + t.Fatalf("expected ErrV2Hardfork, got %v", err) + } else if runtime.GOOS == "windows" && err == nil { + t.Fatal("expected windows error, got nil") } - } + }) - // fetch all sectors at once and compare - price, _ := session.Settings().RPCSectorRootsCost(0, uint64(len(sectors))).Total() - roots, err := session.SectorRoots(context.Background(), 0, uint64(len(sectors)), price) - if err != nil { - t.Fatal(err) - } - for i := range roots { - if roots[i] != rhp2.SectorRoot(§ors[i]) { - t.Fatal("sector root mismatch") + t.Run("rpc after require height", func(t *testing.T) { + // mine until the require height + testutil.MineAndSync(t, node, node.Wallet.Address(), int(network.HardforkV2.RequireHeight-node.Chain.Tip().Height)) + + transport := dialHost(t, hostKey.PublicKey(), l.Addr().String()) + defer transport.Close() + + _, err := rpc2.RPCSettings(transport) + if runtime.GOOS != "windows" && !errors.Is(err, rhp2.ErrV2Hardfork) { // windows responds with wsarecv rather than the error + t.Fatalf("expected ErrV2Hardfork, got %v", err) + } else if runtime.GOOS == "windows" && err == nil { + t.Fatal("expected windows error, got nil") } - } + }) } diff --git a/rhp/v2/session.go b/rhp/v2/session.go index 0e39b1c7..9ab785cd 100644 --- a/rhp/v2/session.go +++ b/rhp/v2/session.go @@ -41,14 +41,12 @@ func (s *session) writeResponse(resp rhp2.ProtocolObject, timeout time.Duration) // ContractRevisable returns an error if a contract is not locked or can't be // revised. A contract is revisable if the revision number is not the max uint64 // value and it is not close to the proof window. -func (s *session) ContractRevisable(height uint64) error { +func (s *session) ContractRevisable() error { switch { case s.contract.Revision.ParentID == (types.FileContractID{}): return ErrNoContractLocked case s.contract.Revision.RevisionNumber == types.MaxRevisionNumber: return ErrContractRevisionLimit - case s.contract.Revision.WindowStart-contracts.RevisionSubmissionBuffer < height: - return ErrContractExpired } return nil } diff --git a/rhp/v3/options.go b/rhp/v3/options.go new file mode 100644 index 00000000..4e2b5f0b --- /dev/null +++ b/rhp/v3/options.go @@ -0,0 +1,41 @@ +package rhp + +import ( + "go.sia.tech/core/types" + "go.sia.tech/hostd/host/contracts" + "go.sia.tech/hostd/rhp" + "go.uber.org/zap" +) + +// SessionHandlerOption is a functional option for session handlers. +type SessionHandlerOption func(*SessionHandler) + +// WithLog sets the logger for the session handler. +func WithLog(l *zap.Logger) SessionHandlerOption { + return func(s *SessionHandler) { + s.log = l + } +} + +// WithSessionReporter sets the session reporter for the session handler. +func WithSessionReporter(r SessionReporter) SessionHandlerOption { + return func(s *SessionHandler) { + s.sessions = r + } +} + +// WithDataMonitor sets the data monitor for the session handler. +func WithDataMonitor(m rhp.DataMonitor) SessionHandlerOption { + return func(s *SessionHandler) { + s.monitor = m + } +} + +type noopSessionReporter struct{} + +func (noopSessionReporter) StartSession(conn *rhp.Conn, proto string, version int) (sessionID rhp.UID, end func()) { + return rhp.UID{}, func() {} +} +func (noopSessionReporter) StartRPC(sessionID rhp.UID, rpc types.Specifier) (rpcID rhp.UID, end func(contracts.Usage, error)) { + return rhp.UID{}, func(contracts.Usage, error) {} +} diff --git a/rhp/v3/pricetable.go b/rhp/v3/pricetable.go index b6ce5fc1..44f406e8 100644 --- a/rhp/v3/pricetable.go +++ b/rhp/v3/pricetable.go @@ -112,7 +112,7 @@ func (sh *SessionHandler) PriceTable() (rhp3.HostPriceTable, error) { return rhp3.HostPriceTable{}, fmt.Errorf("failed to get registry entries: %w", err) } - fee := sh.tpool.RecommendedFee() + fee := sh.chain.RecommendedFee() currentHeight := sh.chain.TipState().Index.Height oneHasting := types.NewCurrency64(1) return rhp3.HostPriceTable{ diff --git a/rhp/v3/rhp.go b/rhp/v3/rhp.go index 2ef4edd5..bb775ac6 100644 --- a/rhp/v3/rhp.go +++ b/rhp/v3/rhp.go @@ -82,20 +82,26 @@ type ( // A ChainManager provides access to the current state of the blockchain. ChainManager interface { + Tip() types.ChainIndex TipState() consensus.State + UnconfirmedParents(txn types.Transaction) []types.Transaction + AddPoolTransactions([]types.Transaction) (known bool, err error) + AddV2PoolTransactions(types.ChainIndex, []types.V2Transaction) (known bool, err error) + RecommendedFee() types.Currency + } + + // A Syncer broadcasts transactions to the network + Syncer interface { + BroadcastTransactionSet([]types.Transaction) + BroadcastV2TransactionSet(types.ChainIndex, []types.V2Transaction) } // A Wallet manages funds and signs transactions Wallet interface { Address() types.Address - FundTransaction(txn *types.Transaction, amount types.Currency) ([]types.Hash256, func(), error) - SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error - } - - // A TransactionPool broadcasts transactions to the network. - TransactionPool interface { - AcceptTransactionSet([]types.Transaction) error - RecommendedFee() types.Currency + FundTransaction(txn *types.Transaction, amount types.Currency, unconfirmed bool) ([]types.Hash256, error) + SignTransaction(txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) + ReleaseInputs(txn []types.Transaction, v2txn []types.V2Transaction) } // A SettingsReporter reports the host's current configuration. @@ -116,20 +122,21 @@ type ( privateKey types.PrivateKey listener net.Listener - monitor rhp.DataMonitor - tg *threadgroup.ThreadGroup accounts AccountManager contracts ContractManager - sessions SessionReporter registry RegistryManager storage StorageManager - log *zap.Logger + settings SettingsReporter + + chain ChainManager + syncer Syncer + wallet Wallet - chain ChainManager - settings SettingsReporter - tpool TransactionPool - wallet Wallet + log *zap.Logger + sessions SessionReporter + monitor rhp.DataMonitor + tg *threadgroup.ThreadGroup priceTables *priceTableManager } @@ -263,6 +270,13 @@ func (sh *SessionHandler) Serve() error { return } + cs := sh.chain.TipState() + // disable rhp3 after v2 require height + if cs.Index.Height >= cs.Network.HardforkV2.RequireHeight { + stream.WriteResponseErr(ErrV2Hardfork) + return + } + go sh.handleHostStream(stream, sessionID, log) } }() @@ -275,27 +289,31 @@ func (sh *SessionHandler) LocalAddr() string { } // NewSessionHandler creates a new SessionHandler -func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, chain ChainManager, tpool TransactionPool, wallet Wallet, accounts AccountManager, contracts ContractManager, registry RegistryManager, storage StorageManager, settings SettingsReporter, monitor rhp.DataMonitor, sessions SessionReporter, log *zap.Logger) (*SessionHandler, error) { +func NewSessionHandler(l net.Listener, hostKey types.PrivateKey, chain ChainManager, syncer Syncer, wallet Wallet, accounts AccountManager, contracts ContractManager, registry RegistryManager, storage StorageManager, settings SettingsReporter, opts ...SessionHandlerOption) (*SessionHandler, error) { sh := &SessionHandler{ privateKey: hostKey, listener: l, - monitor: monitor, - tg: threadgroup.New(), chain: chain, - tpool: tpool, + syncer: syncer, wallet: wallet, accounts: accounts, contracts: contracts, - sessions: sessions, registry: registry, settings: settings, storage: storage, - log: log, + + log: zap.NewNop(), + monitor: rhp.NewNoOpMonitor(), + sessions: noopSessionReporter{}, + tg: threadgroup.New(), priceTables: newPriceTableManager(), } + for _, opt := range opts { + opt(sh) + } return sh, nil } diff --git a/rhp/v3/rpc.go b/rhp/v3/rpc.go index 6003466b..f65b5ea1 100644 --- a/rhp/v3/rpc.go +++ b/rhp/v3/rpc.go @@ -12,10 +12,10 @@ import ( rhp3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" "go.sia.tech/hostd/host/accounts" "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/rhp" - "go.sia.tech/hostd/wallet" "go.uber.org/zap" "lukechampine.com/frand" ) @@ -40,6 +40,14 @@ var ( // ErrNotAcceptingContracts is returned when the host is not accepting // contracts. ErrNotAcceptingContracts = errors.New("host is not accepting contracts") + + // ErrV2Hardfork is returned when a renter tries to form or renew a contract + // after the v2 hardfork has been activated. + ErrV2Hardfork = errors.New("hardfork v2 is active") + + // ErrAfterV2Hardfork is returned when a renter tries to form or renew a + // contract that ends after the v2 hardfork has been activated. + ErrAfterV2Hardfork = errors.New("proof window after hardfork v2 activation") ) // handleRPCPriceTable sends the host's price table to the renter. @@ -239,6 +247,13 @@ func (sh *SessionHandler) handleRPCLatestRevision(s *rhp3.Stream, log *zap.Logge } func (sh *SessionHandler) handleRPCRenew(s *rhp3.Stream, log *zap.Logger) (contracts.Usage, error) { + cs := sh.chain.TipState() + // prevent renewing v1 contracts after the allow height + if cs.Index.Height >= cs.Network.HardforkV2.AllowHeight { + s.WriteResponseErr(ErrV2Hardfork) + return contracts.Usage{}, ErrV2Hardfork + } + s.SetDeadline(time.Now().Add(2 * time.Minute)) if !sh.settings.Settings().AcceptingContracts { s.WriteResponseErr(ErrNotAcceptingContracts) @@ -287,6 +302,12 @@ func (sh *SessionHandler) handleRPCRenew(s *rhp3.Stream, log *zap.Logger) (contr clearingRevision := renewalTxn.FileContractRevisions[0] renewal := renewalTxn.FileContracts[0] + // prevent forming v1 contracts with proof windows after the v2 hardfork + if renewal.WindowStart >= cs.Network.HardforkV2.RequireHeight { + s.WriteResponseErr(ErrAfterV2Hardfork) + return contracts.Usage{}, ErrAfterV2Hardfork + } + // lock the existing contract ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -337,7 +358,7 @@ func (sh *SessionHandler) handleRPCRenew(s *rhp3.Stream, log *zap.Logger) (contr return contracts.Usage{}, err } renterInputs, renterOutputs := len(renewalTxn.SiacoinInputs), len(renewalTxn.SiacoinOutputs) - toSign, release, err := sh.wallet.FundTransaction(&renewalTxn, lockedCollateral) + toSign, err := sh.wallet.FundTransaction(&renewalTxn, lockedCollateral, false) if err != nil { remoteErr := ErrHostInternalError if errors.Is(err, wallet.ErrNotEnoughFunds) { @@ -353,21 +374,21 @@ func (sh *SessionHandler) handleRPCRenew(s *rhp3.Stream, log *zap.Logger) (contr FinalRevisionSignature: signedClearingRevision.HostSignature, } if err := s.WriteResponse(hostAdditions); err != nil { - release() + sh.wallet.ReleaseInputs([]types.Transaction{renewalTxn}, nil) return contracts.Usage{}, fmt.Errorf("failed to write host additions: %w", err) } var renterSigsResp rhp3.RPCRenewSignatures if err := s.ReadRequest(&renterSigsResp, 10*maxRequestSize); err != nil { - release() + sh.wallet.ReleaseInputs([]types.Transaction{renewalTxn}, nil) return contracts.Usage{}, fmt.Errorf("failed to read renter signatures: %w", err) } // create the initial revision and verify the renter's signature - renewalRevision := rhp.InitialRevision(&renewalTxn, hostUnlockKey, req.RenterKey) + renewalRevision := rhp.InitialRevision(renewalTxn, hostUnlockKey, req.RenterKey) renewalSigHash := rhp.HashRevision(renewalRevision) if err := validateRenterRevisionSignature(renterSigsResp.RevisionSignature, renewalRevision.ParentID, renewalSigHash, renterKey); err != nil { - release() + sh.wallet.ReleaseInputs([]types.Transaction{renewalTxn}, nil) err := fmt.Errorf("failed to verify renter revision signature: %w", ErrInvalidRenterSignature) s.WriteResponseErr(err) return contracts.Usage{}, err @@ -401,14 +422,10 @@ func (sh *SessionHandler) handleRPCRenew(s *rhp3.Stream, log *zap.Logger) (contr renterSigs := len(renewalTxn.Signatures) // sign and broadcast the transaction - if err := sh.wallet.SignTransaction(sh.chain.TipState(), &renewalTxn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - s.WriteResponseErr(fmt.Errorf("failed to sign renewal transaction: %w", ErrHostInternalError)) - return contracts.Usage{}, fmt.Errorf("failed to sign renewal transaction: %w", err) - } + sh.wallet.SignTransaction(&renewalTxn, toSign, types.CoveredFields{WholeTransaction: true}) renewalTxnSet := append(parents, renewalTxn) - if err := sh.tpool.AcceptTransactionSet(renewalTxnSet); err != nil { - release() + if _, err := sh.chain.AddPoolTransactions(renewalTxnSet); err != nil { + sh.wallet.ReleaseInputs([]types.Transaction{renewalTxn}, nil) err = fmt.Errorf("failed to broadcast renewal transaction: %w", err) s.WriteResponseErr(err) return contracts.Usage{}, err diff --git a/rhp/v3/rpc_test.go b/rhp/v3/rpc_test.go index dc345c23..43ed1292 100644 --- a/rhp/v3/rpc_test.go +++ b/rhp/v3/rpc_test.go @@ -3,62 +3,165 @@ package rhp_test import ( "bytes" "context" + "errors" + "net" "path/filepath" "reflect" "strings" "testing" - "time" - rhp2 "go.sia.tech/core/rhp/v2" - rhp3 "go.sia.tech/core/rhp/v3" + crhp2 "go.sia.tech/core/rhp/v2" + crhp3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" - "go.sia.tech/hostd/host/settings" + "go.sia.tech/coreutils/chain" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/hostd/host/contracts" "go.sia.tech/hostd/host/storage" - "go.sia.tech/hostd/internal/test" - proto3 "go.sia.tech/hostd/internal/test/rhp/v3" - "go.sia.tech/hostd/rhp/v3" + "go.sia.tech/hostd/internal/testutil" + proto2 "go.sia.tech/hostd/internal/testutil/rhp/v2" + proto3 "go.sia.tech/hostd/internal/testutil/rhp/v3" + "go.sia.tech/hostd/rhp" + rhp2 "go.sia.tech/hostd/rhp/v2" + rhp3 "go.sia.tech/hostd/rhp/v3" + "go.uber.org/zap" "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) -func TestPriceTable(t *testing.T) { - log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) +func formContract(t *testing.T, cm *chain.Manager, wm *wallet.SingleAddressWallet, hostAddr string, renterKey types.PrivateKey, hostKey types.PublicKey, duration uint64) crhp2.ContractRevision { + t.Helper() + + conn, err := net.Dial("tcp", hostAddr) + if err != nil { + t.Fatal("failed to dial host", err) + } + defer conn.Close() + + transport, err := crhp2.NewRenterTransport(conn, hostKey) + if err != nil { + t.Fatal("failed to create transport", err) + } + defer transport.Close() + + settings, err := proto2.RPCSettings(transport) + if err != nil { + t.Fatal("failed to get settings", err) + } + + fc := crhp2.PrepareContractFormation(renterKey.PublicKey(), hostKey, types.Siacoins(1000), types.Siacoins(1000), cm.Tip().Height+duration, settings, wm.Address()) + formationCost := crhp2.ContractFormationCost(cm.TipState(), fc, settings.ContractPrice) + txn := types.Transaction{ + FileContracts: []types.FileContract{fc}, + } + toSign, err := wm.FundTransaction(&txn, formationCost, true) + if err != nil { + t.Fatal("failed to fund formation txn:", err) + } + wm.SignTransaction(&txn, toSign, wallet.ExplicitCoveredFields(txn)) + formationSet := append(cm.UnconfirmedParents(txn), txn) + + revision, _, err := proto2.RPCFormContract(transport, renterKey, formationSet) + if err != nil { + t.Fatal("failed to form contract:", err) + } + return revision +} + +func setupRHP3Host(t *testing.T, node *testutil.HostNode, hostKey types.PrivateKey, maxStorage uint64, log *zap.Logger) (*rhp2.SessionHandler, *rhp3.SessionHandler) { + // start the RHP2 listener for forming contracts + rhp2Listener, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) } - defer renter.Close() - defer host.Close() + t.Cleanup(func() { rhp2Listener.Close() }) - pt, err := host.RHP3PriceTable() + // start the RHP3 listener for the actual test + rhp3Listener, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) } + t.Cleanup(func() { rhp3Listener.Close() }) + + // set the host to accept contracts + s := node.Settings.Settings() + s.AcceptingContracts = true + s.MaxCollateral = types.Siacoins(100000) + s.MaxAccountBalance = types.Siacoins(100000) + s.StoragePrice = types.NewCurrency64(1) + s.ContractPrice = types.NewCurrency64(1) + s.EgressPrice = types.NewCurrency64(1) + s.IngressPrice = types.NewCurrency64(1) + s.BaseRPCPrice = types.NewCurrency64(1) + s.NetAddress = rhp3Listener.Addr().String() + if err := node.Settings.UpdateSettings(s); err != nil { + t.Fatal(err) + } + + // initialize a storage volume + res := make(chan error) + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(t.TempDir(), "storage.dat"), maxStorage, res); err != nil { + t.Fatal(err) + } else if err := <-res; err != nil { + t.Fatal(err) + } - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) + sh2, err := rhp2.NewSessionHandler(rhp2Listener, hostKey, rhp3Listener.Addr().String(), node.Chain, node.Syncer, node.Wallet, node.Contracts, node.Settings, node.Volumes, rhp2.WithLog(log.Named("rhp2"))) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { sh2.Close() }) + go sh2.Serve() + + sh3, err := rhp3.NewSessionHandler(rhp3Listener, hostKey, node.Chain, node.Syncer, node.Wallet, node.Accounts, node.Contracts, node.Registry, node.Volumes, node.Settings, rhp3.WithLog(log.Named("rhp3"))) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { sh3.Close() }) + go sh3.Serve() + + return sh2, sh3 +} + +func TestPriceTable(t *testing.T) { + log := zaptest.NewLogger(t) + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) + + // start the node + sh2, sh3 := setupRHP3Host(t, node, hostKey, 10, log) + + // create a RHP3 session + session, err := proto3.NewSession(context.Background(), hostKey.PublicKey(), sh3.LocalAddr(), node.Chain, node.Wallet) if err != nil { t.Fatal(err) } defer session.Close() + pt, err := sh3.PriceTable() + if err != nil { + t.Fatal(err) + } + retrieved, err := session.ScanPriceTable() if err != nil { t.Fatal(err) } // clear the UID field pt.UID = retrieved.UID + // check that the price tables match if !reflect.DeepEqual(pt, retrieved) { t.Fatal("price tables don't match") } - // pay for a price table using a contract payment - revision, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), 200) - if err != nil { - t.Fatal(err) - } + // form a contract + revision := formContract(t, node.Chain, node.Wallet, sh2.LocalAddr(), renterKey, hostKey.PublicKey(), 200) - account := rhp3.Account(renter.PublicKey()) - payment := proto3.ContractPayment(&revision, renter.PrivateKey(), account) + account := crhp3.Account(renterKey.PublicKey()) + payment := proto3.ContractPayment(&revision, renterKey, account) retrieved, err = session.RegisterPriceTable(payment) if err != nil { @@ -76,7 +179,7 @@ func TestPriceTable(t *testing.T) { t.Fatal(err) } - payment = proto3.AccountPayment(account, renter.PrivateKey()) + payment = proto3.AccountPayment(account, renterKey) // pay for a price table using an account retrieved, err = session.RegisterPriceTable(payment) if err != nil { @@ -91,27 +194,29 @@ func TestPriceTable(t *testing.T) { func TestAppendSector(t *testing.T) { log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) - if err != nil { - t.Fatal(err) - } - defer renter.Close() - defer host.Close() + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) + + // start the node + sh2, sh3 := setupRHP3Host(t, node, hostKey, 10, log) - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) + // create a RHP3 session + session, err := proto3.NewSession(context.Background(), hostKey.PublicKey(), sh3.LocalAddr(), node.Chain, node.Wallet) if err != nil { t.Fatal(err) } defer session.Close() - revision, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(50), types.Siacoins(100), 200) - if err != nil { - t.Fatal(err) - } + // form a contract to upload sectors + revision := formContract(t, node.Chain, node.Wallet, sh2.LocalAddr(), renterKey, hostKey.PublicKey(), 200) // register the price table - account := rhp3.Account(renter.PublicKey()) - payment := proto3.ContractPayment(&revision, renter.PrivateKey(), account) + account := crhp3.Account(renterKey.PublicKey()) + payment := proto3.ContractPayment(&revision, renterKey, account) pt, err := session.RegisterPriceTable(payment) if err != nil { t.Fatal(err) @@ -126,27 +231,27 @@ func TestAppendSector(t *testing.T) { var roots []types.Hash256 for i := 0; i < 10; i++ { // calculate the cost of the upload - cost, _ := pt.BaseCost().Add(pt.AppendSectorCost(revision.Revision.WindowEnd - renter.TipState().Index.Height)).Total() + cost, _ := pt.BaseCost().Add(pt.AppendSectorCost(revision.Revision.WindowEnd - node.Chain.Tip().Height)).Total() if cost.IsZero() { t.Fatal("cost is zero") } - var sector [rhp2.SectorSize]byte + var sector [crhp2.SectorSize]byte frand.Read(sector[:256]) - root := rhp2.SectorRoot(§or) + root := crhp2.SectorRoot(§or) roots = append(roots, root) - if _, err = session.AppendSector(§or, &revision, renter.PrivateKey(), payment, cost); err != nil { + if _, err = session.AppendSector(§or, &revision, renterKey, payment, cost); err != nil { t.Fatal(err) } // check that the contract merkle root matches - if revision.Revision.FileMerkleRoot != rhp2.MetaRoot(roots) { + if revision.Revision.FileMerkleRoot != crhp2.MetaRoot(roots) { t.Fatal("contract merkle root doesn't match") } // download the sector - cost, _ = pt.BaseCost().Add(pt.ReadSectorCost(rhp2.SectorSize)).Total() - downloaded, _, err := session.ReadSector(root, 0, rhp2.SectorSize, payment, cost) + cost, _ = pt.BaseCost().Add(pt.ReadSectorCost(crhp2.SectorSize)).Total() + downloaded, _, err := session.ReadSector(root, 0, crhp2.SectorSize, payment, cost) if err != nil { t.Fatal(err) } else if !bytes.Equal(downloaded, sector[:]) { @@ -155,11 +260,11 @@ func TestAppendSector(t *testing.T) { } // assert ReadSector exposes ErrSectorNotFound - cost, _ := pt.BaseCost().Add(pt.ReadSectorCost(rhp2.SectorSize)).Total() - _, _, err = session.ReadSector(types.Hash256{}, 0, rhp2.SectorSize, payment, cost) + cost, _ := pt.BaseCost().Add(pt.ReadSectorCost(crhp2.SectorSize)).Total() + _, _, err = session.ReadSector(types.Hash256{}, 0, crhp2.SectorSize, payment, cost) if err == nil { t.Fatal("expected error when reading nil sector") - } else if strings.Contains(err.Error(), rhp.ErrHostInternalError.Error()) { + } else if strings.Contains(err.Error(), rhp3.ErrHostInternalError.Error()) { t.Fatal("unexpected internal error", err) } else if !strings.Contains(err.Error(), storage.ErrSectorNotFound.Error()) { t.Fatal("expected storage.ErrSectorNotFound", err) @@ -168,30 +273,29 @@ func TestAppendSector(t *testing.T) { func TestStoreSector(t *testing.T) { log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) - if err != nil { - t.Fatal(err) - } - defer renter.Close() - defer host.Close() + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) - // Resize cache to 0 sectors - host.Storage().ResizeCache(0) + // start the node + sh2, sh3 := setupRHP3Host(t, node, hostKey, 10, log) - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) + // create a RHP3 session + session, err := proto3.NewSession(context.Background(), hostKey.PublicKey(), sh3.LocalAddr(), node.Chain, node.Wallet) if err != nil { t.Fatal(err) } defer session.Close() - revision, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(50), types.Siacoins(100), 200) - if err != nil { - t.Fatal(err) - } + // form a contract to upload sectors + revision := formContract(t, node.Chain, node.Wallet, sh2.LocalAddr(), renterKey, hostKey.PublicKey(), 200) - account := rhp3.Account(renter.PublicKey()) + account := crhp3.Account(renterKey.PublicKey()) // register the price table - payment := proto3.ContractPayment(&revision, renter.PrivateKey(), account) + payment := proto3.ContractPayment(&revision, renterKey, account) pt, err := session.RegisterPriceTable(payment) if err != nil { t.Fatal(err) @@ -204,21 +308,21 @@ func TestStoreSector(t *testing.T) { } // upload a sector - payment = proto3.AccountPayment(account, renter.PrivateKey()) + payment = proto3.AccountPayment(account, renterKey) // calculate the cost of the upload usage := pt.StoreSectorCost(10) cost, _ := usage.Total() - var sector [rhp2.SectorSize]byte + var sector [crhp2.SectorSize]byte frand.Read(sector[:256]) - root := rhp2.SectorRoot(§or) + root := crhp2.SectorRoot(§or) if err = session.StoreSector(§or, 10, payment, cost); err != nil { t.Fatal(err) } // download the sector - usage = pt.ReadSectorCost(rhp2.SectorSize) + usage = pt.ReadSectorCost(crhp2.SectorSize) cost, _ = usage.Total() - downloaded, _, err := session.ReadSector(root, 0, rhp2.SectorSize, payment, cost) + downloaded, _, err := session.ReadSector(root, 0, crhp2.SectorSize, payment, cost) if err != nil { t.Fatal(err) } else if !bytes.Equal(downloaded, sector[:]) { @@ -226,15 +330,12 @@ func TestStoreSector(t *testing.T) { } // mine until the sector expires - if err := host.MineBlocks(types.VoidAddress, 10); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) // sync time + testutil.MineAndSync(t, node, node.Wallet.Address(), 10) // check that the sector was deleted - usage = pt.ReadSectorCost(rhp2.SectorSize) + usage = pt.ReadSectorCost(crhp2.SectorSize) cost, _ = usage.Total() - _, _, err = session.ReadSector(root, 0, rhp2.SectorSize, payment, cost) + _, _, err = session.ReadSector(root, 0, crhp2.SectorSize, payment, cost) if err == nil { t.Fatal("expected error when reading sector") } @@ -242,26 +343,28 @@ func TestStoreSector(t *testing.T) { func TestReadSectorOffset(t *testing.T) { log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) - if err != nil { - t.Fatal(err) - } - defer renter.Close() - defer host.Close() + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) + // start the node + sh2, sh3 := setupRHP3Host(t, node, hostKey, 10, log) + + // create a RHP3 session + session, err := proto3.NewSession(context.Background(), hostKey.PublicKey(), sh3.LocalAddr(), node.Chain, node.Wallet) if err != nil { t.Fatal(err) } defer session.Close() - revision, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(100), types.Siacoins(200), 200) - if err != nil { - t.Fatal(err) - } + // form a contract to upload sectors + revision := formContract(t, node.Chain, node.Wallet, sh2.LocalAddr(), renterKey, hostKey.PublicKey(), 200) - account := rhp3.Account(renter.PublicKey()) - payment := proto3.ContractPayment(&revision, renter.PrivateKey(), account) + account := crhp3.Account(renterKey.PublicKey()) + payment := proto3.ContractPayment(&revision, renterKey, account) // register the price table pt, err := session.RegisterPriceTable(payment) if err != nil { @@ -274,18 +377,18 @@ func TestReadSectorOffset(t *testing.T) { t.Fatal(err) } - cost, _ := pt.BaseCost().Add(pt.AppendSectorCost(revision.Revision.WindowEnd - renter.TipState().Index.Height)).Total() - var sectors [][rhp2.SectorSize]byte + cost, _ := pt.BaseCost().Add(pt.AppendSectorCost(revision.Revision.WindowEnd - node.Chain.Tip().Height)).Total() + var sectors [][crhp2.SectorSize]byte for i := 0; i < 5; i++ { // upload a few sectors - payment = proto3.AccountPayment(account, renter.PrivateKey()) + payment = proto3.AccountPayment(account, renterKey) // calculate the cost of the upload if cost.IsZero() { t.Fatal("cost is zero") } - var sector [rhp2.SectorSize]byte + var sector [crhp2.SectorSize]byte frand.Read(sector[:256]) - _, err = session.AppendSector(§or, &revision, renter.PrivateKey(), payment, cost) + _, err = session.AppendSector(§or, &revision, renterKey, payment, cost) if err != nil { t.Fatal(err) } @@ -294,7 +397,7 @@ func TestReadSectorOffset(t *testing.T) { // download the sector cost, _ = pt.BaseCost().Add(pt.ReadOffsetCost(256)).Total() - downloaded, _, err := session.ReadOffset(rhp2.SectorSize*3+64, 256, revision.ID(), payment, cost) + downloaded, _, err := session.ReadOffset(crhp2.SectorSize*3+64, 256, revision.ID(), payment, cost) if err != nil { t.Fatal(err) } else if !bytes.Equal(downloaded, sectors[3][64:64+256]) { @@ -304,63 +407,66 @@ func TestReadSectorOffset(t *testing.T) { func TestRenew(t *testing.T) { log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) + hostKey, renterKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V1Network() + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) + + // start the node + sh2, sh3 := setupRHP3Host(t, node, hostKey, 10, log) + + // create a RHP3 session + session, err := proto3.NewSession(context.Background(), hostKey.PublicKey(), sh3.LocalAddr(), node.Chain, node.Wallet) if err != nil { t.Fatal(err) } - defer renter.Close() - defer host.Close() + defer session.Close() - t.Run("empty contract", func(t *testing.T) { - state := renter.TipState() - // form a contract - origin, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), state.Index.Height+200) - if err != nil { - t.Fatal(err) - } + account := crhp3.Account(renterKey.PublicKey()) - settings, err := renter.Settings(context.Background(), host.RHP2Addr(), host.PublicKey()) + assertContractUsage := func(t *testing.T, id types.FileContractID, lockedCollateral types.Currency, expected contracts.Usage) { + t.Helper() + + contract, err := node.Contracts.Contract(id) if err != nil { t.Fatal(err) + } else if !contract.LockedCollateral.Equals(lockedCollateral) { + t.Fatalf("expected locked collateral %v, got %v", lockedCollateral, contract.LockedCollateral) } - // mine a few blocks into the contract - if err := host.MineBlocks(host.WalletAddress(), 10); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) + uv := reflect.ValueOf(&contract.Usage).Elem() + ev := reflect.ValueOf(&expected).Elem() - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) - if err != nil { - t.Fatal(err) + for i := 0; i < uv.NumField(); i++ { + va := ev.Field(i).Interface().(types.Currency) + vb := uv.Field(i).Interface().(types.Currency) + if !va.Equals(vb) { + t.Fatalf("field %v: expected %v, got %v", uv.Type().Field(i).Name, va, vb) + } } - defer session.Close() + } - account := rhp3.Account(renter.PublicKey()) - payment := proto3.ContractPayment(&origin, renter.PrivateKey(), account) - // register a price table to use for the renewal - pt, err := session.RegisterPriceTable(payment) - if err != nil { - t.Fatal(err) - } + assertRenewal := func(t *testing.T, parentID types.FileContractID, renewal crhp2.ContractRevision, expectedRevisionNumber uint64, expectedRoot types.Hash256, expectedFilesize uint64) { + t.Helper() - state = renter.TipState() - renewHeight := origin.Revision.WindowEnd + 10 - renterFunds := types.Siacoins(10) - additionalCollateral := types.Siacoins(20) - renewal, _, err := session.RenewContract(&origin, settings.Address, renter.PrivateKey(), renterFunds, additionalCollateral, renewHeight) - if err != nil { - t.Fatal(err) + if renewal.Revision.FileMerkleRoot != expectedRoot { + t.Fatalf("expected root %v, got %v", expectedRoot, renewal.Revision.FileMerkleRoot) + } else if renewal.Revision.Filesize != expectedFilesize { + t.Fatalf("expected filesize %d, got %d", expectedFilesize, renewal.Revision.Filesize) + } else if renewal.Revision.RevisionNumber != expectedRevisionNumber { + t.Fatalf("expected revision number %d, got %d", expectedRevisionNumber, renewal.Revision.RevisionNumber) } - // mine a block to confirm the revision - if err := host.MineBlocks(host.WalletAddress(), 1); err != nil { - t.Fatal(err) + sigHash := rhp.HashRevision(renewal.Revision) + if !hostKey.PublicKey().VerifyHash(sigHash, types.Signature(renewal.Signatures[1].Signature)) { + t.Fatal("host signature invalid") + } else if !renterKey.PublicKey().VerifyHash(sigHash, types.Signature(renewal.Signatures[0].Signature)) { + t.Fatal("renter signature invalid") } - time.Sleep(100 * time.Millisecond) - expectedRevenue := pt.ContractPrice.Add(pt.UpdatePriceTableCost) - old, err := host.Contracts().Contract(origin.ID()) + old, err := node.Contracts.Contract(parentID) if err != nil { t.Fatal(err) } else if old.Revision.Filesize != 0 { @@ -369,274 +475,172 @@ func TestRenew(t *testing.T) { t.Fatal("merkle root mismatch") } else if old.RenewedTo != renewal.ID() { t.Fatal("renewed to mismatch") - } else if !old.Usage.RPCRevenue.Equals(expectedRevenue) { - t.Fatalf("expected old contract rpc revenue to equal %d, got %d", expectedRevenue, old.Usage.RPCRevenue) } - contract, err := host.Contracts().Contract(renewal.ID()) + renewed, err := node.Contracts.Contract(renewal.ID()) if err != nil { t.Fatal(err) - } else if contract.Revision.Filesize != origin.Revision.Filesize { + } else if renewed.Revision.Filesize != expectedFilesize { t.Fatal("filesize mismatch") - } else if contract.Revision.FileMerkleRoot != origin.Revision.FileMerkleRoot { + } else if renewed.Revision.FileMerkleRoot != expectedRoot { t.Fatal("merkle root mismatch") - } else if !contract.LockedCollateral.Equals(additionalCollateral) { - t.Fatalf("locked collateral mismatch: expected %d, got %d", additionalCollateral, contract.LockedCollateral) - } else if !contract.Usage.RiskedCollateral.IsZero() { - t.Fatalf("expected zero risked collateral, got %d", contract.Usage.RiskedCollateral) - } else if !contract.Usage.RPCRevenue.Equals(pt.ContractPrice) { - t.Fatalf("expected %d RPC revenue, got %d", settings.ContractPrice, contract.Usage.RPCRevenue) - } else if !contract.Usage.StorageRevenue.Equals(pt.RenewContractCost) { // renew contract cost is treated as storage revenue because it is burned - t.Fatalf("expected %d storage revenue, got %d", pt.RenewContractCost, contract.Usage.StorageRevenue) - } else if contract.RenewedFrom != origin.ID() { - t.Fatalf("expected renewed from %s, got %s", origin.ID(), contract.RenewedFrom) + } else if renewed.RenewedFrom != parentID { + t.Fatalf("expected renewed from %s, got %s", parentID, renewed.RenewedFrom) } - }) + } - t.Run("non-empty contract", func(t *testing.T) { + t.Run("empty contract", func(t *testing.T) { // form a contract - state := renter.TipState() - origin, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(10), types.Siacoins(20), state.Index.Height+200) - if err != nil { - t.Fatal(err) - } + origin := formContract(t, node.Chain, node.Wallet, sh2.LocalAddr(), renterKey, hostKey.PublicKey(), 200) - settings, err := renter.Settings(context.Background(), host.RHP2Addr(), host.PublicKey()) + testutil.MineAndSync(t, node, node.Wallet.Address(), 5) + + payment := proto3.ContractPayment(&origin, renterKey, account) + + // register a price table to use for the renewal + pt, err := session.RegisterPriceTable(payment) if err != nil { t.Fatal(err) } - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) + renewHeight := origin.Revision.WindowEnd + 10 + renterFunds := types.Siacoins(10) + additionalCollateral := types.Siacoins(20) + renewal, _, err := session.RenewContract(&origin, node.Wallet.Address(), renterKey, renterFunds, additionalCollateral, renewHeight) if err != nil { t.Fatal(err) } - defer session.Close() - account := rhp3.Account(renter.PublicKey()) - payment := proto3.ContractPayment(&origin, renter.PrivateKey(), account) + // mine a block to confirm the revision + testutil.MineAndSync(t, node, node.Wallet.Address(), 1) + + assertRenewal(t, origin.ID(), renewal, 1, origin.Revision.FileMerkleRoot, origin.Revision.Filesize) + assertContractUsage(t, renewal.ID(), additionalCollateral, contracts.Usage{ + RPCRevenue: pt.ContractPrice, + StorageRevenue: pt.RenewContractCost, // renew contract cost is included because it is burned on failure + }) + }) + + t.Run("drained contract", func(t *testing.T) { + // form a contract + origin := formContract(t, node.Chain, node.Wallet, sh2.LocalAddr(), renterKey, hostKey.PublicKey(), 200) + + testutil.MineAndSync(t, node, node.Wallet.Address(), 5) + + payment := proto3.ContractPayment(&origin, renterKey, account) + // register a price table to use for the renewal pt, err := session.RegisterPriceTable(payment) if err != nil { t.Fatal(err) } - // fund an account leaving no funds for the renewal - if _, err := session.FundAccount(account, payment, origin.Revision.ValidRenterPayout().Sub(pt.FundAccountCost)); err != nil { - t.Fatal(err) - } - - // generate a sector - var sector [rhp2.SectorSize]byte - frand.Read(sector[:256]) - - // calculate the remaining duration of the contract + // upload a sector var remainingDuration uint64 contractExpiration := uint64(origin.Revision.WindowEnd) - currentHeight := renter.TipState().Index.Height + currentHeight := node.Chain.Tip().Height if contractExpiration < currentHeight { t.Fatal("contract expired") } - payment = proto3.AccountPayment(account, renter.PrivateKey()) + // generate a sector + var sector [crhp2.SectorSize]byte + frand.Read(sector[:256]) - // upload the sector remainingDuration = contractExpiration - currentHeight usage := pt.BaseCost().Add(pt.AppendSectorCost(remainingDuration)) cost, _ := usage.Total() - if _, err := session.AppendSector(§or, &origin, renter.PrivateKey(), payment, cost); err != nil { + if _, err := session.AppendSector(§or, &origin, renterKey, payment, cost); err != nil { t.Fatal(err) } - // mine a few blocks into the contract - if err := host.MineBlocks(host.WalletAddress(), 10); err != nil { + // fund the account leaving no funds for the renewal + if _, err := session.FundAccount(account, payment, origin.Revision.ValidRenterPayout().Sub(pt.FundAccountCost)); err != nil { t.Fatal(err) } - time.Sleep(100 * time.Millisecond) - state = renter.TipState() + // mine a few blocks into the contract + testutil.MineAndSync(t, node, node.Wallet.Address(), 5) + renewHeight := origin.Revision.WindowEnd + 10 renterFunds := types.Siacoins(10) additionalCollateral := types.Siacoins(20) - renewal, _, err := session.RenewContract(&origin, settings.Address, renter.PrivateKey(), renterFunds, additionalCollateral, renewHeight) + renewal, _, err := session.RenewContract(&origin, node.Wallet.Address(), renterKey, renterFunds, additionalCollateral, renewHeight) if err != nil { t.Fatal(err) } extension := renewal.Revision.WindowEnd - origin.Revision.WindowEnd - baseStorageRevenue := pt.RenewContractCost.Add(pt.WriteStoreCost.Mul64(origin.Revision.Filesize).Mul64(extension)) // renew contract cost is included because it is burned on failure - baseRiskedCollateral := settings.Collateral.Mul64(extension).Mul64(origin.Revision.Filesize) - - expectedExchange := pt.ContractPrice.Add(pt.FundAccountCost).Add(pt.UpdatePriceTableCost).Add(usage.Base) - old, err := host.Contracts().Contract(origin.ID()) - if err != nil { - t.Fatal(err) - } else if old.Revision.Filesize != 0 { - t.Fatal("filesize mismatch") - } else if old.Revision.FileMerkleRoot != (types.Hash256{}) { - t.Fatal("merkle root mismatch") - } else if old.RenewedTo != renewal.ID() { - t.Fatal("renewed to mismatch") - } else if !old.Usage.RPCRevenue.Equals(expectedExchange) { // renewal renew goes on the new contract - t.Fatalf("expected rpc revenue to equal contract price + fund account cost %d, got %d", expectedExchange, old.Usage.RPCRevenue) - } - - contract, err := host.Contracts().Contract(renewal.ID()) - if err != nil { - t.Fatal(err) - } else if contract.Revision.Filesize != origin.Revision.Filesize { - t.Fatal("filesize mismatch") - } else if contract.Revision.FileMerkleRoot != origin.Revision.FileMerkleRoot { - t.Fatal("merkle root mismatch") - } else if contract.LockedCollateral.Cmp(additionalCollateral) <= 0 { - t.Fatalf("locked collateral mismatch: expected at least %d, got %d", additionalCollateral, contract.LockedCollateral) - } else if !contract.Usage.RPCRevenue.Equals(pt.ContractPrice) { - t.Fatalf("expected %d RPC revenue, got %d", pt.ContractPrice, contract.Usage.RPCRevenue) - } else if !contract.Usage.RiskedCollateral.Equals(baseRiskedCollateral) { - t.Fatalf("expected %d risked collateral, got %d", baseRiskedCollateral, contract.Usage.RiskedCollateral) - } else if !contract.Usage.StorageRevenue.Equals(baseStorageRevenue) { - t.Fatalf("expected %d storage revenue, got %d", baseStorageRevenue, contract.Usage.StorageRevenue) - } else if contract.RenewedFrom != origin.ID() { - t.Fatalf("expected renewed from %s, got %s", origin.ID(), contract.RenewedFrom) - } + baseRiskedCollateral := pt.CollateralCost.Mul64(extension).Mul64(origin.Revision.Filesize) + assertContractUsage(t, renewal.Revision.ParentID, additionalCollateral.Add(baseRiskedCollateral), contracts.Usage{ + RPCRevenue: pt.ContractPrice, + RiskedCollateral: baseRiskedCollateral, + StorageRevenue: pt.RenewContractCost.Add(pt.WriteStoreCost.Mul64(origin.Revision.Filesize).Mul64(extension)), // renew contract cost is included because it is burned on failure + }) + assertRenewal(t, origin.ID(), renewal, 1, origin.Revision.FileMerkleRoot, origin.Revision.Filesize) }) } -func BenchmarkAppendSector(b *testing.B) { - log := zaptest.NewLogger(b) - renter, host, err := test.NewTestingPair(b.TempDir(), log) - if err != nil { - b.Fatal(err) - } - defer renter.Close() - defer host.Close() - - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) - if err != nil { - b.Fatal(err) - } - defer session.Close() - - revision, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(50), types.Siacoins(100), 200) - if err != nil { - b.Fatal(err) - } - - account := rhp3.Account(renter.PublicKey()) - // register the price table - payment := proto3.ContractPayment(&revision, renter.PrivateKey(), account) - pt, err := session.RegisterPriceTable(payment) - if err != nil { - b.Fatal(err) - } - - // fund an account - if _, err = session.FundAccount(account, payment, types.Siacoins(10)); err != nil { - b.Fatal(err) - } - - // upload a sector - payment = proto3.AccountPayment(account, renter.PrivateKey()) - // calculate the cost of the upload - cost, _ := pt.BaseCost().Add(pt.AppendSectorCost(revision.Revision.WindowEnd - renter.TipState().Index.Height)).Total() - if cost.IsZero() { - b.Fatal("cost is zero") - } - - var sectors [][rhp2.SectorSize]byte - for i := 0; i < b.N; i++ { - var sector [rhp2.SectorSize]byte - frand.Read(sector[:256]) - sectors = append(sectors, sector) - } - - b.ResetTimer() - b.ReportAllocs() - b.SetBytes(rhp2.SectorSize) - - for i := 0; i < b.N; i++ { - _, err = session.AppendSector(§ors[i], &revision, renter.PrivateKey(), payment, cost) - if err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkReadSector(b *testing.B) { - log := zaptest.NewLogger(b) - renter, host, err := test.NewTestingPair(b.TempDir(), log) - if err != nil { - b.Fatal(err) - } - defer renter.Close() - defer host.Close() - - s := settings.DefaultSettings - s.MaxAccountBalance = types.Siacoins(100) - s.MaxCollateral = types.Siacoins(10000) - s.EgressPrice = types.ZeroCurrency - s.IngressPrice = types.ZeroCurrency +func TestRPCV2(t *testing.T) { + log := zaptest.NewLogger(t) + renterKey, hostKey := types.GeneratePrivateKey(), types.GeneratePrivateKey() + network, genesis := testutil.V2Network() + network.HardforkV2.AllowHeight = 180 + network.HardforkV2.RequireHeight = 200 + node := testutil.NewHostNode(t, hostKey, network, genesis, log) + + // set the host to accept contracts + s := node.Settings.Settings() s.AcceptingContracts = true - if err := host.UpdateSettings(s); err != nil { - b.Fatal(err) + s.NetAddress = "localhost:9983" + if err := node.Settings.UpdateSettings(s); err != nil { + t.Fatal(err) } - if err := host.AddVolume(filepath.Join(b.TempDir(), "data.dat"), uint64(b.N)); err != nil { - b.Fatal(err) + // initialize a storage volume + res := make(chan error) + if _, err := node.Volumes.AddVolume(context.Background(), filepath.Join(t.TempDir(), "storage.dat"), 10, res); err != nil { + t.Fatal(err) + } else if err := <-res; err != nil { + t.Fatal(err) } - session, err := renter.NewRHP3Session(context.Background(), host.RHP3Addr(), host.PublicKey()) - if err != nil { - b.Fatal(err) - } - defer session.Close() + // fund the wallet + testutil.MineAndSync(t, node, node.Wallet.Address(), 150) - revision, err := renter.FormContract(context.Background(), host.RHP2Addr(), host.PublicKey(), types.Siacoins(500), types.Siacoins(1000), 200) - if err != nil { - b.Fatal(err) - } + // start the node + sh2, sh3 := setupRHP3Host(t, node, hostKey, 10, log) - account := rhp3.Account(renter.PublicKey()) - // register the price table - payment := proto3.ContractPayment(&revision, renter.PrivateKey(), account) - pt, err := session.RegisterPriceTable(payment) - if err != nil { - b.Fatal(err) - } + // form a contract that expires before the hardfork + origin := formContract(t, node.Chain, node.Wallet, sh2.LocalAddr(), renterKey, hostKey.PublicKey(), 20) + // mine a block to confirm the contract + testutil.MineAndSync(t, node, node.Wallet.Address(), 1) - // fund an account - _, err = session.FundAccount(account, payment, types.Siacoins(100)) + // create a RHP3 session + session, err := proto3.NewSession(context.Background(), hostKey.PublicKey(), sh3.LocalAddr(), node.Chain, node.Wallet) if err != nil { - b.Fatal(err) - } - - // upload a sector - payment = proto3.AccountPayment(account, renter.PrivateKey()) - // calculate the cost of the upload - cost, _ := pt.BaseCost().Add(pt.AppendSectorCost(revision.Revision.WindowEnd - renter.TipState().Index.Height)).Total() - if cost.IsZero() { - b.Fatal("cost is zero") - } - - var roots []types.Hash256 - for i := 0; i < b.N; i++ { - var sector [rhp2.SectorSize]byte - frand.Read(sector[:256]) - - _, err = session.AppendSector(§or, &revision, renter.PrivateKey(), payment, cost) - if err != nil { - b.Fatal(err) - } - roots = append(roots, rhp2.SectorRoot(§or)) + t.Fatal(err) } + defer session.Close() - b.ResetTimer() - b.ReportAllocs() - b.SetBytes(rhp2.SectorSize) - - for i := 0; i < b.N; i++ { - _, _, err = session.ReadSector(roots[i], 0, rhp2.SectorSize, payment, cost) - if err != nil { - b.Fatal(err) - } + // try to renew the contract with an ending after the hardfork + renewHeight := network.HardforkV2.RequireHeight + renterFunds := types.Siacoins(10) + additionalCollateral := types.Siacoins(20) + _, _, err = session.RenewContract(&origin, node.Wallet.Address(), renterKey, renterFunds, additionalCollateral, renewHeight) + if !errors.Is(err, rhp3.ErrAfterV2Hardfork) { + t.Fatalf("expected after v2 hardfork error, got %v", err) + } + + // mine to activate the v2 hardfork + testutil.MineAndSync(t, node, node.Wallet.Address(), int(network.HardforkV2.AllowHeight-node.Chain.Tip().Height)) + + // try to renew the contract with an end height before the require height, but after the hardfork activation + renewHeight = origin.Revision.WindowEnd + 10 + renterFunds = types.Siacoins(10) + additionalCollateral = types.Siacoins(20) + _, _, err = session.RenewContract(&origin, node.Wallet.Address(), renterKey, renterFunds, additionalCollateral, renewHeight) + if !errors.Is(err, rhp3.ErrV2Hardfork) { + t.Fatalf("expected v2 hardfork error, got %v", err) } } diff --git a/rhp/v3/websocket_test.go b/rhp/v3/websocket_test.go deleted file mode 100644 index e257bfe7..00000000 --- a/rhp/v3/websocket_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package rhp_test - -import ( - "context" - "encoding/json" - "testing" - - rhp3 "go.sia.tech/core/rhp/v3" - "go.sia.tech/hostd/internal/test" - "go.uber.org/goleak" - "go.uber.org/zap/zaptest" - "nhooyr.io/websocket" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestWebSockets(t *testing.T) { - log := zaptest.NewLogger(t) - renter, host, err := test.NewTestingPair(t.TempDir(), log) - if err != nil { - t.Fatal(err) - } - defer renter.Close() - defer host.Close() - - c, _, err := websocket.Dial(context.Background(), "ws://"+host.RHP3WSAddr()+"/ws", nil) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusNormalClosure, "") - - conn := websocket.NetConn(context.Background(), c, websocket.MessageBinary) - transport, err := rhp3.NewRenterTransport(conn, host.PublicKey()) - if err != nil { - t.Fatal(err) - } - defer transport.Close() - - stream := transport.DialStream() - defer stream.Close() - - if err := stream.WriteRequest(rhp3.RPCUpdatePriceTableID, nil); err != nil { - t.Fatal(err) - } - var resp rhp3.RPCUpdatePriceTableResponse - if err := stream.ReadResponse(&resp, 4096); err != nil { - t.Fatal(err) - } - var pt rhp3.HostPriceTable - if err := json.Unmarshal(resp.PriceTableJSON, &pt); err != nil { - t.Fatal(err) - } -} diff --git a/wallet/persist.go b/wallet/persist.go deleted file mode 100644 index c881278b..00000000 --- a/wallet/persist.go +++ /dev/null @@ -1,46 +0,0 @@ -package wallet - -import ( - "time" - - "go.sia.tech/core/types" - "go.sia.tech/siad/modules" -) - -type ( - // An UpdateTransaction atomically updates the wallet store - UpdateTransaction interface { - AddSiacoinElement(SiacoinElement) error - RemoveSiacoinElement(types.SiacoinOutputID) error - AddTransaction(Transaction) error - RevertBlock(types.BlockID) error - - AddWalletDelta(value types.Currency, timestamp time.Time) error - SubWalletDelta(value types.Currency, timestamp time.Time) error - } - - // A SingleAddressStore stores the state of a single-address wallet. - // Implementations are assumed to be thread safe. - SingleAddressStore interface { - // LastWalletChange returns the consensus change ID and block height of - // the last wallet change. - LastWalletChange() (id modules.ConsensusChangeID, height uint64, err error) - // UnspentSiacoinElements returns a list of all unspent siacoin outputs - UnspentSiacoinElements() ([]SiacoinElement, error) - // Transactions returns a paginated list of transactions ordered by - // block height, descending. If no more transactions are available, - // (nil, nil) should be returned. - Transactions(limit, offset int) ([]Transaction, error) - // TransactionCount returns the total number of transactions in the - // wallet. - TransactionCount() (uint64, error) - UpdateWallet(ccID modules.ConsensusChangeID, height uint64, fn func(UpdateTransaction) error) error - // ResetWallet resets the wallet to its initial state. This is used when a - // consensus subscription error occurs. - ResetWallet(seedHash types.Hash256) error - // VerifyWalletKey checks that the wallet seed matches the existing seed - // hash. This detects if the user's recovery phrase has changed and the - // wallet needs to rescan. - VerifyWalletKey(seedHash types.Hash256) error - } -) diff --git a/wallet/wallet.go b/wallet/wallet.go deleted file mode 100644 index c2ad84fb..00000000 --- a/wallet/wallet.go +++ /dev/null @@ -1,805 +0,0 @@ -package wallet - -import ( - "bytes" - "context" - "errors" - "fmt" - "sort" - "strings" - "sync" - "sync/atomic" - "time" - - "gitlab.com/NebulousLabs/encoding" - "go.sia.tech/core/consensus" - "go.sia.tech/core/types" - "go.sia.tech/hostd/internal/chain" - "go.sia.tech/hostd/internal/threadgroup" - "go.sia.tech/siad/modules" - stypes "go.sia.tech/siad/types" - "go.uber.org/zap" -) - -const ( - // transactionDefragThreshold is the number of utxos at which the wallet - // will attempt to defrag itself by including small utxos in transactions. - transactionDefragThreshold = 30 - // maxInputsForDefrag is the maximum number of inputs a transaction can - // have before the wallet will stop adding inputs - maxInputsForDefrag = 30 - // maxDefragUTXOs is the maximum number of utxos that will be added to a - // transaction when defragging - maxDefragUTXOs = 10 -) - -// transaction sources indicate the source of a transaction. Transactions can -// either be created by sending Siacoins between unlock hashes or they can be -// created by consensus (e.g. a miner payout, a siafund claim, or a contract). -const ( - TxnSourceTransaction TransactionSource = "transaction" - TxnSourceMinerPayout TransactionSource = "miner" - TxnSourceSiafundClaim TransactionSource = "siafundClaim" - TxnSourceContract TransactionSource = "contract" - TxnSourceFoundationPayout TransactionSource = "foundation" -) - -var ( - // ErrNotEnoughFunds is returned when there are not enough unspent outputs - // to fund a transaction. - ErrNotEnoughFunds = errors.New("not enough funds") -) - -type ( - // A TransactionSource is a string indicating the source of a transaction. - TransactionSource string - - // A ChainManager manages the current state of the blockchain. - ChainManager interface { - TipState() consensus.State - BlockAtHeight(height uint64) (types.Block, bool) - PoolTransactions() []types.Transaction - Subscribe(subscriber modules.ConsensusSetSubscriber, ccID modules.ConsensusChangeID, cancel <-chan struct{}) error - } - - // A SiacoinElement is a SiacoinOutput along with its ID. - SiacoinElement struct { - types.SiacoinOutput - ID types.SiacoinOutputID - } - - // A Transaction is an on-chain transaction relevant to a particular wallet, - // paired with useful metadata. - Transaction struct { - ID types.TransactionID `json:"id"` - Index types.ChainIndex `json:"index"` - Transaction types.Transaction `json:"transaction"` - Inflow types.Currency `json:"inflow"` - Outflow types.Currency `json:"outflow"` - Source TransactionSource `json:"source"` - Timestamp time.Time `json:"timestamp"` - } - - // A SingleAddressWallet is a hot wallet that manages the outputs controlled by - // a single address. - SingleAddressWallet struct { - scanHeight uint64 // ensure 64-bit alignment on 32-bit systems - - priv types.PrivateKey - addr types.Address - - cm ChainManager - store SingleAddressStore - log *zap.Logger - tg *threadgroup.ThreadGroup - - mu sync.Mutex - // locked is a set of siacoin output IDs locked by FundTransaction. They - // will be released either by explicitly calling release for unused - // transactions or expiring after 3 hours. - locked map[types.SiacoinOutputID]time.Time - } -) - -// ErrDifferentSeed is returned when a different seed is provided to -// NewSingleAddressWallet than was used to initialize the wallet -var ErrDifferentSeed = errors.New("seed differs from wallet seed") - -// EncodeTo implements types.EncoderTo. -func (txn Transaction) EncodeTo(e *types.Encoder) { - txn.ID.EncodeTo(e) - txn.Index.EncodeTo(e) - txn.Transaction.EncodeTo(e) - (*types.V1Currency)(&txn.Inflow).EncodeTo(e) - (*types.V1Currency)(&txn.Outflow).EncodeTo(e) - e.WriteString(string(txn.Source)) - e.WriteTime(txn.Timestamp) -} - -// DecodeFrom implements types.DecoderFrom. -func (txn *Transaction) DecodeFrom(d *types.Decoder) { - txn.ID.DecodeFrom(d) - txn.Index.DecodeFrom(d) - txn.Transaction.DecodeFrom(d) - (*types.V1Currency)(&txn.Inflow).DecodeFrom(d) - (*types.V1Currency)(&txn.Outflow).DecodeFrom(d) - txn.Source = TransactionSource(d.ReadString()) - txn.Timestamp = d.ReadTime() -} - -func transactionIsRelevant(txn types.Transaction, addr types.Address) bool { - for i := range txn.SiacoinInputs { - if txn.SiacoinInputs[i].UnlockConditions.UnlockHash() == addr { - return true - } - } - for i := range txn.SiacoinOutputs { - if txn.SiacoinOutputs[i].Address == addr { - return true - } - } - for i := range txn.SiafundInputs { - if txn.SiafundInputs[i].UnlockConditions.UnlockHash() == addr { - return true - } - if txn.SiafundInputs[i].ClaimAddress == addr { - return true - } - } - for i := range txn.SiafundOutputs { - if txn.SiafundOutputs[i].Address == addr { - return true - } - } - for i := range txn.FileContracts { - for _, sco := range txn.FileContracts[i].ValidProofOutputs { - if sco.Address == addr { - return true - } - } - for _, sco := range txn.FileContracts[i].MissedProofOutputs { - if sco.Address == addr { - return true - } - } - } - for i := range txn.FileContractRevisions { - for _, sco := range txn.FileContractRevisions[i].ValidProofOutputs { - if sco.Address == addr { - return true - } - } - for _, sco := range txn.FileContractRevisions[i].MissedProofOutputs { - if sco.Address == addr { - return true - } - } - } - return false -} - -// isLocked returns whether an output is currently locked by FundTransaction. -func (sw *SingleAddressWallet) isLocked(id types.SiacoinOutputID) bool { - return sw.locked[id].After(time.Now()) -} - -func (sw *SingleAddressWallet) tpoolRelevant(unspentElements []SiacoinElement) (relevant []Transaction, created map[types.SiacoinOutputID]types.SiacoinElement, spent map[types.SiacoinOutputID]bool) { - txns := sw.cm.PoolTransactions() - created = make(map[types.SiacoinOutputID]types.SiacoinElement) - spent = make(map[types.SiacoinOutputID]bool) - - utxos := make(map[types.SiacoinOutputID]types.SiacoinElement, len(unspentElements)) - for _, sce := range unspentElements { - utxos[sce.ID] = types.SiacoinElement{ - StateElement: types.StateElement{ - ID: types.Hash256(sce.ID), - }, - SiacoinOutput: sce.SiacoinOutput, - } - } - - timestamp := time.Now() - for _, txn := range txns { - var inflow, outflow types.Currency - for _, sci := range txn.SiacoinInputs { - if sci.UnlockConditions.UnlockHash() == sw.addr { - spent[types.SiacoinOutputID(sci.ParentID)] = true - outflow = outflow.Add(utxos[types.SiacoinOutputID(sci.ParentID)].SiacoinOutput.Value) - delete(created, sci.ParentID) - } - } - for i, sco := range txn.SiacoinOutputs { - if sco.Address == sw.addr { - outputID := txn.SiacoinOutputID(i) - utxos[outputID] = types.SiacoinElement{ - StateElement: types.StateElement{ - ID: types.Hash256(outputID), - }, - SiacoinOutput: sco, - } - created[outputID] = utxos[outputID] - inflow = inflow.Add(sco.Value) - } - } - - if !transactionIsRelevant(txn, sw.addr) { - continue - } - relevant = append(relevant, Transaction{ - ID: txn.ID(), - Transaction: txn, - Inflow: inflow, - Outflow: outflow, - Source: TxnSourceTransaction, - Timestamp: timestamp, - }) - } - return -} - -// Close closes the wallet -func (sw *SingleAddressWallet) Close() error { - sw.tg.Stop() - return nil -} - -// Address returns the address of the wallet. -func (sw *SingleAddressWallet) Address() types.Address { - return sw.addr -} - -// UnlockConditions returns the unlock conditions of the wallet. -func (sw *SingleAddressWallet) UnlockConditions() types.UnlockConditions { - return types.StandardUnlockConditions(sw.priv.PublicKey()) -} - -// Balance returns the balance of the wallet. -func (sw *SingleAddressWallet) Balance() (spendable, confirmed, unconfirmed types.Currency, err error) { - done, err := sw.tg.Add() - if err != nil { - return types.ZeroCurrency, types.ZeroCurrency, types.ZeroCurrency, err - } - defer done() - - outputs, err := sw.store.UnspentSiacoinElements() - if err != nil { - return types.ZeroCurrency, types.ZeroCurrency, types.ZeroCurrency, fmt.Errorf("failed to get unspent outputs: %w", err) - } - sw.mu.Lock() - defer sw.mu.Unlock() - - _, tpoolUTXOs, spent := sw.tpoolRelevant(outputs) - - for _, sco := range outputs { - confirmed = confirmed.Add(sco.Value) - if !sw.isLocked(sco.ID) && !spent[sco.ID] { - spendable = spendable.Add(sco.Value) - } - } - - for _, sco := range tpoolUTXOs { - if spent[types.SiacoinOutputID(sco.ID)] { - continue - } - unconfirmed = unconfirmed.Add(sco.SiacoinOutput.Value) - } - return -} - -// Transactions returns a paginated list of transactions, ordered by block -// height descending. If no more transactions are available, (nil, nil) is -// returned. -func (sw *SingleAddressWallet) Transactions(limit, offset int) ([]Transaction, error) { - done, err := sw.tg.Add() - if err != nil { - return nil, err - } - defer done() - return sw.store.Transactions(limit, offset) -} - -// TransactionCount returns the total number of transactions in the wallet. -func (sw *SingleAddressWallet) TransactionCount() (uint64, error) { - done, err := sw.tg.Add() - if err != nil { - return 0, err - } - defer done() - return sw.store.TransactionCount() -} - -// FundTransaction adds siacoin inputs worth at least amount to the provided -// transaction. If necessary, a change output will also be added. The inputs -// will not be available to future calls to FundTransaction unless ReleaseInputs -// is called. -func (sw *SingleAddressWallet) FundTransaction(txn *types.Transaction, amount types.Currency) ([]types.Hash256, func(), error) { - done, err := sw.tg.Add() - if err != nil { - return nil, nil, err - } - defer done() - - if amount.IsZero() { - return nil, func() {}, nil - } - - sw.mu.Lock() - defer sw.mu.Unlock() - - utxos, err := sw.store.UnspentSiacoinElements() - if err != nil { - return nil, nil, err - } - - _, _, tpoolSpent := sw.tpoolRelevant(utxos) - - // remove locked and spent outputs - usableUTXOs := utxos[:0] - for _, sce := range utxos { - if sw.isLocked(sce.ID) || tpoolSpent[types.SiacoinOutputID(sce.ID)] { - continue - } - usableUTXOs = append(usableUTXOs, sce) - } - - // sort by value, descending - sort.Slice(usableUTXOs, func(i, j int) bool { - return usableUTXOs[i].Value.Cmp(usableUTXOs[j].Value) > 0 - }) - - // fund the transaction using the largest utxos first - var selected []SiacoinElement - var inputSum types.Currency - for i, sce := range usableUTXOs { - if inputSum.Cmp(amount) >= 0 { - usableUTXOs = usableUTXOs[i:] - break - } - selected = append(selected, sce) - inputSum = inputSum.Add(sce.Value) - } - - // if the transaction can't be funded, return an error - if inputSum.Cmp(amount) < 0 { - return nil, nil, ErrNotEnoughFunds - } - - // check if remaining utxos should be defragged - txnInputs := len(txn.SiacoinInputs) + len(selected) - if len(usableUTXOs) > transactionDefragThreshold && txnInputs < maxInputsForDefrag { - // add the smallest utxos to the transaction - defraggable := usableUTXOs - if len(defraggable) > maxDefragUTXOs { - defraggable = defraggable[len(defraggable)-maxDefragUTXOs:] - } - for i := len(defraggable) - 1; i >= 0; i-- { - if txnInputs >= maxInputsForDefrag { - break - } - - sce := defraggable[i] - selected = append(selected, sce) - inputSum = inputSum.Add(sce.Value) - txnInputs++ - } - } - - // add a change output if necessary - if inputSum.Cmp(amount) > 0 { - txn.SiacoinOutputs = append(txn.SiacoinOutputs, types.SiacoinOutput{ - Value: inputSum.Sub(amount), - Address: sw.addr, - }) - } - - toSign := make([]types.Hash256, len(selected)) - for i, sce := range selected { - txn.SiacoinInputs = append(txn.SiacoinInputs, types.SiacoinInput{ - ParentID: types.SiacoinOutputID(sce.ID), - UnlockConditions: types.StandardUnlockConditions(sw.priv.PublicKey()), - }) - toSign[i] = types.Hash256(sce.ID) - sw.locked[sce.ID] = time.Now().Add(3 * time.Hour) - } - - release := func() { - sw.mu.Lock() - defer sw.mu.Unlock() - for _, id := range toSign { - delete(sw.locked, types.SiacoinOutputID(id)) - } - } - return toSign, release, nil -} - -// SignTransaction adds a signature to each of the specified inputs. -func (sw *SingleAddressWallet) SignTransaction(cs consensus.State, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error { - done, err := sw.tg.Add() - if err != nil { - return err - } - defer done() - - for _, id := range toSign { - var h types.Hash256 - if cf.WholeTransaction { - h = cs.WholeSigHash(*txn, id, 0, 0, cf.Signatures) - } else { - h = cs.PartialSigHash(*txn, cf) - } - sig := sw.priv.SignHash(h) - txn.Signatures = append(txn.Signatures, types.TransactionSignature{ - ParentID: id, - CoveredFields: cf, - PublicKeyIndex: 0, - Signature: sig[:], - }) - } - return nil -} - -// ScanHeight returns the block height the wallet has scanned to. -func (sw *SingleAddressWallet) ScanHeight() uint64 { - return atomic.LoadUint64(&sw.scanHeight) -} - -// UnconfirmedTransactions returns all unconfirmed transactions relevant to the -// wallet. -func (sw *SingleAddressWallet) UnconfirmedTransactions() ([]Transaction, error) { - sw.mu.Lock() - defer sw.mu.Unlock() - - utxos, err := sw.store.UnspentSiacoinElements() - if err != nil { - return nil, fmt.Errorf("failed to get unspent outputs: %w", err) - } - - relevant, _, _ := sw.tpoolRelevant(utxos) - return relevant, nil -} - -// ProcessConsensusChange implements modules.ConsensusSetSubscriber. -func (sw *SingleAddressWallet) ProcessConsensusChange(cc modules.ConsensusChange) { - done, err := sw.tg.Add() - if err != nil { - return - } - defer done() - - sw.log.Debug("processing consensus change", zap.Int("applied", len(cc.AppliedBlocks)), zap.Int("reverted", len(cc.RevertedBlocks))) - start := time.Now() - - // create payout transactions for each matured siacoin output. Each diff - // should correspond to an applied block. This is done outside of the - // database transaction to reduce lock contention. - appliedPayoutTxns := make([][]Transaction, len(cc.AppliedDiffs)) - // calculate the block height of the first applied diff - blockHeight := uint64(cc.BlockHeight) - uint64(len(cc.AppliedBlocks)) + 1 - for i := 0; i < len(cc.AppliedDiffs); i, blockHeight = i+1, blockHeight+1 { - var block types.Block - convertToCore(cc.AppliedBlocks[i], (*types.V1Block)(&block)) - - diff := cc.AppliedDiffs[i] - index := types.ChainIndex{ - ID: block.ID(), - Height: blockHeight, - } - // determine the source of each delayed output - delayedOutputSources := make(map[types.SiacoinOutputID]TransactionSource) - if blockHeight > uint64(stypes.MaturityDelay) { - matureHeight := blockHeight - uint64(stypes.MaturityDelay) - // get the block that has matured - matureBlock, ok := sw.cm.BlockAtHeight(matureHeight) - if !ok { - sw.log.Error("failed to get matured block", zap.Uint64("height", blockHeight), zap.Uint64("maturedHeight", matureHeight)) - sw.Close() - return - } - matureID := matureBlock.ID() - delayedOutputSources[matureID.FoundationOutputID()] = TxnSourceFoundationPayout - for i := range matureBlock.MinerPayouts { - delayedOutputSources[matureID.MinerOutputID(i)] = TxnSourceMinerPayout - } - for _, txn := range matureBlock.Transactions { - for _, output := range txn.SiafundInputs { - delayedOutputSources[output.ParentID.ClaimOutputID()] = TxnSourceSiafundClaim - } - } - } - - for _, dsco := range diff.DelayedSiacoinOutputDiffs { - // if a delayed output is reverted in an applied diff, the - // output has matured -- add a payout transaction. - if types.Address(dsco.SiacoinOutput.UnlockHash) != sw.addr || dsco.Direction != modules.DiffRevert { - continue - } - // contract payouts are harder to identify, any unknown output - // ID is assumed to be a contract payout. - var source TransactionSource - if s, ok := delayedOutputSources[types.SiacoinOutputID(dsco.ID)]; ok { - source = s - } else { - source = TxnSourceContract - } - // append the payout transaction to the diff - var utxo types.SiacoinOutput - convertToCore(dsco.SiacoinOutput, (*types.V1SiacoinOutput)(&utxo)) - sce := SiacoinElement{ - ID: types.SiacoinOutputID(dsco.ID), - SiacoinOutput: utxo, - } - appliedPayoutTxns[i] = append(appliedPayoutTxns[i], payoutTransaction(sce, index, source, block.Timestamp)) - } - } - - spentOutputs := make(map[types.SiacoinOutputID]types.SiacoinOutput) - for _, applied := range cc.AppliedDiffs { - for _, diff := range applied.SiacoinOutputDiffs { - if diff.Direction == modules.DiffRevert { - var so types.SiacoinOutput - convertToCore(diff.SiacoinOutput, (*types.V1SiacoinOutput)(&so)) - spentOutputs[types.SiacoinOutputID(diff.ID)] = so - } - } - } - - // begin a database transaction to update the wallet state - err = sw.store.UpdateWallet(cc.ID, uint64(cc.BlockHeight), func(tx UpdateTransaction) error { - // add new siacoin outputs and remove spent or reverted siacoin outputs - for _, diff := range cc.SiacoinOutputDiffs { - if types.Address(diff.SiacoinOutput.UnlockHash) != sw.addr { - continue - } - if diff.Direction == modules.DiffApply { - var sco types.SiacoinOutput - convertToCore(diff.SiacoinOutput, (*types.V1SiacoinOutput)(&sco)) - err := tx.AddSiacoinElement(SiacoinElement{ - SiacoinOutput: sco, - ID: types.SiacoinOutputID(diff.ID), - }) - sw.log.Debug("added utxo", zap.String("id", diff.ID.String()), zap.String("value", sco.Value.ExactString()), zap.String("address", sco.Address.String())) - if err != nil { - return fmt.Errorf("failed to add siacoin element %v: %w", diff.ID, err) - } - } else { - err := tx.RemoveSiacoinElement(types.SiacoinOutputID(diff.ID)) - if err != nil { - return fmt.Errorf("failed to remove siacoin element %v: %w", diff.ID, err) - } - sw.log.Debug("removed utxo", zap.String("id", diff.ID.String()), zap.String("value", diff.SiacoinOutput.Value.String()), zap.String("address", diff.SiacoinOutput.UnlockHash.String())) - } - } - - // revert blocks -- will also revert all transactions and payout transactions - for _, reverted := range cc.RevertedBlocks { - blockID := types.BlockID(reverted.ID()) - if err := tx.RevertBlock(blockID); err != nil { - return fmt.Errorf("failed to revert block %v: %w", blockID, err) - } - } - - for i, diff := range cc.RevertedDiffs { - blockTimestamp := time.Unix(int64(cc.RevertedBlocks[i].Timestamp), 0) - for _, sco := range diff.SiacoinOutputDiffs { - var addr types.Address - copy(addr[:], sco.SiacoinOutput.UnlockHash[:]) - if addr != sw.addr { - continue - } - - var value types.Currency - convertToCore(sco.SiacoinOutput.Value, (*types.V1Currency)(&value)) - switch sco.Direction { - case modules.DiffApply: - if err := tx.AddWalletDelta(value, blockTimestamp); err != nil { - return fmt.Errorf("failed to add wallet delta: %w", err) - } - case modules.DiffRevert: - if err := tx.SubWalletDelta(value, blockTimestamp); err != nil { - return fmt.Errorf("failed to sub wallet delta: %w", err) - } - } - } - } - - for i, diff := range cc.AppliedDiffs { - blockTimestamp := time.Unix(int64(cc.AppliedBlocks[i].Timestamp), 0) - for _, sco := range diff.SiacoinOutputDiffs { - var addr types.Address - copy(addr[:], sco.SiacoinOutput.UnlockHash[:]) - if addr != sw.addr { - continue - } - - var value types.Currency - convertToCore(sco.SiacoinOutput.Value, (*types.V1Currency)(&value)) - switch sco.Direction { - case modules.DiffApply: - if err := tx.AddWalletDelta(value, blockTimestamp); err != nil { - return fmt.Errorf("failed to add wallet delta: %w", err) - } - case modules.DiffRevert: - if err := tx.SubWalletDelta(value, blockTimestamp); err != nil { - return fmt.Errorf("failed to sub wallet delta: %w", err) - } - } - } - } - - // calculate the block height of the first applied block - blockHeight = uint64(cc.BlockHeight) - uint64(len(cc.AppliedBlocks)) + 1 - // apply transactions - for i := 0; i < len(cc.AppliedBlocks); i, blockHeight = i+1, blockHeight+1 { - var block types.Block - convertToCore(cc.AppliedBlocks[i], (*types.V1Block)(&block)) - - index := types.ChainIndex{ - ID: block.ID(), - Height: blockHeight, - } - - // apply actual transactions -- only relevant transactions should be - // added to the database - for _, txn := range block.Transactions { - if !transactionIsRelevant(txn, sw.addr) { - continue - } - var inflow, outflow types.Currency - for _, out := range txn.SiacoinOutputs { - if out.Address == sw.addr { - inflow = inflow.Add(out.Value) - } - } - for _, in := range txn.SiacoinInputs { - if in.UnlockConditions.UnlockHash() == sw.addr { - so, ok := spentOutputs[in.ParentID] - if !ok { - panic("spent output not found") - } - outflow = outflow.Add(so.Value) - } - } - - // ignore transactions that don't affect the wallet (e.g. - // "setup" transactions) - if inflow.Equals(outflow) { - continue - } - - err := tx.AddTransaction(Transaction{ - ID: txn.ID(), - Index: index, - Inflow: inflow, - Outflow: outflow, - Source: TxnSourceTransaction, - Transaction: txn, - Timestamp: block.Timestamp, - }) - if err != nil { - return fmt.Errorf("failed to add transaction %v: %w", txn.ID(), err) - } - } - - // apply payout transactions -- all transactions should be relevant - // to the wallet - for _, txn := range appliedPayoutTxns[i] { - if err := tx.AddTransaction(txn); err != nil { - return fmt.Errorf("failed to add payout transaction %v: %w", txn.ID, err) - } - } - } - return nil - }) - if err != nil { - sw.log.Panic("failed to update wallet", zap.Error(err), zap.String("changeID", cc.ID.String()), zap.Uint64("height", uint64(cc.BlockHeight))) - } - - atomic.StoreUint64(&sw.scanHeight, uint64(cc.BlockHeight)) - sw.log.Debug("applied consensus change", zap.String("changeID", cc.ID.String()), zap.Int("applied", len(cc.AppliedBlocks)), zap.Int("reverted", len(cc.RevertedBlocks)), zap.Uint64("height", uint64(cc.BlockHeight)), zap.Duration("elapsed", time.Since(start)), zap.String("address", sw.addr.String())) -} - -// payoutTransaction wraps a delayed siacoin output in a transaction for display -// in the wallet. -func payoutTransaction(output SiacoinElement, index types.ChainIndex, source TransactionSource, timestamp time.Time) Transaction { - return Transaction{ - ID: types.TransactionID(output.ID), - Index: index, - Transaction: types.Transaction{ - SiacoinOutputs: []types.SiacoinOutput{output.SiacoinOutput}, - }, - Inflow: output.Value, - Source: source, - Timestamp: timestamp, - } -} - -// convertToCore converts a siad type to an equivalent core type. -func convertToCore(siad encoding.SiaMarshaler, core types.DecoderFrom) { - var buf bytes.Buffer - siad.MarshalSia(&buf) - d := types.NewBufDecoder(buf.Bytes()) - core.DecodeFrom(d) - if d.Err() != nil { - panic(d.Err()) - } -} - -// NewSingleAddressWallet returns a new SingleAddressWallet using the provided private key and store. -func NewSingleAddressWallet(priv types.PrivateKey, cm ChainManager, store SingleAddressStore, log *zap.Logger) (*SingleAddressWallet, error) { - changeID, scanHeight, err := store.LastWalletChange() - if err != nil { - return nil, fmt.Errorf("failed to get last wallet change: %w", err) - } - - seedHash := types.HashBytes(priv[:]) - if err := store.VerifyWalletKey(seedHash); errors.Is(err, ErrDifferentSeed) { - changeID = modules.ConsensusChangeBeginning - scanHeight = 0 - if err := store.ResetWallet(seedHash); err != nil { - return nil, fmt.Errorf("failed to reset wallet: %w", err) - } - log.Info("wallet reset due to seed change") - } else if err != nil { - return nil, fmt.Errorf("failed to verify wallet key: %w", err) - } - - sw := &SingleAddressWallet{ - priv: priv, - scanHeight: scanHeight, - - store: store, - cm: cm, - log: log, - tg: threadgroup.New(), - - addr: types.StandardUnlockHash(priv.PublicKey()), - // locked is a set of siacoin output IDs locked by FundTransaction. They - // will be released either by calling Release for unused transactions or - // being confirmed in a block. - locked: make(map[types.SiacoinOutputID]time.Time), - } - - go func() { - ctx, cancel, err := sw.tg.AddContext(context.Background()) - if err != nil { - sw.log.Error("failed to add context", zap.Error(err)) - return - } - defer cancel() - - t := time.NewTicker(3 * time.Hour) - defer t.Stop() - - for { - select { - case <-t.C: - sw.mu.Lock() - for id, expiration := range sw.locked { - if expiration.Before(time.Now()) { - delete(sw.locked, id) - } - } - sw.mu.Unlock() - case <-ctx.Done(): - return - } - } - }() - - go func() { - // note: start in goroutine to avoid blocking startup - err := cm.Subscribe(sw, changeID, sw.tg.Done()) - if errors.Is(err, chain.ErrInvalidChangeID) { - sw.log.Warn("rescanning blockchain due to unknown consensus change ID") - // reset change ID and subscribe again - if err := store.ResetWallet(seedHash); err != nil { - sw.log.Fatal("failed to reset wallet", zap.Error(err)) - } else if err = cm.Subscribe(sw, modules.ConsensusChangeBeginning, sw.tg.Done()); err != nil { - sw.log.Fatal("failed to reset consensus change subscription", zap.Error(err)) - } - } else if err != nil && !strings.Contains(err.Error(), "ThreadGroup already stopped") { - sw.log.Fatal("failed to subscribe to consensus set", zap.Error(err)) - } - }() - return sw, nil -} diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go deleted file mode 100644 index 286a94eb..00000000 --- a/wallet/wallet_test.go +++ /dev/null @@ -1,511 +0,0 @@ -package wallet_test - -import ( - "encoding/json" - "sort" - "testing" - "time" - - "go.sia.tech/core/types" - "go.sia.tech/hostd/internal/test" - "go.sia.tech/hostd/wallet" - stypes "go.sia.tech/siad/types" - "go.uber.org/zap/zaptest" - "lukechampine.com/frand" -) - -func TestWallet(t *testing.T) { - log := zaptest.NewLogger(t) - w, err := test.NewWallet(types.GeneratePrivateKey(), t.TempDir(), log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - _, balance, _, err := w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %v", balance) - } - - initialState := w.TipState() - - // mine a block to fund the wallet - if err := w.MineBlocks(w.Address(), 1); err != nil { - t.Fatal(err) - } - - // the outputs have not matured yet - _, balance, _, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %v", balance) - } else if m, err := w.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %d", m.Balance) - } - - // mine until the first output has matured - if err := w.MineBlocks(types.VoidAddress, int(stypes.MaturityDelay)); err != nil { - t.Fatal(err) - } - time.Sleep(500 * time.Millisecond) // sleep for consensus sync - - // check the wallet's reported balance - expectedBalance := initialState.BlockReward() - _, balance, _, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(expectedBalance) { - t.Fatalf("expected %d balance, got %d", expectedBalance, balance) - } else if m, err := w.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Balance.Equals(expectedBalance) { - t.Fatalf("expected %d balance, got %d", expectedBalance, m.Balance) - } - - // check that the wallet has a single transaction - count, err := w.TransactionCount() - if err != nil { - t.Fatal(err) - } else if count != 1 { - t.Fatalf("expected 1 transaction, got %v", count) - } - - // check that the payout transaction was created - txns, err := w.Transactions(100, 0) - if err != nil { - t.Fatal(err) - } else if len(txns) != 1 { - t.Fatalf("expected 1 transaction, got %v", len(txns)) - } else if txns[0].Source != wallet.TxnSourceMinerPayout { - t.Fatalf("expected miner payout, got %v", txns[0].Source) - } - - // split the wallet's balance into 20 outputs - splitOutputs := make([]types.SiacoinOutput, 20) - for i := range splitOutputs { - splitOutputs[i] = types.SiacoinOutput{ - Value: expectedBalance.Div64(20), - Address: w.Address(), - } - } - if txn, err := w.SendSiacoins(splitOutputs); err != nil { - buf, _ := json.MarshalIndent(txn, "", " ") - t.Log(string(buf)) - t.Fatal(err) - } - - // check that the wallet's spendable balance and unconfirmed balance are - // correct - spendable, balance, unconfirmed, err := w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(expectedBalance) { - t.Fatalf("expected %v balance, got %v", expectedBalance, balance) - } else if !spendable.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero spendable balance, got %v", spendable) - } else if !unconfirmed.Equals(expectedBalance) { - t.Fatalf("expected %v unconfirmed balance, got %v", expectedBalance, unconfirmed) - } - - // mine another block to confirm the transaction - if err := w.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(500 * time.Millisecond) - - // check that the wallet's balance is the same - _, balance, unconfirmed, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(expectedBalance) { - t.Fatalf("expected %v balance, got %v", expectedBalance, balance) - } else if !unconfirmed.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero unconfirmed balance, got %v", unconfirmed) - } else if m, err := w.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Balance.Equals(expectedBalance) { - t.Fatalf("expected %d balance, got %d", expectedBalance, m.Balance) - } - - // check that the wallet only has one transaction. The split transaction - // does not count since inflow = outflow - count, err = w.TransactionCount() - if err != nil { - t.Fatal(err) - } else if count != 1 { - t.Fatalf("expected 1 transactions, got %v", count) - } - - // send all the outputs to the burn address individually - var sentTransactions []types.Transaction - for i := 0; i < 20; i++ { - txn, err := w.SendSiacoins([]types.SiacoinOutput{ - {Value: expectedBalance.Div64(20)}, - }) - if err != nil { - t.Fatal(err) - } - sentTransactions = append(sentTransactions, txn) - } - - time.Sleep(250 * time.Millisecond) // sleep for tpool sync - // check that the wallet's spendable balance and unconfirmed balance are - // correct - spendable, balance, unconfirmed, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(expectedBalance) { - t.Fatalf("expected %v balance, got %v", expectedBalance, balance) - } else if !spendable.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero spendable balance, got %v", spendable) - } else if !unconfirmed.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero unconfirmed balance, got %v", unconfirmed) - } - - // mine another block to confirm the transactions - if err := w.MineBlocks(types.VoidAddress, 1); err != nil { - t.Fatal(err) - } - time.Sleep(500 * time.Millisecond) - - // check that the wallet now has 21 transactions, 1 + 20 void transactions - count, err = w.TransactionCount() - if err != nil { - t.Fatal(err) - } else if count != 21 { - t.Fatalf("expected 21 transactions, got %v", count) - } else if m, err := w.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected %d balance, got %d", types.ZeroCurrency, m.Balance) - } - - // check that the paginated transactions are in the proper order - for i := 0; i < 20; i++ { - expectedTxn := sentTransactions[i] - txns, err := w.Transactions(1, i) - if err != nil { - t.Fatal(err) - } else if len(txns) != 1 { - t.Fatalf("expected 1 transaction, got %v", len(txns)) - } else if txns[0].Transaction.ID() != expectedTxn.ID() { - t.Fatalf("expected transaction %v, got %v", expectedTxn.ID(), txns[0].Transaction.ID()) - } else if txns[0].Source != wallet.TxnSourceTransaction { - t.Fatalf("expected transaction source, got %v", txns[0].Source) - } - } - - // start a new node to trigger a reorg - w2, err := test.NewWallet(types.GeneratePrivateKey(), t.TempDir(), log.Named("wallet2")) - if err != nil { - t.Fatal(err) - } - defer w2.Close() - - // mine enough blocks on the second node to trigger a reorg - if err := w2.MineBlocks(types.Address{}, int(stypes.MaturityDelay)*4); err != nil { - t.Fatal(err) - } - - // connect the nodes. node1 should begin reverting its blocks - if err := w.ConnectPeer(w2.GatewayAddr()); err != nil { - t.Fatal(err) - } - for i := 0; i < 100; i++ { - if w.TipState().Index.ID == w2.TipState().Index.ID { - break - } - time.Sleep(time.Second) - } - if w.TipState().Index.ID != w2.TipState().Index.ID { - t.Fatal("nodes are not synced") - } - - // check that the wallet's balance is back to 0 - _, balance, _, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %v", balance) - } else if m, err := w.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected %d balance, got %d", types.ZeroCurrency, m.Balance) - } - - // check that all transactions have been deleted - txns, err = w.Transactions(100, 0) - if err != nil { - t.Fatal(err) - } else if len(txns) != 0 { - t.Fatalf("expected 0 transactions, got %v", len(txns)) - } -} - -func TestWalletReset(t *testing.T) { - log := zaptest.NewLogger(t) - dir := t.TempDir() - w, err := test.NewWallet(types.GeneratePrivateKey(), dir, log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - _, balance, _, err := w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.IsZero() { - t.Fatalf("expected zero balance, got %v", balance) - } - - // mine until the wallet has funds - if err := w.MineBlocks(w.Address(), int(stypes.MaturityDelay)*2); err != nil { - t.Fatal(err) - } - time.Sleep(time.Second) // sleep for sync - - height := w.ScanHeight() - - // check that the wallet has UTXOs and transactions - _, balance, _, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if balance.IsZero() { - t.Fatal("expected non-zero balance") - } else if txns, err := w.Transactions(100, 0); err != nil { - t.Fatal(err) - } else if len(txns) == 0 { - t.Fatal("expected transactions") - } - - m, err := w.Store().Metrics(time.Now()) - if err != nil { - t.Fatal(err) - } else if m.Balance.IsZero() { - t.Fatal("expected non-zero balance") - } - - // close the wallet and trigger a reset by using a different private key - w.Close() - - w, err = test.NewWallet(types.GeneratePrivateKey(), dir, log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - // wait for the wallet to resync - for i := 0; i < 100; i++ { - if current := w.ScanHeight(); current == height { - break - } - time.Sleep(time.Second) // sleep for sync - } - if current := w.ScanHeight(); current != height { - t.Fatalf("expected scan height %v, got %v", height, current) - } - - // check that the wallet has no UTXOs or transactions - _, balance, _, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.IsZero() { - t.Fatalf("expected zero balance, got %v", balance) - } else if txns, err := w.Transactions(100, 0); err != nil { - t.Fatal(err) - } else if len(txns) != 0 { - t.Fatal("expected no transactions") - } - - m, err = w.Store().Metrics(time.Now()) - if err != nil { - t.Fatal(err) - } else if !m.Balance.IsZero() { - t.Fatal("expected zero balance") - } -} - -func TestWalletUTXOSelection(t *testing.T) { - log := zaptest.NewLogger(t) - w, err := test.NewWallet(types.GeneratePrivateKey(), t.TempDir(), log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - _, balance, _, err := w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %v", balance) - } - - // mine until the wallet has 100 mature outputs - if err := w.MineBlocks(w.Address(), 100+int(stypes.MaturityDelay)); err != nil { - t.Fatal(err) - } - - time.Sleep(time.Second) // sleep for consensus sync - - // check that the expected utxos were used - utxos, err := w.Store().UnspentSiacoinElements() - if err != nil { - t.Fatal(err) - } else if len(utxos) != 100 { - t.Fatalf("expected 100 utxos, got %v", len(utxos)) - } - sort.Slice(utxos, func(i, j int) bool { - return utxos[i].Value.Cmp(utxos[j].Value) > 0 - }) - - // send a transaction to the burn address - sendAmount := types.Siacoins(10) - minerFee := types.Siacoins(1) - txn := types.Transaction{ - MinerFees: []types.Currency{minerFee}, - SiacoinOutputs: []types.SiacoinOutput{ - {Address: types.VoidAddress, Value: sendAmount}, - }, - } - - fundAmount := sendAmount.Add(minerFee) - toSign, release, err := w.FundTransaction(&txn, fundAmount) - if err != nil { - t.Fatal(err) - } - - if len(txn.SiacoinInputs) != 11 { - t.Fatalf("expected 10 additional defrag inputs, got %v", len(toSign)-1) - } else if len(txn.SiacoinOutputs) != 2 { - t.Fatalf("expected a change output, got %v", len(txn.SiacoinOutputs)) - } - - // check that the expected UTXOs were added - spent := []wallet.SiacoinElement{utxos[0]} - rem := utxos[90:] - for i := len(rem) - 1; i >= 0; i-- { - spent = append(spent, rem[i]) - } - - for i := range txn.SiacoinInputs { - if txn.SiacoinInputs[i].ParentID != spent[i].ID { - t.Fatalf("expected input %v to spend %v, got %v", i, spent[i].ID, txn.SiacoinInputs[i].ParentID) - } - } - - if err := w.SignTransaction(w.TipState(), &txn, toSign, types.CoveredFields{WholeTransaction: true}); err != nil { - release() - t.Fatal(err) - } else if err := w.TPool().AcceptTransactionSet([]types.Transaction{txn}); err != nil { - release() - t.Fatal(err) - } -} - -func TestTransactionUnconfirmedValue(t *testing.T) { - log := zaptest.NewLogger(t) - w, err := test.NewWallet(types.GeneratePrivateKey(), t.TempDir(), log.Named("wallet")) - if err != nil { - t.Fatal(err) - } - defer w.Close() - - _, balance, _, err := w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %v", balance) - } - - initialState := w.TipState() - - // mine a block to fund the wallet - if err := w.MineBlocks(w.Address(), 1); err != nil { - t.Fatal(err) - } - - // the outputs have not matured yet - _, balance, _, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %v", balance) - } else if m, err := w.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Balance.Equals(types.ZeroCurrency) { - t.Fatalf("expected zero balance, got %d", m.Balance) - } - - // mine until the first output has matured - if err := w.MineBlocks(types.VoidAddress, int(stypes.MaturityDelay)); err != nil { - t.Fatal(err) - } - time.Sleep(500 * time.Millisecond) // sleep for consensus sync - - // check the wallet's reported balance - expectedBalance := initialState.BlockReward() - _, balance, _, err = w.Balance() - if err != nil { - t.Fatal(err) - } else if !balance.Equals(expectedBalance) { - t.Fatalf("expected %d balance, got %d", expectedBalance, balance) - } else if m, err := w.Store().Metrics(time.Now()); err != nil { - t.Fatal(err) - } else if !m.Balance.Equals(expectedBalance) { - t.Fatalf("expected %d balance, got %d", expectedBalance, m.Balance) - } - - // check that the wallet has a single transaction - count, err := w.TransactionCount() - if err != nil { - t.Fatal(err) - } else if count != 1 { - t.Fatalf("expected 1 transaction, got %v", count) - } - - // check that the payout transaction was created - txns, err := w.Transactions(100, 0) - if err != nil { - t.Fatal(err) - } else if len(txns) != 1 { - t.Fatalf("expected 1 transaction, got %v", len(txns)) - } else if txns[0].Source != wallet.TxnSourceMinerPayout { - t.Fatalf("expected miner payout, got %v", txns[0].Source) - } else if !txns[0].Inflow.Equals(expectedBalance) { - t.Fatalf("expected %v inflow, got %v", expectedBalance, txns[0].Inflow) - } - - // create a transaction sending half of the balance to the void - sendAmount := types.Siacoins(uint32(frand.Uint64n(50000))) - _, err = w.SendSiacoins([]types.SiacoinOutput{ - {Address: types.VoidAddress, Value: sendAmount}, - }) - if err != nil { - t.Fatal(err) - } - - expectedInflow := expectedBalance.Sub(sendAmount) - expectedOutflow := expectedBalance - - unconfirmed, err := w.UnconfirmedTransactions() - if err != nil { - t.Fatal(err) - } else if len(unconfirmed) != 1 { - t.Fatalf("expected 1 unconfirmed transaction, got %v", len(unconfirmed)) - } else if !unconfirmed[0].Inflow.Equals(expectedInflow) { - t.Fatalf("expected %v inflow, got %v", expectedBalance.Div64(2), unconfirmed[0].Inflow) - } else if !unconfirmed[0].Outflow.Equals(expectedOutflow) { - t.Fatalf("expected %v outflow, got %v", expectedBalance.Div64(2), unconfirmed[0].Outflow) - } - - // this is reversed, but currency can't handle negatives - value := unconfirmed[0].Outflow.Sub(unconfirmed[0].Inflow) - if !value.Equals(sendAmount) { - t.Fatalf("expected %v value, got %v", sendAmount, value) - } -} diff --git a/webhooks/noop.go b/webhooks/noop.go new file mode 100644 index 00000000..54e7abcb --- /dev/null +++ b/webhooks/noop.go @@ -0,0 +1,23 @@ +package webhooks + +// A WebhookBroadcaster broadcasts events to webhooks. +type WebhookBroadcaster interface { + BroadcastEvent(event string, scope string, data any) error + BroadcastToWebhook(hookID int64, event string, scope string, data any) error +} + +// A NoOpBroadcaster is a WebhookBroadcaster that does nothing. +type NoOpBroadcaster struct{} + +// BroadcastEvent implements WebhookBroadcaster. +func (NoOpBroadcaster) BroadcastEvent(event string, scope string, data any) error { return nil } + +// BroadcastToWebhook implements WebhookBroadcaster. +func (NoOpBroadcaster) BroadcastToWebhook(hookID int64, event string, scope string, data any) error { + return nil +} + +// NewNop returns a new NoOpBroadcaster. +func NewNop() NoOpBroadcaster { + return NoOpBroadcaster{} +} diff --git a/webhooks/webhooks.go b/webhooks/webhooks.go index cddb6b47..0c338922 100644 --- a/webhooks/webhooks.go +++ b/webhooks/webhooks.go @@ -36,8 +36,8 @@ type ( hooks map[int64]bool } - // A WebHook is a callback that is invoked when an event occurs. - WebHook struct { + // A Webhook is a callback that is invoked when an event occurs. + Webhook struct { ID int64 `json:"id"` CallbackURL string `json:"callbackURL"` SecretKey string `json:"secretKey"` @@ -47,7 +47,7 @@ type ( // A UID is a unique identifier for an event. UID [32]byte - // An Event is a notification sent to a WebHook callback. + // An Event is a notification sent to a Webhook callback. Event struct { ID UID `json:"id"` Event string `json:"event"` @@ -55,33 +55,35 @@ type ( Data any `json:"data"` } - // A Store stores and retrieves WebHooks. + // A Store stores and retrieves Webhooks. Store interface { - RegisterWebHook(url, secret string, scopes []string) (int64, error) - UpdateWebHook(id int64, url string, scopes []string) error - RemoveWebHook(id int64) error - WebHooks() ([]WebHook, error) + RegisterWebhook(url, secret string, scopes []string) (int64, error) + UpdateWebhook(id int64, url string, scopes []string) error + RemoveWebhook(id int64) error + Webhooks() ([]Webhook, error) } - // A Manager manages WebHook subscribers and broadcasts events + // A Manager manages Webhook subscribers and broadcasts events Manager struct { store Store log *zap.Logger tg *threadgroup.ThreadGroup mu sync.Mutex - hooks map[int64]WebHook + hooks map[int64]Webhook scopes *scope } ) +var _ WebhookBroadcaster = (*Manager)(nil) + // Close closes the Manager. func (m *Manager) Close() error { m.tg.Stop() return nil } -func (m *Manager) findMatchingHooks(s string) (hooks []WebHook) { +func (m *Manager) findMatchingHooks(s string) (hooks []Webhook) { // recursively match hooks var match func(scopeParts []string, parent *scope) match = func(scopeParts []string, parent *scope) { @@ -139,8 +141,8 @@ func (m *Manager) removeHookScopes(id int64) { remove(m.scopes) } -// WebHooks returns all registered WebHooks. -func (m *Manager) WebHooks() (hooks []WebHook, _ error) { +// Webhooks returns all registered Webhooks. +func (m *Manager) Webhooks() (hooks []Webhook, _ error) { m.mu.Lock() defer m.mu.Unlock() for _, hook := range m.hooks { @@ -149,26 +151,26 @@ func (m *Manager) WebHooks() (hooks []WebHook, _ error) { return } -// RegisterWebHook registers a new WebHook. -func (m *Manager) RegisterWebHook(url string, scopes []string) (WebHook, error) { +// RegisterWebhook registers a new Webhook. +func (m *Manager) RegisterWebhook(url string, scopes []string) (Webhook, error) { done, err := m.tg.Add() if err != nil { - return WebHook{}, err + return Webhook{}, err } defer done() secret := hex.EncodeToString(frand.Bytes(16)) // register the hook in the database - id, err := m.store.RegisterWebHook(url, secret, scopes) + id, err := m.store.RegisterWebhook(url, secret, scopes) if err != nil { - return WebHook{}, fmt.Errorf("failed to register WebHook: %w", err) + return Webhook{}, fmt.Errorf("failed to register Webhook: %w", err) } m.mu.Lock() defer m.mu.Unlock() // add the hook to the in-memory map - hook := WebHook{ + hook := Webhook{ ID: id, CallbackURL: url, SecretKey: secret, @@ -180,8 +182,8 @@ func (m *Manager) RegisterWebHook(url string, scopes []string) (WebHook, error) return hook, nil } -// RemoveWebHook removes a registered WebHook. -func (m *Manager) RemoveWebHook(id int64) error { +// RemoveWebhook removes a registered Webhook. +func (m *Manager) RemoveWebhook(id int64) error { done, err := m.tg.Add() if err != nil { return err @@ -189,7 +191,7 @@ func (m *Manager) RemoveWebHook(id int64) error { defer done() // remove the hook from the database - if err := m.store.RemoveWebHook(id); err != nil { + if err := m.store.RemoveWebhook(id); err != nil { return err } @@ -201,18 +203,18 @@ func (m *Manager) RemoveWebHook(id int64) error { return nil } -// UpdateWebHook updates the URL and scopes of a registered WebHook. -func (m *Manager) UpdateWebHook(id int64, url string, scopes []string) (WebHook, error) { +// UpdateWebhook updates the URL and scopes of a registered Webhook. +func (m *Manager) UpdateWebhook(id int64, url string, scopes []string) (Webhook, error) { done, err := m.tg.Add() if err != nil { - return WebHook{}, err + return Webhook{}, err } defer done() // update the hook in the database - err = m.store.UpdateWebHook(id, url, scopes) + err = m.store.UpdateWebhook(id, url, scopes) if err != nil { - return WebHook{}, err + return Webhook{}, err } m.mu.Lock() @@ -220,7 +222,7 @@ func (m *Manager) UpdateWebHook(id int64, url string, scopes []string) (WebHook, // update the hook in the in-memory map hook, ok := m.hooks[id] if !ok { - panic("UpdateWebHook called on nonexistent WebHook") // developer error + panic("UpdateWebhook called on nonexistent Webhook") // developer error } hook.CallbackURL = url hook.Scopes = scopes @@ -232,10 +234,10 @@ func (m *Manager) UpdateWebHook(id int64, url string, scopes []string) (WebHook, return hook, nil } -func sendEventData(ctx context.Context, hook WebHook, buf []byte) error { +func sendEventData(ctx context.Context, hook Webhook, buf []byte) error { req, err := http.NewRequestWithContext(ctx, "POST", hook.CallbackURL, bytes.NewReader(buf)) if err != nil { - return fmt.Errorf("failed to create WebHook request: %w", err) + return fmt.Errorf("failed to create Webhook request: %w", err) } // set the secret key and content type @@ -255,7 +257,7 @@ func sendEventData(ctx context.Context, hook WebHook, buf []byte) error { return nil } -// BroadcastToWebhook sends an event to a specific WebHook subscriber. +// BroadcastToWebhook sends an event to a specific Webhook subscriber. func (m *Manager) BroadcastToWebhook(hookID int64, event string, scope string, data any) error { done, err := m.tg.Add() if err != nil { @@ -281,7 +283,7 @@ func (m *Manager) BroadcastToWebhook(hookID int64, event string, scope string, d hook, ok := m.hooks[hookID] if !ok { - return fmt.Errorf("webhook not found") + return fmt.Errorf("Webhook not found") } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -291,13 +293,13 @@ func (m *Manager) BroadcastToWebhook(hookID int64, event string, scope string, d start := time.Now() if err := sendEventData(ctx, hook, buf); err != nil { - return fmt.Errorf("failed to send webhook event: %w", err) + return fmt.Errorf("failed to send Webhook event: %w", err) } - log.Debug("sent webhook event", zap.Duration("elapsed", time.Since(start))) + log.Debug("sent Webhook event", zap.Duration("elapsed", time.Since(start))) return nil } -// BroadcastEvent sends an event to all registered WebHooks that match the +// BroadcastEvent sends an event to all registered Webhooks that match the // event's scope. func (m *Manager) BroadcastEvent(event string, scope string, data any) error { done, err := m.tg.Add() @@ -326,7 +328,7 @@ func (m *Manager) BroadcastEvent(event string, scope string, data any) error { hooks := m.findMatchingHooks(scope) for _, hook := range hooks { - go func(hook WebHook) { + go func(hook Webhook) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -334,29 +336,29 @@ func (m *Manager) BroadcastEvent(event string, scope string, data any) error { start := time.Now() if err := sendEventData(ctx, hook, buf); err != nil { - log.Error("failed to send webhook event", zap.Error(err)) + log.Error("failed to send Webhook event", zap.Error(err)) return } - log.Debug("sent webhook event", zap.Duration("elapsed", time.Since(start))) + log.Debug("sent Webhook event", zap.Duration("elapsed", time.Since(start))) }(hook) } return nil } -// NewManager creates a new WebHook Manager +// NewManager creates a new Webhook Manager func NewManager(store Store, log *zap.Logger) (*Manager, error) { m := &Manager{ store: store, log: log, tg: threadgroup.New(), - hooks: make(map[int64]WebHook), + hooks: make(map[int64]Webhook), scopes: &scope{children: make(map[string]*scope), hooks: make(map[int64]bool)}, } - _, err := store.WebHooks() + _, err := store.Webhooks() if err != nil { - return nil, fmt.Errorf("failed to load WebHooks: %w", err) + return nil, fmt.Errorf("failed to load Webhooks: %w", err) } return m, nil } diff --git a/webhooks/webhooks_test.go b/webhooks/webhooks_test.go index 32e37721..820c6e15 100644 --- a/webhooks/webhooks_test.go +++ b/webhooks/webhooks_test.go @@ -23,20 +23,20 @@ type jsonEvent struct { Error error `json:"-"` } -func registerWebhook(t testing.TB, wr *webhooks.Manager, scopes []string) (webhooks.WebHook, <-chan jsonEvent, error) { +func registerWebhook(t testing.TB, wr *webhooks.Manager, scopes []string) (webhooks.Webhook, <-chan jsonEvent, error) { // create a listener for the webhook l, err := net.Listen("tcp", ":0") if err != nil { - return webhooks.WebHook{}, nil, fmt.Errorf("failed to create listener: %w", err) + return webhooks.Webhook{}, nil, fmt.Errorf("failed to create listener: %w", err) } t.Cleanup(func() { l.Close() }) // add a webhook - hook, err := wr.RegisterWebHook("http://"+l.Addr().String(), scopes) + hook, err := wr.RegisterWebhook("http://"+l.Addr().String(), scopes) if err != nil { - return webhooks.WebHook{}, nil, fmt.Errorf("failed to register webhook: %w", err) + return webhooks.Webhook{}, nil, fmt.Errorf("failed to register webhook: %w", err) } // create an http server to listen for the webhook @@ -65,7 +65,7 @@ func registerWebhook(t testing.TB, wr *webhooks.Manager, scopes []string) (webho return hook, recv, nil } -func TestWebHooks(t *testing.T) { +func TestWebhooks(t *testing.T) { log := zaptest.NewLogger(t) db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "hostd.db"), log.Named("sqlite")) @@ -126,10 +126,10 @@ func TestWebHooks(t *testing.T) { } // update the webhook to have the "all scope" - hook, err = wr.UpdateWebHook(hook.ID, hook.CallbackURL, []string{"all"}) + hook, err = wr.UpdateWebhook(hook.ID, hook.CallbackURL, []string{"all"}) if err != nil { t.Fatal(err) - } else if hooks, err := wr.WebHooks(); err != nil { + } else if hooks, err := wr.Webhooks(); err != nil { t.Fatal(err) } else if len(hooks) != 1 { t.Fatal("expected 1 webhook") @@ -145,9 +145,9 @@ func TestWebHooks(t *testing.T) { } // unregister the webhook - if err := wr.RemoveWebHook(hook.ID); err != nil { + if err := wr.RemoveWebhook(hook.ID); err != nil { t.Fatal(err) - } else if hooks, err := wr.WebHooks(); err != nil { + } else if hooks, err := wr.Webhooks(); err != nil { t.Fatal(err) } else if len(hooks) != 0 { t.Fatal("expected no webhooks")