updatesigntool.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #include "openmpt/all/BuildSettings.hpp"
  2. #include "mpt/base/span.hpp"
  3. #include "mpt/crypto/hash.hpp"
  4. #include "mpt/crypto/jwk.hpp"
  5. #include "mpt/environment/environment.hpp"
  6. #include "mpt/exception_text/exception_text.hpp"
  7. #include "mpt/io/base.hpp"
  8. #include "mpt/io/io.hpp"
  9. #include "mpt/io/io_stdstream.hpp"
  10. #include "mpt/out_of_memory/out_of_memory.hpp"
  11. #include "mpt/string/types.hpp"
  12. #include "mpt/string_transcode/transcode.hpp"
  13. #include "mpt/uuid/uuid.hpp"
  14. #include "mpt/uuid_namespace/uuid_namespace.hpp"
  15. #include "../common/mptBaseMacros.h"
  16. #include "../common/mptBaseTypes.h"
  17. #include "../common/mptBaseUtils.h"
  18. #include "../common/mptStringFormat.h"
  19. #include "../common/mptPathString.h"
  20. #include "../common/mptFileIO.h"
  21. #include "../common/Logging.h"
  22. #include "../common/misc_util.h"
  23. #include <exception>
  24. #include <iostream>
  25. #include <locale>
  26. #include <optional>
  27. #include <stdexcept>
  28. #include <vector>
  29. #if defined(MPT_BUILD_MSVC)
  30. #if MPT_COMPILER_MSVC || MPT_COMPILER_CLANG
  31. #pragma comment(lib, "bcrypt.lib")
  32. #pragma comment(lib, "ncrypt.lib")
  33. #pragma comment(lib, "rpcrt4.lib")
  34. #pragma comment(lib, "shlwapi.lib")
  35. #pragma comment(lib, "winmm.lib")
  36. #endif // MPT_COMPILER_MSVC || MPT_COMPILER_CLANG
  37. #endif // MPT_BUILD_MSVC
  38. OPENMPT_NAMESPACE_BEGIN
  39. using namespace mpt::uuid_literals;
  40. #if defined(MPT_ASSERT_HANDLER_NEEDED) && defined(MPT_BUILD_UPDATESIGNTOOL)
  41. MPT_NOINLINE void AssertHandler(const mpt::source_location &loc, const char *expr, const char *msg)
  42. {
  43. if(msg)
  44. {
  45. mpt::log::GlobalLogger().SendLogMessage(loc, LogError, "ASSERT",
  46. MPT_USTRING("ASSERTION FAILED: ") + mpt::transcode<mpt::ustring>(mpt::common_encoding::ascii, msg) + MPT_USTRING(" (") + mpt::transcode<mpt::ustring>(mpt::common_encoding::ascii, expr) + MPT_USTRING(")")
  47. );
  48. } else
  49. {
  50. mpt::log::GlobalLogger().SendLogMessage(loc, LogError, "ASSERT",
  51. MPT_USTRING("ASSERTION FAILED: ") + mpt::transcode<mpt::ustring>(mpt::common_encoding::ascii, expr)
  52. );
  53. }
  54. }
  55. #endif
  56. namespace updatesigntool {
  57. static mpt::ustring get_keyname(mpt::ustring keyname)
  58. {
  59. if(keyname == MPT_USTRING("auto"))
  60. {
  61. constexpr mpt::UUID ns = "9a88e12a-a132-4215-8bd0-3a002da65373"_uuid;
  62. mpt::ustring computername = mpt::getenv(MPT_USTRING("COMPUTERNAME")).value_or(MPT_USTRING(""));
  63. mpt::ustring username = mpt::getenv(MPT_USTRING("USERNAME")).value_or(MPT_USTRING(""));
  64. mpt::ustring name = MPT_UFORMAT("host={} user={}")(computername, username);
  65. mpt::UUID uuid = mpt::UUIDRFC4122NamespaceV5(ns, name);
  66. keyname = MPT_UFORMAT("OpenMPT Update Signing Key {}")(uuid);
  67. }
  68. return keyname;
  69. }
  70. static void main(const std::vector<mpt::ustring> &args)
  71. {
  72. try
  73. {
  74. if(args.size() < 2)
  75. {
  76. throw std::invalid_argument("Usage: updatesigntool [dumpkey|sign] ...");
  77. }
  78. if(args[1] == MPT_USTRING(""))
  79. {
  80. throw std::invalid_argument("Usage: updatesigntool [dumpkey|sign] ...");
  81. } else if(args[1] == MPT_USTRING("dumpkey"))
  82. {
  83. if(args.size() != 4)
  84. {
  85. throw std::invalid_argument("Usage: updatesigntool dumpkey KEYNAME FILENAME");
  86. }
  87. mpt::ustring keyname = get_keyname(args[2]);
  88. mpt::ustring filename = args[3];
  89. mpt::crypto::keystore keystore(mpt::crypto::keystore::domain::user);
  90. mpt::crypto::asymmetric::rsassa_pss<>::managed_private_key key(keystore, keyname);
  91. mpt::SafeOutputFile sfo(mpt::PathString::FromUnicode(filename));
  92. mpt::ofstream & fo = sfo.stream();
  93. mpt::IO::WriteText(fo, mpt::transcode<std::string>(mpt::common_encoding::utf8, key.get_public_key_data().as_jwk()));
  94. fo.flush();
  95. } else if(args[1] == MPT_USTRING("sign"))
  96. {
  97. if(args.size() != 6)
  98. {
  99. throw std::invalid_argument("Usage: updatesigntool sign [raw|jws_compact|jws] KEYNAME INPUTFILENAME OUTPUTFILENAME");
  100. }
  101. mpt::ustring mode = args[2];
  102. mpt::ustring keyname = get_keyname(args[3]);
  103. mpt::ustring inputfilename = args[4];
  104. mpt::ustring outputfilename = args[5];
  105. mpt::crypto::keystore keystore(mpt::crypto::keystore::domain::user);
  106. mpt::crypto::asymmetric::rsassa_pss<>::managed_private_key key(keystore, keyname);
  107. std::vector<std::byte> data;
  108. {
  109. mpt::ifstream fi(mpt::PathString::FromUnicode(inputfilename), std::ios::binary);
  110. fi.imbue(std::locale::classic());
  111. fi.exceptions(std::ios::badbit);
  112. while(!mpt::IO::IsEof(fi))
  113. {
  114. std::array<std::byte, mpt::IO::BUFFERSIZE_TINY> buf;
  115. mpt::append(data, mpt::IO::ReadRaw(fi, mpt::as_span(buf)));
  116. }
  117. }
  118. if(mode == MPT_USTRING(""))
  119. {
  120. throw std::invalid_argument("Usage: updatesigntool sign [raw|jws_compact|jws] KEYNAME INPUTFILENAME OUTPUTFILENAME");
  121. } else if(mode == MPT_USTRING("raw"))
  122. {
  123. std::vector<std::byte> signature = key.sign(mpt::as_span(data));
  124. mpt::SafeOutputFile sfo(mpt::PathString::FromUnicode(outputfilename));
  125. mpt::ofstream & fo = sfo.stream();
  126. mpt::IO::WriteRaw(fo, mpt::as_span(signature));
  127. fo.flush();
  128. } else if(mode == MPT_USTRING("jws_compact"))
  129. {
  130. mpt::ustring signature = key.jws_compact_sign(mpt::as_span(data));
  131. mpt::SafeOutputFile sfo(mpt::PathString::FromUnicode(outputfilename));
  132. mpt::ofstream & fo = sfo.stream();
  133. mpt::IO::WriteText(fo, mpt::transcode<std::string>(mpt::common_encoding::utf8, signature));
  134. fo.flush();
  135. } else if(mode == MPT_USTRING("jws"))
  136. {
  137. mpt::ustring signature = key.jws_sign(mpt::as_span(data));
  138. mpt::SafeOutputFile sfo(mpt::PathString::FromUnicode(outputfilename));
  139. mpt::ofstream & fo = sfo.stream();
  140. mpt::IO::WriteText(fo, mpt::transcode<std::string>(mpt::common_encoding::utf8, signature));
  141. fo.flush();
  142. } else
  143. {
  144. throw std::invalid_argument("Usage: updatesigntool sign [raw|jws_compact|jws] KEYNAME INPUTFILENAME OUTPUTFILENAME");
  145. }
  146. } else
  147. {
  148. throw std::invalid_argument("Usage: updatesigntool [dumpkey|sign] ...");
  149. }
  150. } catch(const std::exception &e)
  151. {
  152. std::cerr << mpt::get_exception_text<std::string>(e) << std::endl;
  153. throw;
  154. }
  155. }
  156. } // namespace updatesigntool
  157. OPENMPT_NAMESPACE_END
  158. #if defined(WIN32) && defined(UNICODE)
  159. int wmain(int argc, wchar_t *argv[])
  160. #else
  161. int main(int argc, char *argv[])
  162. #endif
  163. {
  164. std::locale::global(std::locale(""));
  165. std::vector<mpt::ustring> args;
  166. for(int arg = 0; arg < argc; ++arg)
  167. {
  168. #if defined(WIN32) && defined(UNICODE)
  169. args.push_back(mpt::transcode<mpt::ustring>(argv[arg]));
  170. #else
  171. args.push_back(mpt::transcode<mpt::ustring>(mpt::logical_encoding::locale, argv[arg]));
  172. #endif
  173. }
  174. try
  175. {
  176. OPENMPT_NAMESPACE::updatesigntool::main(args);
  177. } catch(...)
  178. {
  179. return 1;
  180. }
  181. return 0;
  182. }