Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Simple TEE VID Labeling App #1991

Open
wants to merge 9 commits into
base: stevenwarejones_data_watcher
Choose a base branch
from
Open
11 changes: 11 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ bazel_dep(
version = "0.5.0",
repo_name = "wfa_virtual_people_common",
)
bazel_dep(
name = "virtual-people-core-serving",
version = "0.2.1",
repo_name = "wfa_virtual_people_core_serving",
)
# DO_NOT_SUBMIT (world-federation-of-advertisers/virtual-people-core-serving#79)
archive_override(
module_name = "virtual-people-core-serving",
strip_prefix = "virtual-people-core-serving-3fd624714c92c505a36840fdecf05233eb747e5f",
urls = "https://github.com/world-federation-of-advertisers/virtual-people-core-serving/archive/3fd624714c92c505a36840fdecf05233eb747e5f.tar.gz",
)
bazel_dep(
name = "grpc-gateway",
version = "2.18.1",
Expand Down
6 changes: 6 additions & 0 deletions MODULE.bazel.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library")

package(
default_testonly = True,
default_visibility = ["//visibility:public"],
)

kt_jvm_library(
name = "vid_labeler_app",
srcs = ["VidLabelerApp.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/securecomputation/teesdk:base_tee_application",
"//src/main/proto/wfa/measurement/securecomputation/datawatcher/v1alpha:data_watcher_config_kt_jvm_proto",
"//src/main/proto/wfa/measurement/securecomputation/teeapps/v1alpha:tee_app_config_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/protobuf",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/pubsub:google_pub_sub_client",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/storage:mesos_recordio_storage_client",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/storage/testing",
"@wfa_virtual_people_common//src/main/proto/wfa/virtual_people/common:event_kt_jvm_proto",
"@wfa_virtual_people_common//src/main/proto/wfa/virtual_people/common:model_kt_jvm_proto",
"@wfa_virtual_people_core_serving//src/main/kotlin/org/wfanet/virtualpeople/core/labeler",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/

package org.wfanet.measurement.securecomputation.vidlabeling

import com.google.protobuf.Parser
import com.google.protobuf.TextFormat
import com.google.protobuf.kotlin.toByteStringUtf8
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.reduce
import org.wfanet.measurement.queue.QueueSubscriber
import org.wfanet.measurement.securecomputation.datawatcher.v1alpha.DataWatcherConfig.TriggeredApp
import org.wfanet.measurement.securecomputation.teeapps.v1alpha.TeeAppConfig
import org.wfanet.measurement.securecomputation.teesdk.BaseTeeApplication
import org.wfanet.measurement.storage.MesosRecordIoStorageClient
import org.wfanet.measurement.storage.StorageClient
import org.wfanet.virtualpeople.common.CompiledNode
import org.wfanet.virtualpeople.common.LabelerInput
import org.wfanet.virtualpeople.common.LabelerOutput
import org.wfanet.virtualpeople.core.labeler.Labeler

/*
* TEE VID Labeling App.
*/
class VidLabelerApp(
private val storageClient: StorageClient,
queueName: String,
queueSubscriber: QueueSubscriber,
parser: Parser<TriggeredApp>
) :
BaseTeeApplication<TriggeredApp>(
subscriptionId = queueName,
queueSubscriber = queueSubscriber,
parser = parser
) {

/*
* Currently, labels events using a single thread. Consider using a different approach for faster labeling.
* TODO: Read and write using serialized ByteString rather than TextProto once the MesosStorageClient bug is fixed.
*/
private suspend fun labelPath(
inputBlobKey: String,
outputBlobKey: String,
labeler: Labeler,
storageClient: MesosRecordIoStorageClient
) {
val inputBlob =
storageClient.getBlob(inputBlobKey)
?: throw IllegalArgumentException("Input blob does not exist")
val inputRecords = inputBlob.read()

val outputFlow =
inputRecords.map { byteString ->
val labelerInput =
LabelerInput.getDefaultInstance()
.newBuilderForType()
.apply { TextFormat.Parser.newBuilder().build().merge(byteString.toStringUtf8(), this) }
.build() as LabelerInput
val labelerOutput: LabelerOutput = labeler.label(input = labelerInput)
labelerOutput.toString().toByteStringUtf8()
}

storageClient.writeBlob(outputBlobKey, outputFlow)
}

override suspend fun runWork(message: TriggeredApp) {
val teeAppConfig = message.config.unpack(TeeAppConfig::class.java)
assert(teeAppConfig.workTypeCase == TeeAppConfig.WorkTypeCase.VID_LABELING_CONFIG)
val vidLabelingConfig = teeAppConfig.vidLabelingConfig
val compiledNode: CompiledNode =
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA")
when (vidLabelingConfig.modelFormatCase) {
TeeAppConfig.VidLabelingConfig.ModelFormatCase.MODEL_BLOB_TEXT_PROTO_PATH -> {
val vidModelBlob = storageClient.getBlob(vidLabelingConfig.modelBlobTextProtoPath)!!
val modelData =
vidModelBlob.read().reduce { acc, byteString -> acc.concat(byteString) }.toStringUtf8()
CompiledNode.getDefaultInstance()
.newBuilderForType()
.apply { TextFormat.Parser.newBuilder().build().merge(modelData, this) }
.build() as CompiledNode
}
TeeAppConfig.VidLabelingConfig.ModelFormatCase.MODEL_BLOB_RIEGELI_PATH ->
TODO("Currently Unsupported")
TeeAppConfig.VidLabelingConfig.ModelFormatCase.MODEL_LINE -> TODO("Currently Unsupported")
TeeAppConfig.VidLabelingConfig.ModelFormatCase.MODELFORMAT_NOT_SET ->
throw Exception("Invalid model format: ${vidLabelingConfig.modelFormatCase}")
}

val labeler = Labeler.build(compiledNode)

val inputBlobKey = message.path
val outputBlobKey =
vidLabelingConfig.outputBasePath + inputBlobKey.removePrefix(vidLabelingConfig.inputBasePath)

val mesosRecordIoStorageClient = MesosRecordIoStorageClient(storageClient)
labelPath(inputBlobKey, outputBlobKey, labeler, mesosRecordIoStorageClient)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load(
"@wfa_rules_kotlin_jvm//kotlin:defs.bzl",
"kt_jvm_grpc_proto_library",
"kt_jvm_proto_library",
)

package(default_visibility = ["//visibility:public"])

IMPORT_PREFIX = "/src/main/proto"

#Resources and shared message types.

proto_library(
name = "tee_app_config_proto",
srcs = ["tee_app_config.proto"],
strip_import_prefix = IMPORT_PREFIX,
deps = [
"@com_google_googleapis//google/api:field_behavior_proto",
"@com_google_googleapis//google/api:resource_proto",
"@com_google_protobuf//:any_proto",
],
)

kt_jvm_proto_library(
name = "tee_app_config_kt_jvm_proto",
deps = [":tee_app_config_proto"],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2025 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package wfa.measurement.securecomputation.teeapps.v1alpha;

import "google/api/field_behavior.proto";
import "google/api/resource.proto";

option java_package = "org.wfanet.measurement.securecomputation.teeapps.v1alpha";
option java_multiple_files = true;
option java_outer_classname = "TeeAppConfigProto";

message TeeAppConfig {
message VidLabelingConfig {
string input_base_path = 1;
string output_base_path = 2;
oneof ModelFormat {
string model_blob_text_proto_path = 3;
string model_blob_riegeli_path = 4;
string model_line = 5 [
(google.api.resource_reference).type = "halo.wfanet.org/ModelLine",
(google.api.field_behavior) = REQUIRED,
(google.api.field_behavior) = IMMUTABLE
];
}
}
message ReachAndFrequencyConfig {
// TODO
}
oneof WorkType {
VidLabelingConfig vid_labeling_config = 1;
ReachAndFrequencyConfig reach_and_frequency_config = 2;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test")

package(
default_testonly = True,
default_visibility = ["//visibility:public"],
)

kt_jvm_test(
name = "VidLabelerAppTest",
srcs = ["VidLabelerAppTest.kt"],
data = [
"@wfa_virtual_people_core_serving//src/main/resources/labeler:labeler_integration_test_data",
],
test_class = "org.wfanet.measurement.securecomputation.teeapps.vidlabeling.VidLabelerAppTest",
deps = [
"//src/main/kotlin/org/wfanet/measurement/securecomputation/controlplane/v1alpha:google_pub_sub_work_items_service",
"//src/main/kotlin/org/wfanet/measurement/securecomputation/teeapps/vidlabeling:vid_labeler_app",
"//src/main/proto/wfa/measurement/securecomputation/controlplane/v1alpha:work_item_kt_jvm_proto",
"//src/main/proto/wfa/measurement/securecomputation/controlplane/v1alpha:work_items_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//imports/java/com/google/common/truth",
"@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto",
"@wfa_common_jvm//imports/kotlin/kotlin/test",
"@wfa_common_jvm//imports/kotlin/org/mockito/kotlin",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/pubsub:google_pub_sub_client",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/pubsub:publisher",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/pubsub:subscriber",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/pubsub/testing:google_pub_sub_emulator_client",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/pubsub/testing:google_pub_sub_emulator_provider",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/queue:queue_subscriber",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/storage/testing",
"@wfa_virtual_people_common//src/main/proto/wfa/virtual_people/common:event_kt_jvm_proto",
"@wfa_virtual_people_common//src/main/proto/wfa/virtual_people/common:model_kt_jvm_proto",
"@wfa_virtual_people_core_serving//src/main/kotlin/org/wfanet/virtualpeople/core/labeler",
],
)
Loading
Loading