From edb8d2715c18d45930d17a4b905903fdb4b4ab36 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 14 Jan 2025 15:45:55 -0800 Subject: [PATCH 01/13] Create TEE Requisition Fetcher. --- .../requisitions/BUILD.bazel | 19 +++++ .../requisitions/RequisitionFetcher.kt | 85 +++++++++++++++++++ .../requisitions/v1alpha/BUILD.bazel | 22 +++++ .../requisitions/v1alpha/kingdom_config.proto | 41 +++++++++ 4 files changed, 167 insertions(+) create mode 100644 src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel create mode 100644 src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt create mode 100644 src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel create mode 100644 src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel new file mode 100644 index 00000000000..f9dc3ddac16 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -0,0 +1,19 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package(default_visibility = ["//visibility:public"]) + +kt_jvm_library( + name = "requisition_fetcher", + srcs = ["RequisitionFetcher.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", + "//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:kingdom_config_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/cloud/storage", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/gcs", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt new file mode 100644 index 00000000000..9bff0715f1c --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2025 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.wfanet.measurement.securecomputation.requisitions + +import com.google.cloud.storage.StorageOptions +import io.grpc.StatusException +import java.io.File +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.forEach +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import org.wfanet.measurement.api.v2alpha.ListRequisitionsResponse +//import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt +//import org.wfanet.measurement.api.v2alpha.Measurement +import org.wfanet.measurement.api.v2alpha.Requisition +import org.wfanet.measurement.common.crypto.readCertificateCollection +//import org.wfanet.measurement.common.crypto.readCertificateCollection +import org.wfanet.measurement.common.grpc.buildTlsChannel +import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub +import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest +import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources +import org.wfanet.measurement.gcloud.gcs.GcsStorageClient +//import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest +import org.wfanet.measurement.securecomputation.requisitions.v1alpha.KingdomConfig + +class RequisitionFetcher( + val config: KingdomConfig, + val blobUri: String // Output location to write requisitions to +) { + + suspend fun fetchRequisitions(): Flow { + val publicChannel = buildTlsChannel( + config.publicApiTarget, + readCertificateCollection(requireNotNull(File(config.certCollectionPath))), + config.publicApiCertHost, + ) + val requisitionsStub = RequisitionsCoroutineStub(publicChannel) + return requisitionsStub + .withAuthenticationKey(config.apiAuthenticationKey) + .listResources { pageToken -> + val response: ListRequisitionsResponse = + try { + listRequisitions(listRequisitionsRequest { this.pageToken = pageToken }) + } catch (e: StatusException) { + throw Exception("Unable to list requisitions.", e) + } + ResourceList(response.requisitionsList, response.nextPageToken) + } + .flattenConcat() + } + + suspend fun storeRequisitions(requisitions: Flow) { + val storageClient = GcsStorageClient( + StorageOptions.newBuilder().setProjectId(config.googleCloudStorageProject).build().service, + config.googleCloudStorageBucket + ) + storageClient.writeBlob(blobUri, requisitions.map { it.toByteString() }) + } + + suspend fun run() { + val requisitions = fetchRequisitions() + storeRequisitions(requisitions) + } +} + +// question: who are we fetching requisitions for? is it an mc? how does this fit into the datawatcher architecture? + + + diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel new file mode 100644 index 00000000000..9e62463de0f --- /dev/null +++ b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel @@ -0,0 +1,22 @@ +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_proto_library") + +package(default_visibility = ["//visibility:public"]) + +IMPORT_PREFIX = "/src/main/proto" + +proto_library( + name = "kingdom_config_proto", + srcs = ["kingdom_config.proto"], + strip_import_prefix = IMPORT_PREFIX, + deps = [ + "@com_google_googleapis//google/api:field_behavior_proto", + "@com_google_googleapis//google/api:resource_proto", + "@com_google_protobuf//:any_proto", + ], +) + +kt_jvm_proto_library( + name = "kingdom_config_kt_jvm_proto", + deps = [":kingdom_config_proto"], +) diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto new file mode 100644 index 00000000000..7f1dab68cde --- /dev/null +++ b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto @@ -0,0 +1,41 @@ +// Copyright 2025 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.securecomputation.requisitions.v1alpha; + +import "google/protobuf/any.proto"; +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; + +option java_package = "org.wfanet.measurement.securecomputation.requisitions.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "KingdomConfigProto"; + + + +message KingdomConfig { + string public_api_cert_host = 1; + + string public_api_target = 2; + + string cert_collection_path = 3; + + string api_authentication_key = 4; + + string google_cloud_storage_project = 5; + + string google_cloud_storage_bucket = 6; +} From c63ad08a870fa2cd012adeebd187f1dd23180c27 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Thu, 16 Jan 2025 07:40:50 -0800 Subject: [PATCH 02/13] Add data_provider field to KingdomConfig to be set as parent field in listRequisitions() call. --- .../securecomputation/requisitions/RequisitionFetcher.kt | 5 ++++- .../requisitions/v1alpha/kingdom_config.proto | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index 9bff0715f1c..11ef5554d87 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -56,7 +56,10 @@ class RequisitionFetcher( .listResources { pageToken -> val response: ListRequisitionsResponse = try { - listRequisitions(listRequisitionsRequest { this.pageToken = pageToken }) + listRequisitions(listRequisitionsRequest { + parent = config.dataProvider + this.pageToken = pageToken + }) } catch (e: StatusException) { throw Exception("Unable to list requisitions.", e) } diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto index 7f1dab68cde..28f44295c7b 100644 --- a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto +++ b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto @@ -35,7 +35,9 @@ message KingdomConfig { string api_authentication_key = 4; - string google_cloud_storage_project = 5; + string data_provider = 5; - string google_cloud_storage_bucket = 6; + string google_cloud_storage_project = 6; + + string google_cloud_storage_bucket = 7; } From d868898310038e3632dba0eda1b8370142cab227 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 28 Jan 2025 14:34:25 -0800 Subject: [PATCH 03/13] Update logic to filter out requisitions that have already been pulled in. Create storage config. --- .../requisitions/BUILD.bazel | 3 + .../requisitions/RequisitionFetcher.kt | 58 ++++++++++--------- .../requisitions/v1alpha/BUILD.bazel | 17 ++++++ .../requisitions/v1alpha/kingdom_config.proto | 4 -- .../requisitions/v1alpha/storage_config.proto | 39 +++++++++++++ 5 files changed, 90 insertions(+), 31 deletions(-) create mode 100644 src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index f9dc3ddac16..dd0ecd64c51 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -10,7 +10,10 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:kingdom_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:storage_config_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/cloud/functions:functions_framework_api", "@wfa_common_jvm//imports/java/com/google/cloud/storage", + "@wfa_common_jvm//imports/java/com/google/events", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index 11ef5554d87..e40fbcca7c0 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -19,15 +19,10 @@ import com.google.cloud.storage.StorageOptions import io.grpc.StatusException import java.io.File import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.forEach import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.toList import org.wfanet.measurement.api.v2alpha.ListRequisitionsResponse -//import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt -//import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.Requisition import org.wfanet.measurement.common.crypto.readCertificateCollection -//import org.wfanet.measurement.common.crypto.readCertificateCollection import org.wfanet.measurement.common.grpc.buildTlsChannel import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest @@ -36,28 +31,44 @@ import org.wfanet.measurement.common.api.grpc.ResourceList import org.wfanet.measurement.common.api.grpc.flattenConcat import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.gcloud.gcs.GcsStorageClient -//import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest import org.wfanet.measurement.securecomputation.requisitions.v1alpha.KingdomConfig +import org.wfanet.measurement.securecomputation.requisitions.v1alpha.StorageConfig +import com.google.cloud.functions.CloudEventsFunction +import io.cloudevents.CloudEvent +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.runBlocking + +// 1. Polls for new requisitions +// 2. Stores new requisitions into Google Cloud Storage class RequisitionFetcher( - val config: KingdomConfig, - val blobUri: String // Output location to write requisitions to -) { + private val kingdomConfig: KingdomConfig, + private val storageConfig: StorageConfig, +): CloudEventsFunction { + + override fun accept(event: CloudEvent) { + runBlocking { + storeRequisitions(fetchRequisitions()) + } + } - suspend fun fetchRequisitions(): Flow { + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private suspend fun fetchRequisitions(): Flow { val publicChannel = buildTlsChannel( - config.publicApiTarget, - readCertificateCollection(requireNotNull(File(config.certCollectionPath))), - config.publicApiCertHost, + kingdomConfig.publicApiTarget, + readCertificateCollection(File(kingdomConfig.certCollectionPath)), + kingdomConfig.publicApiCertHost, ) val requisitionsStub = RequisitionsCoroutineStub(publicChannel) + return requisitionsStub - .withAuthenticationKey(config.apiAuthenticationKey) + .withAuthenticationKey(kingdomConfig.apiAuthenticationKey) .listResources { pageToken -> val response: ListRequisitionsResponse = try { listRequisitions(listRequisitionsRequest { - parent = config.dataProvider + parent = kingdomConfig.dataProvider this.pageToken = pageToken }) } catch (e: StatusException) { @@ -66,23 +77,16 @@ class RequisitionFetcher( ResourceList(response.requisitionsList, response.nextPageToken) } .flattenConcat() + .filter { it.updateTime.seconds > storageConfig.lastUpdate.seconds } } - suspend fun storeRequisitions(requisitions: Flow) { + private suspend fun storeRequisitions(requisitions: Flow) { val storageClient = GcsStorageClient( - StorageOptions.newBuilder().setProjectId(config.googleCloudStorageProject).build().service, - config.googleCloudStorageBucket + StorageOptions.newBuilder().setProjectId(storageConfig.project).build().service, + storageConfig.bucket ) - storageClient.writeBlob(blobUri, requisitions.map { it.toByteString() }) - } - - suspend fun run() { - val requisitions = fetchRequisitions() - storeRequisitions(requisitions) + storageClient.writeBlob(storageConfig.blobUri, requisitions.map { it.toByteString() }) } } -// question: who are we fetching requisitions for? is it an mc? how does this fit into the datawatcher architecture? - - diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel index 9e62463de0f..fb5ca5d0c42 100644 --- a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel +++ b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel @@ -20,3 +20,20 @@ kt_jvm_proto_library( name = "kingdom_config_kt_jvm_proto", deps = [":kingdom_config_proto"], ) + +proto_library( + name = "storage_config_proto", + srcs = ["storage_config.proto"], + strip_import_prefix = IMPORT_PREFIX, + deps = [ + "@com_google_googleapis//google/api:field_behavior_proto", + "@com_google_googleapis//google/api:resource_proto", + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +kt_jvm_proto_library( + name = "storage_config_kt_jvm_proto", + deps = [":storage_config_proto"], +) diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto index 28f44295c7b..40c11e39431 100644 --- a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto +++ b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto @@ -36,8 +36,4 @@ message KingdomConfig { string api_authentication_key = 4; string data_provider = 5; - - string google_cloud_storage_project = 6; - - string google_cloud_storage_bucket = 7; } diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto new file mode 100644 index 00000000000..2229f79e109 --- /dev/null +++ b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto @@ -0,0 +1,39 @@ +// Copyright 2025 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.securecomputation.requisitions.v1alpha; + +import "google/protobuf/any.proto"; +import "google/protobuf/timestamp.proto"; +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; + + +option java_package = "org.wfanet.measurement.securecomputation.requisitions.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "StorageConfigProto"; + + + +message StorageConfig { + string project = 1; + + string bucket = 2; + + string blob_uri = 3; + + google.protobuf.Timestamp last_update = 4; +} From 360f8c0a6a8dcda6e4a97c650c3646023e73981d Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Wed, 29 Jan 2025 14:59:33 -0800 Subject: [PATCH 04/13] Add test file to be implemented later in this PR. Add filter to listRequisitions() method to only retrieve unfulfilled requisitions. Move filter function so that it does not succeed the call to flattenConcat(). --- .../requisitions/RequisitionFetcher.kt | 17 +++--- .../requisitions/BUILD.bazel | 27 ++++++++++ .../requisitions/RequisitionFetcherTest.kt | 53 +++++++++++++++++++ 3 files changed, 90 insertions(+), 7 deletions(-) create mode 100644 src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel create mode 100644 src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index e40fbcca7c0..a8e5e22ae21 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -33,11 +33,12 @@ import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.gcloud.gcs.GcsStorageClient import org.wfanet.measurement.securecomputation.requisitions.v1alpha.KingdomConfig import org.wfanet.measurement.securecomputation.requisitions.v1alpha.StorageConfig -import com.google.cloud.functions.CloudEventsFunction -import io.cloudevents.CloudEvent +import com.google.cloud.functions.HttpFunction +import com.google.cloud.functions.HttpRequest +import com.google.cloud.functions.HttpResponse import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.flow.filter import kotlinx.coroutines.runBlocking +import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt // 1. Polls for new requisitions @@ -45,12 +46,13 @@ import kotlinx.coroutines.runBlocking class RequisitionFetcher( private val kingdomConfig: KingdomConfig, private val storageConfig: StorageConfig, -): CloudEventsFunction { +): HttpFunction { - override fun accept(event: CloudEvent) { + override fun service(request: HttpRequest, response: HttpResponse) { runBlocking { storeRequisitions(fetchRequisitions()) } + response.writer.write("New requisitions persisted in GCS") } @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. @@ -60,6 +62,7 @@ class RequisitionFetcher( readCertificateCollection(File(kingdomConfig.certCollectionPath)), kingdomConfig.publicApiCertHost, ) + val requisitionsStub = RequisitionsCoroutineStub(publicChannel) return requisitionsStub @@ -70,14 +73,14 @@ class RequisitionFetcher( listRequisitions(listRequisitionsRequest { parent = kingdomConfig.dataProvider this.pageToken = pageToken + filter = ListRequisitionsRequestKt.filter { states += Requisition.State.UNFULFILLED } }) } catch (e: StatusException) { throw Exception("Unable to list requisitions.", e) } - ResourceList(response.requisitionsList, response.nextPageToken) + ResourceList(response.requisitionsList.filter { it.updateTime.seconds > storageConfig.lastUpdate.seconds }, response.nextPageToken) } .flattenConcat() - .filter { it.updateTime.seconds > storageConfig.lastUpdate.seconds } } private suspend fun storeRequisitions(requisitions: Flow) { diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel new file mode 100644 index 00000000000..4ebd3d6fa57 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -0,0 +1,27 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library", "kt_jvm_test") + +package( + default_testonly = True, + default_visibility = ["//visibility:public"], +) + +kt_jvm_test( + name = "requisition_fetcher_test", + srcs = ["RequisitionFetcherTest.kt"], + test_class = "org.wfanet.measurement.securecomputation.requisitions.RequisitionFetcherTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions:requisition_fetcher", + "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:kingdom_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:storage_config_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/cloud/functions:functions_framework_api", + "@wfa_common_jvm//imports/java/com/google/cloud/storage", + "@wfa_common_jvm//imports/java/com/google/events", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/java/org/mockito", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/gcs", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt new file mode 100644 index 00000000000..f353eb14d5f --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt @@ -0,0 +1,53 @@ +package org.wfanet.measurement.securecomputation.requisitions + +import com.google.cloud.functions.HttpRequest +import com.google.cloud.functions.HttpResponse +import com.google.cloud.storage.Blob +import com.google.cloud.storage.Bucket +import com.google.cloud.storage.Storage +import com.google.protobuf.Timestamp +import java.io.BufferedWriter +import java.io.PrintWriter +import java.io.StringWriter +import java.net.HttpURLConnection +import java.net.URL +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.mockito.Mockito.* +import org.mockito.kotlin.whenever +import org.wfanet.measurement.gcloud.gcs.GcsFromFlags +import org.wfanet.measurement.gcloud.gcs.GcsStorageClient +import org.wfanet.measurement.securecomputation.requisitions.v1alpha.kingdomConfig +import org.wfanet.measurement.securecomputation.requisitions.v1alpha.storageConfig +import org.wfanet.measurement.storage.testing.InMemoryStorageClient + +@RunWith(JUnit4::class) +class RequisitionFetcherTest { + + @Test + fun `fetch new requisitions`() { + val request: HttpRequest = mock() + val response: HttpResponse = mock() + + //@TODO(jojijacob): Implement test + + val kingdomConfig = kingdomConfig { + publicApiTarget = "target" + publicApiCertHost = "cert-host" + certCollectionPath = "collection-path" + apiAuthenticationKey = "authentication-key" + dataProvider = "dataprovider" + } + + val storageConfig = storageConfig { + project = "project-id" + bucket = "bucket" + blobUri = "https://storage.googleapis.com/bucket/blob" + lastUpdate = Timestamp.getDefaultInstance() + } + val fetcher = RequisitionFetcher(kingdomConfig, storageConfig) + fetcher.service(request, response) + } + +} From fa5181bb8d59f5a49619c6a8cb0bb4867c64c122 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Thu, 30 Jan 2025 09:13:38 -0800 Subject: [PATCH 05/13] Update logic in RequisitionFetcher to only store new requisitions in GCS bucket by checking what is already stored --- .../requisitions/RequisitionFetcher.kt | 17 ++++++++++++----- .../requisitions/v1alpha/storage_config.proto | 4 ---- .../requisitions/RequisitionFetcherTest.kt | 16 ---------------- 3 files changed, 12 insertions(+), 25 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index a8e5e22ae21..b533aec3892 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -19,7 +19,6 @@ import com.google.cloud.storage.StorageOptions import io.grpc.StatusException import java.io.File import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map import org.wfanet.measurement.api.v2alpha.ListRequisitionsResponse import org.wfanet.measurement.api.v2alpha.Requisition import org.wfanet.measurement.common.crypto.readCertificateCollection @@ -40,7 +39,6 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.runBlocking import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt - // 1. Polls for new requisitions // 2. Stores new requisitions into Google Cloud Storage class RequisitionFetcher( @@ -52,7 +50,7 @@ class RequisitionFetcher( runBlocking { storeRequisitions(fetchRequisitions()) } - response.writer.write("New requisitions persisted in GCS") + response.writer.write("New requisitions stored in GCS bucket: ${storageConfig.bucket}") } @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. @@ -70,6 +68,7 @@ class RequisitionFetcher( .listResources { pageToken -> val response: ListRequisitionsResponse = try { + // Retrieve all UNFULFILLED requisitions for a given EDP listRequisitions(listRequisitionsRequest { parent = kingdomConfig.dataProvider this.pageToken = pageToken @@ -78,7 +77,7 @@ class RequisitionFetcher( } catch (e: StatusException) { throw Exception("Unable to list requisitions.", e) } - ResourceList(response.requisitionsList.filter { it.updateTime.seconds > storageConfig.lastUpdate.seconds }, response.nextPageToken) + ResourceList(response.requisitionsList, response.nextPageToken) } .flattenConcat() } @@ -88,7 +87,15 @@ class RequisitionFetcher( StorageOptions.newBuilder().setProjectId(storageConfig.project).build().service, storageConfig.bucket ) - storageClient.writeBlob(storageConfig.blobUri, requisitions.map { it.toByteString() }) + + // Only stores the requisition if it does not already exist in the GCS bucket by checking if the blob URI(created + // using the requisition name, ensuring uniqueness) is populated. + requisitions.collect { requisition -> + val blobUri = "gs://${storageConfig.bucket}/${requisition.name}" + if(storageClient.getBlob(blobUri) != null) { + storageClient.writeBlob(blobUri, requisition.toByteString()) + } + } } } diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto index 2229f79e109..96d3179ff1c 100644 --- a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto +++ b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto @@ -32,8 +32,4 @@ message StorageConfig { string project = 1; string bucket = 2; - - string blob_uri = 3; - - google.protobuf.Timestamp last_update = 4; } diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt index f353eb14d5f..3991e4f0f47 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt @@ -2,25 +2,12 @@ package org.wfanet.measurement.securecomputation.requisitions import com.google.cloud.functions.HttpRequest import com.google.cloud.functions.HttpResponse -import com.google.cloud.storage.Blob -import com.google.cloud.storage.Bucket -import com.google.cloud.storage.Storage -import com.google.protobuf.Timestamp -import java.io.BufferedWriter -import java.io.PrintWriter -import java.io.StringWriter -import java.net.HttpURLConnection -import java.net.URL import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.mockito.Mockito.* -import org.mockito.kotlin.whenever -import org.wfanet.measurement.gcloud.gcs.GcsFromFlags -import org.wfanet.measurement.gcloud.gcs.GcsStorageClient import org.wfanet.measurement.securecomputation.requisitions.v1alpha.kingdomConfig import org.wfanet.measurement.securecomputation.requisitions.v1alpha.storageConfig -import org.wfanet.measurement.storage.testing.InMemoryStorageClient @RunWith(JUnit4::class) class RequisitionFetcherTest { @@ -43,11 +30,8 @@ class RequisitionFetcherTest { val storageConfig = storageConfig { project = "project-id" bucket = "bucket" - blobUri = "https://storage.googleapis.com/bucket/blob" - lastUpdate = Timestamp.getDefaultInstance() } val fetcher = RequisitionFetcher(kingdomConfig, storageConfig) fetcher.service(request, response) } - } From 7da86b36701a53451c1ba5394ef1a8fd0ad304e5 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Thu, 30 Jan 2025 13:54:48 -0800 Subject: [PATCH 06/13] Change architecture to mirror that of RequisitionFulfiller --- .../requisitions/BUILD.bazel | 17 ++- .../requisitions/RequisitionFetcher.kt | 106 ++++++++---------- .../requisitions/RequisitionFetcherRunner.kt | 100 +++++++++++++++++ .../requisitions/v1alpha/BUILD.bazel | 39 ------- .../requisitions/v1alpha/kingdom_config.proto | 39 ------- .../requisitions/v1alpha/storage_config.proto | 35 ------ .../requisitions/BUILD.bazel | 12 -- .../requisitions/RequisitionFetcherTest.kt | 24 ---- 8 files changed, 154 insertions(+), 218 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt delete mode 100644 src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel delete mode 100644 src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto delete mode 100644 src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index dd0ecd64c51..fdf7d116918 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -6,17 +6,16 @@ kt_jvm_library( name = "requisition_fetcher", srcs = ["RequisitionFetcher.kt"], deps = [ - "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:kingdom_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:storage_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/cloud/functions:functions_framework_api", - "@wfa_common_jvm//imports/java/com/google/cloud/storage", - "@wfa_common_jvm//imports/java/com/google/events", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/gcs", ], ) + +kt_jvm_library( + name = "requisition_fetcher_runner", + srcs = ["RequisitionFetcherRunner.kt"], + deps = [ + ":requisition_fetcher", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index b533aec3892..3f846423537 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -15,88 +15,74 @@ */ package org.wfanet.measurement.securecomputation.requisitions -import com.google.cloud.storage.StorageOptions import io.grpc.StatusException -import java.io.File -import kotlinx.coroutines.flow.Flow -import org.wfanet.measurement.api.v2alpha.ListRequisitionsResponse import org.wfanet.measurement.api.v2alpha.Requisition -import org.wfanet.measurement.common.crypto.readCertificateCollection -import org.wfanet.measurement.common.grpc.buildTlsChannel import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.api.grpc.ResourceList -import org.wfanet.measurement.common.api.grpc.flattenConcat -import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.gcloud.gcs.GcsStorageClient -import org.wfanet.measurement.securecomputation.requisitions.v1alpha.KingdomConfig -import org.wfanet.measurement.securecomputation.requisitions.v1alpha.StorageConfig -import com.google.cloud.functions.HttpFunction -import com.google.cloud.functions.HttpRequest -import com.google.cloud.functions.HttpResponse -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.runBlocking import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt +import org.wfanet.measurement.api.v2alpha.Measurement +import org.wfanet.measurement.common.throttler.Throttler +import java.util.logging.Logger // 1. Polls for new requisitions // 2. Stores new requisitions into Google Cloud Storage class RequisitionFetcher( - private val kingdomConfig: KingdomConfig, - private val storageConfig: StorageConfig, -): HttpFunction { + private val requisitionsStub: RequisitionsCoroutineStub, + private val gcsStorageClient: GcsStorageClient, + private val gcsBucket: String, + private val dataProviderName: String, + private val throttler: Throttler, + ) { - override fun service(request: HttpRequest, response: HttpResponse) { - runBlocking { - storeRequisitions(fetchRequisitions()) - } - response.writer.write("New requisitions stored in GCS bucket: ${storageConfig.bucket}") + suspend fun run() { + throttler.loopOnReady { executeRequisitionFetchingWorkflow() } } - @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. - private suspend fun fetchRequisitions(): Flow { - val publicChannel = buildTlsChannel( - kingdomConfig.publicApiTarget, - readCertificateCollection(File(kingdomConfig.certCollectionPath)), - kingdomConfig.publicApiCertHost, - ) + suspend fun executeRequisitionFetchingWorkflow() { + logger.info("Executing requisitionFetchingWorkflow for $dataProviderName...") + + val requisitions = fetchRequisitions() - val requisitionsStub = RequisitionsCoroutineStub(publicChannel) + if (requisitions.isEmpty()) { + logger.fine("No unfulfilled requisitions for $dataProviderName. Polling again later...") + return + } + + storeRequisitions(requisitions) + } - return requisitionsStub - .withAuthenticationKey(kingdomConfig.apiAuthenticationKey) - .listResources { pageToken -> - val response: ListRequisitionsResponse = - try { - // Retrieve all UNFULFILLED requisitions for a given EDP - listRequisitions(listRequisitionsRequest { - parent = kingdomConfig.dataProvider - this.pageToken = pageToken - filter = ListRequisitionsRequestKt.filter { states += Requisition.State.UNFULFILLED } - }) - } catch (e: StatusException) { - throw Exception("Unable to list requisitions.", e) - } - ResourceList(response.requisitionsList, response.nextPageToken) + private suspend fun fetchRequisitions(): List { + val request = listRequisitionsRequest { + parent = dataProviderName + filter = ListRequisitionsRequestKt.filter { + states += Requisition.State.UNFULFILLED + measurementStates += Measurement.State.AWAITING_REQUISITION_FULFILLMENT } - .flattenConcat() + } + + try { + return requisitionsStub.listRequisitions(request).requisitionsList + } catch (e: StatusException) { + throw Exception("Error listing requisitions", e) + } } - private suspend fun storeRequisitions(requisitions: Flow) { - val storageClient = GcsStorageClient( - StorageOptions.newBuilder().setProjectId(storageConfig.project).build().service, - storageConfig.bucket - ) + private suspend fun storeRequisitions(requisitions: List) { + for (requisition in requisitions) { + val blobUri = "gs://${gcsBucket}/${requisition.name}" - // Only stores the requisition if it does not already exist in the GCS bucket by checking if the blob URI(created - // using the requisition name, ensuring uniqueness) is populated. - requisitions.collect { requisition -> - val blobUri = "gs://${storageConfig.bucket}/${requisition.name}" - if(storageClient.getBlob(blobUri) != null) { - storageClient.writeBlob(blobUri, requisition.toByteString()) + // Only stores the requisition if it does not already exist in the GCS bucket by checking if the blob URI(created + // using the requisition name, ensuring uniqueness) is populated. + if(gcsStorageClient.getBlob(blobUri) != null) { + gcsStorageClient.writeBlob(blobUri, requisition.toByteString()) } } } + + companion object { + val logger: Logger = Logger.getLogger(this::class.java.name) + } } diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt new file mode 100644 index 00000000000..ae7292a5dc6 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt @@ -0,0 +1,100 @@ +package org.wfanet.measurement.securecomputation.requisitions + +import com.google.cloud.storage.StorageOptions +import java.time.Clock +import java.time.Duration +import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub +import org.wfanet.measurement.common.crypto.SigningCerts +import org.wfanet.measurement.common.grpc.TlsFlags +import org.wfanet.measurement.common.grpc.buildMutualTlsChannel +import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.gcloud.gcs.GcsStorageClient +import picocli.CommandLine + +class RequisitionFetcherRunner: Runnable { + @CommandLine.Option( + names = ["--kingdom-public-api-target"], + description = ["gRPC target (authority) of the Kingdom public API server"], + required = true, + ) + private lateinit var target: String + + @CommandLine.Option( + names = ["--kingdom-public-api-cert-host"], + description = + [ + "Expected hostname (DNS-ID) in the Kingdom public API server's TLS certificate.", + "This overrides derivation of the TLS DNS-ID from --kingdom-public-api-target.", + ], + required = true, + ) + private lateinit var certHost: String + + @CommandLine.Option( + names = ["--gcs-project-id"], + description = + [ + "Project ID for the GCS instance where the new requisitions will be stored." + ], + required = true, + ) + private lateinit var gcsProjectId: String + + @CommandLine.Option( + names = ["--gcs-bucket"], + description = + [ + "Name of the bucket within the GCS instance where the new requisitions will be stored" + ], + required = true, + ) + private lateinit var gcsBucket: String + + @CommandLine.Option( + names = ["--data-provider-resource-name"], + description = + [ + "The public API resource name of the data provider for which requisitions will be fetched." + ], + required = true, + ) + private lateinit var dataProviderResourceName: String + + @CommandLine.Option( + names = ["--throttler-minimum-interval"], + description = ["Minimum throttle interval"], + defaultValue = "300s", // 5 minutes + ) + private lateinit var throttlerMinimumInterval: Duration + + @CommandLine.Mixin + lateinit var tlsFlags: TlsFlags + private set + + + + override fun run() { + val clientCerts = + SigningCerts.fromPemFiles( + certificateFile = tlsFlags.certFile, + privateKeyFile = tlsFlags.privateKeyFile, + trustedCertCollectionFile = tlsFlags.certCollectionFile, + ) + + val publicChannel = buildMutualTlsChannel( + target, + clientCerts, + certHost, + ) + + val requisitionsStub = RequisitionsCoroutineStub(publicChannel) + val gcsStorageClient = GcsStorageClient( + StorageOptions.newBuilder().setProjectId(gcsProjectId).build().service, + gcsBucket + ) + + val throttler = MinimumIntervalThrottler(Clock.systemUTC(), throttlerMinimumInterval) + + RequisitionFetcher(requisitionsStub, gcsStorageClient, dataProviderResourceName, gcsBucket, throttler) + } +} diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel deleted file mode 100644 index fb5ca5d0c42..00000000000 --- a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/BUILD.bazel +++ /dev/null @@ -1,39 +0,0 @@ -load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_proto_library") - -package(default_visibility = ["//visibility:public"]) - -IMPORT_PREFIX = "/src/main/proto" - -proto_library( - name = "kingdom_config_proto", - srcs = ["kingdom_config.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - "@com_google_protobuf//:any_proto", - ], -) - -kt_jvm_proto_library( - name = "kingdom_config_kt_jvm_proto", - deps = [":kingdom_config_proto"], -) - -proto_library( - name = "storage_config_proto", - srcs = ["storage_config.proto"], - strip_import_prefix = IMPORT_PREFIX, - deps = [ - "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_googleapis//google/api:resource_proto", - "@com_google_protobuf//:any_proto", - "@com_google_protobuf//:timestamp_proto", - ], -) - -kt_jvm_proto_library( - name = "storage_config_kt_jvm_proto", - deps = [":storage_config_proto"], -) diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto deleted file mode 100644 index 40c11e39431..00000000000 --- a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/kingdom_config.proto +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2025 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.securecomputation.requisitions.v1alpha; - -import "google/protobuf/any.proto"; -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; - -option java_package = "org.wfanet.measurement.securecomputation.requisitions.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "KingdomConfigProto"; - - - -message KingdomConfig { - string public_api_cert_host = 1; - - string public_api_target = 2; - - string cert_collection_path = 3; - - string api_authentication_key = 4; - - string data_provider = 5; -} diff --git a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto b/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto deleted file mode 100644 index 96d3179ff1c..00000000000 --- a/src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha/storage_config.proto +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2025 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.securecomputation.requisitions.v1alpha; - -import "google/protobuf/any.proto"; -import "google/protobuf/timestamp.proto"; -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; - - -option java_package = "org.wfanet.measurement.securecomputation.requisitions.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "StorageConfigProto"; - - - -message StorageConfig { - string project = 1; - - string bucket = 2; -} diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index 4ebd3d6fa57..b6fc25280d6 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -10,18 +10,6 @@ kt_jvm_test( srcs = ["RequisitionFetcherTest.kt"], test_class = "org.wfanet.measurement.securecomputation.requisitions.RequisitionFetcherTest", deps = [ - "//src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions:requisition_fetcher", - "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:kingdom_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/securecomputation/requisitions/v1alpha:storage_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/cloud/functions:functions_framework_api", - "@wfa_common_jvm//imports/java/com/google/cloud/storage", - "@wfa_common_jvm//imports/java/com/google/events", - "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/org/junit", - "@wfa_common_jvm//imports/java/org/mockito", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/gcs", ], ) diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt index 3991e4f0f47..24b80d0efb2 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt @@ -1,37 +1,13 @@ package org.wfanet.measurement.securecomputation.requisitions -import com.google.cloud.functions.HttpRequest -import com.google.cloud.functions.HttpResponse import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -import org.mockito.Mockito.* -import org.wfanet.measurement.securecomputation.requisitions.v1alpha.kingdomConfig -import org.wfanet.measurement.securecomputation.requisitions.v1alpha.storageConfig @RunWith(JUnit4::class) class RequisitionFetcherTest { - @Test fun `fetch new requisitions`() { - val request: HttpRequest = mock() - val response: HttpResponse = mock() - //@TODO(jojijacob): Implement test - - val kingdomConfig = kingdomConfig { - publicApiTarget = "target" - publicApiCertHost = "cert-host" - certCollectionPath = "collection-path" - apiAuthenticationKey = "authentication-key" - dataProvider = "dataprovider" - } - - val storageConfig = storageConfig { - project = "project-id" - bucket = "bucket" - } - val fetcher = RequisitionFetcher(kingdomConfig, storageConfig) - fetcher.service(request, response) } } From 696255334e1c5c03969fdebd4fac6864d12305e8 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 4 Feb 2025 08:34:45 -0800 Subject: [PATCH 07/13] Restructure code so runner implements HttpFunction interface and calls RequisitionFetcher. Adds test. --- src/main/k8s/testing/secretfiles/BUILD.bazel | 1 + .../requisitions/BUILD.bazel | 13 ++ .../requisitions/RequisitionFetcher.kt | 9 +- .../requisitions/RequisitionFetcherRunner.kt | 122 ++++------- .../requisitions/BUILD.bazel | 25 +++ .../requisitions/RequisitionFetcherTest.kt | 192 +++++++++++++++++- 6 files changed, 275 insertions(+), 87 deletions(-) diff --git a/src/main/k8s/testing/secretfiles/BUILD.bazel b/src/main/k8s/testing/secretfiles/BUILD.bazel index 8a8341dd2fb..ad90d4b05be 100644 --- a/src/main/k8s/testing/secretfiles/BUILD.bazel +++ b/src/main/k8s/testing/secretfiles/BUILD.bazel @@ -14,6 +14,7 @@ package( "//src/test/kotlin/org/wfanet/measurement/kingdom/batch:__subpackages__", "//src/test/kotlin/org/wfanet/measurement/loadtest:__subpackages__", "//src/test/kotlin/org/wfanet/measurement/reporting:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions:__subpackages__", "//src/test/kotlin/org/wfanet/panelmatch/integration:__subpackages__", ], ) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index fdf7d116918..664a389675a 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -7,11 +7,18 @@ kt_jvm_library( srcs = ["RequisitionFetcher.kt"], deps = [ "//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/java/com/google/cloud/functions:functions_framework_api", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/gcs", ], ) +java_binary( + name = "RequisitionFetcher", + main_class = "org.wfanet.measurement.securecomputation.requisitions.RequisitionFetcher", + runtime_deps = [":requisition_fetcher"], +) + kt_jvm_library( name = "requisition_fetcher_runner", srcs = ["RequisitionFetcherRunner.kt"], @@ -19,3 +26,9 @@ kt_jvm_library( ":requisition_fetcher", ], ) + +java_binary( + name = "RequisitionFetcherRunner", + main_class = "org.wfanet.measurement.securecomputation.requisitions.RequisitionFetcherRunner", + runtime_deps = [":requisition_fetcher_runner"], +) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index 3f846423537..72ce0dc88a5 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -22,7 +22,6 @@ import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest import org.wfanet.measurement.gcloud.gcs.GcsStorageClient import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt import org.wfanet.measurement.api.v2alpha.Measurement -import org.wfanet.measurement.common.throttler.Throttler import java.util.logging.Logger // 1. Polls for new requisitions @@ -32,13 +31,7 @@ class RequisitionFetcher( private val gcsStorageClient: GcsStorageClient, private val gcsBucket: String, private val dataProviderName: String, - private val throttler: Throttler, ) { - - suspend fun run() { - throttler.loopOnReady { executeRequisitionFetchingWorkflow() } - } - suspend fun executeRequisitionFetchingWorkflow() { logger.info("Executing requisitionFetchingWorkflow for $dataProviderName...") @@ -74,7 +67,7 @@ class RequisitionFetcher( // Only stores the requisition if it does not already exist in the GCS bucket by checking if the blob URI(created // using the requisition name, ensuring uniqueness) is populated. - if(gcsStorageClient.getBlob(blobUri) != null) { + if(gcsStorageClient.getBlob(blobUri) == null) { gcsStorageClient.writeBlob(blobUri, requisition.toByteString()) } } diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt index ae7292a5dc6..d957790480c 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt @@ -1,100 +1,66 @@ package org.wfanet.measurement.securecomputation.requisitions +import com.google.cloud.functions.HttpFunction +import com.google.cloud.functions.HttpRequest +import com.google.cloud.functions.HttpResponse import com.google.cloud.storage.StorageOptions -import java.time.Clock -import java.time.Duration +import java.io.File +import kotlinx.coroutines.runBlocking import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.common.crypto.SigningCerts -import org.wfanet.measurement.common.grpc.TlsFlags import org.wfanet.measurement.common.grpc.buildMutualTlsChannel -import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.common.toByteArray import org.wfanet.measurement.gcloud.gcs.GcsStorageClient -import picocli.CommandLine -class RequisitionFetcherRunner: Runnable { - @CommandLine.Option( - names = ["--kingdom-public-api-target"], - description = ["gRPC target (authority) of the Kingdom public API server"], - required = true, - ) - private lateinit var target: String +class RequisitionFetcherRunner: HttpFunction { + override fun service(request: HttpRequest?, response: HttpResponse?) { + val clientCerts = runBlocking { + getClientCerts() + } - @CommandLine.Option( - names = ["--kingdom-public-api-cert-host"], - description = - [ - "Expected hostname (DNS-ID) in the Kingdom public API server's TLS certificate.", - "This overrides derivation of the TLS DNS-ID from --kingdom-public-api-target.", - ], - required = true, - ) - private lateinit var certHost: String + val publicChannel = buildMutualTlsChannel( + System.getenv("TARGET"), + clientCerts, + System.getenv("CERT_HOST"), + ) - @CommandLine.Option( - names = ["--gcs-project-id"], - description = - [ - "Project ID for the GCS instance where the new requisitions will be stored." - ], - required = true, - ) - private lateinit var gcsProjectId: String + val requisitionsStub = RequisitionsCoroutineStub(publicChannel) + val requisitionsStorageClient = GcsStorageClient( + StorageOptions.newBuilder().setProjectId(System.getenv("REQUISITIONS_GCS_PROJECT_ID")).build().service, + System.getenv("REQUISITIONS_GCS_BUCKET") + ) - @CommandLine.Option( - names = ["--gcs-bucket"], - description = - [ - "Name of the bucket within the GCS instance where the new requisitions will be stored" - ], - required = true, - ) - private lateinit var gcsBucket: String + val fetcher = RequisitionFetcher(requisitionsStub, requisitionsStorageClient, System.getenv("DATAPROVIDER_NAME"), System.getenv("GCS_BUCKET")) + runBlocking { + fetcher.executeRequisitionFetchingWorkflow() + } + } - @CommandLine.Option( - names = ["--data-provider-resource-name"], - description = - [ - "The public API resource name of the data provider for which requisitions will be fetched." - ], - required = true, - ) - private lateinit var dataProviderResourceName: String + private suspend fun getClientCerts(): SigningCerts { + val authenticationStorageClient = GcsStorageClient( + StorageOptions.newBuilder().setProjectId(System.getenv("AUTHENTICATION_GCS_PROJECT_ID")).build().service, + System.getenv("AUTHENTICATION_GCS_BUCKET") + ) - @CommandLine.Option( - names = ["--throttler-minimum-interval"], - description = ["Minimum throttle interval"], - defaultValue = "300s", // 5 minutes - ) - private lateinit var throttlerMinimumInterval: Duration + val certBlob = authenticationStorageClient.getBlob("gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_FILE_PATH")}") + val privateKeyBlob = authenticationStorageClient.getBlob("gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("PRIVATE_KEY_FILE_PATH")}") + val certCollectionBlob = authenticationStorageClient.getBlob("gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_COLLECTION_FILE_PATH")}") - @CommandLine.Mixin - lateinit var tlsFlags: TlsFlags - private set + val certFile = File.createTempFile("cert", ".pem") + certFile.writeBytes(checkNotNull(certBlob).read().toByteArray()) + val privateKeyFile = File.createTempFile("private_key", ".key") + privateKeyFile.writeBytes(checkNotNull(privateKeyBlob).read().toByteArray()) - override fun run() { - val clientCerts = - SigningCerts.fromPemFiles( - certificateFile = tlsFlags.certFile, - privateKeyFile = tlsFlags.privateKeyFile, - trustedCertCollectionFile = tlsFlags.certCollectionFile, - ) + val certCollectionFile = File.createTempFile("cert_collection", ".pem") + certCollectionFile.writeBytes(checkNotNull(certCollectionBlob).read().toByteArray()) - val publicChannel = buildMutualTlsChannel( - target, - clientCerts, - certHost, - ) - val requisitionsStub = RequisitionsCoroutineStub(publicChannel) - val gcsStorageClient = GcsStorageClient( - StorageOptions.newBuilder().setProjectId(gcsProjectId).build().service, - gcsBucket + return SigningCerts.fromPemFiles( + certificateFile = certFile, + privateKeyFile = privateKeyFile, + trustedCertCollectionFile = certCollectionFile, ) - - val throttler = MinimumIntervalThrottler(Clock.systemUTC(), throttlerMinimumInterval) - - RequisitionFetcher(requisitionsStub, gcsStorageClient, dataProviderResourceName, gcsBucket, throttler) } } diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index b6fc25280d6..7df5ee44762 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -8,8 +8,33 @@ package( kt_jvm_test( name = "requisition_fetcher_test", srcs = ["RequisitionFetcherTest.kt"], + data = [ + "//src/main/k8s/testing/secretfiles:all_configs", + "//src/main/k8s/testing/secretfiles:all_der_files", + "//src/main/k8s/testing/secretfiles:all_tink_keysets", + "//src/main/k8s/testing/secretfiles:edp_trusted_certs.pem", + ], test_class = "org.wfanet.measurement.securecomputation.requisitions.RequisitionFetcherTest", deps = [ + "//src/main/kotlin/org/wfanet/measurement/api/v2alpha/testing", + "//src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions:requisition_fetcher", + "//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/cloud/storage/contrib/nio", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", ], ) + +java_test( + name = "RequisitionFetcherTest", + test_class = "org.wfanet.measurement.securecomputation.requisitions.RequisitionFetcherTest", + runtime_deps = [":requisition_fetcher_test"], +) diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt index 24b80d0efb2..e53dc0a87b1 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt @@ -1,13 +1,203 @@ package org.wfanet.measurement.securecomputation.requisitions +import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper +import com.google.protobuf.ByteString +import com.google.protobuf.kotlin.toByteString +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.random.Random +import kotlinx.coroutines.runBlocking +import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import org.mockito.kotlin.any +import org.wfanet.measurement.api.v2alpha.Certificate +import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey +import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey +import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt +import org.wfanet.measurement.api.v2alpha.ProtocolConfig +import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt +import org.wfanet.measurement.api.v2alpha.Requisition +import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt +import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt +import org.wfanet.measurement.api.v2alpha.certificate +import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams +import org.wfanet.measurement.api.v2alpha.event_templates.testing.Person +import org.wfanet.measurement.api.v2alpha.listRequisitionsResponse +import org.wfanet.measurement.api.v2alpha.measurementSpec +import org.wfanet.measurement.api.v2alpha.protocolConfig +import org.wfanet.measurement.api.v2alpha.requisition +import org.wfanet.measurement.api.v2alpha.requisitionSpec +import org.wfanet.measurement.common.crypto.Hashing +import org.wfanet.measurement.common.crypto.SigningKeyHandle +import org.wfanet.measurement.common.crypto.subjectKeyIdentifier +import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle +import org.wfanet.measurement.common.crypto.tink.loadPrivateKey +import org.wfanet.measurement.common.crypto.tink.loadPublicKey +import org.wfanet.measurement.common.getRuntimePath +import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule +import org.wfanet.measurement.common.grpc.testing.mockService +import org.wfanet.measurement.common.identity.externalIdToApiId +import org.wfanet.measurement.common.pack +import org.wfanet.measurement.common.readByteString +import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey +import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisitionSpec +import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec +import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec +import org.wfanet.measurement.gcloud.gcs.GcsStorageClient +import com.google.common.truth.Truth.assertThat +import org.wfanet.measurement.common.toByteArray @RunWith(JUnit4::class) class RequisitionFetcherTest { + + private val requisitionsServiceMock: RequisitionsGrpcKt.RequisitionsCoroutineImplBase = mockService { + onBlocking { listRequisitions(any()) } + .thenReturn(listRequisitionsResponse { requisitions += REQUISITION }) + } + + @get:Rule + val grpcTestServerRule = GrpcTestServerRule { + addService(requisitionsServiceMock) + } + private val requisitionsStub: RequisitionsGrpcKt.RequisitionsCoroutineStub by lazy { + RequisitionsGrpcKt.RequisitionsCoroutineStub(grpcTestServerRule.channel) + } @Test fun `fetch new requisitions`() { - //@TODO(jojijacob): Implement test + val storage = LocalStorageHelper.getOptions().service + val storageClient = GcsStorageClient(storage, BUCKET) + val fetcher = RequisitionFetcher(requisitionsStub, storageClient, BUCKET, DATA_PROVIDER_NAME) + var persistedRequisition: ByteString? + runBlocking { + fetcher.executeRequisitionFetchingWorkflow() + persistedRequisition = storageClient.getBlob("gs://${BUCKET}/${REQUISITION.name}")?.read()?.toByteArray()?.toByteString() + } + println(persistedRequisition) + assertThat(REQUISITION.toByteString()).isEqualTo(persistedRequisition) + } + + companion object { + private const val BUCKET = "requisition-storage-test-bucket" + private const val MC_ID = "mc" + private const val MC_NAME = "measurementConsumers/$MC_ID" + private const val PDP_DISPLAY_NAME = "pdp1" + private val SECRET_FILES_PATH: Path = + checkNotNull( + getRuntimePath( + Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") + ) + ) + private const val PDP_ID = "somePopulationDataProvider" + private const val PDP_NAME = "dataProviders/$PDP_ID" + + private val MEASUREMENT_CONSUMER_CERTIFICATE_DER = + SECRET_FILES_PATH.resolve("edp_trusted_certs.pem").toFile().readByteString() + private const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/AAAAAAAAAHs" + private const val MEASUREMENT_NAME = "$MC_NAME/measurements/BBBBBBBBBHs" + private const val MEASUREMENT_CONSUMER_CERTIFICATE_NAME = + "$MEASUREMENT_CONSUMER_NAME/certificates/AAAAAAAAAcg" + private val MEASUREMENT_CONSUMER_CERTIFICATE = certificate { + name = MEASUREMENT_CONSUMER_CERTIFICATE_NAME + x509Der = MEASUREMENT_CONSUMER_CERTIFICATE_DER + } + + private const val DATA_PROVIDER_NAME = "dataProviders/AAAAAAAAAHs" + + private const val MODEL_PROVIDER_NAME = "modelProviders/AAAAAAAAAHs" + private const val MODEL_SUITE_NAME = "$MODEL_PROVIDER_NAME/modelSuites/AAAAAAAAAHs" + + private const val MODEL_LINE_NAME = "${MODEL_SUITE_NAME}/modelLines/AAAAAAAAAHs" + + private val MC_SIGNING_KEY: SigningKeyHandle = + loadSigningKey("${MC_ID}_cs_cert.der", "${MC_ID}_cs_private.der") + private val PDP_SIGNING_KEY: SigningKeyHandle = + loadSigningKey("${PDP_DISPLAY_NAME}_cs_cert.der", "${PDP_DISPLAY_NAME}_cs_private.der") + private val DATA_PROVIDER_CERTIFICATE_KEY: DataProviderCertificateKey = + DataProviderCertificateKey(PDP_ID, externalIdToApiId(8L)) + + private val DATA_PROVIDER_CERTIFICATE: Certificate = certificate { + name = DATA_PROVIDER_CERTIFICATE_KEY.toName() + x509Der = PDP_SIGNING_KEY.certificate.encoded.toByteString() + subjectKeyIdentifier = PDP_SIGNING_KEY.certificate.subjectKeyIdentifier!! + } + + private val MC_PUBLIC_KEY: EncryptionPublicKey = + loadPublicKey(SECRET_FILES_PATH.resolve("mc_enc_public.tink").toFile()) + .toEncryptionPublicKey() + private val MC_PRIVATE_KEY: TinkPrivateKeyHandle = + loadPrivateKey(SECRET_FILES_PATH.resolve("mc_enc_private.tink").toFile()) + private val DATA_PROVIDER_PUBLIC_KEY: EncryptionPublicKey = + loadPublicKey(SECRET_FILES_PATH.resolve("${PDP_DISPLAY_NAME}_enc_public.tink").toFile()) + .toEncryptionPublicKey() + + private val REQUISITION_SPEC = requisitionSpec { + population = + RequisitionSpecKt.population { + filter = RequisitionSpecKt.eventFilter { + expression = "person.age_group == ${Person.AgeGroup.YEARS_18_TO_34_VALUE}" + } + } + measurementPublicKey = MC_PUBLIC_KEY.pack() + nonce = Random.nextLong() + } + private val ENCRYPTED_REQUISITION_SPEC = + encryptRequisitionSpec( + signRequisitionSpec(REQUISITION_SPEC, MC_SIGNING_KEY), + DATA_PROVIDER_PUBLIC_KEY, + ) + + private val OUTPUT_DP_PARAMS = differentialPrivacyParams { + epsilon = 1.0 + delta = 1E-12 + } + private val MEASUREMENT_SPEC = measurementSpec { + measurementPublicKey = MC_PUBLIC_KEY.pack() + reachAndFrequency = MeasurementSpecKt.reachAndFrequency { + reachPrivacyParams = OUTPUT_DP_PARAMS + frequencyPrivacyParams = OUTPUT_DP_PARAMS + maximumFrequency = 10 + } + vidSamplingInterval = MeasurementSpecKt.vidSamplingInterval { + start = 0.0f + width = 1.0f + } + nonceHashes += Hashing.hashSha256(REQUISITION_SPEC.nonce) + modelLine = MODEL_LINE_NAME + } + + private val REQUISITION = requisition { + name = "${PDP_NAME}/requisitions/foo" + measurement = MEASUREMENT_NAME + state = Requisition.State.UNFULFILLED + measurementConsumerCertificate = MEASUREMENT_CONSUMER_CERTIFICATE_NAME + measurementSpec = signMeasurementSpec(MEASUREMENT_SPEC, MC_SIGNING_KEY) + encryptedRequisitionSpec = ENCRYPTED_REQUISITION_SPEC + protocolConfig = protocolConfig { + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + } + } + } + dataProviderCertificate = DATA_PROVIDER_CERTIFICATE.name + dataProviderPublicKey = DATA_PROVIDER_PUBLIC_KEY.pack() + } + + private fun loadSigningKey( + certDerFileName: String, + privateKeyDerFileName: String, + ): SigningKeyHandle { + return org.wfanet.measurement.common.crypto.testing.loadSigningKey( + SECRET_FILES_PATH.resolve(certDerFileName).toFile(), + SECRET_FILES_PATH.resolve(privateKeyDerFileName).toFile(), + ) + } } } From 341ec9e120dc33433eedeb4c2123bad0e6253f0c Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 4 Feb 2025 08:43:32 -0800 Subject: [PATCH 08/13] Lint. --- .../requisitions/RequisitionFetcher.kt | 24 +++---- .../requisitions/RequisitionFetcherRunner.kt | 70 ++++++++++++------- .../requisitions/RequisitionFetcherTest.kt | 52 ++++++++------ 3 files changed, 85 insertions(+), 61 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index 72ce0dc88a5..82e5f347305 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -16,13 +16,13 @@ package org.wfanet.measurement.securecomputation.requisitions import io.grpc.StatusException +import java.util.logging.Logger +import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt +import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.Requisition import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest import org.wfanet.measurement.gcloud.gcs.GcsStorageClient -import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt -import org.wfanet.measurement.api.v2alpha.Measurement -import java.util.logging.Logger // 1. Polls for new requisitions // 2. Stores new requisitions into Google Cloud Storage @@ -31,7 +31,7 @@ class RequisitionFetcher( private val gcsStorageClient: GcsStorageClient, private val gcsBucket: String, private val dataProviderName: String, - ) { +) { suspend fun executeRequisitionFetchingWorkflow() { logger.info("Executing requisitionFetchingWorkflow for $dataProviderName...") @@ -48,10 +48,11 @@ class RequisitionFetcher( private suspend fun fetchRequisitions(): List { val request = listRequisitionsRequest { parent = dataProviderName - filter = ListRequisitionsRequestKt.filter { - states += Requisition.State.UNFULFILLED - measurementStates += Measurement.State.AWAITING_REQUISITION_FULFILLMENT - } + filter = + ListRequisitionsRequestKt.filter { + states += Requisition.State.UNFULFILLED + measurementStates += Measurement.State.AWAITING_REQUISITION_FULFILLMENT + } } try { @@ -65,9 +66,10 @@ class RequisitionFetcher( for (requisition in requisitions) { val blobUri = "gs://${gcsBucket}/${requisition.name}" - // Only stores the requisition if it does not already exist in the GCS bucket by checking if the blob URI(created + // Only stores the requisition if it does not already exist in the GCS bucket by checking if + // the blob URI(created // using the requisition name, ensuring uniqueness) is populated. - if(gcsStorageClient.getBlob(blobUri) == null) { + if (gcsStorageClient.getBlob(blobUri) == null) { gcsStorageClient.writeBlob(blobUri, requisition.toByteString()) } } @@ -77,5 +79,3 @@ class RequisitionFetcher( val logger: Logger = Logger.getLogger(this::class.java.name) } } - - diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt index d957790480c..3fa5c755112 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt @@ -12,40 +12,59 @@ import org.wfanet.measurement.common.grpc.buildMutualTlsChannel import org.wfanet.measurement.common.toByteArray import org.wfanet.measurement.gcloud.gcs.GcsStorageClient -class RequisitionFetcherRunner: HttpFunction { +class RequisitionFetcherRunner : HttpFunction { override fun service(request: HttpRequest?, response: HttpResponse?) { - val clientCerts = runBlocking { - getClientCerts() - } + val clientCerts = runBlocking { getClientCerts() } - val publicChannel = buildMutualTlsChannel( - System.getenv("TARGET"), - clientCerts, - System.getenv("CERT_HOST"), - ) + val publicChannel = + buildMutualTlsChannel( + System.getenv("TARGET"), + clientCerts, + System.getenv("CERT_HOST"), + ) val requisitionsStub = RequisitionsCoroutineStub(publicChannel) - val requisitionsStorageClient = GcsStorageClient( - StorageOptions.newBuilder().setProjectId(System.getenv("REQUISITIONS_GCS_PROJECT_ID")).build().service, - System.getenv("REQUISITIONS_GCS_BUCKET") - ) + val requisitionsStorageClient = + GcsStorageClient( + StorageOptions.newBuilder() + .setProjectId(System.getenv("REQUISITIONS_GCS_PROJECT_ID")) + .build() + .service, + System.getenv("REQUISITIONS_GCS_BUCKET") + ) - val fetcher = RequisitionFetcher(requisitionsStub, requisitionsStorageClient, System.getenv("DATAPROVIDER_NAME"), System.getenv("GCS_BUCKET")) - runBlocking { - fetcher.executeRequisitionFetchingWorkflow() - } + val fetcher = + RequisitionFetcher( + requisitionsStub, + requisitionsStorageClient, + System.getenv("DATAPROVIDER_NAME"), + System.getenv("GCS_BUCKET") + ) + runBlocking { fetcher.executeRequisitionFetchingWorkflow() } } private suspend fun getClientCerts(): SigningCerts { - val authenticationStorageClient = GcsStorageClient( - StorageOptions.newBuilder().setProjectId(System.getenv("AUTHENTICATION_GCS_PROJECT_ID")).build().service, - System.getenv("AUTHENTICATION_GCS_BUCKET") - ) - - val certBlob = authenticationStorageClient.getBlob("gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_FILE_PATH")}") - val privateKeyBlob = authenticationStorageClient.getBlob("gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("PRIVATE_KEY_FILE_PATH")}") - val certCollectionBlob = authenticationStorageClient.getBlob("gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_COLLECTION_FILE_PATH")}") + val authenticationStorageClient = + GcsStorageClient( + StorageOptions.newBuilder() + .setProjectId(System.getenv("AUTHENTICATION_GCS_PROJECT_ID")) + .build() + .service, + System.getenv("AUTHENTICATION_GCS_BUCKET") + ) + val certBlob = + authenticationStorageClient.getBlob( + "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_FILE_PATH")}" + ) + val privateKeyBlob = + authenticationStorageClient.getBlob( + "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("PRIVATE_KEY_FILE_PATH")}" + ) + val certCollectionBlob = + authenticationStorageClient.getBlob( + "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_COLLECTION_FILE_PATH")}" + ) val certFile = File.createTempFile("cert", ".pem") certFile.writeBytes(checkNotNull(certBlob).read().toByteArray()) @@ -56,7 +75,6 @@ class RequisitionFetcherRunner: HttpFunction { val certCollectionFile = File.createTempFile("cert_collection", ".pem") certCollectionFile.writeBytes(checkNotNull(certCollectionBlob).read().toByteArray()) - return SigningCerts.fromPemFiles( certificateFile = certFile, privateKeyFile = privateKeyFile, diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt index e53dc0a87b1..5fc494ce413 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt @@ -1,6 +1,7 @@ package org.wfanet.measurement.securecomputation.requisitions import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper +import com.google.common.truth.Truth.assertThat import com.google.protobuf.ByteString import com.google.protobuf.kotlin.toByteString import java.nio.file.Path @@ -41,26 +42,23 @@ import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.common.pack import org.wfanet.measurement.common.readByteString +import org.wfanet.measurement.common.toByteArray import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisitionSpec import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec import org.wfanet.measurement.gcloud.gcs.GcsStorageClient -import com.google.common.truth.Truth.assertThat -import org.wfanet.measurement.common.toByteArray @RunWith(JUnit4::class) class RequisitionFetcherTest { - private val requisitionsServiceMock: RequisitionsGrpcKt.RequisitionsCoroutineImplBase = mockService { - onBlocking { listRequisitions(any()) } - .thenReturn(listRequisitionsResponse { requisitions += REQUISITION }) - } + private val requisitionsServiceMock: RequisitionsGrpcKt.RequisitionsCoroutineImplBase = + mockService { + onBlocking { listRequisitions(any()) } + .thenReturn(listRequisitionsResponse { requisitions += REQUISITION }) + } - @get:Rule - val grpcTestServerRule = GrpcTestServerRule { - addService(requisitionsServiceMock) - } + @get:Rule val grpcTestServerRule = GrpcTestServerRule { addService(requisitionsServiceMock) } private val requisitionsStub: RequisitionsGrpcKt.RequisitionsCoroutineStub by lazy { RequisitionsGrpcKt.RequisitionsCoroutineStub(grpcTestServerRule.channel) } @@ -72,7 +70,12 @@ class RequisitionFetcherTest { var persistedRequisition: ByteString? runBlocking { fetcher.executeRequisitionFetchingWorkflow() - persistedRequisition = storageClient.getBlob("gs://${BUCKET}/${REQUISITION.name}")?.read()?.toByteArray()?.toByteString() + persistedRequisition = + storageClient + .getBlob("gs://${BUCKET}/${REQUISITION.name}") + ?.read() + ?.toByteArray() + ?.toByteString() } println(persistedRequisition) assertThat(REQUISITION.toByteString()).isEqualTo(persistedRequisition) @@ -135,9 +138,10 @@ class RequisitionFetcherTest { private val REQUISITION_SPEC = requisitionSpec { population = RequisitionSpecKt.population { - filter = RequisitionSpecKt.eventFilter { - expression = "person.age_group == ${Person.AgeGroup.YEARS_18_TO_34_VALUE}" - } + filter = + RequisitionSpecKt.eventFilter { + expression = "person.age_group == ${Person.AgeGroup.YEARS_18_TO_34_VALUE}" + } } measurementPublicKey = MC_PUBLIC_KEY.pack() nonce = Random.nextLong() @@ -154,15 +158,17 @@ class RequisitionFetcherTest { } private val MEASUREMENT_SPEC = measurementSpec { measurementPublicKey = MC_PUBLIC_KEY.pack() - reachAndFrequency = MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = OUTPUT_DP_PARAMS - frequencyPrivacyParams = OUTPUT_DP_PARAMS - maximumFrequency = 10 - } - vidSamplingInterval = MeasurementSpecKt.vidSamplingInterval { - start = 0.0f - width = 1.0f - } + reachAndFrequency = + MeasurementSpecKt.reachAndFrequency { + reachPrivacyParams = OUTPUT_DP_PARAMS + frequencyPrivacyParams = OUTPUT_DP_PARAMS + maximumFrequency = 10 + } + vidSamplingInterval = + MeasurementSpecKt.vidSamplingInterval { + start = 0.0f + width = 1.0f + } nonceHashes += Hashing.hashSha256(REQUISITION_SPEC.nonce) modelLine = MODEL_LINE_NAME } From 8fcad68d950b7126a5dd0dd52b536bb41b96ad5e Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 4 Feb 2025 10:15:25 -0800 Subject: [PATCH 09/13] Lint. Remove unneeded bazel deps. --- .../requisitions/RequisitionFetcherRunner.kt | 12 +++---- .../requisitions/BUILD.bazel | 3 -- .../requisitions/RequisitionFetcherTest.kt | 32 ++++++------------- 3 files changed, 13 insertions(+), 34 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt index 3fa5c755112..0216569d867 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt @@ -17,11 +17,7 @@ class RequisitionFetcherRunner : HttpFunction { val clientCerts = runBlocking { getClientCerts() } val publicChannel = - buildMutualTlsChannel( - System.getenv("TARGET"), - clientCerts, - System.getenv("CERT_HOST"), - ) + buildMutualTlsChannel(System.getenv("TARGET"), clientCerts, System.getenv("CERT_HOST")) val requisitionsStub = RequisitionsCoroutineStub(publicChannel) val requisitionsStorageClient = @@ -30,7 +26,7 @@ class RequisitionFetcherRunner : HttpFunction { .setProjectId(System.getenv("REQUISITIONS_GCS_PROJECT_ID")) .build() .service, - System.getenv("REQUISITIONS_GCS_BUCKET") + System.getenv("REQUISITIONS_GCS_BUCKET"), ) val fetcher = @@ -38,7 +34,7 @@ class RequisitionFetcherRunner : HttpFunction { requisitionsStub, requisitionsStorageClient, System.getenv("DATAPROVIDER_NAME"), - System.getenv("GCS_BUCKET") + System.getenv("REQUISITIONS_GCS_BUCKET"), ) runBlocking { fetcher.executeRequisitionFetchingWorkflow() } } @@ -50,7 +46,7 @@ class RequisitionFetcherRunner : HttpFunction { .setProjectId(System.getenv("AUTHENTICATION_GCS_PROJECT_ID")) .build() .service, - System.getenv("AUTHENTICATION_GCS_BUCKET") + System.getenv("AUTHENTICATION_GCS_BUCKET"), ) val certBlob = diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index 7df5ee44762..76a6f5ab4ce 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -22,11 +22,8 @@ kt_jvm_test( "//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto", "@wfa_common_jvm//imports/java/com/google/cloud/storage/contrib/nio", "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", "@wfa_common_jvm//imports/java/org/junit", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt index 5fc494ce413..c592d9c454b 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt @@ -2,7 +2,6 @@ package org.wfanet.measurement.securecomputation.requisitions import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper import com.google.common.truth.Truth.assertThat -import com.google.protobuf.ByteString import com.google.protobuf.kotlin.toByteString import java.nio.file.Path import java.nio.file.Paths @@ -33,15 +32,12 @@ import org.wfanet.measurement.api.v2alpha.requisitionSpec import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.subjectKeyIdentifier -import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey import org.wfanet.measurement.common.crypto.tink.loadPublicKey import org.wfanet.measurement.common.getRuntimePath import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.common.pack -import org.wfanet.measurement.common.readByteString import org.wfanet.measurement.common.toByteArray import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisitionSpec @@ -51,7 +47,6 @@ import org.wfanet.measurement.gcloud.gcs.GcsStorageClient @RunWith(JUnit4::class) class RequisitionFetcherTest { - private val requisitionsServiceMock: RequisitionsGrpcKt.RequisitionsCoroutineImplBase = mockService { onBlocking { listRequisitions(any()) } @@ -62,22 +57,21 @@ class RequisitionFetcherTest { private val requisitionsStub: RequisitionsGrpcKt.RequisitionsCoroutineStub by lazy { RequisitionsGrpcKt.RequisitionsCoroutineStub(grpcTestServerRule.channel) } + @Test - fun `fetch new requisitions`() { + fun `fetch new requisitions and store in GCS bucket`() { val storage = LocalStorageHelper.getOptions().service val storageClient = GcsStorageClient(storage, BUCKET) val fetcher = RequisitionFetcher(requisitionsStub, storageClient, BUCKET, DATA_PROVIDER_NAME) - var persistedRequisition: ByteString? - runBlocking { + val persistedRequisition = runBlocking { fetcher.executeRequisitionFetchingWorkflow() - persistedRequisition = - storageClient - .getBlob("gs://${BUCKET}/${REQUISITION.name}") - ?.read() - ?.toByteArray() - ?.toByteString() + storageClient + .getBlob("gs://${BUCKET}/${REQUISITION.name}") + ?.read() + ?.toByteArray() + ?.toByteString() } - println(persistedRequisition) + assertThat(REQUISITION.toByteString()).isEqualTo(persistedRequisition) } @@ -95,16 +89,10 @@ class RequisitionFetcherTest { private const val PDP_ID = "somePopulationDataProvider" private const val PDP_NAME = "dataProviders/$PDP_ID" - private val MEASUREMENT_CONSUMER_CERTIFICATE_DER = - SECRET_FILES_PATH.resolve("edp_trusted_certs.pem").toFile().readByteString() private const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/AAAAAAAAAHs" private const val MEASUREMENT_NAME = "$MC_NAME/measurements/BBBBBBBBBHs" private const val MEASUREMENT_CONSUMER_CERTIFICATE_NAME = "$MEASUREMENT_CONSUMER_NAME/certificates/AAAAAAAAAcg" - private val MEASUREMENT_CONSUMER_CERTIFICATE = certificate { - name = MEASUREMENT_CONSUMER_CERTIFICATE_NAME - x509Der = MEASUREMENT_CONSUMER_CERTIFICATE_DER - } private const val DATA_PROVIDER_NAME = "dataProviders/AAAAAAAAAHs" @@ -129,8 +117,6 @@ class RequisitionFetcherTest { private val MC_PUBLIC_KEY: EncryptionPublicKey = loadPublicKey(SECRET_FILES_PATH.resolve("mc_enc_public.tink").toFile()) .toEncryptionPublicKey() - private val MC_PRIVATE_KEY: TinkPrivateKeyHandle = - loadPrivateKey(SECRET_FILES_PATH.resolve("mc_enc_private.tink").toFile()) private val DATA_PROVIDER_PUBLIC_KEY: EncryptionPublicKey = loadPublicKey(SECRET_FILES_PATH.resolve("${PDP_DISPLAY_NAME}_enc_public.tink").toFile()) .toEncryptionPublicKey() From d59f4aa79fd7a5adaa338877cc828d4f5c9d5c4d Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 4 Feb 2025 11:25:56 -0800 Subject: [PATCH 10/13] Move initialization to clients and certs to companion object so it is only done once, as opposed to every time the method is invoked. --- .../requisitions/RequisitionFetcherRunner.kt | 73 ++++++++++--------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt index 0216569d867..b5e2df9e12b 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt @@ -14,6 +14,10 @@ import org.wfanet.measurement.gcloud.gcs.GcsStorageClient class RequisitionFetcherRunner : HttpFunction { override fun service(request: HttpRequest?, response: HttpResponse?) { + runBlocking { requisitionFetcher.executeRequisitionFetchingWorkflow() } + } + + companion object { val clientCerts = runBlocking { getClientCerts() } val publicChannel = @@ -29,52 +33,51 @@ class RequisitionFetcherRunner : HttpFunction { System.getenv("REQUISITIONS_GCS_BUCKET"), ) - val fetcher = + val requisitionFetcher = RequisitionFetcher( requisitionsStub, requisitionsStorageClient, System.getenv("DATAPROVIDER_NAME"), System.getenv("REQUISITIONS_GCS_BUCKET"), ) - runBlocking { fetcher.executeRequisitionFetchingWorkflow() } - } - private suspend fun getClientCerts(): SigningCerts { - val authenticationStorageClient = - GcsStorageClient( - StorageOptions.newBuilder() - .setProjectId(System.getenv("AUTHENTICATION_GCS_PROJECT_ID")) - .build() - .service, - System.getenv("AUTHENTICATION_GCS_BUCKET"), - ) + private suspend fun getClientCerts(): SigningCerts { + val authenticationStorageClient = + GcsStorageClient( + StorageOptions.newBuilder() + .setProjectId(System.getenv("AUTHENTICATION_GCS_PROJECT_ID")) + .build() + .service, + System.getenv("AUTHENTICATION_GCS_BUCKET"), + ) - val certBlob = - authenticationStorageClient.getBlob( - "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_FILE_PATH")}" - ) - val privateKeyBlob = - authenticationStorageClient.getBlob( - "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("PRIVATE_KEY_FILE_PATH")}" - ) - val certCollectionBlob = - authenticationStorageClient.getBlob( - "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_COLLECTION_FILE_PATH")}" - ) + val certBlob = + authenticationStorageClient.getBlob( + "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_FILE_PATH")}" + ) + val privateKeyBlob = + authenticationStorageClient.getBlob( + "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("PRIVATE_KEY_FILE_PATH")}" + ) + val certCollectionBlob = + authenticationStorageClient.getBlob( + "gs://${System.getenv("AUTHENTICATION_GCS_BUCKET")}/${System.getenv("CERT_COLLECTION_FILE_PATH")}" + ) - val certFile = File.createTempFile("cert", ".pem") - certFile.writeBytes(checkNotNull(certBlob).read().toByteArray()) + val certFile = File.createTempFile("cert", ".pem") + certFile.writeBytes(checkNotNull(certBlob).read().toByteArray()) - val privateKeyFile = File.createTempFile("private_key", ".key") - privateKeyFile.writeBytes(checkNotNull(privateKeyBlob).read().toByteArray()) + val privateKeyFile = File.createTempFile("private_key", ".key") + privateKeyFile.writeBytes(checkNotNull(privateKeyBlob).read().toByteArray()) - val certCollectionFile = File.createTempFile("cert_collection", ".pem") - certCollectionFile.writeBytes(checkNotNull(certCollectionBlob).read().toByteArray()) + val certCollectionFile = File.createTempFile("cert_collection", ".pem") + certCollectionFile.writeBytes(checkNotNull(certCollectionBlob).read().toByteArray()) - return SigningCerts.fromPemFiles( - certificateFile = certFile, - privateKeyFile = privateKeyFile, - trustedCertCollectionFile = certCollectionFile, - ) + return SigningCerts.fromPemFiles( + certificateFile = certFile, + privateKeyFile = privateKeyFile, + trustedCertCollectionFile = certCollectionFile, + ) + } } } From f3e66e7dd15fa996c06045fc41247e29a85507e6 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 4 Feb 2025 11:31:59 -0800 Subject: [PATCH 11/13] Add license headers --- .../securecomputation/requisitions/BUILD.bazel | 1 + .../requisitions/RequisitionFetcher.kt | 1 + .../requisitions/RequisitionFetcherRunner.kt | 16 ++++++++++++++++ .../requisitions/RequisitionFetcherTest.kt | 16 ++++++++++++++++ 4 files changed, 34 insertions(+) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index 664a389675a..a055fe71321 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_java//java:defs.bzl", "java_binary") load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") package(default_visibility = ["//visibility:public"]) diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt index 82e5f347305..56ed9aca23a 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcher.kt @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.wfanet.measurement.securecomputation.requisitions import io.grpc.StatusException diff --git a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt index b5e2df9e12b..70b95974ebc 100644 --- a/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt +++ b/src/main/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherRunner.kt @@ -1,3 +1,19 @@ +/* + * Copyright 2025 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.wfanet.measurement.securecomputation.requisitions import com.google.cloud.functions.HttpFunction diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt index c592d9c454b..83c40130beb 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/RequisitionFetcherTest.kt @@ -1,3 +1,19 @@ +/* + * Copyright 2025 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.wfanet.measurement.securecomputation.requisitions import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper From f9564cc7ac217a1364ead550766a5d97eae152b8 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 4 Feb 2025 11:35:36 -0800 Subject: [PATCH 12/13] Add missing imports to test build file --- .../measurement/securecomputation/requisitions/BUILD.bazel | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index 76a6f5ab4ce..981b0fee9e5 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -1,4 +1,5 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library", "kt_jvm_test") +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") +load("@rules_java//java:defs.bzl", "java_test") package( default_testonly = True, From 6c32b167648c576004f12a21902f40d83bb73776 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Tue, 4 Feb 2025 12:02:54 -0800 Subject: [PATCH 13/13] Lint --- .../measurement/securecomputation/requisitions/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel index 981b0fee9e5..3f5418ba821 100644 --- a/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/securecomputation/requisitions/BUILD.bazel @@ -1,5 +1,5 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") load("@rules_java//java:defs.bzl", "java_test") +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") package( default_testonly = True,