bignum_common.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. """Common features for bignum in test generation framework."""
  2. # Copyright The Mbed TLS Contributors
  3. # SPDX-License-Identifier: Apache-2.0
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License"); you may
  6. # not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  13. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from abc import abstractmethod
  17. import enum
  18. from typing import Iterator, List, Tuple, TypeVar, Any
  19. from itertools import chain
  20. from . import test_case
  21. from . import test_data_generation
  22. from .bignum_data import INPUTS_DEFAULT, MODULI_DEFAULT
  23. T = TypeVar('T') #pylint: disable=invalid-name
  24. def invmod(a: int, n: int) -> int:
  25. """Return inverse of a to modulo n.
  26. Equivalent to pow(a, -1, n) in Python 3.8+. Implementation is equivalent
  27. to long_invmod() in CPython.
  28. """
  29. b, c = 1, 0
  30. while n:
  31. q, r = divmod(a, n)
  32. a, b, c, n = n, c, b - q*c, r
  33. # at this point a is the gcd of the original inputs
  34. if a == 1:
  35. return b
  36. raise ValueError("Not invertible")
  37. def invmod_positive(a: int, n: int) -> int:
  38. """Return a non-negative inverse of a to modulo n."""
  39. inv = invmod(a, n)
  40. return inv if inv >= 0 else inv + n
  41. def hex_to_int(val: str) -> int:
  42. """Implement the syntax accepted by mbedtls_test_read_mpi().
  43. This is a superset of what is accepted by mbedtls_test_read_mpi_core().
  44. """
  45. if val in ['', '-']:
  46. return 0
  47. return int(val, 16)
  48. def quote_str(val: str) -> str:
  49. return "\"{}\"".format(val)
  50. def bound_mpi(val: int, bits_in_limb: int) -> int:
  51. """First number exceeding number of limbs needed for given input value."""
  52. return bound_mpi_limbs(limbs_mpi(val, bits_in_limb), bits_in_limb)
  53. def bound_mpi_limbs(limbs: int, bits_in_limb: int) -> int:
  54. """First number exceeding maximum of given number of limbs."""
  55. bits = bits_in_limb * limbs
  56. return 1 << bits
  57. def limbs_mpi(val: int, bits_in_limb: int) -> int:
  58. """Return the number of limbs required to store value."""
  59. return (val.bit_length() + bits_in_limb - 1) // bits_in_limb
  60. def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
  61. """Return all pair combinations from input values."""
  62. return [(x, y) for x in values for y in values]
  63. def hex_digits_for_limb(limbs: int, bits_in_limb: int) -> int:
  64. """ Retrun the hex digits need for a number of limbs. """
  65. return 2 * (limbs * bits_in_limb // 8)
  66. class OperationCommon(test_data_generation.BaseTest):
  67. """Common features for bignum binary operations.
  68. This adds functionality common in binary operation tests.
  69. Attributes:
  70. symbol: Symbol to use for the operation in case description.
  71. input_values: List of values to use as test case inputs. These are
  72. combined to produce pairs of values.
  73. input_cases: List of tuples containing pairs of test case inputs. This
  74. can be used to implement specific pairs of inputs.
  75. unique_combinations_only: Boolean to select if test case combinations
  76. must be unique. If True, only A,B or B,A would be included as a test
  77. case. If False, both A,B and B,A would be included.
  78. input_style: Controls the way how test data is passed to the functions
  79. in the generated test cases. "variable" passes them as they are
  80. defined in the python source. "arch_split" pads the values with
  81. zeroes depending on the architecture/limb size. If this is set,
  82. test cases are generated for all architectures.
  83. arity: the number of operands for the operation. Currently supported
  84. values are 1 and 2.
  85. """
  86. symbol = ""
  87. input_values = INPUTS_DEFAULT # type: List[str]
  88. input_cases = [] # type: List[Any]
  89. unique_combinations_only = False
  90. input_styles = ["variable", "fixed", "arch_split"] # type: List[str]
  91. input_style = "variable" # type: str
  92. limb_sizes = [32, 64] # type: List[int]
  93. arities = [1, 2]
  94. arity = 2
  95. suffix = False # for arity = 1, symbol can be prefix (default) or suffix
  96. def __init__(self, val_a: str, val_b: str = "0", bits_in_limb: int = 32) -> None:
  97. self.val_a = val_a
  98. self.val_b = val_b
  99. # Setting the int versions here as opposed to making them @properties
  100. # provides earlier/more robust input validation.
  101. self.int_a = hex_to_int(val_a)
  102. self.int_b = hex_to_int(val_b)
  103. if bits_in_limb not in self.limb_sizes:
  104. raise ValueError("Invalid number of bits in limb!")
  105. if self.input_style == "arch_split":
  106. self.dependencies = ["MBEDTLS_HAVE_INT{:d}".format(bits_in_limb)]
  107. self.bits_in_limb = bits_in_limb
  108. @property
  109. def boundary(self) -> int:
  110. if self.arity == 1:
  111. return self.int_a
  112. elif self.arity == 2:
  113. return max(self.int_a, self.int_b)
  114. raise ValueError("Unsupported number of operands!")
  115. @property
  116. def limb_boundary(self) -> int:
  117. return bound_mpi(self.boundary, self.bits_in_limb)
  118. @property
  119. def limbs(self) -> int:
  120. return limbs_mpi(self.boundary, self.bits_in_limb)
  121. @property
  122. def hex_digits(self) -> int:
  123. return hex_digits_for_limb(self.limbs, self.bits_in_limb)
  124. def format_arg(self, val: str) -> str:
  125. if self.input_style not in self.input_styles:
  126. raise ValueError("Unknown input style!")
  127. if self.input_style == "variable":
  128. return val
  129. else:
  130. return val.zfill(self.hex_digits)
  131. def format_result(self, res: int) -> str:
  132. res_str = '{:x}'.format(res)
  133. return quote_str(self.format_arg(res_str))
  134. @property
  135. def arg_a(self) -> str:
  136. return self.format_arg(self.val_a)
  137. @property
  138. def arg_b(self) -> str:
  139. if self.arity == 1:
  140. raise AttributeError("Operation is unary and doesn't have arg_b!")
  141. return self.format_arg(self.val_b)
  142. def arguments(self) -> List[str]:
  143. args = [quote_str(self.arg_a)]
  144. if self.arity == 2:
  145. args.append(quote_str(self.arg_b))
  146. return args + self.result()
  147. def description(self) -> str:
  148. """Generate a description for the test case.
  149. If not set, case_description uses the form A `symbol` B, where symbol
  150. is used to represent the operation. Descriptions of each value are
  151. generated to provide some context to the test case.
  152. """
  153. if not self.case_description:
  154. if self.arity == 1:
  155. format_string = "{1:x} {0}" if self.suffix else "{0} {1:x}"
  156. self.case_description = format_string.format(
  157. self.symbol, self.int_a
  158. )
  159. elif self.arity == 2:
  160. self.case_description = "{:x} {} {:x}".format(
  161. self.int_a, self.symbol, self.int_b
  162. )
  163. return super().description()
  164. @property
  165. def is_valid(self) -> bool:
  166. return True
  167. @abstractmethod
  168. def result(self) -> List[str]:
  169. """Get the result of the operation.
  170. This could be calculated during initialization and stored as `_result`
  171. and then returned, or calculated when the method is called.
  172. """
  173. raise NotImplementedError
  174. @classmethod
  175. def get_value_pairs(cls) -> Iterator[Tuple[str, str]]:
  176. """Generator to yield pairs of inputs.
  177. Combinations are first generated from all input values, and then
  178. specific cases provided.
  179. """
  180. if cls.arity == 1:
  181. yield from ((a, "0") for a in cls.input_values)
  182. elif cls.arity == 2:
  183. if cls.unique_combinations_only:
  184. yield from combination_pairs(cls.input_values)
  185. else:
  186. yield from (
  187. (a, b)
  188. for a in cls.input_values
  189. for b in cls.input_values
  190. )
  191. else:
  192. raise ValueError("Unsupported number of operands!")
  193. @classmethod
  194. def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
  195. if cls.input_style not in cls.input_styles:
  196. raise ValueError("Unknown input style!")
  197. if cls.arity not in cls.arities:
  198. raise ValueError("Unsupported number of operands!")
  199. if cls.input_style == "arch_split":
  200. test_objects = (cls(a, b, bits_in_limb=bil)
  201. for a, b in cls.get_value_pairs()
  202. for bil in cls.limb_sizes)
  203. special_cases = (cls(*args, bits_in_limb=bil) # type: ignore
  204. for args in cls.input_cases
  205. for bil in cls.limb_sizes)
  206. else:
  207. test_objects = (cls(a, b)
  208. for a, b in cls.get_value_pairs())
  209. special_cases = (cls(*args) for args in cls.input_cases)
  210. yield from (valid_test_object.create_test_case()
  211. for valid_test_object in filter(
  212. lambda test_object: test_object.is_valid,
  213. chain(test_objects, special_cases)
  214. )
  215. )
  216. class ModulusRepresentation(enum.Enum):
  217. """Representation selector of a modulus."""
  218. # Numerical values aligned with the type mbedtls_mpi_mod_rep_selector
  219. INVALID = 0
  220. MONTGOMERY = 2
  221. OPT_RED = 3
  222. def symbol(self) -> str:
  223. """The C symbol for this representation selector."""
  224. return 'MBEDTLS_MPI_MOD_REP_' + self.name
  225. @classmethod
  226. def supported_representations(cls) -> List['ModulusRepresentation']:
  227. """Return all representations that are supported in positive test cases."""
  228. return [cls.MONTGOMERY, cls.OPT_RED]
  229. class ModOperationCommon(OperationCommon):
  230. #pylint: disable=abstract-method
  231. """Target for bignum mod_raw test case generation."""
  232. moduli = MODULI_DEFAULT # type: List[str]
  233. montgomery_form_a = False
  234. disallow_zero_a = False
  235. def __init__(self, val_n: str, val_a: str, val_b: str = "0",
  236. bits_in_limb: int = 64) -> None:
  237. super().__init__(val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb)
  238. self.val_n = val_n
  239. # Setting the int versions here as opposed to making them @properties
  240. # provides earlier/more robust input validation.
  241. self.int_n = hex_to_int(val_n)
  242. def to_montgomery(self, val: int) -> int:
  243. return (val * self.r) % self.int_n
  244. def from_montgomery(self, val: int) -> int:
  245. return (val * self.r_inv) % self.int_n
  246. def convert_from_canonical(self, canonical: int,
  247. rep: ModulusRepresentation) -> int:
  248. """Convert values from canonical representation to the given representation."""
  249. if rep is ModulusRepresentation.MONTGOMERY:
  250. return self.to_montgomery(canonical)
  251. elif rep is ModulusRepresentation.OPT_RED:
  252. return canonical
  253. else:
  254. raise ValueError('Modulus representation not supported: {}'
  255. .format(rep.name))
  256. @property
  257. def boundary(self) -> int:
  258. return self.int_n
  259. @property
  260. def arg_a(self) -> str:
  261. if self.montgomery_form_a:
  262. value_a = self.to_montgomery(self.int_a)
  263. else:
  264. value_a = self.int_a
  265. return self.format_arg('{:x}'.format(value_a))
  266. @property
  267. def arg_n(self) -> str:
  268. return self.format_arg(self.val_n)
  269. def format_arg(self, val: str) -> str:
  270. return super().format_arg(val).zfill(self.hex_digits)
  271. def arguments(self) -> List[str]:
  272. return [quote_str(self.arg_n)] + super().arguments()
  273. @property
  274. def r(self) -> int: # pylint: disable=invalid-name
  275. l = limbs_mpi(self.int_n, self.bits_in_limb)
  276. return bound_mpi_limbs(l, self.bits_in_limb)
  277. @property
  278. def r_inv(self) -> int:
  279. return invmod(self.r, self.int_n)
  280. @property
  281. def r2(self) -> int: # pylint: disable=invalid-name
  282. return pow(self.r, 2)
  283. @property
  284. def is_valid(self) -> bool:
  285. if self.int_a >= self.int_n:
  286. return False
  287. if self.disallow_zero_a and self.int_a == 0:
  288. return False
  289. if self.arity == 2 and self.int_b >= self.int_n:
  290. return False
  291. return True
  292. def description(self) -> str:
  293. """Generate a description for the test case.
  294. It uses the form A `symbol` B mod N, where symbol is used to represent
  295. the operation.
  296. """
  297. if not self.case_description:
  298. return super().description() + " mod {:x}".format(self.int_n)
  299. return super().description()
  300. @classmethod
  301. def input_cases_args(cls) -> Iterator[Tuple[Any, Any, Any]]:
  302. if cls.arity == 1:
  303. yield from ((n, a, "0") for a, n in cls.input_cases)
  304. elif cls.arity == 2:
  305. yield from ((n, a, b) for a, b, n in cls.input_cases)
  306. else:
  307. raise ValueError("Unsupported number of operands!")
  308. @classmethod
  309. def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
  310. if cls.input_style not in cls.input_styles:
  311. raise ValueError("Unknown input style!")
  312. if cls.arity not in cls.arities:
  313. raise ValueError("Unsupported number of operands!")
  314. if cls.input_style == "arch_split":
  315. test_objects = (cls(n, a, b, bits_in_limb=bil)
  316. for n in cls.moduli
  317. for a, b in cls.get_value_pairs()
  318. for bil in cls.limb_sizes)
  319. special_cases = (cls(*args, bits_in_limb=bil)
  320. for args in cls.input_cases_args()
  321. for bil in cls.limb_sizes)
  322. else:
  323. test_objects = (cls(n, a, b)
  324. for n in cls.moduli
  325. for a, b in cls.get_value_pairs())
  326. special_cases = (cls(*args) for args in cls.input_cases_args())
  327. yield from (valid_test_object.create_test_case()
  328. for valid_test_object in filter(
  329. lambda test_object: test_object.is_valid,
  330. chain(test_objects, special_cases)
  331. ))
  332. # BEGIN MERGE SLOT 1
  333. # END MERGE SLOT 1
  334. # BEGIN MERGE SLOT 2
  335. # END MERGE SLOT 2
  336. # BEGIN MERGE SLOT 3
  337. # END MERGE SLOT 3
  338. # BEGIN MERGE SLOT 4
  339. # END MERGE SLOT 4
  340. # BEGIN MERGE SLOT 5
  341. # END MERGE SLOT 5
  342. # BEGIN MERGE SLOT 6
  343. # END MERGE SLOT 6
  344. # BEGIN MERGE SLOT 7
  345. # END MERGE SLOT 7
  346. # BEGIN MERGE SLOT 8
  347. # END MERGE SLOT 8
  348. # BEGIN MERGE SLOT 9
  349. # END MERGE SLOT 9
  350. # BEGIN MERGE SLOT 10
  351. # END MERGE SLOT 10