psa_storage.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """Knowledge about the PSA key store as implemented in Mbed TLS.
  2. Note that if you need to make a change that affects how keys are
  3. stored, this may indicate that the key store is changing in a
  4. backward-incompatible way! Think carefully about backward compatibility
  5. before changing how test data is constructed or validated.
  6. """
  7. # Copyright The Mbed TLS Contributors
  8. # SPDX-License-Identifier: Apache-2.0
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License"); you may
  11. # not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  18. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import re
  22. import struct
  23. from typing import Dict, List, Optional, Set, Union
  24. import unittest
  25. from . import c_build_helper
  26. class Expr:
  27. """Representation of a C expression with a known or knowable numerical value."""
  28. def __init__(self, content: Union[int, str]):
  29. if isinstance(content, int):
  30. digits = 8 if content > 0xffff else 4
  31. self.string = '{0:#0{1}x}'.format(content, digits + 2)
  32. self.value_if_known = content #type: Optional[int]
  33. else:
  34. self.string = content
  35. self.unknown_values.add(self.normalize(content))
  36. self.value_if_known = None
  37. value_cache = {} #type: Dict[str, int]
  38. """Cache of known values of expressions."""
  39. unknown_values = set() #type: Set[str]
  40. """Expressions whose values are not present in `value_cache` yet."""
  41. def update_cache(self) -> None:
  42. """Update `value_cache` for expressions registered in `unknown_values`."""
  43. expressions = sorted(self.unknown_values)
  44. values = c_build_helper.get_c_expression_values(
  45. 'unsigned long', '%lu',
  46. expressions,
  47. header="""
  48. #include <psa/crypto.h>
  49. """,
  50. include_path=['include']) #type: List[str]
  51. for e, v in zip(expressions, values):
  52. self.value_cache[e] = int(v, 0)
  53. self.unknown_values.clear()
  54. @staticmethod
  55. def normalize(string: str) -> str:
  56. """Put the given C expression in a canonical form.
  57. This function is only intended to give correct results for the
  58. relatively simple kind of C expression typically used with this
  59. module.
  60. """
  61. return re.sub(r'\s+', r'', string)
  62. def value(self) -> int:
  63. """Return the numerical value of the expression."""
  64. if self.value_if_known is None:
  65. if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
  66. return int(self.string, 0)
  67. normalized = self.normalize(self.string)
  68. if normalized not in self.value_cache:
  69. self.update_cache()
  70. self.value_if_known = self.value_cache[normalized]
  71. return self.value_if_known
  72. Exprable = Union[str, int, Expr]
  73. """Something that can be converted to a C expression with a known numerical value."""
  74. def as_expr(thing: Exprable) -> Expr:
  75. """Return an `Expr` object for `thing`.
  76. If `thing` is already an `Expr` object, return it. Otherwise build a new
  77. `Expr` object from `thing`. `thing` can be an integer or a string that
  78. contains a C expression.
  79. """
  80. if isinstance(thing, Expr):
  81. return thing
  82. else:
  83. return Expr(thing)
  84. class Key:
  85. """Representation of a PSA crypto key object and its storage encoding.
  86. """
  87. LATEST_VERSION = 0
  88. """The latest version of the storage format."""
  89. def __init__(self, *,
  90. version: Optional[int] = None,
  91. id: Optional[int] = None, #pylint: disable=redefined-builtin
  92. lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
  93. type: Exprable, #pylint: disable=redefined-builtin
  94. bits: int,
  95. usage: Exprable, alg: Exprable, alg2: Exprable,
  96. material: bytes #pylint: disable=used-before-assignment
  97. ) -> None:
  98. self.version = self.LATEST_VERSION if version is None else version
  99. self.id = id #pylint: disable=invalid-name #type: Optional[int]
  100. self.lifetime = as_expr(lifetime) #type: Expr
  101. self.type = as_expr(type) #type: Expr
  102. self.bits = bits #type: int
  103. self.usage = as_expr(usage) #type: Expr
  104. self.alg = as_expr(alg) #type: Expr
  105. self.alg2 = as_expr(alg2) #type: Expr
  106. self.material = material #type: bytes
  107. MAGIC = b'PSA\000KEY\000'
  108. @staticmethod
  109. def pack(
  110. fmt: str,
  111. *args: Union[int, Expr]
  112. ) -> bytes: #pylint: disable=used-before-assignment
  113. """Pack the given arguments into a byte string according to the given format.
  114. This function is similar to `struct.pack`, but with the following differences:
  115. * All integer values are encoded with standard sizes and in
  116. little-endian representation. `fmt` must not include an endianness
  117. prefix.
  118. * Arguments can be `Expr` objects instead of integers.
  119. * Only integer-valued elements are supported.
  120. """
  121. return struct.pack('<' + fmt, # little-endian, standard sizes
  122. *[arg.value() if isinstance(arg, Expr) else arg
  123. for arg in args])
  124. def bytes(self) -> bytes:
  125. """Return the representation of the key in storage as a byte array.
  126. This is the content of the PSA storage file. When PSA storage is
  127. implemented over stdio files, this does not include any wrapping made
  128. by the PSA-storage-over-stdio-file implementation.
  129. Note that if you need to make a change in this function,
  130. this may indicate that the key store is changing in a
  131. backward-incompatible way! Think carefully about backward
  132. compatibility before making any change here.
  133. """
  134. header = self.MAGIC + self.pack('L', self.version)
  135. if self.version == 0:
  136. attributes = self.pack('LHHLLL',
  137. self.lifetime, self.type, self.bits,
  138. self.usage, self.alg, self.alg2)
  139. material = self.pack('L', len(self.material)) + self.material
  140. else:
  141. raise NotImplementedError
  142. return header + attributes + material
  143. def hex(self) -> str:
  144. """Return the representation of the key as a hexadecimal string.
  145. This is the hexadecimal representation of `self.bytes`.
  146. """
  147. return self.bytes().hex()
  148. def location_value(self) -> int:
  149. """The numerical value of the location encoded in the key's lifetime."""
  150. return self.lifetime.value() >> 8
  151. class TestKey(unittest.TestCase):
  152. # pylint: disable=line-too-long
  153. """A few smoke tests for the functionality of the `Key` class."""
  154. def test_numerical(self):
  155. key = Key(version=0,
  156. id=1, lifetime=0x00000001,
  157. type=0x2400, bits=128,
  158. usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
  159. material=b'@ABCDEFGHIJKLMNO')
  160. expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
  161. self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
  162. self.assertEqual(key.hex(), expected_hex)
  163. def test_names(self):
  164. length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
  165. key = Key(version=0,
  166. id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
  167. type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
  168. usage=0, alg=0, alg2=0,
  169. material=b'\x00' * length)
  170. expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
  171. self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
  172. self.assertEqual(key.hex(), expected_hex)
  173. def test_defaults(self):
  174. key = Key(type=0x1001, bits=8,
  175. usage=0, alg=0, alg2=0,
  176. material=b'\x2a')
  177. expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
  178. self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
  179. self.assertEqual(key.hex(), expected_hex)