diff --git a/.gitignore b/.gitignore index 6dae5f80..98261f14 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,8 @@ __pycache__/ *.prof # C extensions *.so +tmp/ +algo_arch_implem_v1_train/ # Distribution / packaging .Python diff --git a/notebooks/dataset_api.ipynb b/notebooks/dataset_api.ipynb new file mode 100644 index 00000000..9d505494 --- /dev/null +++ b/notebooks/dataset_api.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from scaaml.io import Dataset\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# specify dummy data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "root_path = './'\n", + "architecture = 'arch'\n", + "implementation = 'implem'\n", + "algorithm = 'algo'\n", + "version = 1\n", + "minfo = {\n", + " \"trace1\": {\n", + " \"type\": \"power\",\n", + " \"len\": 1024,\n", + " }\n", + "}\n", + "\n", + "apinfo = {\n", + " \"key\": {\n", + " \"len\": 16,\n", + " \"max_val\": 256\n", + " }\n", + "}\n", + "chip_id = 1 # which chip this was captured on\n", + "comment= \"this is a test\" \n", + "purpose = \"train\"\n", + "example_per_shard = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## generate fake data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "key = np.random.randint(0, 255, 16)\n", + "key2 = np.random.randint(0, 255, 16)\n", + "trace1 = np.random.rand(1024)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "## init" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33m[Warning] Path exist, some files might be over-written\u001b[0m\n", + "\u001b[32mDataset path: algo_arch_implem_v1_train\u001b[0m\n" + ] + } + ], + "source": [ + "ds = Dataset(root_path=root_path,\n", + " architecture=architecture,\n", + " implementation=implementation,\n", + " algorithm=algorithm,\n", + " version=version,\n", + " purpose=purpose,\n", + " comment=comment,\n", + " chip_id=chip_id,\n", + " examples_per_shard=example_per_shard,\n", + " measurements_info=minfo,\n", + " attack_points_info=apinfo)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## writing shard" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train\n", + "defaultdict(, {})\n" + ] + } + ], + "source": [ + "ds.new_shard(key, 1, split='train')\n", + "ds.write_example({\"key\": key,\n", + " #\"sub_byte_in\": key\n", + " }, {\"trace1\": trace1})\n", + "ds.close_shard()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using a dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path = './algo_arch_implem_v1_train'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# display dataset info" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[36m[Dataset Summary]\u001b[0m\n", + "\u001b[33mInfo\u001b[0m\n", + "-------------- --------------\n", + "architecture arch\n", + "implementation implem\n", + "algorithm algo\n", + "version 1\n", + "chip_id 1\n", + "comment this is a test\n", + "purpose train\n", + "compression GZIP\n", + "-------------- --------------\n", + "\u001b[33m\n", + "Attack Points\u001b[0m\n", + "ap len max_val\n", + "---- ----- ---------\n", + "key 16 256\n", + "\u001b[35m\n", + "Measurements\u001b[0m\n", + "name type len\n", + "------ ------ -----\n", + "trace1 power 1024\n", + "\u001b[32m\n", + "Content\u001b[0m\n", + "split num_keys num_examples\n", + "------- ---------- --------------\n", + "train 1 1\n" + ] + } + ], + "source": [ + "\n", + "Dataset.summary(dataset_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[35mreloading algo_arch_implem_v1_train\\info.json\u001b[0m\n", + "\u001b[32mDataset path: algo_arch_implem_v1_train\\algo_arch_implem_v1_train\u001b[0m\n" + ] + } + ], + "source": [ + "trace_len = 1024\n", + "\n", + "train_ds, inputs, outputs = Dataset.as_tfdataset(dataset_path, \n", + " split='train', \n", + " attack_points='key',\n", + " traces='trace1',\n", + " traces_max_len=trace_len,\n", + " bytes=1,\n", + " shards=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x -> dict_keys(['trace1']) (32, 32, 32)\n", + "y -> dict_keys(['key_1']) (32, 256)\n" + ] + } + ], + "source": [ + "for batch in train_ds.take(1):\n", + " print('x ->', batch[0].keys(), batch[0]['trace1'].shape)\n", + " print('y ->', batch[1].keys(), batch[1]['key_1'].shape)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "f4b45c82ab6242cd8608f401505676f0fc37e99e72925447dbc1e4dcd37ea533" + }, + "kernelspec": { + "display_name": "Python 3.8.10 64-bit ('venv': venv)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/scaaml/io/README.md b/scaaml/io/README.md new file mode 100644 index 00000000..e69de29b diff --git a/scaaml/io/__init__.py b/scaaml/io/__init__.py new file mode 100644 index 00000000..50344402 --- /dev/null +++ b/scaaml/io/__init__.py @@ -0,0 +1 @@ +from .dataset import Dataset # noqa diff --git a/scaaml/io/dataset.py b/scaaml/io/dataset.py new file mode 100644 index 00000000..9458b056 --- /dev/null +++ b/scaaml/io/dataset.py @@ -0,0 +1,480 @@ +"Build and load tensorFlow dataset Record wrapper" +import math +import json +import os +import tensorflow as tf +from typing import Dict, List, Union +from pathlib import Path +from termcolor import cprint +from collections import defaultdict +from tqdm.auto import tqdm +from tabulate import tabulate +from scaaml.utils import bytelist_to_hex +from time import time +from .utils import sha256sum +from .shard import Shard + + +class Dataset(): + def __init__( + self, + root_path: str, + architecture: str, + implementation: str, + algorithm: str, + version: int, + chip_id: int, + purpose: str, + comment: str, + examples_per_shard: int, + measurements_info: Dict, + attack_points_info: Dict, + compression: str = "GZIP", + shards_list: defaultdict = None, + keys_per_split: defaultdict = None, + examples_per_split: defaultdict = None, + capture_info: dict = {}, + min_values: Dict[str, int] = {}, + max_values: Dict[str, int] = {}, + ) -> None: + self.root_path = root_path + self.architecture = architecture + self.implementation = implementation + self.algorithm = algorithm + self.version = version + self.compression = compression + self.chip_id = chip_id + self.purpose = purpose + self.comment = comment + + self.capture_info = capture_info + self.measurements_info = measurements_info + self.attack_points_info = attack_points_info + + if purpose not in ['train', 'holdout']: + raise ValueError("Invalid purpose", purpose) + + # create directory -- check if its empty + self.slug = "%s_%s_%s_v%s_%s" % (algorithm, architecture, + implementation, version, purpose) + self.path = Path(self.root_path) / self.slug + if self.path.exists(): + cprint("[Warning] Path exist, some files might be over-written", + 'yellow') + else: + # create path if needed + self.path.mkdir(parents=True) + Path(self.path / 'train').mkdir() + Path(self.path / 'test').mkdir() + Path(self.path / 'holdout').mkdir() + + cprint("Dataset path: %s" % self.path, 'green') + + # current shard tracking + self.curr_shard_key = None # current shard_key + self.shard_key = None + self.shard_path = None + self.shard_split = None + self.shard_part = None + self.shard_relative_path = None # for the shardlist + self.curr_shard = None # current_ shard object + + # counters - must be passed as param to allow reload. + self.shards_list = shards_list or defaultdict(list) + self.keys_per_split = keys_per_split or defaultdict(int) + self.examples_per_split = examples_per_split or defaultdict(int) + self.examples_per_shard = examples_per_shard + self.min_values = min_values + self.max_values = max_values + for k in measurements_info.keys(): + # init only if not existing + if k not in min_values: + self.min_values[k] = math.inf + self.max_values[k] = 0 + + # write config + self._write_config() + + def new_shard(self, key: list, part: int, split: str): + """Initiate a new key + + Args: + key: the key that was used to create the measurements. + + part: Indicate which part of a given key set of catpure this + shard represent. Capture are splitted into parts to easily + allow to restrict the number of traces used per key. + + split: the split the shard belongs to {train, test, holdout} + """ + # finalize previous shard if need + if self.curr_shard: + self.close_shard() + + if split not in ['train', 'test', 'holdout']: + raise ValueError("Invalid split, must be: {train, test, holdout}") + + if part < 1 or part > 10: + raise ValueError("Invalid part value -- muse be in [1, 10]") + + self.shard_split = split + self.shard_part = part + self.shard_key = bytelist_to_hex(key, spacer='') + + # shard name + fname = "%s_%s.tfrec" % (self.shard_key, self.shard_part) + self.shard_relative_path = "%s/%s" % (split, fname) + self.shard_path = str(self.path / self.shard_relative_path) + + # new shard + self.curr_shard = Shard(self.shard_path, + attack_points_info=self.attack_points_info, + measurements_info=self.measurements_info, + compression=self.compression) + + def write_example(self, attack_points: Dict, measurement: Dict): + self.curr_shard.write(attack_points, measurement) + + def close_shard(self): + # close the shard + + stats = self.curr_shard.close() + + # update min/max values + for k, v in stats['min_values'].items(): + self.min_values[k] = min(self.min_values[k], v) + + for k, v in stats['max_values'].items(): + self.max_values[k] = max(self.max_values[k], v) + + # update stats + + self.examples_per_split[self.shard_split] += stats['examples'] + print(self.shard_split) + print(self.keys_per_split) + self.keys_per_split[self.shard_split] += 1 + + # record in shardlist + self.shards_list[self.shard_split].append({ + "path": str(self.shard_relative_path), + "examples": stats['examples'], + "sha256": sha256sum(self.shard_path), + "key": self.shard_key + }) + + # update config + self._write_config() + self.curr_shard = None + + @staticmethod + def as_tfdataset(dataset_path: str, + split: str, + attack_points: Union[List, str], + traces: Union[List, str], + bytes: Union[List, int], + shards: int, + traces_max_len: int = None, + trace_block_size: int = 1, + batch_size: int = 32, + prefetch: int = 10, + file_parallelism: int = 1, + parallelism: int = os.cpu_count(), + shuffle: int = 1000 + ) -> Union[tf.data.Dataset, Dict, Dict]: + """"Dataset as tfdataset + """ + + trace_seq_len = traces_max_len // trace_block_size + if traces_max_len % trace_block_size: + raise ValueError("trace_max_len must be a multiple of trace_block_size") + + # boxing + if isinstance(traces, str): + traces = [traces] + if isinstance(bytes, int): + bytes = [bytes] + if isinstance(attack_points, str): + attack_points = [attack_points] + + # loading info + dpath = Path(dataset_path) + dataset = Dataset.from_config(dataset_path) + + if split not in dataset.keys_per_split: + raise ValueError("Unknown split -- see Dataset.summary() for list") + + # TF_FEATURES construction: must contains all features and be global + tf_features = {} # what is decoded + for name, ipt in dataset.measurements_info.items(): + tf_features[name] = tf.io.FixedLenFeature([ipt['len']], tf.float32) + for name, ap in dataset.attack_points_info.items(): + tf_features[name] = tf.io.FixedLenFeature([ap['len']], tf.int64) + + # decoding funtion + def from_tfrecord(tfrecord): + rec = tf.io.parse_single_example(tfrecord, tf_features) + return rec + + # inputs construction + inputs = {} # model inputs + for name in traces: + ipt = dataset.measurements_info[name] + inputs[name] = ipt + + inputs[name]['min'] = tf.constant(dataset.min_values[name]) + inputs[name]['max'] = tf.constant(dataset.max_values[name]) + inputs[name]['delta'] = tf.constant(inputs[name]['max'] - inputs[name]['min']) + + # output construction + outputs = {} # model outputs + for name in attack_points: + for b in bytes: + n = "%s_%s" % (name, b) + ap = dataset.attack_points_info[name] + outputs[n] = ap + outputs[n]['ap'] = name + outputs[n]['byte'] = b + + # processing function + # @tf.function + def process_record(rec): + "process the tf record to get it ready for learning" + x = {} + # normalize the traces + for name, data in inputs.items(): + trace = rec[name] + + # truncate if needed + if traces_max_len: + trace = trace[:traces_max_len] + + # rescale + trace = 2 * ((trace - data['min']) / (data['delta'])) - 1 + + # reshape + trace = tf.reshape(trace, (trace_seq_len, trace_block_size)) + + # assign + x[name] = trace + + # one_hot the outptut for each ap/byte + y = {} + for name, data in outputs.items(): + v = tf.one_hot(rec[data['ap']][data['byte']], data['max_val']) + y[name] = v + + return (x, y) + + # collect and truncate shard list of a given split + # this is done prior to anything to allow to only download the nth + # first shards + shards_list = dataset.shards_list[split] + if shards: + shards_list = shards_list[:shards] + shards_paths = [str(dpath / s['path']) for s in shards_list] + num_shards = len(shards_paths) + # print(shards_paths) + # dataset creation + # with tf.device('/cpu:0'): + # shuffle the shard order + ds = tf.data.Dataset.from_tensor_slices(shards_paths) + + # shuffle shard order + ds = ds.shuffle(num_shards) + # This is the tricky part, we are using the interleave function to + # do the sampling as requested by the user. This is not the + # standard use of the function or an obvious way to do it but + # its by far the faster and more compatible way to do so + # we are favoring for once those factors over readability + # deterministic=False is not an error, it is what allows us to + # create random batch + ds = ds.interleave( + lambda x: tf.data.TFRecordDataset(x, + compression_type=dataset.compression), # noqa + cycle_length=num_shards, + block_length=1, + num_parallel_calls=file_parallelism, + deterministic=False) + # decode to records + ds = ds.map(from_tfrecord, num_parallel_calls=parallelism) + # process them + ds = ds.map(process_record, num_parallel_calls=parallelism) + + # # randomize + ds = ds.shuffle(shuffle) + + # # batching with repeat + ds = ds.repeat() + ds = ds.batch(batch_size) + ds = ds.prefetch(prefetch) + + return ds, inputs, outputs + + @staticmethod + def summary(dataset_path): + """Print a summary of the dataset""" + lst = [ + 'architecture', 'implementation', 'algorithm', 'version', + 'chip_id', 'comment', 'purpose', 'compression' + ] + fpath = Dataset._get_config_path(dataset_path) + config = json.loads(open(fpath).read()) + cprint("[Dataset Summary]", 'cyan') + cprint("Info", 'yellow') + print(tabulate([[k, config[k]] for k in lst])) + + cprint("\nAttack Points", 'yellow') + d = [[k, v['len'], v['max_val']] + for k, v in config['attack_points_info'].items()] + print(tabulate(d, headers=['ap', 'len', 'max_val'])) + + cprint("\nMeasurements", 'magenta') + d = [[k, v['type'], v['len']] + for k, v in config['measurements_info'].items()] + print(tabulate(d, headers=['name', 'type', 'len'])) + + cprint("\nContent", 'green') + d = [] + for split in config['keys_per_split'].keys(): + d.append([ + split, + config['keys_per_split'][split], + config['examples_per_split'][split], + ]) + print(tabulate(d, ['split', 'num_keys', 'num_examples'])) + + @staticmethod + def inspect(dataset_path, split, shard_id, num_example): + """Display the content of a given shard""" + fpath = Dataset._get_config_path(dataset_path) + config = json.loads(open(fpath).read()) + spath = str(Path(fpath) / config['shards_list'][split][shard_id]['path']) + cprint("Reading shard %s" % spath, 'cyan') + s = Shard(spath, + attack_points_info=config['attack_points_info'], + measurements_info=config['measurements_info'], + compression=config['compression']) + data = s.read(num=num_example) + print(data) + return(data) + + def check(self): + """Check the dataset integrity""" + # check examples are balances + seen_keys = {} # use to ensure keys are not reused + + for split, expected_examples in self.examples_per_split.items(): + slist = self.shards_list[split] + # checking we have the rigt number of shards + if len(slist) != self.keys_per_split[split]: + raise ValueError("Num shards in shard_list != self.shards") + + pb = tqdm(total=len(slist), desc="Checking %s split" % split) + actual_examples = 0 + for sinfo in slist: + + # no key reuse + if sinfo['key'] in seen_keys: + raise ValueError("Duplicate key", sinfo['key']) + else: + seen_keys[sinfo['key']] = 1 + + actual_examples += sinfo['examples'] + shard_path = self.path / sinfo['path'] + sh = sha256sum(shard_path) + if sh != sinfo['sha256']: + raise ValueError(sinfo['path'], "SHA256 miss-match") + pb.update() + + pb.close() + + if actual_examples != expected_examples: + raise ValueError("sum example don't match top_examples") + + def _write_config(self): + config = { + "architecture": self.architecture, + "implementation": self.implementation, + "algorithm": self.algorithm, + "version": self.version, + "chip_id": self.chip_id, + "comment": self.comment, + "purpose": self.purpose, + "compression": self.compression, + "shards_list": self.shards_list, + "keys_per_split": self.keys_per_split, + "examples_per_shard": self.examples_per_shard, + "examples_per_split": self.examples_per_split, + "capture_info": self.capture_info, + "measurements_info": self.measurements_info, + "attack_points_info": self.attack_points_info, + "min_values": self.min_values, + "max_values": self.max_values, + } + + with open(self._get_config_path(self.path), 'w+') as o: + o.write(json.dumps(config)) + + @staticmethod + def from_config(dataset_path: str): + dpath = Path(dataset_path) + fpath = Dataset._get_config_path(dataset_path) + cprint("reloading %s" % fpath, 'magenta') + config = json.loads(open(fpath).read()) + return Dataset( + root_path=str(dpath), + architecture=config['architecture'], + implementation=config['implementation'], + algorithm=config['algorithm'], + version=config['version'], + comment=config['comment'], + purpose=config['purpose'], + chip_id=config['chip_id'], + measurements_info=config['measurements_info'], + attack_points_info=config['attack_points_info'], + capture_info=config['capture_info'], + compression=config['compression'], + shards_list=config['shards_list'], + keys_per_split=config['keys_per_split'], + examples_per_split=config['examples_per_split'], + examples_per_shard=config['examples_per_shard'], + min_values=config['min_values'], + max_values=config['max_values'], + ) + + @staticmethod + def _get_config_path(path): + return str(Path(path) / 'info.json') + + @staticmethod + def cleanup_shards(dataset_path): + "remove non_existing shards from the config" + dpath = Path(dataset_path) + fpath = Dataset._get_config_path(dataset_path) + config = json.loads(open(fpath).read()) + stats = [] + new_shards_list = defaultdict(list) + for split, slist in config['shards_list'].items(): + kept = 0 + removed = 0 + for s in slist: + spath = Path(dpath / s['path']) + if spath.exists(): + new_shards_list[split].append(s) + kept += 1 + else: + removed += 1 + + stats.append([split, kept, removed]) + + # save old config + sav_path = str(fpath) + ".sav.%d.json" % (time()) + cprint("Saving old config to %s" % sav_path, 'cyan') + with open(sav_path, 'w+') as o: + json.dump(config, o) + + config['shards_list'] = new_shards_list + with open(fpath, 'w+') as o: + json.dump(config, o) + cprint("Writing cleaned config", 'green') + print(tabulate(stats, headers=['split', 'kept', 'removed'])) diff --git a/scaaml/io/shard.py b/scaaml/io/shard.py new file mode 100644 index 00000000..ac3bb359 --- /dev/null +++ b/scaaml/io/shard.py @@ -0,0 +1,148 @@ +import math +import tensorflow as tf +from typing import Dict, List +from .tfdata import int64_feature, float_feature + + +class Shard(): + """A shard contains N measurement pertaining to the same key""" + def __init__(self, path: str, attack_points_info: Dict, + measurements_info: Dict, compression: True) -> None: + self.path = path + self.attack_points_info = attack_points_info + self.measurements_info = measurements_info + self.compression = compression + + # Writer if needed + self.has_writer = False + self.writer = None + + # counters + self.examples = 0 + self.min_values = {} + self.max_values = {} + for k in measurements_info.keys(): + self.min_values[k] = math.inf + self.max_values[k] = 0 + + # build and cache tffeature format + self.features = self._build_tffeature() + + def write(self, attack_points: Dict[str, List[int]], + measurements: Dict[str, List[float]]): + """Write example on disk as TFRecord + + Args: + attack_points: Attack points values. + measurements: Measurements values. + """ + with tf.device('/cpu:0'): + # open writer if needed + # !do not put in the init to avoid erasing on read + if not self.has_writer: + self.writer = tf.io.TFRecordWriter(self.path, self.compression) + self.has_writer = True + example = self._to_tfrecord(attack_points, measurements) + self.writer.write(example) + self.examples += 1 + + def read(self, num=10) -> Dict: + """Open and read N examples from the shard""" + _shard = tf.data.TFRecordDataset(self.path, + compression_type=self.compression) + data = _shard.map(self._from_tfrecord) + return data.take(num) + + def close(self) -> Dict: + "close shard and return statistics" + if not self.writer: + raise ValueError("Trying to close a shard that was not open") + + self.writer.close() + return { + "examples": self.examples, + "min_values": self.min_values, + "max_values": self.max_values + } + + def _to_tfrecord(self, attack_points, measurements): + """Convert example data into a tfrecord example + + Args: + attack_points: attack points data + measurements: measurements data + + Returns: + TF.train.Example + """ + + # check there are no unexpected values + for k in attack_points: + if k not in self.attack_points_info: + raise ValueError("Attack poiint", k, "not specified") + + for k in measurements: + if k not in self.measurements_info: + raise ValueError("Measurement", k, "not specified") + + + feature = {} + # attack points as integers + for ap_name, info in self.attack_points_info.items(): + expected_len = info['len'] + ap_value = attack_points[ap_name] + + # check that we get the len specified in the info + if len(ap_value) != expected_len: + raise ValueError(ap_name, len(ap_value), "don't have the right len", expected_len) + + # convert + feature[ap_name] = int64_feature(ap_value) + + # measurements as float + for mname, info in self.measurements_info.items(): + expected_len = info['len'] + measurement = measurements[mname] + + # check that the measurement len match what is specified in info + if len(measurement) != expected_len: + raise ValueError(mname, "don't have the right len") + + # min and max + self.min_values[mname] = min(self.min_values[mname], + float(tf.reduce_min(measurement))) + self.max_values[mname] = max(self.max_values[mname], + float(tf.reduce_max(measurement))) + + # convert + feature[mname] = float_feature(measurement) + + tffeats = tf.train.Features(feature=feature) + record = tf.train.Example(features=tffeats) + return record.SerializeToString() + + def _from_tfrecord(self, tfrecord): + """Convert tf_record to dictionary + + Args: + tf_record: tf_record to parse + Returns: + reloaded example as dictionary + """ + return tf.io.parse_single_example(tfrecord, self._build_tffeature()) + + def _build_tffeature(self): + "build tf feature dictionary based of meta data" + features = {} + + # attack points + for k, info in self.attack_points_info.items(): + flen = info['len'] + features[k] = tf.io.FixedLenFeature([flen], tf.int64) + + # measurements + for k, info in self.measurements_info.items(): + flen = info['len'] + features[k] = tf.io.FixedLenFeature([flen], tf.float32) + + return features diff --git a/scaaml/io/tfdata.py b/scaaml/io/tfdata.py new file mode 100644 index 00000000..74611b03 --- /dev/null +++ b/scaaml/io/tfdata.py @@ -0,0 +1,19 @@ +import tensorflow as tf + + +def bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + if isinstance(value, type(tf.constant(0))): + # BytesList won't unpack a string from an EagerTensor. + value = value.numpy() + return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) + + +def float_feature(value): + """Returns a float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + + +def int64_feature(value): + """Returns an int64_list from a bool / enum / int / uint.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) diff --git a/scaaml/io/utils.py b/scaaml/io/utils.py new file mode 100644 index 00000000..1838817d --- /dev/null +++ b/scaaml/io/utils.py @@ -0,0 +1,12 @@ +import hashlib + + +def sha256sum(filename): + "compute the sha256 of a given file" + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, 'rb', buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + h.update(mv[:n]) + return h.hexdigest() diff --git a/scaaml/utils.py b/scaaml/utils.py index e9c0b545..29fb81a7 100644 --- a/scaaml/utils.py +++ b/scaaml/utils.py @@ -33,12 +33,12 @@ def pretty_hex(val): return s.upper() -def bytelist_to_hex(lst): +def bytelist_to_hex(lst, spacer=' '): h = [] for e in lst: h.append(pretty_hex(e)) - return " ".join(h) + return spacer.join(h) def hex_display(lst, prefix="", color='green'): diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py new file mode 100644 index 00000000..2f8656a1 --- /dev/null +++ b/tests/io/test_dataset.py @@ -0,0 +1,68 @@ +import numpy as np +from scaaml.io import Dataset + + +def test_basic_workflow(tmp_path): + root_path = tmp_path + architecture = 'arch' + implementation = 'implem' + algorithm = 'algo' + version = 1 + minfo = { + # test missing measurem,net raise value + # test extra measurement raise value + "trace1": { + "type": "power", + "len": 1024, + } + } + apinfo = { + "key": { + "len": 16, + "max_val": 256 + }, + + # test missing attack point raise value + # test extra attack point raise value + # "sub_byte_in": { + # "len": 16, + # "max_val": 256 + # } + } + + chip_id = 1 + comment = "this is a test" + purpose = "train" + example_per_shard = 1 + + key = np.random.randint(0, 255, 16) + key2 = np.random.randint(0, 255, 16) + trace1 = np.random.rand(1024) + + ds = Dataset(root_path=root_path, + architecture=architecture, + implementation=implementation, + algorithm=algorithm, + version=version, + purpose=purpose, + comment=comment, + chip_id=chip_id, + examples_per_shard=example_per_shard, + measurements_info=minfo, + attack_points_info=apinfo) + + ds.new_shard(key, 1, 'train') + ds.write_example({"key": key}, {"trace1": trace1}) + ds.close_shard() + + # 256 keys - with uniform bytes + ds.new_shard(key2, 1, 'test') + ds.write_example({"key": key2}, {"trace1": trace1}) + ds.close_shard() + + # check dataset integrity and consistency + ds.check() + slug = ds.slug + # reload + ds2 = Dataset.from_config(root_path / slug) + ds2.inspect('train', 0, 1)