-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bandit support #25
Changes from all commits
5ff2f5e
c5497f6
ba25612
9a601ee
31e193a
000dd9c
bb747be
3b9adb3
52e499a
2be2085
e952177
de7bbea
743fc49
51070bc
e69d9aa
abb2848
de1c745
a52d153
14bab7f
5384bd2
efed6f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ plugins { | |
} | ||
|
||
group = 'cloud.eppo' | ||
version = '2.1.0-SNAPSHOT' | ||
version = '3.0.0-SNAPSHOT' | ||
ext.isReleaseVersion = !version.endsWith("SNAPSHOT") | ||
|
||
dependencies { | ||
|
@@ -23,6 +23,7 @@ dependencies { | |
testImplementation 'commons-io:commons-io:2.11.0' | ||
testImplementation "com.google.truth:truth:1.4.4" | ||
testImplementation 'org.mockito:mockito-core:4.11.0' | ||
testImplementation 'org.mockito:mockito-inline:4.11.0' | ||
} | ||
|
||
test { | ||
|
@@ -140,14 +141,15 @@ tasks.withType(PublishToMavenRepository) { | |
} | ||
} | ||
|
||
signing { | ||
sign publishing.publications.mavenJava | ||
if (System.env['CI']) { | ||
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE) | ||
if (!project.gradle.startParameter.taskNames.contains('publishToMavenLocal')) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't sign the maven package if publishing to local maven repository (for multi-repository development) |
||
signing { | ||
sign publishing.publications.mavenJava | ||
if (System.env['CI']) { | ||
useInMemoryPgpKeys(System.env.GPG_PRIVATE_KEY, System.env.GPG_PASSPHRASE) | ||
} | ||
} | ||
} | ||
|
||
|
||
javadoc { | ||
failOnError = false | ||
options.addStringOption('Xdoclint:none', '-quiet') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package cloud.eppo; | ||
|
||
import cloud.eppo.ufc.dto.DiscriminableAttributes; | ||
|
||
public class BanditEvaluationResult { | ||
|
||
private final String flagKey; | ||
private final String subjectKey; | ||
private final DiscriminableAttributes subjectAttributes; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fancy word! I like it |
||
private final String actionKey; | ||
private final DiscriminableAttributes actionAttributes; | ||
private final double actionScore; | ||
private final double actionWeight; | ||
private final double gamma; | ||
private final double optimalityGap; | ||
|
||
public BanditEvaluationResult( | ||
String flagKey, | ||
String subjectKey, | ||
DiscriminableAttributes subjectAttributes, | ||
String actionKey, | ||
DiscriminableAttributes actionAttributes, | ||
double actionScore, | ||
double actionWeight, | ||
double gamma, | ||
double optimalityGap) { | ||
this.flagKey = flagKey; | ||
this.subjectKey = subjectKey; | ||
this.subjectAttributes = subjectAttributes; | ||
this.actionKey = actionKey; | ||
this.actionAttributes = actionAttributes; | ||
this.actionScore = actionScore; | ||
this.actionWeight = actionWeight; | ||
this.gamma = gamma; | ||
this.optimalityGap = optimalityGap; | ||
} | ||
|
||
public String getFlagKey() { | ||
return flagKey; | ||
} | ||
|
||
public String getSubjectKey() { | ||
return subjectKey; | ||
} | ||
|
||
public DiscriminableAttributes getSubjectAttributes() { | ||
return subjectAttributes; | ||
} | ||
|
||
public String getActionKey() { | ||
return actionKey; | ||
} | ||
|
||
public DiscriminableAttributes getActionAttributes() { | ||
return actionAttributes; | ||
} | ||
|
||
public double getActionScore() { | ||
return actionScore; | ||
} | ||
|
||
public double getActionWeight() { | ||
return actionWeight; | ||
} | ||
|
||
public double getGamma() { | ||
return gamma; | ||
} | ||
|
||
public double getOptimalityGap() { | ||
return optimalityGap; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
package cloud.eppo; | ||
|
||
import cloud.eppo.ufc.dto.*; | ||
import java.util.*; | ||
import java.util.stream.Collectors; | ||
|
||
public class BanditEvaluator { | ||
|
||
private static final int BANDIT_ASSIGNMENT_SHARDS = 10000; // hard-coded for now | ||
|
||
public static BanditEvaluationResult evaluateBandit( | ||
String flagKey, | ||
String subjectKey, | ||
DiscriminableAttributes subjectAttributes, | ||
Actions actions, | ||
BanditModelData modelData) { | ||
Map<String, Double> actionScores = scoreActions(subjectAttributes, actions, modelData); | ||
Map<String, Double> actionWeights = | ||
weighActions(actionScores, modelData.getGamma(), modelData.getActionProbabilityFloor()); | ||
String selectedActionKey = selectAction(flagKey, subjectKey, actionWeights); | ||
|
||
// Compute optimality gap in terms of score | ||
double topScore = | ||
actionScores.values().stream().mapToDouble(Double::doubleValue).max().orElse(0); | ||
double optimalityGap = topScore - actionScores.get(selectedActionKey); | ||
|
||
return new BanditEvaluationResult( | ||
flagKey, | ||
subjectKey, | ||
subjectAttributes, | ||
selectedActionKey, | ||
actions.get(selectedActionKey), | ||
actionScores.get(selectedActionKey), | ||
actionWeights.get(selectedActionKey), | ||
modelData.getGamma(), | ||
optimalityGap); | ||
} | ||
|
||
private static Map<String, Double> scoreActions( | ||
DiscriminableAttributes subjectAttributes, Actions actions, BanditModelData modelData) { | ||
return actions.entrySet().stream() | ||
.collect( | ||
Collectors.toMap( | ||
Map.Entry::getKey, | ||
e -> { | ||
String actionName = e.getKey(); | ||
DiscriminableAttributes actionAttributes = e.getValue(); | ||
|
||
// get all coefficients known to the model for this action | ||
BanditCoefficients banditCoefficients = | ||
modelData.getCoefficients().get(actionName); | ||
|
||
if (banditCoefficients == null) { | ||
// Unknown action; return the default action score | ||
return modelData.getDefaultActionScore(); | ||
} | ||
|
||
// Score the action using the provided attributes | ||
double actionScore = banditCoefficients.getIntercept(); | ||
actionScore += | ||
scoreContextForCoefficients( | ||
actionAttributes.getNumericAttributes(), | ||
banditCoefficients.getActionNumericCoefficients()); | ||
actionScore += | ||
scoreContextForCoefficients( | ||
actionAttributes.getCategoricalAttributes(), | ||
banditCoefficients.getActionCategoricalCoefficients()); | ||
actionScore += | ||
scoreContextForCoefficients( | ||
subjectAttributes.getNumericAttributes(), | ||
banditCoefficients.getSubjectNumericCoefficients()); | ||
actionScore += | ||
scoreContextForCoefficients( | ||
subjectAttributes.getCategoricalAttributes(), | ||
banditCoefficients.getSubjectCategoricalCoefficients()); | ||
|
||
return actionScore; | ||
})); | ||
} | ||
|
||
private static double scoreContextForCoefficients( | ||
Attributes attributes, Map<String, ? extends BanditAttributeCoefficients> coefficients) { | ||
|
||
double totalScore = 0.0; | ||
|
||
for (BanditAttributeCoefficients attributeCoefficients : coefficients.values()) { | ||
EppoValue contextValue = attributes.get(attributeCoefficients.getAttributeKey()); | ||
// The coefficient implementation knows how to score | ||
double attributeScore = attributeCoefficients.scoreForAttributeValue(contextValue); | ||
Comment on lines
+88
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit different than other SDKs; the logic for how to score a given attribute value is collocated with the type of attribute (i.e. |
||
totalScore += attributeScore; | ||
} | ||
|
||
return totalScore; | ||
} | ||
|
||
private static Map<String, Double> weighActions( | ||
Map<String, Double> actionScores, double gamma, double actionProbabilityFloor) { | ||
Double highestScore = null; | ||
String highestScoredAction = null; | ||
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) { | ||
if (highestScore == null | ||
|| actionScore.getValue() > highestScore | ||
|| actionScore | ||
.getValue() | ||
.equals(highestScore) // note: we break ties for scores by action name | ||
&& actionScore.getKey().compareTo(highestScoredAction) < 0) { | ||
highestScore = actionScore.getValue(); | ||
highestScoredAction = actionScore.getKey(); | ||
} | ||
} | ||
|
||
// Weigh all the actions using their score | ||
Map<String, Double> actionWeights = new HashMap<>(); | ||
double totalNonHighestWeight = 0.0; | ||
for (Map.Entry<String, Double> actionScore : actionScores.entrySet()) { | ||
if (actionScore.getKey().equals(highestScoredAction)) { | ||
// The highest scored action is weighed at the end | ||
continue; | ||
} | ||
|
||
// Compute weight (probability) | ||
double unboundedProbability = | ||
1 / (actionScores.size() + (gamma * (highestScore - actionScore.getValue()))); | ||
double minimumProbability = actionProbabilityFloor / actionScores.size(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @typotter bandit probability floor now normalized There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
double boundedProbability = Math.max(unboundedProbability, minimumProbability); | ||
totalNonHighestWeight += boundedProbability; | ||
|
||
actionWeights.put(actionScore.getKey(), boundedProbability); | ||
} | ||
|
||
// Weigh the highest scoring action (defensively preventing a negative probability) | ||
double weightForHighestScore = Math.max(1 - totalNonHighestWeight, 0); | ||
actionWeights.put(highestScoredAction, weightForHighestScore); | ||
return actionWeights; | ||
} | ||
|
||
private static String selectAction( | ||
String flagKey, String subjectKey, Map<String, Double> actionWeights) { | ||
// Deterministically "shuffle" the actions | ||
// This way as action weights shift, a bunch of users who were on the edge of one action won't | ||
// all be shifted to the same new action at the same time. | ||
List<String> shuffledActionKeys = | ||
actionWeights.keySet().stream() | ||
.sorted( | ||
Comparator.comparingInt( | ||
(String actionKey) -> | ||
ShardUtils.getShard( | ||
flagKey + "-" + subjectKey + "-" + actionKey, | ||
BANDIT_ASSIGNMENT_SHARDS)) | ||
.thenComparing(actionKey -> actionKey)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
.collect(Collectors.toList()); | ||
|
||
// Select action from the shuffled actions, based on weight | ||
double assignedShard = | ||
ShardUtils.getShard(flagKey + "-" + subjectKey, BANDIT_ASSIGNMENT_SHARDS); | ||
double assignmentWeightThreshold = assignedShard / (double) BANDIT_ASSIGNMENT_SHARDS; | ||
double cumulativeWeight = 0; | ||
String assignedAction = null; | ||
for (String actionKey : shuffledActionKeys) { | ||
cumulativeWeight += actionWeights.get(actionKey); | ||
if (cumulativeWeight > assignmentWeightThreshold) { | ||
assignedAction = actionKey; | ||
break; | ||
} | ||
} | ||
return assignedAction; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
|
||
import cloud.eppo.ufc.dto.FlagConfig; | ||
import java.io.IOException; | ||
import java.util.HashSet; | ||
import java.util.Set; | ||
import okhttp3.Response; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
@@ -12,28 +14,44 @@ public class ConfigurationRequestor { | |
|
||
private final EppoHttpClient client; | ||
private final ConfigurationStore configurationStore; | ||
private final Set<String> loadedBanditModelVersions; | ||
|
||
public ConfigurationRequestor(ConfigurationStore configurationStore, EppoHttpClient client) { | ||
this.configurationStore = configurationStore; | ||
this.client = client; | ||
this.loadedBanditModelVersions = new HashSet<>(); | ||
} | ||
|
||
// TODO: async loading for android | ||
public void load() { | ||
log.debug("Fetching configuration"); | ||
Response response = client.get("/api/flag-config/v1/config"); | ||
String flagConfigurationJsonString = requestBody("/api/flag-config/v1/config"); | ||
configurationStore.setFlagsFromJsonString(flagConfigurationJsonString); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need to worry about threading here and lock reads until the bandit models are set below? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good thinking! The underlying implementation is a thread-safe ConcurrentHashMap, which is what we are using in the Android SDK without (known) issues There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ConcurrentHashMap underlying will keep reads and writes happening in an orderly fashion, but there's no lock around the setting of the flags and the subsequent setting of the bandit parameters below, so, there's a brief instant the config store could be accessed while it has incomplete configuration set. If a caller manages to evaluate a bandit between those two instructions, the SDK will have incomplete data and may not have the required bandit. The scenario is probably unlikely, contrived and probably nearly impossible to test but I'm curious how wrong things would go if this happened (user would get returned the bandit ket as the variation with action=null which the devs should already have handling for in their app). |
||
|
||
Set<String> neededModelVersions = configurationStore.banditModelVersions(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible for more than one bandit to have the same model version string? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, model versions are unique to each bandit. Under the hood, this version ID--for now--is the auto incrementing primary key in a model_versions table that has versions for all bandits (as it includes the bandit id) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification |
||
boolean needBanditParameters = !loadedBanditModelVersions.containsAll(neededModelVersions); | ||
Comment on lines
+31
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Taking advantage of the latest and greatest UFC feature: bandit model versions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shooting for bleeding edge all at once eh? I like your style 🤠 |
||
if (needBanditParameters) { | ||
String banditParametersJsonString = requestBody("/api/flag-config/v1/bandits"); | ||
configurationStore.setBanditParametersFromJsonString(banditParametersJsonString); | ||
// Record the model versions that we just loaded, so we can compare when the store is later | ||
// updated | ||
loadedBanditModelVersions.clear(); | ||
loadedBanditModelVersions.addAll(configurationStore.banditModelVersions()); | ||
} | ||
} | ||
|
||
private String requestBody(String route) { | ||
Response response = client.get(route); | ||
if (!response.isSuccessful() || response.body() == null) { | ||
throw new RuntimeException("Failed to fetch from " + route); | ||
} | ||
try { | ||
if (!response.isSuccessful()) { | ||
throw new RuntimeException("Failed to fetch configuration"); | ||
} | ||
configurationStore.setFlagsFromJsonString(response.body().string()); | ||
return response.body().string(); | ||
} catch (IOException e) { | ||
// TODO: better exception handling? | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
// TODO: async loading for android | ||
|
||
public FlagConfig getConfiguration(String flagKey) { | ||
return configurationStore.getFlag(flagKey); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed for mocking static methods (such as
evaluateBandit()
)