Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bandit support #25

Merged
merged 21 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ plugins {
}

group = 'cloud.eppo'
version = '2.1.0-SNAPSHOT'
version = '3.0.0-SNAPSHOT'
ext.isReleaseVersion = !version.endsWith("SNAPSHOT")

dependencies {
Expand All @@ -23,6 +23,7 @@ dependencies {
testImplementation 'commons-io:commons-io:2.11.0'
testImplementation "com.google.truth:truth:1.4.4"
testImplementation 'org.mockito:mockito-core:4.11.0'
testImplementation 'org.mockito:mockito-inline:4.11.0'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for mocking static methods (such as evaluateBandit())

}

test {
Expand Down Expand Up @@ -140,14 +141,15 @@ tasks.withType(PublishToMavenRepository) {
}
}

signing {
sign publishing.publications.mavenJava
if (System.env['CI']) {
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE)
if (!project.gradle.startParameter.taskNames.contains('publishToMavenLocal')) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't sign the maven package if publishing to local maven repository (for multi-repository development)

signing {
sign publishing.publications.mavenJava
if (System.env['CI']) {
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE)
}
}
}


javadoc {
failOnError = false
options.addStringOption('Xdoclint:none', '-quiet')
Expand Down
73 changes: 73 additions & 0 deletions src/main/java/cloud/eppo/BanditEvaluationResult.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package cloud.eppo;

import cloud.eppo.ufc.dto.DiscriminableAttributes;

public class BanditEvaluationResult {

private final String flagKey;
private final String subjectKey;
private final DiscriminableAttributes subjectAttributes;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DiscriminableAttributes is an interface I introduced for attributes that can broken up (discriminated) by numeric and contextual. You'll see it more in action later, but basically, it helps abstract converting broken-out attributes to combined and vice versa.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fancy word! I like it

private final String actionKey;
private final DiscriminableAttributes actionAttributes;
private final double actionScore;
private final double actionWeight;
private final double gamma;
private final double optimalityGap;

public BanditEvaluationResult(
String flagKey,
String subjectKey,
DiscriminableAttributes subjectAttributes,
String actionKey,
DiscriminableAttributes actionAttributes,
double actionScore,
double actionWeight,
double gamma,
double optimalityGap) {
this.flagKey = flagKey;
this.subjectKey = subjectKey;
this.subjectAttributes = subjectAttributes;
this.actionKey = actionKey;
this.actionAttributes = actionAttributes;
this.actionScore = actionScore;
this.actionWeight = actionWeight;
this.gamma = gamma;
this.optimalityGap = optimalityGap;
}

public String getFlagKey() {
return flagKey;
}

public String getSubjectKey() {
return subjectKey;
}

public DiscriminableAttributes getSubjectAttributes() {
return subjectAttributes;
}

public String getActionKey() {
return actionKey;
}

public DiscriminableAttributes getActionAttributes() {
return actionAttributes;
}

public double getActionScore() {
return actionScore;
}

public double getActionWeight() {
return actionWeight;
}

public double getGamma() {
return gamma;
}

public double getOptimalityGap() {
return optimalityGap;
}
}
168 changes: 168 additions & 0 deletions src/main/java/cloud/eppo/BanditEvaluator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package cloud.eppo;

import cloud.eppo.ufc.dto.*;
import java.util.*;
import java.util.stream.Collectors;

public class BanditEvaluator {

private static final int BANDIT_ASSIGNMENT_SHARDS = 10000; // hard-coded for now

public static BanditEvaluationResult evaluateBandit(
String flagKey,
String subjectKey,
DiscriminableAttributes subjectAttributes,
Actions actions,
BanditModelData modelData) {
Map<String, Double> actionScores = scoreActions(subjectAttributes, actions, modelData);
Map<String, Double> actionWeights =
weighActions(actionScores, modelData.getGamma(), modelData.getActionProbabilityFloor());
String selectedActionKey = selectAction(flagKey, subjectKey, actionWeights);

// Compute optimality gap in terms of score
double topScore =
actionScores.values().stream().mapToDouble(Double::doubleValue).max().orElse(0);
double optimalityGap = topScore - actionScores.get(selectedActionKey);

return new BanditEvaluationResult(
flagKey,
subjectKey,
subjectAttributes,
selectedActionKey,
actions.get(selectedActionKey),
actionScores.get(selectedActionKey),
actionWeights.get(selectedActionKey),
modelData.getGamma(),
optimalityGap);
}

private static Map<String, Double> scoreActions(
DiscriminableAttributes subjectAttributes, Actions actions, BanditModelData modelData) {
return actions.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
e -> {
String actionName = e.getKey();
DiscriminableAttributes actionAttributes = e.getValue();

// get all coefficients known to the model for this action
BanditCoefficients banditCoefficients =
modelData.getCoefficients().get(actionName);

if (banditCoefficients == null) {
// Unknown action; return the default action score
return modelData.getDefaultActionScore();
}

// Score the action using the provided attributes
double actionScore = banditCoefficients.getIntercept();
actionScore +=
scoreContextForCoefficients(
actionAttributes.getNumericAttributes(),
banditCoefficients.getActionNumericCoefficients());
actionScore +=
scoreContextForCoefficients(
actionAttributes.getCategoricalAttributes(),
banditCoefficients.getActionCategoricalCoefficients());
actionScore +=
scoreContextForCoefficients(
subjectAttributes.getNumericAttributes(),
banditCoefficients.getSubjectNumericCoefficients());
actionScore +=
scoreContextForCoefficients(
subjectAttributes.getCategoricalAttributes(),
banditCoefficients.getSubjectCategoricalCoefficients());

return actionScore;
}));
}

private static double scoreContextForCoefficients(
Attributes attributes, Map<String, ? extends BanditAttributeCoefficients> coefficients) {

double totalScore = 0.0;

for (BanditAttributeCoefficients attributeCoefficients : coefficients.values()) {
EppoValue contextValue = attributes.get(attributeCoefficients.getAttributeKey());
// The coefficient implementation knows how to score
double attributeScore = attributeCoefficients.scoreForAttributeValue(contextValue);
Comment on lines +88 to +89
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit different than other SDKs; the logic for how to score a given attribute value is collocated with the type of attribute (i.e. BanditNumericAttributeCoefficients and BanditCategoricalAttributeCoefficients)

totalScore += attributeScore;
}

return totalScore;
}

private static Map<String, Double> weighActions(
Map<String, Double> actionScores, double gamma, double actionProbabilityFloor) {
Double highestScore = null;
String highestScoredAction = null;
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) {
if (highestScore == null
|| actionScore.getValue() > highestScore
|| actionScore
.getValue()
.equals(highestScore) // note: we break ties for scores by action name
&& actionScore.getKey().compareTo(highestScoredAction) < 0) {
highestScore = actionScore.getValue();
highestScoredAction = actionScore.getKey();
}
}

// Weigh all the actions using their score
Map<String, Double> actionWeights = new HashMap<>();
double totalNonHighestWeight = 0.0;
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) {
if (actionScore.getKey().equals(highestScoredAction)) {
// The highest scored action is weighed at the end
continue;
}

// Compute weight (probability)
double unboundedProbability =
1 / (actionScores.size() + (gamma * (highestScore - actionScore.getValue())));
double minimumProbability = actionProbabilityFloor / actionScores.size();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@typotter bandit probability floor now normalized

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

double boundedProbability = Math.max(unboundedProbability, minimumProbability);
totalNonHighestWeight += boundedProbability;

actionWeights.put(actionScore.getKey(), boundedProbability);
}

// Weigh the highest scoring action (defensively preventing a negative probability)
double weightForHighestScore = Math.max(1 - totalNonHighestWeight, 0);
actionWeights.put(highestScoredAction, weightForHighestScore);
return actionWeights;
}

private static String selectAction(
String flagKey, String subjectKey, Map<String, Double> actionWeights) {
// Deterministically "shuffle" the actions
// This way as action weights shift, a bunch of users who were on the edge of one action won't
// all be shifted to the same new action at the same time.
List<String> shuffledActionKeys =
actionWeights.keySet().stream()
.sorted(
Comparator.comparingInt(
(String actionKey) ->
ShardUtils.getShard(
flagKey + "-" + subjectKey + "-" + actionKey,
BANDIT_ASSIGNMENT_SHARDS))
.thenComparing(actionKey -> actionKey))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

.collect(Collectors.toList());

// Select action from the shuffled actions, based on weight
double assignedShard =
ShardUtils.getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS);
double assignmentWeightThreshold = assignedShard / (double) BANDIT_ASSIGNMENT_SHARDS;
double cumulativeWeight = 0;
String assignedAction = null;
for (String actionKey : shuffledActionKeys) {
cumulativeWeight += actionWeights.get(actionKey);
if (cumulativeWeight > assignmentWeightThreshold) {
assignedAction = actionKey;
break;
}
}
return assignedAction;
}
}
34 changes: 26 additions & 8 deletions src/main/java/cloud/eppo/ConfigurationRequestor.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import cloud.eppo.ufc.dto.FlagConfig;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -12,28 +14,44 @@ public class ConfigurationRequestor {

private final EppoHttpClient client;
private final ConfigurationStore configurationStore;
private final Set<String> loadedBanditModelVersions;

public ConfigurationRequestor(ConfigurationStore configurationStore, EppoHttpClient client) {
this.configurationStore = configurationStore;
this.client = client;
this.loadedBanditModelVersions = new HashSet<>();
}

// TODO: async loading for android
public void load() {
log.debug("Fetching configuration");
Response response = client.get("/api/flag-config/v1/config");
String flagConfigurationJsonString = requestBody("/api/flag-config/v1/config");
configurationStore.setFlagsFromJsonString(flagConfigurationJsonString);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to worry about threading here and lock reads until the bandit models are set below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thinking! The underlying implementation is a thread-safe ConcurrentHashMap, which is what we are using in the Android SDK without (known) issues

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ConcurrentHashMap underlying will keep reads and writes happening in an orderly fashion, but there's no lock around the setting of the flags and the subsequent setting of the bandit parameters below, so, there's a brief instant the config store could be accessed while it has incomplete configuration set. If a caller manages to evaluate a bandit between those two instructions, the SDK will have incomplete data and may not have the required bandit. The scenario is probably unlikely, contrived and probably nearly impossible to test but I'm curious how wrong things would go if this happened (user would get returned the bandit ket as the variation with action=null which the devs should already have handling for in their app).


Set<String> neededModelVersions = configurationStore.banditModelVersions();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for more than one bandit to have the same model version string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, model versions are unique to each bandit.

Under the hood, this version ID--for now--is the auto incrementing primary key in a model_versions table that has versions for all bandits (as it includes the bandit id)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification

boolean needBanditParameters = !loadedBanditModelVersions.containsAll(neededModelVersions);
Comment on lines +31 to +32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking advantage of the latest and greatest UFC feature: bandit model versions.
We will only request bandit parameters if we are missing model versions at play (as opposed to every time we poll if there are one or more bandits).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shooting for bleeding edge all at once eh? I like your style 🤠

if (needBanditParameters) {
String banditParametersJsonString = requestBody("/api/flag-config/v1/bandits");
configurationStore.setBanditParametersFromJsonString(banditParametersJsonString);
// Record the model versions that we just loaded, so we can compare when the store is later
// updated
loadedBanditModelVersions.clear();
loadedBanditModelVersions.addAll(configurationStore.banditModelVersions());
}
}

private String requestBody(String route) {
Response response = client.get(route);
if (!response.isSuccessful() || response.body() == null) {
throw new RuntimeException("Failed to fetch from " + route);
}
try {
if (!response.isSuccessful()) {
throw new RuntimeException("Failed to fetch configuration");
}
configurationStore.setFlagsFromJsonString(response.body().string());
return response.body().string();
} catch (IOException e) {
// TODO: better exception handling?
throw new RuntimeException(e);
}
}

// TODO: async loading for android

public FlagConfig getConfiguration(String flagKey) {
return configurationStore.getFlag(flagKey);
}
Expand Down
Loading