diff --git a/AuthServ/AuthManifest.h b/AuthServ/AuthManifest.h index fcbf0cf..388da81 100644 --- a/AuthServ/AuthManifest.h +++ b/AuthServ/AuthManifest.h @@ -27,6 +27,9 @@ namespace DS struct AuthFileInfo { AuthFileInfo() : m_fileSize() { } + AuthFileInfo(ST::string filename, uint32_t fileSize) + : m_filename(std::move(filename)), m_fileSize(fileSize) + { } AuthFileInfo(const AuthFileInfo&) = delete; AuthFileInfo& operator=(const AuthFileInfo&) = delete; @@ -48,6 +51,11 @@ namespace DS size_t fileCount() const { return m_files.size(); } + void addFile(ST::string filename, uint32_t fileSize) + { + m_files.emplace_back(std::move(filename), fileSize); + } + private: std::vector m_files; }; diff --git a/AuthServ/AuthServer.cpp b/AuthServ/AuthServer.cpp index 5e071b8..f81ccce 100755 --- a/AuthServ/AuthServer.cpp +++ b/AuthServ/AuthServer.cpp @@ -17,13 +17,17 @@ #include "AuthServer_Private.h" #include "AuthManifest.h" +#include "SDL/DescriptorDb.h" #include "Types/BitVector.h" #include "Types/Uuid.h" #include "settings.h" #include "errors.h" + #include #include #include +#include +#include #define NODE_SIZE_MAX (4 * 1024 * 1024) @@ -566,7 +570,33 @@ void cb_fileList(AuthServer_Private& client) ST::string mfsname = ST::format("{}{}_{}.list", DS::Settings::AuthRoot(), directory, fileext); DS::AuthManifest mfs; - DS::NetResultCode result = mfs.loadManifest(mfsname.c_str()); + DS::NetResultCode result = DS::e_NetPending; + + // Special case: SDL files + // For production shards, we expect for them to be listed in the secure preloader manifest. + // If that hasn't been done, don't worry about the SDL lists - just use the SDL files that + // DS would load on start up. + if (directory.compare_i("SDL") == 0 && fileext.compare_i("sdl") == 0) { + auto populateSdl = [&mfs](const ST::string& path) { + struct stat sbuf; + if (stat(path.c_str(), &sbuf) < 0) + throw DS::SystemError("[Auth] Unable to stat SDL file", strerror(errno)); + ST::string filename = path.after_last('/'); + mfs.addFile(ST::format("SDL\\{}", filename), sbuf.st_size); + return true; + }; + try { + SDL::DescriptorDb::ForDescriptorFiles(DS::Settings::SdlPath(), std::move(populateSdl)); + result = DS::e_NetSuccess; + } catch (const DS::SystemError& err) { + fputs(err.what(), stderr); + result = DS::e_NetInternalError; + } + } else { + result = mfs.loadManifest(mfsname.c_str()); + } + + DS_ASSERT(result != DS::e_NetPending); client.m_buffer.write(result); if (result != DS::e_NetSuccess) { @@ -594,6 +624,7 @@ void cb_downloadStart(AuthServer_Private& client) // Download filename ST::string filename = DS::CryptRecvString(client.m_sock, client.m_crypt); + filename = filename.replace("\\", "/"); // Ensure filename is jailed to our data path if (filename.find("..") != -1) { @@ -604,12 +635,18 @@ void cb_downloadStart(AuthServer_Private& client) SEND_REPLY(); return; } - filename = filename.replace("\\", "/"); - filename = DS::Settings::AuthRoot() + filename; - DS::FileStream* stream = new DS::FileStream(); + // Special case: SDL files from the server' SDL directory. + ST_ssize_t slashPos = filename.find_last('/'); + if (slashPos != -1 && filename.left(slashPos).compare_i("SDL") == 0 && filename.after_last('.').compare_i("sdl") == 0) { + filename = DS::Settings::SdlPath() + filename.substr(slashPos); + } else { + filename = DS::Settings::AuthRoot() + filename; + } + + auto fileStream = std::make_unique(); try { - stream->open(filename.c_str(), "rb"); + fileStream->open(filename.c_str(), "rb"); } catch (const DS::FileIOException& ex) { ST::printf(stderr, "[Auth] Could not open file {}: {}\n[Auth] Requested by {}\n", filename, ex.what(), DS::SockIpAddress(client.m_sock)); @@ -618,10 +655,31 @@ void cb_downloadStart(AuthServer_Private& client) client.m_buffer.write(0); // Chunk offset client.m_buffer.write(0); // Data packet size SEND_REPLY(); - delete stream; return; } + // All auth downloads must be encrypted. + std::unique_ptr stream; + if (!DS::EncryptedStream::CheckEncryption(fileStream.get()).has_value()) { + auto bufStream = std::make_unique(); + { + DS::EncryptedStream encStream(bufStream.get(), DS::EncryptedStream::Mode::e_write, + DS::EncryptedStream::Type::e_xxtea, + DS::Settings::DroidKey()); + uint8_t buf[CHUNK_SIZE]; + while (fileStream->tell() < fileStream->size()) { + ssize_t nread = fileStream->readBytes(buf, sizeof(buf)); + DS_ASSERT(nread >= 0); + encStream.writeBytes(buf, nread); + } + } + bufStream->seek(0, SEEK_SET); + stream = std::move(bufStream); + } else { + stream = std::move(fileStream); + } + + DS_ASSERT(stream); client.m_buffer.write(DS::e_NetSuccess); client.m_buffer.write(stream->size()); client.m_buffer.write(stream->tell()); @@ -631,12 +689,11 @@ void cb_downloadStart(AuthServer_Private& client) client.m_buffer.write(CHUNK_SIZE); stream->readBytes(data, CHUNK_SIZE); client.m_buffer.writeBytes(data, CHUNK_SIZE); - client.m_downloads[transId] = stream; + client.m_downloads[transId] = std::move(stream); } else { client.m_buffer.write(stream->size()); stream->readBytes(data, stream->size()); client.m_buffer.writeBytes(data, stream->size()); - delete stream; } SEND_REPLY(); @@ -670,7 +727,6 @@ void cb_downloadNext(AuthServer_Private& client) client.m_buffer.write(bytesLeft); fi->second->readBytes(data, bytesLeft); client.m_buffer.writeBytes(data, bytesLeft); - delete fi->second; client.m_downloads.erase(fi); } diff --git a/AuthServ/AuthServer_Private.h b/AuthServ/AuthServer_Private.h index bb4b084..520093b 100644 --- a/AuthServ/AuthServer_Private.h +++ b/AuthServ/AuthServer_Private.h @@ -26,6 +26,7 @@ #include #include #include +#include enum AuthServer_MsgIds { @@ -105,18 +106,9 @@ struct AuthServer_Private : public AuthClient_Private uint32_t m_acctFlags; AuthServer_PlayerInfo m_player; uint32_t m_ageNodeId; - std::map m_downloads; + std::map> m_downloads; AuthServer_Private() : m_serverChallenge(0), m_acctFlags(0), m_ageNodeId(0) { } - - ~AuthServer_Private() - { - while (!m_downloads.empty()) { - auto item = m_downloads.begin(); - delete item->second; - m_downloads.erase(item); - } - } }; extern std::list s_authClients; diff --git a/SDL/DescriptorDb.cpp b/SDL/DescriptorDb.cpp index f89514e..843f17b 100644 --- a/SDL/DescriptorDb.cpp +++ b/SDL/DescriptorDb.cpp @@ -30,46 +30,39 @@ static int sel_sdl(const dirent* de) SDL::DescriptorDb::descmap_t SDL::DescriptorDb::s_descriptors; -bool SDL::DescriptorDb::LoadDescriptors(const char* sdlpath) +bool SDL::DescriptorDb::LoadDescriptorsFromFile(const ST::string& filename) { - dirent** dirls; - int count = scandir(sdlpath, &dirls, &sel_sdl, &alphasort); - if (count < 0) { - ST::printf(stderr, "[SDL] Error reading SDL descriptors: {}\n", strerror(errno)); - return false; - } - if (count == 0) { - fputs("[SDL] Warning: No SDL descriptors found!\n", stderr); - free(dirls); - return true; - } - SDL::Parser parser; - for (int i=0; id_name); - if (parser.open(filename.c_str())) { - std::list descriptors = parser.parse(); - for (auto it = descriptors.begin(); it != descriptors.end(); ++it) { + if (parser.open(filename.c_str())) { + std::list descriptors = parser.parse(); + for (auto it = descriptors.begin(); it != descriptors.end(); ++it) { #ifdef DEBUG - descmap_t::iterator namei = s_descriptors.find(it->m_name); - if (namei != s_descriptors.end()) { - if (namei->second.find(it->m_version) != namei->second.end()) { - ST::printf(stderr, "[SDL] Warning: Duplicate descriptor version for {}\n", - it->m_name); - } + descmap_t::iterator namei = s_descriptors.find(it->m_name); + if (namei != s_descriptors.end()) { + if (namei->second.find(it->m_version) != namei->second.end()) { + ST::printf(stderr, "[SDL] Warning: Duplicate descriptor version for {}\n", + it->m_name); } + } #endif - s_descriptors[it->m_name][it->m_version] = *it; + s_descriptors[it->m_name][it->m_version] = *it; - // Keep the highest version in -1 - if (s_descriptors[it->m_name][-1].m_version < it->m_version) - s_descriptors[it->m_name][-1] = *it; - } + // Keep the highest version in -1 + if (s_descriptors[it->m_name][-1].m_version < it->m_version) + s_descriptors[it->m_name][-1] = *it; } - parser.close(); - free(dirls[i]); } - free(dirls); + return true; +} + +bool SDL::DescriptorDb::LoadDescriptors(const char* sdlpath) +{ + try { + ForDescriptorFiles(sdlpath, LoadDescriptorsFromFile); + } catch (const DS::SystemError& err) { + fputs(err.what(), stderr); + return false; + } return true; } @@ -124,3 +117,27 @@ bool SDL::DescriptorDb::ForLatestDescriptors(descfunc_t functor) return true; } +bool SDL::DescriptorDb::ForDescriptorFiles(const char* sdlpath, filefunc_t functor) +{ + dirent** dirls; + int count = scandir(sdlpath, &dirls, &sel_sdl, &alphasort); + + DS_ASSERT(count > 0); + if (count == 0) + fputs("[SDL] Warning: No SDL descriptors found!\n", stderr); + if (count < 0) + throw DS::SystemError("[SDL] Error scanning for SDL files", strerror(errno)); + + bool retval = true; + for (int i = 0; i < count; i++) { + if (!functor(ST::format("{}/{}", sdlpath, dirls[i]->d_name))) { + retval = false; + break; + } + } + + for (int i = 0; i < count; i++) + free(dirls[i]); + free(dirls); + return retval; +} diff --git a/SDL/DescriptorDb.h b/SDL/DescriptorDb.h index bf1ffbb..e9d3e85 100644 --- a/SDL/DescriptorDb.h +++ b/SDL/DescriptorDb.h @@ -100,17 +100,21 @@ namespace SDL { public: typedef std::function descfunc_t; + typedef std::function filefunc_t; static bool LoadDescriptors(const char* sdlpath); static StateDescriptor* FindDescriptor(const ST::string& name, int version); static StateDescriptor* FindLatestDescriptor(const ST::string& name); static bool ForLatestDescriptors(descfunc_t functor); + static bool ForDescriptorFiles(const char* sdlpath, filefunc_t functor); private: DescriptorDb() = delete; DescriptorDb(const DescriptorDb&) = delete; ~DescriptorDb() = delete; + static bool LoadDescriptorsFromFile(const ST::string& path); + typedef std::unordered_map versionmap_t; typedef std::unordered_map descmap_t; static descmap_t s_descriptors; diff --git a/SDL/SdlParser.cpp b/SDL/SdlParser.cpp index da03eb2..586a09a 100644 --- a/SDL/SdlParser.cpp +++ b/SDL/SdlParser.cpp @@ -16,6 +16,8 @@ ******************************************************************************/ #include "SdlParser.h" +#include "settings.h" +#include "streams.h" #include @@ -27,25 +29,35 @@ static const char* s_toknames[] = { "", "", "", "", }; -bool SDL::Parser::open(const char* filename) +DS::Stream* SDL::Parser::stream() const { - char sanitycheck[12]; + if (m_encStream) + return m_encStream; + else + return m_fileStream; +} - m_file = fopen(filename, "r"); - if (!m_file) { - ST::printf(stderr, "[SDL] Error opening file {} for reading\n", filename); +bool SDL::Parser::open(const char* filename) +{ + m_fileStream = new DS::FileStream(); + try { + m_fileStream->open(filename, "r"); + } catch (DS::FileIOException& ex) { + ST::printf(stderr, "[SDL] Error opening file {} for reading: {}\n", filename, ex.what()); + close(); return false; } - memset(sanitycheck, 0, sizeof(sanitycheck)); - fread(sanitycheck, 1, 12, m_file); - fseek(m_file, 0, SEEK_SET); - if (memcmp(sanitycheck, "whatdoyousee", 12) == 0 - || memcmp(sanitycheck, "notthedroids", 12) == 0 - || memcmp(sanitycheck, "BriceIsSmart", 12) == 0) { - fputs("[SDL] Error: DirtSand does not support encrypted SDL sources\n", stderr); - fputs("[SDL] Please decrypt your SDL files and re-start DirtSand\n", stderr); - ST::printf(stderr, "[SDL] Error in file: {}\n", filename); - return false; + + if (DS::EncryptedStream::CheckEncryption(m_fileStream).has_value()) { + try { + m_encStream = new DS::EncryptedStream(m_fileStream, DS::EncryptedStream::Mode::e_read); + } catch (DS::FileIOException& ex) { + ST::printf(stderr, "[SDL] Error opening file {} for reading: {}\n", filename, ex.what()); + close(); + return false; + } + if (m_encStream->getEncType() == DS::EncryptedStream::Type::e_xxtea) + m_encStream->setKeys(DS::Settings::DroidKey()); } m_filename = filename; @@ -54,6 +66,16 @@ bool SDL::Parser::open(const char* filename) return true; } +void SDL::Parser::close() +{ + delete m_encStream; + m_encStream = nullptr; + delete m_fileStream; + m_fileStream = nullptr; + m_filename.clear(); + m_lineno = -1; +} + static SDL::TokenType str_to_toktype(const ST::string& str) { if (str == "STATEDESC") @@ -112,7 +134,7 @@ SDL::Token SDL::Parser::next() while (m_buffer.empty()) { Token tokbuf; char lnbuf[4096]; - if (!fgets(reinterpret_cast(lnbuf), 4096, m_file)) { + if (!m_fileStream->readLine(lnbuf, sizeof(lnbuf))) { tokbuf.m_type = e_TokEof; m_buffer.push_back(tokbuf); break; diff --git a/SDL/SdlParser.h b/SDL/SdlParser.h index acda607..5996667 100644 --- a/SDL/SdlParser.h +++ b/SDL/SdlParser.h @@ -20,9 +20,15 @@ #include "DescriptorDb.h" #include "strings.h" -#include #include +namespace DS +{ + class EncryptedStream; + class FileStream; + class Stream; +} + namespace SDL { enum TokenType @@ -50,18 +56,11 @@ namespace SDL class Parser { public: - Parser() : m_file(), m_lineno(-1) { } + Parser() : m_fileStream(), m_encStream(), m_lineno(-1) { } ~Parser() { close(); } bool open(const char* filename); - void close() - { - if (m_file) - fclose(m_file); - m_file = nullptr; - m_filename.clear(); - m_lineno = -1; - } + void close(); const char* filename() const { return m_filename.c_str(); } @@ -71,10 +70,13 @@ namespace SDL std::list parse(); private: - FILE* m_file; + DS::FileStream* m_fileStream; + DS::EncryptedStream* m_encStream; ST::string m_filename; long m_lineno; std::list m_buffer; + + DS::Stream* stream() const; }; } diff --git a/Tests/CMakeLists.txt b/Tests/CMakeLists.txt index 63acdb6..ce0aa50 100644 --- a/Tests/CMakeLists.txt +++ b/Tests/CMakeLists.txt @@ -8,6 +8,7 @@ FetchContent_MakeAvailable(Catch2) set(test_SOURCES main.cpp + Test_EncryptedStream.cpp Test_Location.cpp Test_SDL.cpp Test_ShaHash.cpp diff --git a/Tests/Test_EncryptedStream.cpp b/Tests/Test_EncryptedStream.cpp new file mode 100644 index 0000000..afb9f6f --- /dev/null +++ b/Tests/Test_EncryptedStream.cpp @@ -0,0 +1,131 @@ +/****************************************************************************** + * This file is part of dirtsand. * + * * + * dirtsand is free software: you can redistribute it and/or modify * + * it under the terms of the GNU Affero General Public License as * + * published by the Free Software Foundation, either version 3 of the * + * License, or (at your option) any later version. * + * * + * dirtsand is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU Affero General Public License for more details. * + * * + * You should have received a copy of the GNU Affero General Public License * + * along with dirtsand. If not, see . * + ******************************************************************************/ + +#include "streams.h" + +#include +#include + +TEST_CASE("EncryptedStream known values", "[streams]") +{ + const char result[] = "Hello, world!"; + constexpr size_t resultsz = sizeof(result) - 1; + + SECTION("xxtea") { + uint32_t keys[] = { 0x31415926, 0x53589793, 0x23846264, 0x33832795 }; + uint8_t buff[] = { + 0x6E, 0x6F, 0x74, 0x74, 0x68, 0x65, 0x64, 0x72, 0x6F, 0x69, 0x64, 0x73, + 0x0D, 0x00, 0x00, 0x00, 0x93, 0xBD, 0x71, 0x93, 0xA4, 0x40, 0xC2, 0x6A, + 0x37, 0xD1, 0xA7, 0x9E, 0xEA, 0x93, 0x45, 0xC9 + }; + + DS::BufferStream base(buff, sizeof(buff)); + DS::EncryptedStream stream(&base, DS::EncryptedStream::Mode::e_read, std::nullopt, keys); + REQUIRE(stream.getEncType() == DS::EncryptedStream::Type::e_xxtea); + REQUIRE(stream.size() == resultsz); + + char test[sizeof(result)]; + stream.readBytes(test, resultsz); + test[sizeof(test) - 1] = 0; + CAPTURE(test); + REQUIRE(memcmp(test, result, resultsz) == 0); + + REQUIRE(stream.atEof()); + } + + SECTION("tea") { + uint8_t buff[] = { + 0x77, 0x68, 0x61, 0x74, 0x64, 0x6F, 0x79, 0x6F, 0x75, 0x73, 0x65, 0x65, + 0x0D, 0x00, 0x00, 0x00, 0xAC, 0xC1, 0xA6, 0xB6, 0xDC, 0x33, 0x95, 0x0E, + 0x99, 0x18, 0xAE, 0xFC, 0x9C, 0xD3, 0x00, 0xB9 + }; + + DS::BufferStream base(buff, sizeof(buff)); + DS::EncryptedStream stream(&base, DS::EncryptedStream::Mode::e_read); + REQUIRE(stream.getEncType() == DS::EncryptedStream::Type::e_tea); + REQUIRE(stream.size() == resultsz); + + char test[sizeof(result)]; + stream.readBytes(test, resultsz); + test[sizeof(test) - 1] = 0; + CAPTURE(test); + REQUIRE(memcmp(test, result, resultsz) == 0); + + REQUIRE(stream.atEof()); + } +} + +#define WRITE_STRING(_stream, _str) \ + _stream.writeBytes(_str, sizeof(_str) - 1); + +#define CHECK_STRING(_stream, _str) \ + { \ + constexpr size_t bufsz = sizeof(_str) - 1; \ + uint8_t buf[bufsz]; \ + _stream.readBytes(buf, bufsz); \ + REQUIRE(memcmp(buf, _str, bufsz) == 0); \ + } // + +TEST_CASE("EncryptedStream round-trip", "[streams]",) { + auto type = GENERATE( + DS::EncryptedStream::Type::e_xxtea, + DS::EncryptedStream::Type::e_tea + ); + + DS::BufferStream base; + { + DS::EncryptedStream stream(&base, DS::EncryptedStream::Mode::e_write, type); + WRITE_STRING(stream, "Small"); // Purposefully take up less than a full block + WRITE_STRING(stream, "!! "); // Complete the block from the previous write + WRITE_STRING(stream, "BlockSZ!"); // A full block + WRITE_STRING(stream, " ... And finally, something longer than a single block!"); + } + base.seek(0, SEEK_SET); + { + DS::EncryptedStream stream(&base, DS::EncryptedStream::Mode::e_read); + CHECK_STRING(stream, "Small"); + CHECK_STRING(stream, "!! "); + CHECK_STRING(stream, "BlockSZ!"); + CHECK_STRING(stream, " ... And finally, something longer than a single block!"); + } + base.seek(0, SEEK_SET); + { + DS::EncryptedStream stream(&base, DS::EncryptedStream::Mode::e_read); + CHECK_STRING(stream, "Small!! BlockSZ! ... And finally, something longer than a single block!"); + } +} + +TEST_CASE("EncryptedStream Magic Strings", "[streams]") +{ + SECTION("whatdoyousee") { + const char buffer[] { "whatdoyousee\x0\x0\x0\x0" }; + DS::BufferStream base(buffer, sizeof(buffer)); + REQUIRE(DS::EncryptedStream::CheckEncryption(&base) == DS::EncryptedStream::Type::e_tea); + } + + SECTION("BriceIsSmart") { + const char buffer[] { "BriceIsSmart\x0\x0\x0\x0" }; + DS::BufferStream base(buffer, sizeof(buffer)); + REQUIRE(DS::EncryptedStream::CheckEncryption(&base) == DS::EncryptedStream::Type::e_tea); + } + + SECTION("notthedroids") { + const char buffer[] { "notthedroids\x0\x0\x0\x0" }; + DS::BufferStream base(buffer, sizeof(buffer)); + REQUIRE(DS::EncryptedStream::CheckEncryption(&base) == DS::EncryptedStream::Type::e_xxtea); + } +} diff --git a/streams.cpp b/streams.cpp index e6c6396..fbe9542 100644 --- a/streams.cpp +++ b/streams.cpp @@ -22,6 +22,27 @@ #include #include +bool DS::Stream::readLine(void* buffer, size_t count) +{ + DS_ASSERT(count >= 1); + char* outp = reinterpret_cast(buffer); + char* endp = outp + count - 1; + bool eof = false; + + while (outp < endp) { + ssize_t nread = readBytes(outp, 1); + if (nread == 0) { + eof = true; + break; + } + char c = *outp++; + if (c == '\n') + break; + } + *outp = 0; + return !eof; +} + ST::string DS::Stream::readString(size_t length, DS::StringType format) { if (format == e_StringUTF16) { @@ -258,3 +279,254 @@ DS::Blob DS::HexDecode(const ST::string& value) ST::hex_decode(value, result, resultLen); return Blob::Steal(result, resultLen); } + +DS::EncryptedStream::EncryptedStream( + DS::Stream* base, DS::EncryptedStream::Mode mode, + std::optional type, const uint32_t* keys +) : m_base(base), m_buffer(), m_key(), m_pos(), m_size(), + m_type(type.has_value() ? type.value() : DS::EncryptedStream::Type::e_tea), + m_mode(mode) +{ + DS_ASSERT(base != nullptr); + DS_ASSERT(base->tell() == 0); + + static constexpr uint32_t kTeaKey[] { + 0x6c0a5452, + 0x03827d0f, + 0x3a170b92, + 0x16db7fc2 + }; + if (keys) { + memcpy(m_key, keys, sizeof(m_key)); + } else { + static_assert(sizeof(kTeaKey) == sizeof(m_key)); + memcpy(m_key, kTeaKey, sizeof(m_key)); + } + + uint8_t header[16]{}; + DS_ASSERT(!(!type.has_value() && mode == Mode::e_write)); + switch (mode) { + case Mode::e_read: + base->readBytes(header, sizeof(header)); + if (memcmp(header, "whatdoyousee", 12) == 0 || memcmp(header, "BriceIsSmart", 12) == 0) + m_type = DS::EncryptedStream::Type::e_tea; + else if (memcmp(header, "notthedroids", 12) == 0) + m_type = DS::EncryptedStream::Type::e_xxtea; + else + throw DS::FileIOException("Unknown EncryptedString magic"); + DS_ASSERT(!type.has_value() || type.value() == m_type); + m_size = reinterpret_cast(header)[3]; + break; + + case Mode::e_write: + // Write out some temporary junk for the header for now. + uint8_t header[16]{}; + base->writeBytes(header, sizeof(header)); + break; + } +} + +DS::EncryptedStream::~EncryptedStream() +{ + close(); +} + +void DS::EncryptedStream::close() +{ + if (m_base == nullptr) + return; + + if (m_mode == Mode::e_write) { + if (m_pos % sizeof(m_buffer) != 0) + cryptFlush(); + m_base->seek(0, SEEK_SET); + switch (m_type) { + case Type::e_xxtea: + m_base->writeBytes("notthedroids", 12); + break; + case Type::e_tea: + m_base->writeBytes("whatdoyousee", 12); + break; + } + m_base->write(m_size); + } + + m_base = nullptr; +} + +std::optional DS::EncryptedStream::CheckEncryption(const char* filename) +{ + DS::FileStream stream; + stream.open(filename, "r"); + return CheckEncryption(&stream); +} + +std::optional DS::EncryptedStream::CheckEncryption(DS::Stream* stream) +{ + uint32_t pos = stream->tell(); + if (pos != 0) + stream->seek(0, SEEK_SET); + uint8_t header[12]; + stream->readBytes(header, sizeof(header)); + stream->seek(pos, SEEK_SET); + + if (memcmp(header, "whatdoyousee", sizeof(header)) == 0 || memcmp(header, "BriceIsSmart", sizeof(header)) == 0) + return DS::EncryptedStream::Type::e_tea; + if (memcmp(header, "notthedroids", sizeof(header)) == 0) + return DS::EncryptedStream::Type::e_xxtea; + return std::nullopt; +} + +void DS::EncryptedStream::setKeys(const uint32_t* keys) +{ + memcpy(m_key, keys, sizeof(m_key)); +} + +void DS::EncryptedStream::xxteaDecipher(uint32_t* buf, uint32_t num) const +{ + uint32_t key = ((52 / num) + 6) * 0x9E3779B9; + while (key != 0) { + uint32_t xorkey = (key >> 2) & 3; + uint32_t numloop = num - 1; + while (numloop != 0) { + buf[numloop] -= + (((buf[numloop - 1] << 4) ^ (buf[numloop - 1] >> 3)) + + ((buf[numloop - 1] >> 5) ^ (buf[numloop - 1] << 2))) ^ + ((m_key[(numloop & 3) ^ xorkey] ^ buf[numloop - 1]) + + (key ^ buf[numloop - 1])); + numloop--; + } + buf[0] -= + (((buf[num - 1] << 4) ^ (buf[num - 1] >> 3)) + + ((buf[num - 1] >> 5) ^ (buf[num - 1] << 2))) ^ + ((m_key[(numloop & 3) ^ xorkey] ^ buf[num - 1]) + + (key ^ buf[num - 1])); + key += 0x61C88647; + } +} + +void DS::EncryptedStream::xxteaEncipher(uint32_t* buf, uint32_t num) const +{ + uint32_t key = 0; + uint32_t count = (52 / num) + 6; + while (count != 0) { + key -= 0x61C88647; + uint32_t xorkey = (key >> 2) & 3; + uint32_t numloop = 0; + while (numloop != num - 1) { + buf[numloop] += + (((buf[numloop + 1] << 4) ^ (buf[numloop + 1] >> 3)) + + ((buf[numloop + 1] >> 5) ^ (buf[numloop + 1] << 2))) ^ + ((m_key[(numloop & 3) ^ xorkey] ^ buf[numloop + 1]) + + (key ^ buf[numloop + 1])); + numloop++; + } + buf[num - 1] += + (((buf[0] << 4) ^ (buf[0] >> 3)) + + ((buf[0] >> 5) ^ (buf[0] << 2))) ^ + ((m_key[(numloop & 3) ^ xorkey] ^ buf[0]) + + (key ^ buf[0])); + count--; + } +} + +void DS::EncryptedStream::teaDecipher(uint32_t* buf) const +{ + uint32_t second = buf[1], first = buf[0], key = 0xC6EF3720; + + for (size_t i = 0; i < 32; i++) { + second -= (((first >> 5) ^ (first << 4)) + first) + ^ (m_key[(key >> 11) & 3] + key); + key -= 0x9E3779B9; + first -= (((second >> 5) ^ (second << 4)) + second) + ^ (m_key[key & 3] + key); + } + buf[0] = first; + buf[1] = second; +} + +void DS::EncryptedStream::teaEncipher(uint32_t* buf) const +{ + uint32_t first = buf[0], second = buf[1], key = 0; + + for (size_t i = 0; i < 32; i++) { + first += (((second >> 5) ^ (second << 4)) + second) + ^ (m_key[key & 3] + key); + key += 0x9E3779B9; + second += (((first >> 5) ^ (first << 4)) + first) + ^ (m_key[(key >> 11) & 3] + key); + } + buf[1] = second; + buf[0] = first; +} + +void DS::EncryptedStream::cryptFlush() +{ + switch (m_type) { + case Type::e_xxtea: + xxteaEncipher(reinterpret_cast(m_buffer), 2); + break; + case Type::e_tea: + teaEncipher(reinterpret_cast(m_buffer)); + break; + } + m_base->writeBytes(m_buffer, sizeof(m_buffer)); + memset(m_buffer, 0, sizeof(m_buffer)); +} + +ssize_t DS::EncryptedStream::readBytes(void* buffer, size_t count) +{ + if (m_mode != Mode::e_read) + throw FileIOException("EncryptedStream instance is not readable"); + + size_t bp = 0; + size_t lp = m_pos % sizeof(m_buffer); + while (bp < count) { + if (lp == 0) { + m_base->readBytes(m_buffer, sizeof(m_buffer)); + switch (m_type) { + case Type::e_xxtea: + xxteaDecipher(reinterpret_cast(m_buffer), 2); + break; + case Type::e_tea: + teaDecipher(reinterpret_cast(m_buffer)); + break; + } + } + if (lp + (count - bp) >= sizeof(m_buffer)) { + memcpy(reinterpret_cast(buffer) + bp, m_buffer + lp, sizeof(m_buffer) - lp); + bp += sizeof(m_buffer) - lp; + lp = 0; + } else { + memcpy(reinterpret_cast(buffer) + bp, m_buffer + lp, count - bp); + bp = count; + } + } + + m_pos += count; + return count; +} + +ssize_t DS::EncryptedStream::writeBytes(const void* buffer, size_t count) +{ + if (m_mode != Mode::e_write) + throw DS::FileIOException("EncryptedStream instance is not writeable"); + + size_t bp = 0; + size_t lp = m_pos % sizeof(m_buffer); + while (bp < count) { + if (lp + (count - bp) >= sizeof(m_buffer)) { + memcpy(m_buffer + lp, reinterpret_cast(buffer) + bp, sizeof(m_buffer) - lp); + bp += sizeof(m_buffer) - lp; + cryptFlush(); + lp = 0; + } else { + memcpy(m_buffer + lp, reinterpret_cast(buffer) + bp, count - bp); + bp = count; + } + } + + m_pos += count; + m_size = std::max(m_size, m_pos); + return count; +} diff --git a/streams.h b/streams.h index cf6f8ca..61c5b37 100644 --- a/streams.h +++ b/streams.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace DS { @@ -63,6 +64,8 @@ namespace DS return static_cast(value); } + bool readLine(void* buffer, size_t count); + ST::string readString(size_t length, DS::StringType format = e_StringRAW8); ST::string readSafeString(DS::StringType format = e_StringRAW8); @@ -272,6 +275,64 @@ namespace DS Blob Base64Decode(const ST::string& value); Blob HexDecode(const ST::string& value); + + class EncryptedStream : public Stream + { + public: + enum class Mode + { + e_read, + e_write, + }; + + enum class Type + { + e_xxtea, + e_tea, + }; + + protected: + Stream* m_base; + uint8_t m_buffer[8]; + uint32_t m_key[4]; + uint32_t m_pos; + uint32_t m_size; + Type m_type; + Mode m_mode; + + void xxteaDecipher(uint32_t* buf, uint32_t num) const; + void xxteaEncipher(uint32_t* buf, uint32_t num) const; + void teaDecipher(uint32_t* buf) const; + void teaEncipher(uint32_t* buf) const; + void cryptFlush(); + + public: + EncryptedStream(Stream* base, Mode mode, std::optional type = std::nullopt, const uint32_t* keys = nullptr); + ~EncryptedStream() override; + + EncryptedStream(const EncryptedStream&) = delete; + EncryptedStream(EncryptedStream&&) = delete; + + static std::optional CheckEncryption(const char* filename); + static std::optional CheckEncryption(DS::Stream* stream); + + void close(); + + Type getEncType() const { return m_type; } + void setKeys(const uint32_t* keys); + + ssize_t readBytes(void* buffer, size_t count) override; + ssize_t writeBytes(const void* buffer, size_t count) override; + + uint32_t tell() const override { return m_pos; } + void seek(int32_t offset, int whence) override { throw FileIOException("not supported"); } + uint32_t size() const override { return m_size; } + bool atEof() override { return m_pos == m_size; } + void flush() override { m_base->flush(); } + + EncryptedStream& operator =(const EncryptedStream& copy) = delete; + EncryptedStream& operator =(EncryptedStream&& move) = delete; + }; } #endif