From dfcfa49354e3048c91e1af9a845e149321cb2733 Mon Sep 17 00:00:00 2001 From: Colin McCloskey Date: Wed, 30 Aug 2023 15:02:01 +0200 Subject: [PATCH] Adding skeleton AWS Llama classes --- pom.xml | 5 + .../llm/AmazonRequestSignatureV4Utils.java | 135 ++++++++++++++++++ .../meta/chatbridge/llm/LlamaAWSHandler.java | 34 +++++ .../meta/chatbridge/llm/LlamaTokenizer.java | 44 ++++++ .../chatbridge/store/LLMContextManager.java | 21 +++ 5 files changed, 239 insertions(+) create mode 100644 src/main/java/com/meta/chatbridge/llm/AmazonRequestSignatureV4Utils.java create mode 100644 src/main/java/com/meta/chatbridge/llm/LlamaAWSHandler.java create mode 100644 src/main/java/com/meta/chatbridge/llm/LlamaTokenizer.java create mode 100644 src/main/java/com/meta/chatbridge/store/LLMContextManager.java diff --git a/pom.xml b/pom.xml index 758f43c..357e8a8 100644 --- a/pom.xml +++ b/pom.xml @@ -70,6 +70,11 @@ log4j-slf4j2-impl 2.20.0 + + com.knuddels + jtokkit + 0.6.1 + diff --git a/src/main/java/com/meta/chatbridge/llm/AmazonRequestSignatureV4Utils.java b/src/main/java/com/meta/chatbridge/llm/AmazonRequestSignatureV4Utils.java new file mode 100644 index 0000000..6032e2e --- /dev/null +++ b/src/main/java/com/meta/chatbridge/llm/AmazonRequestSignatureV4Utils.java @@ -0,0 +1,135 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.*; +import java.util.Map.Entry; +import java.util.stream.Collectors; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; + +/** + * Copyright 2020 Alex Vasiliev, licensed under the Apache 2.0 license: https://opensource.org/licenses/Apache-2.0 + */ +public class AmazonRequestSignatureV4Utils { + + /** + * Generates signing headers for HTTP request in accordance with Amazon AWS API Signature version 4 process. + *

+ * Following steps outlined here: docs.aws.amazon.com + *

+ * This method takes many arguments as read-only, but adds necessary headers to @{code headers} argument, which is a map. + * The caller should make sure those parameters are copied to the actual request object. + *

+ * The ISO8601 date parameter can be created by making a call to:
+ * - {@code java.time.format.DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'").format(ZonedDateTime.now(ZoneOffset.UTC))}
+ * or, if you prefer joda:
+ * - {@code org.joda.time.format.ISODateTimeFormat.basicDateTimeNoMillis().print(DateTime.now().withZone(DateTimeZone.UTC))} + * + * @param method - HTTP request method, (GET|POST|DELETE|PUT|...), e.g., {@link java.net.HttpURLConnection#getRequestMethod()} + * @param host - URL host, e.g., {@link java.net.URL#getHost()}. + * @param path - URL path, e.g., {@link java.net.URL#getPath()}. + * @param query - URL query, (parameters in sorted order, see the AWS spec) e.g., {@link java.net.URL#getQuery()}. + * @param headers - HTTP request header map. This map is going to have entries added to it by this method. Initially populated with + * headers to be included in the signature. Like often compulsory 'Host' header. e.g., {@link java.net.HttpURLConnection#getRequestProperties()}. + * @param body - The binary request body, for requests like POST. + * @param isoDateTime - The time and date of the request in ISO8601 basic format, see comment above. + * @param awsIdentity - AWS Identity, e.g., "AKIAJTOUYS27JPVRDUYQ" + * @param awsSecret - AWS Secret Key, e.g., "I8Q2hY819e+7KzBnkXj66n1GI9piV+0p3dHglAzQ" + * @param awsRegion - AWS Region, e.g., "us-east-1" + * @param awsService - AWS Service, e.g., "route53" + */ + public static void calculateAuthorizationHeaders( + String method, String host, String path, String query, Map headers, + byte[] body, + String isoDateTime, + String awsIdentity, String awsSecret, String awsRegion, String awsService + ) { + try { + String bodySha256 = hex(sha256(body)); + String isoJustDate = isoDateTime.substring(0, 8); // Cut the date portion of a string like '20150830T123600Z'; + + headers.put("Host", host); + headers.put("X-Amz-Content-Sha256", bodySha256); + headers.put("X-Amz-Date", isoDateTime); + + // (1) https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + List canonicalRequestLines = new ArrayList<>(); + canonicalRequestLines.add(method); + canonicalRequestLines.add(path); + canonicalRequestLines.add(query); + List hashedHeaders = new ArrayList<>(); + List headerKeysSorted = headers.keySet().stream().sorted(Comparator.comparing(e -> e.toLowerCase(Locale.US))).collect(Collectors.toList()); + for (String key : headerKeysSorted) { + hashedHeaders.add(key.toLowerCase(Locale.US)); + canonicalRequestLines.add(key.toLowerCase(Locale.US) + ":" + normalizeSpaces(headers.get(key))); + } + canonicalRequestLines.add(null); // new line required after headers + String signedHeaders = hashedHeaders.stream().collect(Collectors.joining(";")); + canonicalRequestLines.add(signedHeaders); + canonicalRequestLines.add(bodySha256); + String canonicalRequestBody = canonicalRequestLines.stream().map(line -> line == null ? "" : line).collect(Collectors.joining("\n")); + String canonicalRequestHash = hex(sha256(canonicalRequestBody.getBytes(StandardCharsets.UTF_8))); + + // (2) https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html + List stringToSignLines = new ArrayList<>(); + stringToSignLines.add("AWS4-HMAC-SHA256"); + stringToSignLines.add(isoDateTime); + String credentialScope = isoJustDate + "/" + awsRegion + "/" + awsService + "/aws4_request"; + stringToSignLines.add(credentialScope); + stringToSignLines.add(canonicalRequestHash); + String stringToSign = stringToSignLines.stream().collect(Collectors.joining("\n")); + + // (3) https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + byte[] kDate = hmac(("AWS4" + awsSecret).getBytes(StandardCharsets.UTF_8), isoJustDate); + byte[] kRegion = hmac(kDate, awsRegion); + byte[] kService = hmac(kRegion, awsService); + byte[] kSigning = hmac(kService, "aws4_request"); + String signature = hex(hmac(kSigning, stringToSign)); + + String authParameter = "AWS4-HMAC-SHA256 Credential=" + awsIdentity + "/" + credentialScope + ", SignedHeaders=" + signedHeaders + ", Signature=" + signature; + headers.put("Authorization", authParameter); + + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else { + throw new IllegalStateException(e); + } + } + } + + private static String normalizeSpaces(String value) { + return value.replaceAll("\\s+", " ").trim(); + } + + public static String hex(byte[] a) { + StringBuilder sb = new StringBuilder(a.length * 2); + for(byte b: a) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } + + private static byte[] sha256(byte[] bytes) throws Exception { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + digest.update(bytes); + return digest.digest(); + } + + public static byte[] hmac(byte[] key, String msg) throws Exception { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(key, "HmacSHA256")); + return mac.doFinal(msg.getBytes(StandardCharsets.UTF_8)); + } + +} diff --git a/src/main/java/com/meta/chatbridge/llm/LlamaAWSHandler.java b/src/main/java/com/meta/chatbridge/llm/LlamaAWSHandler.java new file mode 100644 index 0000000..eca1605 --- /dev/null +++ b/src/main/java/com/meta/chatbridge/llm/LlamaAWSHandler.java @@ -0,0 +1,34 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import com.meta.chatbridge.message.Message; +import com.meta.chatbridge.store.LLMContextManager; +import com.meta.chatbridge.store.MessageStack; + +public class LlamaAWSHandler implements LLMHandler { + + private final LLMContextManager context; + + public LlamaAWSHandler(LLMContextManager context) { + this.context = context; + } + + @Override + public Message handle(MessageStack messageStack) { +// Take history +// get number of tokens and truncate as needed +// Pass to LLM +// Return response + Message message = (Message) messageStack.messages().get(0); + + return null; + + } +} diff --git a/src/main/java/com/meta/chatbridge/llm/LlamaTokenizer.java b/src/main/java/com/meta/chatbridge/llm/LlamaTokenizer.java new file mode 100644 index 0000000..f2edbd0 --- /dev/null +++ b/src/main/java/com/meta/chatbridge/llm/LlamaTokenizer.java @@ -0,0 +1,44 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingRegistry; +import com.knuddels.jtokkit.api.EncodingType; + +import java.util.*; + +public class LlamaTokenizer { + + private final int MAX_TOKENS = 4096; + private final int MAX_RESPONSE_TOKENS = 1024; + private final int MAX_CONTEXT_TOKENS = 1536; + + private final Encoding tokenizer; + + public LlamaTokenizer() { + EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); + tokenizer = registry.getEncoding(EncodingType.CL100K_BASE); + } + + + /** + * Returns a count of the tokens from the context string + * + * @param contextString The context documents + */ + private int getContextString(String contextString) { + var tokenCount = 0; + + tokenCount += tokenizer.encode(contextString + "\n---\n").size(); // Set the rest of the message here as system, etc + + return tokenCount; + } +} diff --git a/src/main/java/com/meta/chatbridge/store/LLMContextManager.java b/src/main/java/com/meta/chatbridge/store/LLMContextManager.java new file mode 100644 index 0000000..c9f32ea --- /dev/null +++ b/src/main/java/com/meta/chatbridge/store/LLMContextManager.java @@ -0,0 +1,21 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.store; + +public class LLMContextManager { + private static String context = ""; + + public static void setContext(String newContext) { + context = newContext; + } + + public static String getContext() { + return context; + } +}