Skip to content

Commit

Permalink
Skip LLM codemods when no service is available
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Jul 11, 2024
1 parent af04044 commit 47220e4
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public static List<Class<? extends CodeChanger>> asList() {
JDBCResourceLeakCodemod.class,
JEXLInjectionCodemod.class,
JSPScriptletXSSCodemod.class,
// LogFailedLoginCodemod.class,
LimitReadlineCodemod.class,
MavenSecureURLCodemod.class,
OutputResourceLeakCodemod.class,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package io.codemodder.plugins.llm;

import com.google.inject.AbstractModule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Provides configured LLM services. */
public final class LLMServiceModule extends AbstractModule {

private static final String OPENAI_KEY_NAME = "CODEMODDER_OPENAI_API_KEY";
private static final String AZURE_OPENAI_KEY_NAME = "CODEMODDER_AZURE_OPENAI_API_KEY";
private static final String AZURE_OPENAI_ENDPOINT = "CODEMODDER_AZURE_OPENAI_ENDPOINT";
private static final Logger logger = LoggerFactory.getLogger(LLMServiceModule.class);

@Override
protected void configure() {
Expand All @@ -22,19 +25,19 @@ protected void configure() {
+ " must be set");
}
if (azureOpenAIKey != null) {
logger.info("Using Azure OpenAI service with endpoint {}", azureOpenAIEndpoint);
bind(OpenAIService.class)
.toProvider(() -> OpenAIService.fromAzureOpenAI(azureOpenAIKey, azureOpenAIEndpoint));
return;
}

bind(OpenAIService.class).toProvider(() -> OpenAIService.fromOpenAI(getOpenAIToken()));
}

private String getOpenAIToken() {
final var openAIKey = System.getenv(OPENAI_KEY_NAME);
if (openAIKey == null) {
throw new IllegalArgumentException(OPENAI_KEY_NAME + " environment variable must be set");
if (openAIKey != null) {
logger.info("Using OpenAI service");
bind(OpenAIService.class).toProvider(() -> OpenAIService.fromOpenAI(openAIKey));
}
return openAIKey;

logger.info("No LLM service available");
bind(OpenAIService.class).toProvider(OpenAIService::noServiceAvailable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class OpenAIService {
private final OpenAIClient api;
private static final int TIMEOUT_SECONDS = 90;
private final ModelMapper modelMapper;
private boolean serviceAvailable = true;

private static OpenAIClientBuilder builder(final KeyCredential key) {
HttpClientOptions clientOptions = new HttpClientOptions();
Expand All @@ -31,6 +32,12 @@ private static OpenAIClientBuilder builder(final KeyCredential key) {
.credential(key);
}

OpenAIService(final boolean serviceAvailable) {
this.serviceAvailable = serviceAvailable;
this.modelMapper = null;
this.api = null;
}

OpenAIService(final ModelMapper mapper, final KeyCredential key) {
this.modelMapper = mapper;
this.api = builder(key).buildClient();
Expand Down Expand Up @@ -66,6 +73,19 @@ public static OpenAIService fromAzureOpenAI(final String token, final String end
Objects.requireNonNull(endpoint));
}

public static OpenAIService noServiceAvailable() {
return new OpenAIService(false);
}

/**
* Returns whether the service is available.
*
* @return whether the service is available
*/
public boolean isServiceAvailable() {
return serviceAvailable;
}

/**
* Gets the completion for the given messages.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.codemodder.plugins.llm;

import io.codemodder.RuleSarif;
import io.codemodder.SarifPluginRawFileChanger;
import java.util.Objects;

public abstract class SarifPluginLLMCodemod extends SarifPluginRawFileChanger {
protected final OpenAIService openAI;

public SarifPluginLLMCodemod(RuleSarif sarif, final OpenAIService openAI) {
super(sarif);
this.openAI = Objects.requireNonNull(openAI);
}

@Override
public boolean shouldRun() {
return openAI.isServiceAvailable();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@
* </ol>
*/
public abstract class SarifToLLMForBinaryVerificationAndFixingCodemod
extends SarifPluginRawFileChanger {
extends SarifPluginLLMCodemod {

private final OpenAIService openAI;
private final Model model;

protected SarifToLLMForBinaryVerificationAndFixingCodemod(
final RuleSarif sarif, final OpenAIService openAI, final Model model) {
super(sarif);
this.openAI = Objects.requireNonNull(openAI);
super(sarif, openAI);
this.model = Objects.requireNonNull(model);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
* <p>To accomplish that, we need the analysis to "bucket" the code into one of the above
* categories.
*/
public abstract class SarifToLLMForMultiOutcomeCodemod extends SarifPluginRawFileChanger {
public abstract class SarifToLLMForMultiOutcomeCodemod extends SarifPluginLLMCodemod {

private static final Logger logger =
LoggerFactory.getLogger(SarifToLLMForMultiOutcomeCodemod.class);
private final OpenAIService openAI;
private final List<LLMRemediationOutcome> remediationOutcomes;
private final Model categorizationModel;
private final Model codeChangingModel;
Expand All @@ -65,8 +64,7 @@ protected SarifToLLMForMultiOutcomeCodemod(
final List<LLMRemediationOutcome> remediationOutcomes,
final Model categorizationModel,
final Model codeChangingModel) {
super(sarif);
this.openAI = Objects.requireNonNull(openAI);
super(sarif, openAI);
this.remediationOutcomes = Objects.requireNonNull(remediationOutcomes);
if (remediationOutcomes.size() < 2) {
throw new IllegalArgumentException("must have 2+ remediation outcome");
Expand All @@ -78,7 +76,7 @@ protected SarifToLLMForMultiOutcomeCodemod(
@Override
public CodemodFileScanningResult onFileFound(
final CodemodInvocationContext context, final List<Result> results) {
logger.info("processing: {}", context.path());
logger.debug("processing: {}", context.path());

List<CodemodChange> changes = new ArrayList<>();
for (Result result : results) {
Expand Down

0 comments on commit 47220e4

Please sign in to comment.